diff options
author | sanine <sanine.not@pm.me> | 2023-06-11 22:50:42 -0500 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-06-11 22:50:42 -0500 |
commit | 980a5350b5a4845db2bd5d6feb9f463a3c1a3aa6 (patch) | |
tree | 409c93483388b8cede754bc69fc62804a271c045 /src/mind | |
parent | 3b0b005b952b1092404fdd5ae1732ec9561794af (diff) |
add hidden neuron state
Diffstat (limited to 'src/mind')
-rw-r--r-- | src/mind/topology.js | 24 | ||||
-rw-r--r-- | src/mind/topology.test.js | 39 |
2 files changed, 59 insertions, 4 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js index 1cd52d3..19eb399 100644 --- a/src/mind/topology.js +++ b/src/mind/topology.js @@ -72,6 +72,8 @@ function incident_edges(n, adj) { } +// get the indices of the ends of an edge +// in the case of self-loops, both values are the same function edge_ends(n, edge) { const ends = n.adjacency .map((adj, index) => adj[edge] !== 0 ? index : null) @@ -79,7 +81,13 @@ function edge_ends(n, edge) { ends.sort((a, b) => n.adjacency[a][edge] < n.adjacency[b][edge] ? -1 : 1); - return ends; + if (ends.length === 1) { + return { source: ends[0], sink: ends[0] }; + } else if (ends.length === 2) { + return { source: ends[1], sink: ends[0] }; + } else { + throw new Error("something bad happened with the ends"); + } } @@ -91,8 +99,7 @@ function get_value(n, index, input) { const incident = incident_edges(n, adj); const weight = incident.map(x => n.weight[x]); const sources = incident - .map(x => edge_ends(n, x)) - .map(x => x.length === 2 ? x[1] : x[0]); + .map(x => edge_ends(n, x).source); const sum = sources .reduce((acc, x, i) => acc + (weight[i] * get_value(n, x, input)), 0); @@ -102,6 +109,13 @@ function get_value(n, index, input) { function network_compute(n, input, state) { + const hidden = n.adjacency + .map((x, i) => + ( + (!(is_input(n, i))) && + (!is_output(n, i))) ? i : null) + .filter(i => i !== null); + const outputs = n.adjacency .map((x, i) => is_output(n, i) ? i : null) .filter(i => i !== null); @@ -110,7 +124,9 @@ function network_compute(n, input, state) { outputs.map(x => get_value(n, x, input)) ); - const newstate = Object.freeze([]); + const newstate = Object.freeze( + hidden.map(x => get_value(n, x, input)) + ); return Object.freeze([output, newstate]); } diff --git a/src/mind/topology.test.js b/src/mind/topology.test.js index e1c5f87..7612c3d 100644 --- a/src/mind/topology.test.js +++ b/src/mind/topology.test.js @@ -137,3 +137,42 @@ test('multiple input network', () => { [], ]); }); + + +test('multiple outputs', () => { + const n = network(4, 0, 2) + .connect(0, 4, -1) + .connect(1, 4, 1) + .connect(2, 5, -1) + .connect(3, 5, 1); + + expect(n.compute([1,2,3,5], [])).toEqual([ + [ Math.tanh(2-1), Math.tanh(5-3) ], + [], + ]); +}); + + +test('hidden neurons', () => { + const n = network(4, 2, 1) + .connect(0, 4, -1) + .connect(1, 4, 1) + .connect(2, 5, -1) + .connect(3, 5, 1) + .connect(4, 6, -1) + .connect(5, 6, 1); + + expect(n.compute([1,2,3,5], [ 0, 0 ])).toEqual([ + [ Math.tanh( Math.tanh(5-3) - Math.tanh(2-1) ) ], + [ Math.tanh(2-1), Math.tanh(5-3) ], + ]); +}); + + +//test('arbitrary hidden neurons', () => { +// const n = network(1, 2, 1) +// .connect(0, 1, 1) +// .connect(1, 2, -1) +// .connect(2, 3, 2) +// .connect(3, 4, -2); +//}); |