summaryrefslogtreecommitdiff
path: root/src/mind
diff options
context:
space:
mode:
Diffstat (limited to 'src/mind')
-rw-r--r--src/mind/topology.js22
-rw-r--r--src/mind/topology.test.js23
2 files changed, 36 insertions, 9 deletions
diff --git a/src/mind/topology.js b/src/mind/topology.js
index 320b499..56bc498 100644
--- a/src/mind/topology.js
+++ b/src/mind/topology.js
@@ -65,8 +65,10 @@ function network_connect(n, source, sink, weight) {
function incident_edges(n, adj) {
const incident = adj
- .map((edge, index) => edge < 0 ? index : null)
+ .map((edge, index) => (edge < 0) || (edge === 2) ? index : null)
.filter(index => index !== null);
+
+ console.log(incident);
return incident;
}
@@ -91,7 +93,7 @@ function edge_ends(n, edge) {
}
-function get_value(n, index, input, cache) {
+function get_value(n, index, input, prev, cache) {
if (cache !== undefined && cache[index]) {
return cache[index];
}
@@ -103,12 +105,14 @@ function get_value(n, index, input, cache) {
const weight = incident.map(x => n.weight[x]);
const sources = incident
.map(x => edge_ends(n, x).source);
+
+ const values = sources
+ .map(x => x === index ? prev[x - n.input_count] : get_value(n, x, input, prev, cache));
+
+ console.log(n, index, sources, values);
- const sum = sources
- .reduce((acc, x, i) =>
- acc + (weight[i] * get_value(n, x, input, cache)),
- 0
- );
+ const sum = values
+ .reduce((acc, x, i) => acc + (weight[i] * x), 0);
const value = Math.tanh(sum);
@@ -137,11 +141,11 @@ function network_compute(n, input, state) {
.filter(i => i !== null);
const output = Object.freeze(
- outputs.map(x => get_value(n, x, input, value_cache))
+ outputs.map(x => get_value(n, x, input, state, value_cache))
);
const newstate = Object.freeze(
- hidden.map(x => get_value(n, x, input, value_cache))
+ hidden.map(x => get_value(n, x, input, state, value_cache))
);
return Object.freeze([output, newstate]);
diff --git a/src/mind/topology.test.js b/src/mind/topology.test.js
index 5867763..fbe1862 100644
--- a/src/mind/topology.test.js
+++ b/src/mind/topology.test.js
@@ -187,3 +187,26 @@ test('arbitrary hidden neurons', () => {
],
]);
});
+
+
+test('memory', () => {
+ const n = network(0, 1, 1).connect(0, 0, -0.5).connect(0, 1, 2);
+
+ expect(n.compute([], [1])).toEqual([
+ [ Math.tanh( 2 * Math.tanh( -0.5 * 1 ) ) ],
+ [ Math.tanh( -0.5 * 1) ],
+ ]);
+});
+
+
+test('memory and input', () => {
+ const n = network(1, 1, 1)
+ .connect(0, 1, 1)
+ .connect(1, 1, 1)
+ .connect(1, 2, 1);
+
+ expect(n.compute([2], [-1])).toEqual([
+ [ Math.tanh( Math.tanh( 2-1 ) ) ],
+ [ Math.tanh( 2-1 ) ],
+ ]);
+});