module Mind
  ( NeuronIndex (..)
  , getNeuronIndex

  , Edge (..)
  
  , Network (..)
  , createEmptyNetwork
  , connectNeurons
  , compute
  ) where

import Data.Ix
import Data.Maybe

-- index different neuron types
data NeuronIndex = Input Int | Internal Int | Output Int deriving (Show, Eq)

getNeuronIndex :: NeuronIndex -> Int
getNeuronIndex (Input i)    = i
getNeuronIndex (Internal i) = i
getNeuronIndex (Output i)   = i

-- define incident edges
newtype Edge = Edge (NeuronIndex, Float) deriving (Show, Eq)

-- define networks
data Network = Network 
  { numInput :: Int
  , internalNeurons :: [[Edge]]
  , outputNeurons :: [[Edge]]
  } deriving (Show, Eq)

-- create a completely empty network
createEmptyNetwork :: Int -> Int -> Int -> Network
createEmptyNetwork i h o = Network i (replicate h []) (replicate o [])


-- connect two neurons together with a new edge
connectNeurons :: Network -> NeuronIndex -> NeuronIndex -> Float -> Maybe Network
-- internal sink
connectNeurons (Network i h o) source (Internal sink) weight = 
  if (validSource (Network i h o) source) then do
    newH <- insertEdge h sink $ Edge (source, weight)
    return $ Network i newH o
  else Nothing
-- output sink
connectNeurons (Network i h o) source (Output sink) weight = 
  if (validSource (Network i h o) source) then do
    newO <- insertEdge o sink $ Edge (source, weight)
    return $ Network i h newO
  else Nothing
-- 
connectNeurons _ _ (Input _) _ = Nothing



-- helpers for connectNeurons

-- check if a given NeuronIndex can be used as a valid source
validSource :: Network -> NeuronIndex -> Bool
validSource _ (Output _) = False
validSource (Network i _ _) (Input x) = 
  if (inRange (0, i) x) 
  then True else False
validSource (Network _ h _) (Internal x) = 
  if (inRange (0, length h) x)
  then True else False

-- insert a new edge into a neuron list, possibly failing
insertEdge :: [[Edge]] -> Int -> Edge -> Maybe [[Edge]]
insertEdge ns i e 
  | (inRange (0, length ns) i) = let (front, es:back) = splitAt i ns
    in Just $ front ++ [e:es] ++ back
  | otherwise = Nothing


-- network computation
compute :: Network -> [Float] -> [Float] -> Maybe ([Float], [Float])
compute net input state = 
  let 
    s = newState net input state
    state' = map (Just) s
  in Just $ (output net input state state', s)


type InputState = ([Float], [Float])
type NewState = [Maybe Float]


output net input state state'= 
  let 
    numOutput = length $ outputNeurons net
  in 
    map ((fst . getValue net (input, state) state') . Output) [0..numOutput-1]

newState :: Network -> [Float] -> [Float] -> [Float]
newState net input state =
  let numInternal = length $ internalNeurons net
  in fst $ 
    foldr 
    (\x (r, ns) ->
      let (value, ns') = getValue net (input, state) ns (Internal x)
      in (value:r, ns')
    )
    ([], replicate numInternal Nothing)
    [0..numInternal-1]



updateValue :: NewState -> Int -> Float -> NewState
updateValue state' index value =
  let (front, _:back) = splitAt index state'
  in front ++ (Just value):back


getValue :: Network -> InputState -> NewState -> NeuronIndex -> (Float, NewState)
getValue _ (input, _) state' (Input x) = (input !! x, state')
getValue net inputState state' (Internal x) =
  let cached = state' !! x
  in
    if isJust cached then (fromJust cached, state')
    else let 
      (value, ns) = foldEdges net inputState state' (Internal x) (internalNeurons net !! x)
      nss = updateValue ns x value
    in (value, nss)
getValue net inputState state' (Output x) =
  foldEdges net inputState state' (Output x) (outputNeurons net !! x)



foldEdges:: Network -> InputState -> NewState -> NeuronIndex -> [Edge] -> (Float, NewState)
foldEdges net (input, state) state' sink edges =
  let 
    (total, ns) = foldl
      (\(total, ns) (Edge (source, w)) ->
        let 
          (value, ns') = if (sink == source) 
            then (state !! (getNeuronIndex source), ns)
            else getValue net (input, state) ns source
          total' = (w * value) + total
        in (total', ns')
      )
      (0, state')
      edges
  in (tanh total, ns)