'use strict'; import { create } from '../util'; const DEFAULT_WEIGHT_MAX = 4; // prototype for network objects const network_proto = { connect: function(source, sink, weight) { return network_connect(this, source, sink, weight); }, compute: function(inputs, state) { return network_compute(this, inputs, state); }, }; // create a new network export function network(input_count, internal_count, output_count, weight_max = 4) { const count = input_count + internal_count + output_count; const n = create({ input_count, output_count, adjacency: new Array(count).fill([]), weight: [], }, network_proto); return n; } // check index is an input function is_input(n, index) { return index < n.input_count; } // check if index is an output function is_output(n, index) { return index >= (n.adjacency.length - n.output_count); } // check if index is a hidden neuron function is_hidden(n, index) { return (!is_input(n, index)) && (!is_output(n, index)); } // returns a new network with an edge between the given nodes // with the given weight function network_connect(n, source, sink, weight) { if (is_input(n, sink)) { // inputs cannot be sinks throw new Error("attempt to use input as sink"); } if (is_output(n, source)) { // outputs cannot be sources throw new Error("attempt to use output as source"); } return create({ ...n, adjacency: n.adjacency.map((row, i) => { if (i === source && i === sink) { // self-loop return [...row, 2]; } else if (i === source) { return [...row, 1]; } else if (i === sink) { return [...row, -1]; } else { return [...row, 0]; } }), weight: [...n.weight, weight], }, network_proto); } // gets the indices of the edges incident on the given adjacency list function incident_edges(n, adj) { const incident = adj .map((edge, index) => (edge < 0) || (edge === 2) ? index : null) .filter(index => index !== null); return incident; } // get the indices of the ends of an edge // in the case of self-loops, both values are the same function edge_ends(n, edge) { const ends = n.adjacency .map((adj, index) => adj[edge] !== 0 ? index : null) .filter(index => index != null); ends.sort((a, b) => n.adjacency[a][edge] < n.adjacency[b][edge] ? -1 : 1); if (ends.length === 1) { return { source: ends[0], sink: ends[0] }; } else if (ends.length === 2) { return { source: ends[1], sink: ends[0] }; } else { throw new Error("something bad happened with the ends"); } } // recursively get the value of a node from the input nodes, // optionally caching the computed values function get_value(n, index, input, prev, cache) { // check if value is cached if (cache !== undefined && cache[index]) { return cache[index]; } // check if value is input if (is_input(n, index)) { return input[index]; } const adj = n.adjacency[index]; // get adjacency list const incident = incident_edges(n, adj); // get incident edges const weight = incident.map(x => n.weight[x]); // edge weights const sources = incident // get ancestor nodes .map(x => edge_ends(n, x).source); // get the value of each ancestor const values = sources .map(x => x === index // if the ancestor is this node ? prev[x - n.input_count] // then the value is the previous value : get_value(n, x, input, prev, cache)); // else recurse const sum = values // compute the weighted sum of the values .reduce((acc, x, i) => acc + (weight[i] * x), 0); // compute result const value = Math.tanh(sum); // !!! impure caching !!! // cache result if (cache !== undefined) { cache[index] = value; } return value; } // compute a network's output and new hidden state // given the input and previous hidden state function network_compute(n, input, state) { // validate input if (input.length !== n.input_count) { throw new Error("incorrect number of input elements"); } // validate state const hidden_count = n.adjacency.length - n.input_count - n.output_count; if (state.length !== hidden_count) { throw new Error("incorrect number of state elements"); } // !!! impure caching !!! const value_cache = {}; const result = Object.freeze(n.adjacency .map((x, i) => is_output(n, i) ? i : null) // output index or null .filter(i => i !== null) // remove nulls .map(x => get_value(n, x, input, state, value_cache)) // map to computed value ); const newstate = Object.freeze(n.adjacency .map((x, i) => is_hidden(n, i) ? i : null) // hidden index or null .filter(i => i !== null) // remove nulls .map(x => get_value(n, x, input, state, value_cache)) // map to computed value (using cache) ); return Object.freeze([result, newstate]); }