From a2ae7aae8357c8c2684d85fd58b0c5a0563ebab9 Mon Sep 17 00:00:00 2001 From: sanine-a Date: Tue, 9 May 2023 11:31:17 -0500 Subject: refactor: split ecs systems into multiple files --- honey/ecs/collision.lua | 57 ++++++++++ honey/ecs/ecs.lua | 274 ++++++++++++++++++++++++++++++++++++++++++++++ honey/ecs/ecs.test.lua | 280 ++++++++++++++++++++++++++++++++++++++++++++++++ honey/ecs/node.lua | 47 ++++++++ honey/ecs/physics.lua | 158 +++++++++++++++++++++++++++ honey/ecs/render.lua | 98 +++++++++++++++++ honey/ecs/script.lua | 44 ++++++++ 7 files changed, 958 insertions(+) create mode 100644 honey/ecs/collision.lua create mode 100644 honey/ecs/ecs.lua create mode 100644 honey/ecs/ecs.test.lua create mode 100644 honey/ecs/node.lua create mode 100644 honey/ecs/physics.lua create mode 100644 honey/ecs/render.lua create mode 100644 honey/ecs/script.lua (limited to 'honey/ecs') diff --git a/honey/ecs/collision.lua b/honey/ecs/collision.lua new file mode 100644 index 0000000..722c256 --- /dev/null +++ b/honey/ecs/collision.lua @@ -0,0 +1,57 @@ +local glm = require 'honey.glm' +local Vec3 = glm.Vec3 +local ode = honey.ode + +local module = {} +setmetatable(module, {__index=_G}) +setfenv(1, module) + + +--===== collision space =====-- + + +local function createGeom(self, id, collision) + local geom + if collision.class == "sphere" then + geom = ode.CreateSphere(self.space, collision.radius) + elseif collision.class == "capsule" then + geom = ode.CreateCapsule(self.space, collision.radius, collision.length) + elseif collision.class == "plane" then + local node = self.db:getComponent(id, "node") + local m = node.matrix + local normal = node.matrix:mulv3(Vec3{0,1,0}):normalize() + local position = Vec3{m[1][4], m[2][4], m[3][4]} + print(position) + local d = normal:dot(position) + print(normal, d) + geom = ode.CreatePlane(self.space, normal[1], normal[2], normal[3], d) + end + collision._geom = geom + collision._gc = honey.util.gc_canary(function() + print("release geom for id"..id) + ode.GeomDestroy(geom) + end) +end + +system = function(params) + local db = params.db + local space = params.space + return { + db=db, + space=space, + priority=0, + update = function(self, dt) + local query = self.db:queryComponent("collision") + for id, collision in pairs(query) do + if not collision._geom then + createGeom(self, id, collision) + print(id, collision._geom) + end + end + end + } +end + + + +return module diff --git a/honey/ecs/ecs.lua b/honey/ecs/ecs.lua new file mode 100644 index 0000000..b0409e4 --- /dev/null +++ b/honey/ecs/ecs.lua @@ -0,0 +1,274 @@ +math.randomseed(os.time()) + +local glm = require 'honey.glm' + +local module = {} +setmetatable(module, {__index=_G}) +setfenv(1, module) + + +--===== EntityDb =====-- + + +-- EntityDb is a database of entities and their associated components +-- it should be quite efficient to query for all entities with a given component, and reasonably +-- efficient to query for all components of a given entity + + +EntityDb = {} +EntityDb.__index = EntityDb + + +function EntityDb.new(_) + local self = { + entities = {}, + components = {}, + } + setmetatable(self, EntityDb) + return self +end +setmetatable(EntityDb, {__call=EntityDb.new}) + + +local function serialize(tbl) + local tostr = function(x, value) + if type(x) == "table" then + if x.__tostring then + return tostring(x) + else + return serialize(x) + end + elseif type(x) == "string" then + if value then + return string.format("\"%s\"", x) + else + return x + end + else + return tostring(x) + end + end + local str = "{" + for key, value in pairs(tbl) do + if type(key) == "string" and string.match(key, "^_") then + -- ignore keys starting with an underscore + else + str = str .. string.format("%s=%s,", tostr(key), tostr(value, true)) + end + end + str = string.sub(str, 1, -2) .. "}" + return str +end + + +-- save current database to file +function EntityDb.save(self, filename) + local file, err = io.open(filename, "w") + if not file then error(err) end + + for entity in pairs(self.entities) do + local components = self:queryEntity(entity) + file:write(string.format("Entity(\"%s\", %s)\n", entity, serialize(components))) + end + + file:close() +end + + +-- load database from file +function EntityDb.load(self, filename) + print(collectgarbage("count")) + self.entities = {} + self.components = {} + collectgarbage() + print(collectgarbage("count")) + local env = { + Entity = function(id, components) + self:createEntity(id) + self:addComponents(id, components) + end, + Vec3 = glm.Vec3, + Mat4 = glm.Mat4, + } + local f, err = loadfile(filename) + if not f then error(err) end + setfenv(f, env) + f() +end + + +-- check if a given entity id is legitimate +function EntityDb.checkIsValid(self, id) + if not self.entities[id] then + error(string.format("invalid entity id: %s", tostring(id))) + end +end + + +local random = math.random +local function uuid() + local template ='xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' + return string.gsub(template, '[xy]', function (c) + local v = (c == 'x') and random(0, 0xf) or random(8, 0xb) + return string.format('%x', v) + end) +end + + +-- create a new entity +function EntityDb.createEntity(self, id) + local id = id or uuid() -- allow inserting entities at preset ids for loading + self.entities[id] = true + return id +end + + +-- add a component to an entity +function EntityDb.addComponent(self, id, name, value) + self:checkIsValid(id) + + -- create the relevant component table if it doesn't exist + if not self.components[name] then + self.components[name] = { count=0, data={} } + end + + local component = self.components[name] + component.data[id] = value + component.count = component.count + 1 +end + + +-- add multiple components at once, for convenience +function EntityDb.addComponents(self, id, components) + for name, value in pairs(components) do + self:addComponent(id, name, value) + end +end + + +-- create an entity with components +function EntityDb.createEntityWithComponents(self, components) + local id = self:createEntity() + self:addComponents(id, components) + return id +end + + +-- get all entities with a given component +function EntityDb.queryComponent(self, name) + local component = self.components[name] + if component then + return component.data + else + return {} + end +end + + +-- get all components associated with an entity +function EntityDb.queryEntity(self, id) + self:checkIsValid(id) + local query = {} + for name, component in pairs(self.components) do + query[name] = component.data[id] + end + return query +end + + +-- get a specific component from an entity +function EntityDb.getComponent(self, id, name) + self:checkIsValid(id) + return self.components[name].data[id] +end + + +-- remove a component from an entity +function EntityDb.removeComponent(self, id, name) + self:checkIsValid(id) + local component = self.components[name] + if component.data[id] ~= nil then + component.data[id] = nil + component.count = component.count - 1 + if component.count == 0 then + self.components[name] = nil + end + end +end + + +-- remove an entity from the db +function EntityDb.deleteEntity(self, id) + self:checkIsValid(id) + for name in pairs(self.components) do + self:removeComponent(id, name) + end + self.entities[id] = nil +end + + +--===== SystemDb =====-- + +SystemDb = {} +SystemDb.__index = SystemDb + + +function SystemDb.new(_, entityDb) + local self = { + systems = {}, + sorted = {}, + entityDb = entityDb, + } + setmetatable(self, SystemDb) + return self +end +setmetatable(SystemDb, {__call=SystemDb.new}) + + +local function systemId() + local template = "xx:xx:xx" + return string.gsub(template, "x", function(c) + return string.format("%x", random(0, 0xf)) + end) +end + + +function SystemDb.addSystem(self, systemFunc, params) + local system + if type(systemFunc) == "table" then + system = systemFunc + else + local params = params or {} + params.db = self.entityDb + system = systemFunc(params) + end + + local id = systemId() + self.systems[id] = system + table.insert(self.sorted, id) + self:sort() + return id +end + + +function SystemDb.sort(self) + table.sort(self.sorted, function(a, b) return (self.systems[a].priority or 100) < (self.systems[b].priority or 100) end) +end + + +function SystemDb.update(self, dt) + for _, system in ipairs(self.sorted) do + self.systems[system]:update(dt) + end +end + + +function SystemDb.removeSystem(self, id) + self.systems[id] = nil + for i=#self.sorted,1,-1 do + if self.sorted[i] == id then table.remove(self.sorted, i) end + end +end + + +return module diff --git a/honey/ecs/ecs.test.lua b/honey/ecs/ecs.test.lua new file mode 100644 index 0000000..74e27dc --- /dev/null +++ b/honey/ecs/ecs.test.lua @@ -0,0 +1,280 @@ +local testCount = 0 +local failCount = 0 + +local function test(msg, f) + testCount = testCount + 1 + local success, error = xpcall(f, debug.traceback) + if success then + print(msg .. "\t\t[OK]") + else + failCount = failCount + 1 + print(msg .. "\t\t[FAIL]") + print(error) + end +end + + +local ecs = require 'ecs' + + +--===== Component tests =====-- + +local Component = ecs.Component + +test("factories work as expected", function() + local factory = Component.newFactory("health", { percent=100 }) + local comp1 = factory() + assert(comp1.__type == "health", "bad component type for comp1") + assert(comp1.percent == 100, "bat percent for comp1") + + local comp2 = factory{ percent=50 } + assert(comp2.__type == "health", "bad component type for comp2") + assert(comp2.percent == 50, "bad percent for comp2") + + local success = pcall(function() + comp2.dne = 5 + end) + assert(not success, "incorrectly succeeded in setting comp2.dne") + + local success = pcall(function() + local comp3 = factory{ percent = 44, something = 2 } + end) + assert(not success, "incorrectly succeeded in creating comp3") +end) + + +test("components serialize as expected", function() + local position = Component.newFactory("position", { x=0, y=0, z=0 }) + local comp = position{x=10, y=15, z=10} + local str = tostring(comp) + local tbl = (loadstring("return " .. str))() + assert(tbl.__type == "position", "bad type") + assert(tbl.x == 10, "bad x") + assert(tbl.y == 15, "bad y") + assert(tbl.z == 10, "bad z") +end) + + +test("components serialize successfully with subcomponents", function() + local position = Component.newFactory("position", { x=0, y=0, z=0 }) + local player = Component.newFactory("player", { name="", position={} }) + + local p = player{ name="hannah", position=position{x=10, y=9, z=8} } + local tbl = (loadstring("return " .. tostring(p)))() + assert(tbl.__type == "player") + assert(tbl.name == "hannah") + assert(tbl.position.__type == "position") + assert(tbl.position.x == 10) + assert(tbl.position.y == 9) + assert(tbl.position.z == 8) +end) + + +--===== EntityDb tests =====-- + +local EntityDb = ecs.EntityDb + + +test("EntityDb.createEntity() always returns a new id", function() + local db = EntityDb() + + local ids = {} + for i=1,100 do + local id = db:createEntity() + assert(ids[id] == nil, "id was already returned!") + ids[id] = true + end +end) + + +test("EntityDb.queryComponent() gets all entities with a given component", function() + local db = EntityDb() + + local ids = {} + for i=1,100 do + local id = db:createEntity() + if i%2==0 then + ids[id] = 5*i + db:addComponent(id, "number", 5*i) + end + end + + local query = db:queryComponent("number") + local count = 0 + for id, number in pairs(query) do + count = count + 1 + assert(number == ids[id]) + end + assert(count == 50) +end) + + +test("EntityDb.queryEntity() gets all components associated with an entity", function() + local db = EntityDb() + + local entity + for i=1,100 do + local id = db:createEntity() + if i%2 == 0 then db:addComponent(id, "number", 2) end + if i%3 == 0 then db:addComponent(id, "string", "hello") end + if i%5 == 0 then db:addComponent(id, "number2", 4) end + if i%7 == 0 then db:addComponent(id, "string2", "world") end + if i == 30 then entity=id end + end + + local query = db:queryEntity(entity) + assert(query.number == 2) + assert(query.string == "hello") + assert(query.number2 == 4) + assert(query.string2 == nil) +end) + + +test("EntityDb.removeComponent() removes components correctly", function() + local db = EntityDb() + + local id = db:createEntity() + db:addComponent(id, "number", 2) + db:addComponent(id, "string", "hello") + db:addComponent(id, "number2", 4) + db:addComponent(id, "string2", "world") + + local query = db:queryEntity(id) + assert(query.number == 2) + assert(query.string == "hello") + assert(query.number2 == 4) + assert(query.string2 == "world") + + db:removeComponent(id, "string2") + query = db:queryEntity(id) + assert(query.number == 2) + assert(query.string == "hello") + assert(query.number2 == 4) + assert(query.string2 == nil) + + db:removeComponent(id, "number2") + query = db:queryEntity(id) + assert(query.number == 2) + assert(query.string == "hello") + assert(query.number2 == nil) + assert(query.string2 == nil) + + db:removeComponent(id, "string") + query = db:queryEntity(id) + assert(query.number == 2) + assert(query.string == nil) + assert(query.number2 == nil) + assert(query.string2 == nil) + + db:removeComponent(id, "number") + query = db:queryEntity(id) + assert(query.number == nil) + assert(query.string == nil) + assert(query.number2 == nil) + assert(query.string2 == nil) +end) + + +test("EntityDb.removeComponent() deletes component table when empty", function() + local db = EntityDb() + local id1 = db:createEntity() + local id2 = db:createEntity() + db:addComponent(id1, "number", 2) + db:addComponent(id2, "number", 3) + + assert(db.components.number ~= nil) + db:removeComponent(id1, "number") + assert(db.components.number ~= nil) + db:removeComponent(id2, "number") + assert(db.components.number == nil) +end) + + +test("EntityDb.removeComponent() does nothing if the component is not present", function() + local db = EntityDb() + local id1 = db:createEntity() + local id2 = db:createEntity() + db:addComponent(id1, "number", 2) + db:addComponent(id2, "number", 3) + + assert(db.components.number ~= nil) + db:removeComponent(id1, "number") + assert(db.components.number ~= nil) + db:removeComponent(id1, "number") + assert(db.components.number ~= nil) + + db:removeComponent(id2, "number") + assert(db.components.number == nil) + +end) + + +test("EntityDb.deleteEntity() correctly removes an entity", function() + local db = EntityDb() + local id1 = db:createEntity() + local id2 = db:createEntity() + db:addComponent(id1, "number", 2) + db:addComponent(id2, "number", 3) + + local query = db:queryComponent("number") + assert(query[id1] and query[id2]) + db:deleteEntity(id1) + query = db:queryComponent("number") + assert(query[id1] == nil) + assert(query[id2] ~= nil) + + assert(false == pcall(function() + local query = db:queryEntity(id1) + end)) +end) + + +--===== SystemDb tests =====-- + +local SystemDb = ecs.SystemDb + +test("addSystem() correctly sorts systems", function() + local sdb = SystemDb(nil) + local str = "" + sdb:addSystem(function () return { + update=function(self, dt) str = str .. "c" end, + priority = 3, + } end) + sdb:addSystem{ + update=function(self, dt) str = "a" end, + priority = 1, + } + sdb:addSystem{ + update=function(self, dt) str = str .. "b" end, + priority = 2, + } + sdb:update(0) + assert(str == "abc") +end) + + +test("removeSystem() correctly handles things", function() + local sdb = SystemDb(nil) + local str = "" + sdb:addSystem(function () return { + update=function(self, dt) str = str .. "c" end, + priority = 3, + } end) + sdb:addSystem{ + update=function(self, dt) str = "a" end, + priority = 1, + } + local id = sdb:addSystem{ + update=function(self, dt) str = str .. "b" end, + priority = 2, + } + sdb:update(0) + assert(str == "abc") + + sdb:removeSystem(id) + sdb:update(1) + assert(str == "ac") +end) + + +print(string.format("ran %d tests, %d failed", testCount, failCount)) diff --git a/honey/ecs/node.lua b/honey/ecs/node.lua new file mode 100644 index 0000000..39f1898 --- /dev/null +++ b/honey/ecs/node.lua @@ -0,0 +1,47 @@ +local module = {} +setmetatable(module, {__index=_G}) +setfenv(1, module) + +--===== transform cascading =====-- + +system = function(params) + return { + db = params.db, + + priority = 2, + update = function(self, dt) + local nodes = self.db:queryComponent("node") + + -- prepare nodes + for id, node in pairs(nodes) do + node._visited = false + end + + -- helper function + local function recursiveTransform(node) + if node._visited then + return node._matrix + end + + if not node.parent then + node._matrix = node.matrix + else + local parentTransform = self.db:getComponent(node.parent, "node") + local parentMatrix = recursiveTransform(parentTransform) + node._matrix = parentMatrix * node.matrix + end + node._visited = true + return node._matrix + end + + -- compute nodes + for id, node in pairs(nodes) do + recursiveTransform(node) + end + end, + } +end + + + +return module diff --git a/honey/ecs/physics.lua b/honey/ecs/physics.lua new file mode 100644 index 0000000..eac3846 --- /dev/null +++ b/honey/ecs/physics.lua @@ -0,0 +1,158 @@ +local glm = require 'honey.glm' +local Vec3 = glm.Vec3 +local Mat4 = glm.Mat4 +local Quaternion = glm.Quaternion +local ode = honey.ode + +local module = {} +setmetatable(module, {__index=_G}) +setfenv(1, module) + +--===== physics =====-- + + +system = function(params) + local interval = params.interval or 0.016 + local groupSize = params.groupSize or 20 + local refs = {} + return { + db=params.db, + space=params.space, + world=params.world, + contactGroup=ode.JointGroupCreate(groupSize), + time=interval, + + priority=1, + update=function(self, dt) + for i, ref in ipairs(refs) do + print(i, ref.tbl, ref.physics) + end + local query = self.db:queryComponent("physics") + + for id, physics in pairs(query) do + if not physics._body then + print("add physics body for "..id) + local body = ode.BodyCreate(self.world) + physics._gc = honey.util.gc_canary(function() + print("releasing physics body for " .. id) + ode.BodyDestroy(body) + body = nil + end) + + local collision = self.db:getComponent(id, "collision") + if collision then + print(id, collision.class) + ode.GeomSetBody(collision._geom, body) + end + + local mass = ode.MassCreate() + local class = physics.mass.class + if not class then + -- configure mass manually + elseif class == "sphere" then + ode.MassSetSphere( + mass, + physics.mass.density, + physics.mass.radius + ) + elseif class == "capsule" then + ode.MassSetCapsule( + mass, + physics.mass.density, + physics.mass.direction, + physics.mass.radius, + physics.mass.length + ) + end + ode.BodySetMass(body, mass) + local m = self.db:getComponent(id, "node").matrix + ode.BodySetPosition( + body, + m[1][4], m[2][4], m[3][4] + ) + ode.BodySetRotation( + body, + m[1][1], m[1][2], m[1][3], + m[2][1], m[2][2], m[2][3], + m[3][1], m[3][2], m[3][3] + ) + local vel = physics.velocity or Vec3{0,0,0} + ode.BodySetLinearVel( + body, vel[1], vel[2], vel[3] + ) + physics.velocity = vel + + local avel = physics.angularVelocity or Vec3{0,0,0} + ode.BodySetAngularVel( + body, avel[1], avel[2], avel[3] + ) + physics.angularVelocity = avel + + if physics.maxAngularSpeed then + ode.BodySetMaxAngularSpeed(body, physics.maxAngularSpeed) + end + + physics._body = body + end + end + + self.time = self.time + dt + -- only run the update every [interval] seconds + if self.time > interval then + self.time = self.time - interval + + -- check for near collisions between geoms + ode.SpaceCollide(self.space, function(a, b) + -- check for actual collisions + local collisions = ode.Collide(a, b, 1) + if #collisions > 0 then + -- set up the joint params + local contact = ode.CreateContact{ surface={ + mode = ode.ContactBounce + ode.ContactSoftCFM, + mu = ode.Infinity, + bounce = 0.90, + bounce_vel = 0.1, + soft_cfm = 0.001, + }} + ode.ContactSetGeom(contact, collisions[1]) + -- create the joint + local joint = ode.JointCreateContact( + self.world, + self.contactGroup, + contact + ) + -- attach the two bodies + local bodyA = ode.GeomGetBody(a) + local bodyB = ode.GeomGetBody(b) + ode.JointAttach(joint, bodyA, bodyB) + end + end) + -- update the world + ode.WorldQuickStep(self.world, interval) + -- remove all contact joints + ode.JointGroupEmpty(self.contactGroup) + + -- update entity nodes + for id, physics in pairs(query) do + local x,y,z = ode.BodyGetPosition(physics._body) + local d,a,b,c = ode.BodyGetQuaternion(physics._body) + local node = self.db:getComponent(id, "node") + local q = Quaternion{a,b,c,d} + node.matrix + :identity() + :translate(Vec3{x,y,z}) + :mul(Quaternion{a,b,c,d}:toMat4()) + + local vel = physics.velocity + vel[1], vel[2], vel[3] = ode.BodyGetLinearVel(physics._body) + local avel = physics.angularVelocity + avel[1], avel[2], avel[3] = ode.BodyGetAngularVel(physics._body) + end + end + end, + } +end + + + +return module diff --git a/honey/ecs/render.lua b/honey/ecs/render.lua new file mode 100644 index 0000000..4217422 --- /dev/null +++ b/honey/ecs/render.lua @@ -0,0 +1,98 @@ +local glm = require 'honey.glm' +local Vec3 = glm.Vec3 +local Mat4 = glm.Mat4 + +local gl = honey.gl +local glfw = honey.glfw + +local module = {} +setmetatable(module, {__index=_G}) +setfenv(1, module) + +--===== rendering =====-- + +function draw(model, view, projection, textures, shader, mesh) + shader:use() + + -- bind textures + local texOffset = 0 + for name, texTbl in pairs(textures or {}) do + local texture = honey.image.loadImage(texTbl.filename, texTbl.params) + gl.BindTexture(gl.TEXTURE_2D + texOffset, texture.texture) + shader:setInt(name, texOffset) + texOffset = texOffset + 1 + end + + -- configure default uniforms + shader:configure{ + float={ + time=glfw.GetTime(), + }, + matrix={ + view=view, + projection=projection, + model=model, + }, + } + + -- draw mesh + mesh:drawElements() + + -- unbind textures + for i=0,texOffset-1 do + gl.BindTexture(gl.TEXTURE_2D + i, 0) + end +end + +system = function(params) + return { + db = params.db, + priority = params.priority or 99, + update = function(self, dt) + for id, camera in pairs(self.db:queryComponent("camera")) do + local projection = camera.projection + local cameraTransform = self.db:getComponent(id, "node") + local view = Mat4() + if cameraTransform then + honey.glm.mat4_inv(cameraTransform._matrix.data, view.data) + else + view:identity() + end + + local entities = self.db:queryComponent("renderMesh") + for entity, tbl in pairs(entities) do + -- get model + local node = self.db:getComponent(entity, "node") + local model = + (node and node._matrix) or + Mat4():identity() + -- get shader + local shader = honey.shader.loadShader( + tbl.shader.vertex, tbl.shader.fragment + ) + -- get mesh + local mesh = honey.mesh.loadCached( + tbl.mesh.filename, tbl.mesh.index + ) + draw(model, view, projection, tbl.textures, shader, mesh) + end + + entities = self.db:queryComponent("renderQuad") + local quadmesh = honey.mesh.loadCached("builtin.quad", 1) + for entity, tbl in pairs(entities) do + -- get model + local model = Mat4():identity() + -- get shader + local shader = honey.shader.loadShader( + tbl.shader.vertex, tbl.shader.fragment + ) + draw(model, view, projection, tbl.textures, shader, quadmesh) + end + end + end, + } +end + + + +return module diff --git a/honey/ecs/script.lua b/honey/ecs/script.lua new file mode 100644 index 0000000..9ae7d72 --- /dev/null +++ b/honey/ecs/script.lua @@ -0,0 +1,44 @@ +local module = {} +setmetatable(module, {__index=_G}) +setfenv(1, module) + + +-- helper function for retrieving script functions +getFunction = function(script) + local f = require(script.script) + if script.func then + return f[script.func] + else + return f + end +end + + +--===== dispatch messages to handlers =====-- + +dispatch = function(entities, msg, data) + local query = entities:queryComponent(msg) + for id, handler in pairs(query) do + local f = getFunction(handler) + f(entities, id, data) + end +end + +--===== script system =====-- + +system = function(params) + return { + db=params.db, + update=function(self, dt) + local entities = self.db:queryComponent("script") + for id, script in pairs(entities) do + local f = getFunction(script) + f(self.db, id, dt) + end + end + } +end + + + +return module -- cgit v1.2.1