summaryrefslogtreecommitdiff
path: root/src/mind/topology.js
diff options
context:
space:
mode:
authorsanine <sanine.not@pm.me>2023-06-12 09:50:44 -0500
committersanine <sanine.not@pm.me>2023-06-12 09:50:44 -0500
commitc079d06fdeeaa5e7a7b007e5bbc35a2d19bbc316 (patch)
treed1d6dd7ed1945af89ec5f27b6f9ee614d87147f4 /src/mind/topology.js
parentdb486d5d9a762532f5c7bd45920b01ab63cd08bd (diff)
add some better comments
Diffstat (limited to 'src/mind/topology.js')
-rw-r--r--src/mind/topology.js22
1 files changed, 15 insertions, 7 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js
index 7e58718..5d4d52c 100644
--- a/src/mind/topology.js
+++ b/src/mind/topology.js
@@ -106,27 +106,34 @@ function edge_ends(n, edge) {
// recursively get the value of a node from the input nodes,
// optionally caching the computed values
function get_value(n, index, input, prev, cache) {
+ // check if value is cached
if (cache !== undefined && cache[index]) {
return cache[index];
}
+ // check if value is input
if (is_input(n, index)) {
return input[index];
}
- const adj = n.adjacency[index];
- const incident = incident_edges(n, adj);
- const weight = incident.map(x => n.weight[x]);
- const sources = incident
+
+ const adj = n.adjacency[index]; // get adjacency list
+ const incident = incident_edges(n, adj); // get incident edges
+ const weight = incident.map(x => n.weight[x]); // edge weights
+ const sources = incident // get ancestor nodes
.map(x => edge_ends(n, x).source);
- const values = sources
- .map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache));
+ const values = sources // get the value of each ancestor
+ .map(x => x === index // if the ancestor is this node
+ ? prev[x - n.input_count] // then the value is the previous value
+ : get_value(n, x, input, prev, cache)); // else recurse
- const sum = values
+ const sum = values // compute the weighted sum of the values
.reduce((acc, x, i) => acc + (weight[i] * x), 0);
+ // compute result
const value = Math.tanh(sum);
// !!! impure caching !!!
+ // cache result
if (cache !== undefined) {
cache[index] = value;
}
@@ -142,6 +149,7 @@ function network_compute(n, input, state) {
if (input.length !== n.input_count) {
throw new Error("incorrect number of input elements");
}
+ // validate state
const hidden_count = n.adjacency.length - n.input_count - n.output_count;
if (state.length !== hidden_count) {
throw new Error("incorrect number of state elements");