diff options
author | sanine <sanine.not@pm.me> | 2023-06-11 23:21:46 -0500 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-06-11 23:21:46 -0500 |
commit | 7e92bd5b292b99c5f5a3f1b05d2870be32732d92 (patch) | |
tree | 0e5906fa2b7d81f90ea50a6de178c4fdfef5d6e0 | |
parent | b3b2ebddba2dad9f9213ac80cb95033ad48eb7e2 (diff) |
add memory
-rw-r--r-- | src/mind/topology.js | 22 | ||||
-rw-r--r-- | src/mind/topology.test.js | 23 |
2 files changed, 36 insertions, 9 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js index 320b499..56bc498 100644 --- a/src/mind/topology.js +++ b/src/mind/topology.js @@ -65,8 +65,10 @@ function network_connect(n, source, sink, weight) { function incident_edges(n, adj) { const incident = adj - .map((edge, index) => edge < 0 ? index : null) + .map((edge, index) => (edge < 0) || (edge === 2) ? index : null) .filter(index => index !== null); + + console.log(incident); return incident; } @@ -91,7 +93,7 @@ function edge_ends(n, edge) { } -function get_value(n, index, input, cache) { +function get_value(n, index, input, prev, cache) { if (cache !== undefined && cache[index]) { return cache[index]; } @@ -103,12 +105,14 @@ function get_value(n, index, input, cache) { const weight = incident.map(x => n.weight[x]); const sources = incident .map(x => edge_ends(n, x).source); + + const values = sources + .map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache)); + + console.log(n, index, sources, values); - const sum = sources - .reduce((acc, x, i) => - acc + (weight[i] * get_value(n, x, input, cache)), - 0 - ); + const sum = values + .reduce((acc, x, i) => acc + (weight[i] * x), 0); const value = Math.tanh(sum); @@ -137,11 +141,11 @@ function network_compute(n, input, state) { .filter(i => i !== null); const output = Object.freeze( - outputs.map(x => get_value(n, x, input, value_cache)) + outputs.map(x => get_value(n, x, input, state, value_cache)) ); const newstate = Object.freeze( - hidden.map(x => get_value(n, x, input, value_cache)) + hidden.map(x => get_value(n, x, input, state, value_cache)) ); return Object.freeze([output, newstate]); diff --git a/src/mind/topology.test.js b/src/mind/topology.test.js index 5867763..fbe1862 100644 --- a/src/mind/topology.test.js +++ b/src/mind/topology.test.js @@ -187,3 +187,26 @@ test('arbitrary hidden neurons', () => { ], ]); }); + + +test('memory', () => { + const n = network(0, 1, 1).connect(0, 0, -0.5).connect(0, 1, 2); + + expect(n.compute([], [1])).toEqual([ + [ Math.tanh( 2 * Math.tanh( -0.5 * 1 ) ) ], + [ Math.tanh( -0.5 * 1) ], + ]); +}); + + +test('memory and input', () => { + const n = network(1, 1, 1) + .connect(0, 1, 1) + .connect(1, 1, 1) + .connect(1, 2, 1); + + expect(n.compute([2], [-1])).toEqual([ + [ Math.tanh( Math.tanh( 2-1 ) ) ], + [ Math.tanh( 2-1 ) ], + ]); +}); |