diff options
author | sanine <sanine.not@pm.me> | 2023-11-21 21:06:20 -0600 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-11-21 21:06:20 -0600 |
commit | a005bcdf7b24ef60ef0a9336e27801e8d8ec70ad (patch) | |
tree | b8d2765dbc46e6df36f5fd881aef402db3019afd | |
parent | fae22e68282336fc0d7f0efda236410a294b7eb5 (diff) |
implement output computation
-rw-r--r-- | src/Mind.hs | 51 | ||||
-rw-r--r-- | test/MindTest.hs | 4 |
2 files changed, 52 insertions, 3 deletions
diff --git a/src/Mind.hs b/src/Mind.hs index 2971476..1ee70bc 100644 --- a/src/Mind.hs +++ b/src/Mind.hs @@ -11,6 +11,7 @@ module Mind ) where import Data.Ix +import Data.Maybe -- index different neuron types data NeuronIndex = Input Int | Internal Int | Output Int deriving (Show, Eq) @@ -76,4 +77,52 @@ insertEdge ns i e -- network computation compute :: Network -> [Float] -> [Float] -> Maybe ([Float], [Float]) -compute net input state = undefined +compute net input state = Just $ (output net input state, []) + + +type InputState = ([Float], [Float]) +type NewState = [Maybe Float] + + +output net input state = + let + state' = replicate (length $ internalNeurons net) Nothing + numOutput = length $ outputNeurons net + in + map ((fst . getValue net (input, state) state') . Output) [0..numOutput-1] + + +updateValue :: NewState -> Int -> Float -> NewState +updateValue state' index value = + let (front, _:back) = splitAt index state' + in front ++ (Just value):back + + +getValue :: Network -> InputState -> NewState -> NeuronIndex -> (Float, NewState) +getValue _ (input, _) state' (Input x) = (input !! x, state') +getValue net inputState state' (Internal x) = + let cached = state' !! x + in + if isJust cached then (fromJust cached, state') + else let + (value, ns) = foldEdges net inputState state' (internalNeurons net !! x) + nss = updateValue ns x value + in (value, nss) +getValue net inputState state' (Output x) = + foldEdges net inputState state' (outputNeurons net !! x) + + + +foldEdges:: Network -> InputState -> NewState -> [Edge] -> (Float, NewState) +foldEdges net (input, state) state' edges = + let + (total, ns) = foldl + (\(total, ns) (Edge (source, w)) -> + let + (value, ns') = getValue net (input, state) ns source + total' = (w * value) + total + in (total', ns') + ) + (0, state') + edges + in (tanh total, ns) diff --git a/test/MindTest.hs b/test/MindTest.hs index cf5c50a..7b208b4 100644 --- a/test/MindTest.hs +++ b/test/MindTest.hs @@ -92,8 +92,8 @@ networkTests = testGroup "network tests" $ , testCase "multiple inputs, multiple outputs" $ let net = Network 4 [] - [ [ Edge (Input 0, 1.0) - , Edge (Input 1, negate 2.0) + [ [ Edge (Input 0, negate 1.0) + , Edge (Input 1, 1.0) ] , [ Edge (Input 2, negate 1.0) , Edge (Input 3, 1.0) |