summaryrefslogtreecommitdiff
path: root/src/mind
diff options
context:
space:
mode:
Diffstat (limited to 'src/mind')
-rw-r--r--src/mind/topology.js26
-rw-r--r--src/mind/topology.test.js25
2 files changed, 39 insertions, 12 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js
index 19eb399..320b499 100644
--- a/src/mind/topology.js
+++ b/src/mind/topology.js
@@ -91,7 +91,10 @@ function edge_ends(n, edge) {
}
-function get_value(n, index, input) {
+function get_value(n, index, input, cache) {
+ if (cache !== undefined && cache[index]) {
+ return cache[index];
+ }
if (is_input(n, index)) {
return input[index];
}
@@ -102,13 +105,26 @@ function get_value(n, index, input) {
.map(x => edge_ends(n, x).source);
const sum = sources
- .reduce((acc, x, i) => acc + (weight[i] * get_value(n, x, input)), 0);
+ .reduce((acc, x, i) =>
+ acc + (weight[i] * get_value(n, x, input, cache)),
+ 0
+ );
+
+ const value = Math.tanh(sum);
+
+ // !!! impure caching !!!
+ if (cache !== undefined) {
+ cache[index] = value;
+ }
- return Math.tanh(sum);
+ return value;
}
function network_compute(n, input, state) {
+ // !!! impure caching !!!
+ const value_cache = {};
+
const hidden = n.adjacency
.map((x, i) =>
(
@@ -121,11 +137,11 @@ function network_compute(n, input, state) {
.filter(i => i !== null);
const output = Object.freeze(
- outputs.map(x => get_value(n, x, input))
+ outputs.map(x => get_value(n, x, input, value_cache))
);
const newstate = Object.freeze(
- hidden.map(x => get_value(n, x, input))
+ hidden.map(x => get_value(n, x, input, value_cache))
);
return Object.freeze([output, newstate]);
diff --git a/src/mind/topology.test.js b/src/mind/topology.test.js
index 7612c3d..5867763 100644
--- a/src/mind/topology.test.js
+++ b/src/mind/topology.test.js
@@ -169,10 +169,21 @@ test('hidden neurons', () => {
});
-//test('arbitrary hidden neurons', () => {
-// const n = network(1, 2, 1)
-// .connect(0, 1, 1)
-// .connect(1, 2, -1)
-// .connect(2, 3, 2)
-// .connect(3, 4, -2);
-//});
+test('arbitrary hidden neurons', () => {
+ const n = network(1, 2, 1)
+ .connect(0, 1, 1)
+ .connect(1, 2, -1)
+ .connect(2, 3, 2)
+
+ expect(n.compute([1], [0, 0])).toEqual([
+ [ Math.tanh (
+ 2*Math.tanh(
+ -1*Math.tanh( 1 )
+ )
+ ) ],
+ [
+ Math.tanh( -Math.tanh(1) ),
+ Math.tanh(1),
+ ],
+ ]);
+});