summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsanine <sanine.not@pm.me>2023-06-12 09:24:56 -0500
committersanine <sanine.not@pm.me>2023-06-12 09:24:56 -0500
commitdb486d5d9a762532f5c7bd45920b01ab63cd08bd (patch)
tree3ec0d1d1715d92ee9fb09a4d54768d2517b65c3d
parent251f39da74c8d5707eaeef8d5e63ce442720b01f (diff)
clean up network_compute and add input validation
-rw-r--r--src/mind/topology.js42
-rw-r--r--src/mind/topology.test.js20
2 files changed, 42 insertions, 20 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js
index 128b7e9..7e58718 100644
--- a/src/mind/topology.js
+++ b/src/mind/topology.js
@@ -35,6 +35,10 @@ function is_input(n, index) {
function is_output(n, index) {
return index >= (n.adjacency.length - n.output_count);
}
+// check if index is a hidden neuron
+function is_hidden(n, index) {
+ return (!is_input(n, index)) && (!is_output(n, index));
+}
// returns a new network with an edge between the given nodes
@@ -76,8 +80,6 @@ function incident_edges(n, adj) {
.map((edge, index) => (edge < 0) || (edge === 2) ? index : null)
.filter(index => index !== null);
- console.log(incident);
-
return incident;
}
@@ -119,8 +121,6 @@ function get_value(n, index, input, prev, cache) {
const values = sources
.map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache));
- console.log(n, index, sources, values);
-
const sum = values
.reduce((acc, x, i) => acc + (weight[i] * x), 0);
@@ -138,27 +138,29 @@ function get_value(n, index, input, prev, cache) {
// compute a network's output and new hidden state
// given the input and previous hidden state
function network_compute(n, input, state) {
+ // validate input
+ if (input.length !== n.input_count) {
+ throw new Error("incorrect number of input elements");
+ }
+ const hidden_count = n.adjacency.length - n.input_count - n.output_count;
+ if (state.length !== hidden_count) {
+ throw new Error("incorrect number of state elements");
+ }
+
// !!! impure caching !!!
const value_cache = {};
- const hidden = n.adjacency
- .map((x, i) =>
- (
- (!(is_input(n, i))) &&
- (!is_output(n, i))) ? i : null)
- .filter(i => i !== null);
-
- const outputs = n.adjacency
- .map((x, i) => is_output(n, i) ? i : null)
- .filter(i => i !== null);
-
- const output = Object.freeze(
- outputs.map(x => get_value(n, x, input, state, value_cache))
+ const result = Object.freeze(n.adjacency
+ .map((x, i) => is_output(n, i) ? i : null) // output index or null
+ .filter(i => i !== null) // remove nulls
+ .map(x => get_value(n, x, input, state, value_cache)) // map to computed value
);
- const newstate = Object.freeze(
- hidden.map(x => get_value(n, x, input, state, value_cache))
+ const newstate = Object.freeze(n.adjacency
+ .map((x, i) => is_hidden(n, i) ? i : null) // hidden index or null
+ .filter(i => i !== null) // remove nulls
+ .map(x => get_value(n, x, input, state, value_cache)) // map to computed value (using cache)
);
- return Object.freeze([output, newstate]);
+ return Object.freeze([result, newstate]);
}
diff --git a/src/mind/topology.test.js b/src/mind/topology.test.js
index fbe1862..b272040 100644
--- a/src/mind/topology.test.js
+++ b/src/mind/topology.test.js
@@ -210,3 +210,23 @@ test('memory and input', () => {
[ Math.tanh( 2-1 ) ],
]);
});
+
+
+test('input and state must be the correct size', () => {
+ const n = network(2, 1, 1)
+ .connect(0, 2, 1)
+ .connect(1, 2, 1)
+ .connect(2, 3, 1);
+
+ // wrong input size
+ expect(() => n.compute([], [4])).toThrow();
+ expect(() => n.compute([1], [4])).toThrow();
+ expect(() => n.compute([1, 1, 1], [4])).toThrow();
+
+ // wrong state size
+ expect(() => n.compute([1, 1], [])).toThrow();
+ expect(() => n.compute([1, 1], [4, 4])).toThrow();
+
+ // prove correct sizes work
+ n.compute([1, 1], [4]);
+});