diff options
author | sanine <sanine.not@pm.me> | 2023-06-12 09:24:56 -0500 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-06-12 09:24:56 -0500 |
commit | db486d5d9a762532f5c7bd45920b01ab63cd08bd (patch) | |
tree | 3ec0d1d1715d92ee9fb09a4d54768d2517b65c3d | |
parent | 251f39da74c8d5707eaeef8d5e63ce442720b01f (diff) |
clean up network_compute and add input validation
-rw-r--r-- | src/mind/topology.js | 42 | ||||
-rw-r--r-- | src/mind/topology.test.js | 20 |
2 files changed, 42 insertions, 20 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js index 128b7e9..7e58718 100644 --- a/src/mind/topology.js +++ b/src/mind/topology.js @@ -35,6 +35,10 @@ function is_input(n, index) { function is_output(n, index) { return index >= (n.adjacency.length - n.output_count); } +// check if index is a hidden neuron +function is_hidden(n, index) { + return (!is_input(n, index)) && (!is_output(n, index)); +} // returns a new network with an edge between the given nodes @@ -76,8 +80,6 @@ function incident_edges(n, adj) { .map((edge, index) => (edge < 0) || (edge === 2) ? index : null) .filter(index => index !== null); - console.log(incident); - return incident; } @@ -119,8 +121,6 @@ function get_value(n, index, input, prev, cache) { const values = sources .map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache)); - console.log(n, index, sources, values); - const sum = values .reduce((acc, x, i) => acc + (weight[i] * x), 0); @@ -138,27 +138,29 @@ function get_value(n, index, input, prev, cache) { // compute a network's output and new hidden state // given the input and previous hidden state function network_compute(n, input, state) { + // validate input + if (input.length !== n.input_count) { + throw new Error("incorrect number of input elements"); + } + const hidden_count = n.adjacency.length - n.input_count - n.output_count; + if (state.length !== hidden_count) { + throw new Error("incorrect number of state elements"); + } + // !!! impure caching !!! const value_cache = {}; - const hidden = n.adjacency - .map((x, i) => - ( - (!(is_input(n, i))) && - (!is_output(n, i))) ? i : null) - .filter(i => i !== null); - - const outputs = n.adjacency - .map((x, i) => is_output(n, i) ? i : null) - .filter(i => i !== null); - - const output = Object.freeze( - outputs.map(x => get_value(n, x, input, state, value_cache)) + const result = Object.freeze(n.adjacency + .map((x, i) => is_output(n, i) ? i : null) // output index or null + .filter(i => i !== null) // remove nulls + .map(x => get_value(n, x, input, state, value_cache)) // map to computed value ); - const newstate = Object.freeze( - hidden.map(x => get_value(n, x, input, state, value_cache)) + const newstate = Object.freeze(n.adjacency + .map((x, i) => is_hidden(n, i) ? i : null) // hidden index or null + .filter(i => i !== null) // remove nulls + .map(x => get_value(n, x, input, state, value_cache)) // map to computed value (using cache) ); - return Object.freeze([output, newstate]); + return Object.freeze([result, newstate]); } diff --git a/src/mind/topology.test.js b/src/mind/topology.test.js index fbe1862..b272040 100644 --- a/src/mind/topology.test.js +++ b/src/mind/topology.test.js @@ -210,3 +210,23 @@ test('memory and input', () => { [ Math.tanh( 2-1 ) ], ]); }); + + +test('input and state must be the correct size', () => { + const n = network(2, 1, 1) + .connect(0, 2, 1) + .connect(1, 2, 1) + .connect(2, 3, 1); + + // wrong input size + expect(() => n.compute([], [4])).toThrow(); + expect(() => n.compute([1], [4])).toThrow(); + expect(() => n.compute([1, 1, 1], [4])).toThrow(); + + // wrong state size + expect(() => n.compute([1, 1], [])).toThrow(); + expect(() => n.compute([1, 1], [4, 4])).toThrow(); + + // prove correct sizes work + n.compute([1, 1], [4]); +}); |