summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitattributes1
-rw-r--r--example.lua80
-rw-r--r--ga.lua93
-rw-r--r--grammar.lua75
4 files changed, 249 insertions, 0 deletions
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..176a458
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+* text=auto
diff --git a/example.lua b/example.lua
new file mode 100644
index 0000000..fd61f06
--- /dev/null
+++ b/example.lua
@@ -0,0 +1,80 @@
+local Grammar = require 'grammar'
+local GA = require 'ga'
+
+
+local production = {}
+production["<start>"]={{'<expr>'}}
+production["<expr>"]={ {'<var>'}, {'(', '<expr>', '<op>', '<expr>', ')'} }
+production["<op>"]={ {'+'}, {'-'}, {'*'}, {'/'} }
+production["<var>"]={ {'x'} }
+
+
+local grammar = Grammar.createGrammar(production)
+
+
+math.randomseed(0)
+
+
+function randomGenome()
+ local len = math.ceil(math.random() * 256)
+ local genome = {}
+ for i=1,len do
+ table.insert(genome, math.floor(256 * math.random()))
+ end
+ return genome
+end
+
+
+function randomPop()
+ local pop = {}
+ for n=1,1000 do
+ table.insert(pop, randomGenome())
+ end
+ return pop
+end
+
+function evaluate(genome)
+ local s = 'return function(x) return ' .. grammar:expand(genome) .. ' end'
+ local f = assert(loadstring(s))()
+ local target = function(x) return (x*x*x*x) + (x*x*x) + (x*x) + (x) end
+ local cost = 0
+ for n=0,100 do
+ x = (n-50)/50
+ ok, y = pcall(f, x)
+ if (not ok) or (y ~= y) then
+ cost = cost + 100
+ else
+ cost = cost + math.abs(y - target(x))
+ end
+ end
+ return -((0.0001*#s) + cost)
+end
+
+local ga = GA.createGA(randomPop(), {
+ evaluate=evaluate,
+ crossover=function(a, b)
+ local crossoverPoint = math.ceil(#a * math.random())
+ local new = {}
+ for i=1,#b do
+ table.insert(new, (i>=crossoverPoint and b[i]) or a[i])
+ end
+ return new
+ end,
+ mutate=function(a)
+ local new = {}
+ for i,v in ipairs(a) do new[i] = v end
+ local mutationPoint = math.ceil(#a * math.random())
+ new[mutationPoint] = math.floor(256 * math.random())
+ return new
+ end,
+})
+
+
+for n=1,100 do
+ print(ga:step())
+ print(grammar:expand(ga.population[ga.bestMember]))
+end
+
+-- for k,v in ipairs(ga.population) do
+-- print(k, '(', v[1], v[2], v[3], ')', evaluate(v))
+-- end
diff --git a/ga.lua b/ga.lua
new file mode 100644
index 0000000..8b5b0e2
--- /dev/null
+++ b/ga.lua
@@ -0,0 +1,93 @@
+local module = {}
+setmetatable(module, {__index=_G})
+setfenv(1, module)
+
+
+local GA = {}
+local metaGA = { __index=GA }
+
+
+function module.createGA(population, operators)
+ local self = {
+ population=population,
+ operators=operators,
+ }
+
+ local totalWeight =
+ (operators.crossWeight or 1) +
+ (operators.mutationWeight or 1) +
+ (operators.reproductionWeight or 1)
+
+ self.operators.crossDensity = (operators.crossWeight or 1) / totalWeight
+ self.operators.mutationDensity = self.operators.crossDensity + ((operators.mutationWeight or 1) / totalWeight)
+ self.operators.reproductionDensity = 1
+
+ setmetatable(self, metaGA)
+ return self
+end
+
+
+function GA.evaluate(self)
+ self.fitnesses = {}
+ for i, p in ipairs(self.population) do
+ self.fitnesses[i] = self.operators.evaluate(p)
+ if i == 1 or self.fitnesses[i] > self.maxFitness then
+ self.maxFitness = self.fitnesses[i]
+ self.bestMember = i
+ end
+ end
+end
+
+
+function GA.tournamentPick(self, k)
+ local tournamentPop = {}
+ while #tournamentPop < k do
+ table.insert(tournamentPop, math.ceil(#self.population * math.random()))
+ end
+ table.sort(tournamentPop, function(a, b) return self.fitnesses[a] > self.fitnesses[b] end)
+ return tournamentPop[1]
+end
+
+
+function GA.stochasticPick(self)
+ local idx = math.ceil(#self.population * math.random())
+ local f = self.fitnesses[idx] / self.maxFitness
+ if math.random() < f then
+ return idx
+ else
+ return self:stochasticPick()
+ end
+end
+
+
+function GA.createNewPopulation(self, k)
+ k = k or 128
+ local newPop = {}
+ while #newPop < #self.population do
+ local r = math.random()
+ if r < self.operators.crossDensity then
+ local a = self:tournamentPick(k)
+ local b = self:tournamentPick(k)
+ table.insert(newPop, self.operators.crossover(self.population[a], self.population[b]))
+ elseif r < self.operators.mutationDensity then
+ local a = self:tournamentPick(k)
+ table.insert(newPop, self.operators.mutate(self.population[a]))
+ elseif r < self.operators.reproductionDensity then
+ local a = self:tournamentPick(k)
+ table.insert(newPop, self.population[a])
+ end
+ end
+ return newPop
+end
+
+
+function GA.step(self)
+ self.fitnesses = nil
+ self.maxFitness = nil
+ self:evaluate()
+ self.population = self:createNewPopulation()
+ return self.maxFitness
+end
+
+
+return module
diff --git a/grammar.lua b/grammar.lua
new file mode 100644
index 0000000..27e75f7
--- /dev/null
+++ b/grammar.lua
@@ -0,0 +1,75 @@
+local module = {}
+setmetatable(module, {__index=_G})
+setfenv(1, module)
+
+
+local Grammar = {}
+
+function module.createGrammar(nonterminals, functions, terminals)
+ local self = {
+ nonterminals = nonterminals,
+ functions = functions,
+ terminals = terminals,
+ }
+ setmetatable(self, {__index=Grammar})
+ return self
+end
+
+
+function Grammar.expandSymbol(self, sequence, index, expansion)
+ table.remove(sequence, index)
+ for i=1,#expansion do
+ table.insert(sequence, index+i-1, expansion[i])
+ end
+end
+
+
+function Grammar.expandSequence(self, sequence, n)
+ local i = 1
+ while i <= #sequence do
+ local choices = self.nonterminals[sequence[i]]
+ if choices then
+ if #choices == 1 then
+ self:expandSymbol(sequence, i, choices[1])
+ return true, 0
+ else
+ local choice = choices[1 + (n % #choices)]
+ self:expandSymbol(sequence, i, choice)
+ return true, 1
+ end
+ else
+ i = i+1
+ end
+ end
+ return false, 0
+end
+
+
+function Grammar.expand(self, genome, maxLength)
+ maxLoops = maxLoops or 2
+ local sequence = { '<start>' }
+ local geneIndex = 0
+ local count = 0
+ local continue = true
+ while continue do
+ local gene
+ if (count > maxLoops * #genome) then
+ gene = 0
+ else
+ gene = genome[1 + (geneIndex % #genome)]
+ end
+ local increment
+ continue, increment = self:expandSequence(sequence, gene)
+ if not continue then break end
+ geneIndex = geneIndex+increment
+ count = count + 1
+ end
+ local result = ''
+ for _, s in ipairs(sequence) do
+ result = result .. s .. ' '
+ end
+ return result
+end
+
+
+return module