From db486d5d9a762532f5c7bd45920b01ab63cd08bd Mon Sep 17 00:00:00 2001 From: sanine Date: Mon, 12 Jun 2023 09:24:56 -0500 Subject: clean up network_compute and add input validation --- src/mind/topology.js | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) (limited to 'src/mind/topology.js') diff --git a/src/mind/topology.js b/src/mind/topology.js index 128b7e9..7e58718 100644 --- a/src/mind/topology.js +++ b/src/mind/topology.js @@ -35,6 +35,10 @@ function is_input(n, index) { function is_output(n, index) { return index >= (n.adjacency.length - n.output_count); } +// check if index is a hidden neuron +function is_hidden(n, index) { + return (!is_input(n, index)) && (!is_output(n, index)); +} // returns a new network with an edge between the given nodes @@ -76,8 +80,6 @@ function incident_edges(n, adj) { .map((edge, index) => (edge < 0) || (edge === 2) ? index : null) .filter(index => index !== null); - console.log(incident); - return incident; } @@ -119,8 +121,6 @@ function get_value(n, index, input, prev, cache) { const values = sources .map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache)); - console.log(n, index, sources, values); - const sum = values .reduce((acc, x, i) => acc + (weight[i] * x), 0); @@ -138,27 +138,29 @@ function get_value(n, index, input, prev, cache) { // compute a network's output and new hidden state // given the input and previous hidden state function network_compute(n, input, state) { + // validate input + if (input.length !== n.input_count) { + throw new Error("incorrect number of input elements"); + } + const hidden_count = n.adjacency.length - n.input_count - n.output_count; + if (state.length !== hidden_count) { + throw new Error("incorrect number of state elements"); + } + // !!! impure caching !!! const value_cache = {}; - const hidden = n.adjacency - .map((x, i) => - ( - (!(is_input(n, i))) && - (!is_output(n, i))) ? i : null) - .filter(i => i !== null); - - const outputs = n.adjacency - .map((x, i) => is_output(n, i) ? i : null) - .filter(i => i !== null); - - const output = Object.freeze( - outputs.map(x => get_value(n, x, input, state, value_cache)) + const result = Object.freeze(n.adjacency + .map((x, i) => is_output(n, i) ? i : null) // output index or null + .filter(i => i !== null) // remove nulls + .map(x => get_value(n, x, input, state, value_cache)) // map to computed value ); - const newstate = Object.freeze( - hidden.map(x => get_value(n, x, input, state, value_cache)) + const newstate = Object.freeze(n.adjacency + .map((x, i) => is_hidden(n, i) ? i : null) // hidden index or null + .filter(i => i !== null) // remove nulls + .map(x => get_value(n, x, input, state, value_cache)) // map to computed value (using cache) ); - return Object.freeze([output, newstate]); + return Object.freeze([result, newstate]); } -- cgit v1.2.1