summaryrefslogtreecommitdiff
path: root/util/generate-binding.lua
blob: 23aa0579b49993c043e1208cda45ce5332271061 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
local b = {}
setmetatable(b, {__index=_G})
setfenv(1, b)


function ExtractFunctionName(signature)
	return string.match(signature, "([%w_][%w_]*)%s-%(.*%)")
end


function ExtractFunctionType(signature)
	return string.match(signature, "(.+)%s%s-[%w_]+%s-%(.*%)")
end


local function trimWhitespace(s)
	s = string.gsub(s, "^%s*", "")
	s = string.gsub(s, "%s+$", "")
	return s
end


function ExtractFunctionArgs(signature)
	local args = {}
	local argStr = string.match(signature, "%((.*)%)")
	for arg in string.gmatch(argStr, "([^,][^,]*),?") do
		-- handle pointers (e.g. void *q)
		arg = string.gsub(arg, "%*", " * ")

		local type = string.match(arg, "(.+)%s%s-[%w_]+")
		type = string.gsub(type, "%s%s+", " ")
		type = string.gsub(type, "%* %*", "**")
		type = trimWhitespace(type)

		local name = string.match(arg, "([%w_]+)%s-,?$")
		name = trimWhitespace(name)
		table.insert(args, { type=type, name=name })
	end
	return args
end


function GetPointerLevel(ctype)
	local level = 0
	for _ in string.gmatch(ctype, "%*") do
		level = level + 1
	end
	return level
end


function GetLuaType(ctype)
	-- double (triple, etc) pointers
	if GetPointerLevel(ctype) > 1 then return "unknown"
	-- regular pointers
	elseif GetPointerLevel(ctype) == 1  then
		-- strings
		if string.match(ctype, "char") then return "string"
		else return "unknown" end
	-- ordinary variables
	else
		-- numbers
		if     string.match(ctype, "float$") then return "number"
		elseif string.match(ctype, "double$") then return "number"
		-- integers
		elseif string.match(ctype, "char$") then return "integer"
		elseif string.match(ctype, "int$") then return "integer"
		elseif string.match(ctype, "long$") then return "integer"
		-- void
		elseif string.match(ctype, "void$") then return "void"
		-- unknown
		else return "unknown" end
	end
end


function PullArg(arg, index)
	local ltype = GetLuaType(arg.type)

	local pull
	if ltype == "unknown" then
		pull = string.format("/* get: %s */", arg.type)
	else
		pull = string.format("luaL_check%s(L, %d);", ltype, index)
	end

	return string.format("%s %s = %s", arg.type, arg.name, pull)
end


function Call(ftype, fname, args)
	local callArgs = "("
	for index, arg in ipairs(args) do
		callArgs = callArgs .. arg.name
		if index ~= #args then
			callArgs = callArgs .. ", "
		end
	end
	callArgs = callArgs .. ")"

	local ltype = GetLuaType(ftype)
	if ltype == "void" then
		return string.format("%s%s;\n\treturn 0;", fname, callArgs)
	elseif ltype == "unknown" then
		return string.format(
			"%s bind_result = %s%s;\n\t/* push result */\n\treturn /* count */;",
			ftype, fname, callArgs
		)
	else
		return string.format(
			"%s bind_result = %s%s;\n\tlua_push%s(L, bind_result);\n\treturn 1;",
			ftype, fname, callArgs, ltype
		)
	end
end


function bind(signature)
	local ftype = ExtractFunctionType(signature)
	local fname = ExtractFunctionName(signature)
	local args  = ExtractFunctionArgs(signature)

	local result = string.format("int %s_bind(lua_State *L)\n{\n", fname)
	for index, arg in ipairs(args) do
		result = result .. "\t" .. PullArg(arg, index) .. "\n"
	end

	result = result .. "\t" .. Call(ftype, fname, args) .. "\n}"
	return result
end


return b