'use strict';

import { create } from '../util.js';


const DEFAULT_WEIGHT_MAX = 4;


// prototype for network objects
const network_proto = {
	connect: function(source, sink, weight) {
		return network_connect(this, source, sink, weight);
	},
	compute: function(inputs, state) {
		return network_compute(this, inputs, state);
	},
};


// create a new network
export function network(input_count, internal_count, output_count, weight_max = 4) {
	const count = input_count + internal_count + output_count;
	const n = create({
    input_count,
    output_count,
    adjacency: new Array(count).fill([]),
    weight: [],
  }, network_proto);
	return n;
}


// check index is an input
function is_input(n, index) {
	return index < n.input_count;
}
// check if index is an output
function is_output(n, index) {
	return index >= (n.adjacency.length - n.output_count);
}
// check if index is a hidden neuron
function is_hidden(n, index) {
	return (!is_input(n, index)) && (!is_output(n, index));
}


// returns a new network with an edge between the given nodes
// with the given weight
export function network_connect(n, source, sink, weight) {
	if (is_input(n, sink)) {
		// inputs cannot be sinks
		throw new Error("attempt to use input as sink");
	}
	if (is_output(n, source)) {
		// outputs cannot be sources
		throw new Error("attempt to use output as source");
	}

	return create({
	  ...n,
	  adjacency: n.adjacency.map((row, i) => {
  		if (i === source && i === sink) {
  			// self-loop
  			return [...row, 2];
  		} else if (i === source) {
  			return [...row, 1];
  		} else if (i === sink) {
  			return [...row, -1];
  		} else {
  			return [...row, 0];
  		}
  	}),
	  weight: [...n.weight, weight],
	}, network_proto);
}


// gets the indices of the edges incident on the given adjacency list
function incident_edges(n, adj) {
	const incident = adj
		.map((edge, index) => (edge < 0) || (edge === 2) ? index : null)
		.filter(index => index !== null);
	
	return incident;
}


// get the indices of the ends of an edge
// in the case of self-loops, both values are the same
function edge_ends(n, edge) {
	const ends = n.adjacency
		.map((adj, index) => adj[edge] !== 0 ? index : null)
		.filter(index => index != null);
	
	ends.sort((a, b) => n.adjacency[a][edge] < n.adjacency[b][edge] ? -1 : 1);

	if (ends.length === 1) {
		return { source: ends[0], sink: ends[0] };
	} else if (ends.length === 2) {
		return { source: ends[1], sink: ends[0] };
	} else {
		throw new Error("something bad happened with the ends");
	}
}


// recursively get the value of a node from the input nodes,
// optionally caching the computed values
function get_value(n, index, input, prev, cache) {
	// check if value is cached
	if (cache !== undefined && cache[index]) {
		return cache[index];
	}
	// check if value is input
	if (is_input(n, index)) {
		return input[index];
	}

	const adj = n.adjacency[index];                // get adjacency list
	const incident = incident_edges(n, adj);       // get incident edges
	const weight = incident.map(x => n.weight[x]); // edge weights
	const sources = incident                       // get ancestor nodes
		.map(x => edge_ends(n, x).source);
	
	// get the value of each ancestor
	const values = sources
		.map(x => x === index // if the ancestor is this node
			? prev[x - n.input_count] // then the value is the previous value
			: get_value(n, x, input, prev, cache)); // else recurse
	
	const sum = values // compute the weighted sum of the values
		.reduce((acc, x, i) => acc + (weight[i] * x), 0);

	// compute result
	const value = Math.tanh(sum);

	// !!! impure caching !!!
	// cache result
	if (cache !== undefined) {
		cache[index] = value;
	}
	
	return value;
}


// compute a network's output and new hidden state
// given the input and previous hidden state
export function network_compute(n, input, state) {
	// validate input
	if (input.length !== n.input_count) {
		throw new Error("incorrect number of input elements");
	}
	// validate state
	const hidden_count = n.adjacency.length - n.input_count - n.output_count;
	if (state.length !== hidden_count) {
		throw new Error("incorrect number of state elements");
	}

	// !!! impure caching !!!
	const value_cache = {};

	const result = Object.freeze(n.adjacency
		.map((x, i) => is_output(n, i) ? i : null)            // output index or null
		.filter(i => i !== null)                              // remove nulls
		.map(x => get_value(n, x, input, state, value_cache)) // map to computed value
	);

	const newstate = Object.freeze(n.adjacency
		.map((x, i) => is_hidden(n, i) ? i : null)            // hidden index or null
		.filter(i => i !== null)                              // remove nulls
		.map(x => get_value(n, x, input, state, value_cache)) // map to computed value (using cache)
	);
	
	return Object.freeze([result, newstate]);
}