diff options
author | sanine <sanine.not@pm.me> | 2023-06-12 09:50:44 -0500 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-06-12 09:50:44 -0500 |
commit | c079d06fdeeaa5e7a7b007e5bbc35a2d19bbc316 (patch) | |
tree | d1d6dd7ed1945af89ec5f27b6f9ee614d87147f4 /src | |
parent | db486d5d9a762532f5c7bd45920b01ab63cd08bd (diff) |
add some better comments
Diffstat (limited to 'src')
-rw-r--r-- | src/mind/topology.js | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js index 7e58718..5d4d52c 100644 --- a/src/mind/topology.js +++ b/src/mind/topology.js @@ -106,27 +106,34 @@ function edge_ends(n, edge) { // recursively get the value of a node from the input nodes, // optionally caching the computed values function get_value(n, index, input, prev, cache) { + // check if value is cached if (cache !== undefined && cache[index]) { return cache[index]; } + // check if value is input if (is_input(n, index)) { return input[index]; } - const adj = n.adjacency[index]; - const incident = incident_edges(n, adj); - const weight = incident.map(x => n.weight[x]); - const sources = incident + + const adj = n.adjacency[index]; // get adjacency list + const incident = incident_edges(n, adj); // get incident edges + const weight = incident.map(x => n.weight[x]); // edge weights + const sources = incident // get ancestor nodes .map(x => edge_ends(n, x).source); - const values = sources - .map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache)); + const values = sources // get the value of each ancestor + .map(x => x === index // if the ancestor is this node + ? prev[x - n.input_count] // then the value is the previous value + : get_value(n, x, input, prev, cache)); // else recurse - const sum = values + const sum = values // compute the weighted sum of the values .reduce((acc, x, i) => acc + (weight[i] * x), 0); + // compute result const value = Math.tanh(sum); // !!! impure caching !!! + // cache result if (cache !== undefined) { cache[index] = value; } @@ -142,6 +149,7 @@ function network_compute(n, input, state) { if (input.length !== n.input_count) { throw new Error("incorrect number of input elements"); } + // validate state const hidden_count = n.adjacency.length - n.input_count - n.output_count; if (state.length !== hidden_count) { throw new Error("incorrect number of state elements"); |