module Mind ( NeuronIndex (..) , getNeuronIndex , Edge (..) , Network (..) , createEmptyNetwork , connectNeurons , compute ) where import Data.Ix import Data.Maybe -- index different neuron types data NeuronIndex = Input Int | Internal Int | Output Int deriving (Show, Eq, Ord) getNeuronIndex :: NeuronIndex -> Int getNeuronIndex (Input i) = i getNeuronIndex (Internal i) = i getNeuronIndex (Output i) = i -- define incident edges newtype Edge = Edge (NeuronIndex, Float) deriving (Show, Eq) -- define networks data Network = Network { numInput :: Int , internalNeurons :: [[Edge]] , outputNeurons :: [[Edge]] } deriving (Show, Eq) -- create a completely empty network createEmptyNetwork :: Int -> Int -> Int -> Network createEmptyNetwork i h o = Network i (replicate h []) (replicate o []) -- connect two neurons together with a new edge connectNeurons :: Network -> NeuronIndex -> NeuronIndex -> Float -> Maybe Network -- internal sink connectNeurons (Network i h o) source (Internal sink) weight = if (validSource (Network i h o) source) then do newH <- insertEdge h sink $ Edge (source, weight) return $ Network i newH o else Nothing -- output sink connectNeurons (Network i h o) source (Output sink) weight = if (validSource (Network i h o) source) then do newO <- insertEdge o sink $ Edge (source, weight) return $ Network i h newO else Nothing -- connectNeurons _ _ (Input _) _ = Nothing -- helpers for connectNeurons -- check if a given NeuronIndex can be used as a valid source validSource :: Network -> NeuronIndex -> Bool validSource _ (Output _) = False validSource (Network i _ _) (Input x) = if (inRange (0, i) x) then True else False validSource (Network _ h _) (Internal x) = if (inRange (0, length h) x) then True else False -- insert a new edge into a neuron list, possibly failing insertEdge :: [[Edge]] -> Int -> Edge -> Maybe [[Edge]] insertEdge ns i e | (inRange (0, (length ns)-1) i) = let (front, es:back) = splitAt i ns in Just $ front ++ [e:es] ++ back | otherwise = Nothing -- network computation compute :: Network -> [Float] -> [Float] -> Either String ([Float], [Float]) compute net input state | (length input) /= (numInput net) = Left $ "Bad input length: " ++ (show $ length input) | (length state) /= (length $ internalNeurons net) = Left $ "Bad state length: " ++ (show $ length state) | otherwise = let s = newState net input state state' = map (Just) s in Right $ (output net input state state', s) type InputState = ([Float], [Float]) type NewState = [Maybe Float] output :: Network -> [Float] -> [Float] -> NewState -> [Float] output net input state state' = let numOutput = length $ outputNeurons net in map ((fst . getValue net (input, state) state') . Output) [0..numOutput-1] newState :: Network -> [Float] -> [Float] -> [Float] newState net input state = let numInternal = length $ internalNeurons net in fst $ foldr (\x (r, ns) -> let (value, ns') = getValue net (input, state) ns (Internal x) in (value:r, ns') ) ([], replicate numInternal Nothing) [0..numInternal-1] updateValue :: NewState -> Int -> Float -> NewState updateValue state' i value = let (front, _:back) = splitAt i state' in front ++ (Just value):back getValue :: Network -> InputState -> NewState -> NeuronIndex -> (Float, NewState) getValue _ (input, _) state' (Input x) = (input !! x, state') getValue net inputState state' (Internal x) = let cached = state' !! x in if isJust cached then (fromJust cached, state') else let (value, ns) = foldEdges net inputState state' (Internal x) (internalNeurons net !! x) nss = updateValue ns x value in (value, nss) getValue net inputState state' (Output x) = foldEdges net inputState state' (Output x) (outputNeurons net !! x) foldEdges:: Network -> InputState -> NewState -> NeuronIndex -> [Edge] -> (Float, NewState) foldEdges net (input, state) state' sink edges = let (t, ns) = foldl (\(total, nss) (Edge (source, w)) -> let (value, nss') = if (sink == source) then (state !! (getNeuronIndex source), ns) else getValue net (input, state) nss source total' = (w * value) + total in (total', nss') ) (0, state') edges in (tanh t, ns)