summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsanine <sanine.not@pm.me>2023-11-21 21:06:20 -0600
committersanine <sanine.not@pm.me>2023-11-21 21:06:20 -0600
commita005bcdf7b24ef60ef0a9336e27801e8d8ec70ad (patch)
treeb8d2765dbc46e6df36f5fd881aef402db3019afd
parentfae22e68282336fc0d7f0efda236410a294b7eb5 (diff)
implement output computation
-rw-r--r--src/Mind.hs51
-rw-r--r--test/MindTest.hs4
2 files changed, 52 insertions, 3 deletions
diff --git a/src/Mind.hs b/src/Mind.hs
index 2971476..1ee70bc 100644
--- a/src/Mind.hs
+++ b/src/Mind.hs
@@ -11,6 +11,7 @@ module Mind
) where
import Data.Ix
+import Data.Maybe
-- index different neuron types
data NeuronIndex = Input Int | Internal Int | Output Int deriving (Show, Eq)
@@ -76,4 +77,52 @@ insertEdge ns i e
-- network computation
compute :: Network -> [Float] -> [Float] -> Maybe ([Float], [Float])
-compute net input state = undefined
+compute net input state = Just $ (output net input state, [])
+
+
+type InputState = ([Float], [Float])
+type NewState = [Maybe Float]
+
+
+output net input state =
+ let
+ state' = replicate (length $ internalNeurons net) Nothing
+ numOutput = length $ outputNeurons net
+ in
+ map ((fst . getValue net (input, state) state') . Output) [0..numOutput-1]
+
+
+updateValue :: NewState -> Int -> Float -> NewState
+updateValue state' index value =
+ let (front, _:back) = splitAt index state'
+ in front ++ (Just value):back
+
+
+getValue :: Network -> InputState -> NewState -> NeuronIndex -> (Float, NewState)
+getValue _ (input, _) state' (Input x) = (input !! x, state')
+getValue net inputState state' (Internal x) =
+ let cached = state' !! x
+ in
+ if isJust cached then (fromJust cached, state')
+ else let
+ (value, ns) = foldEdges net inputState state' (internalNeurons net !! x)
+ nss = updateValue ns x value
+ in (value, nss)
+getValue net inputState state' (Output x) =
+ foldEdges net inputState state' (outputNeurons net !! x)
+
+
+
+foldEdges:: Network -> InputState -> NewState -> [Edge] -> (Float, NewState)
+foldEdges net (input, state) state' edges =
+ let
+ (total, ns) = foldl
+ (\(total, ns) (Edge (source, w)) ->
+ let
+ (value, ns') = getValue net (input, state) ns source
+ total' = (w * value) + total
+ in (total', ns')
+ )
+ (0, state')
+ edges
+ in (tanh total, ns)
diff --git a/test/MindTest.hs b/test/MindTest.hs
index cf5c50a..7b208b4 100644
--- a/test/MindTest.hs
+++ b/test/MindTest.hs
@@ -92,8 +92,8 @@ networkTests = testGroup "network tests" $
, testCase "multiple inputs, multiple outputs" $
let
net = Network 4 []
- [ [ Edge (Input 0, 1.0)
- , Edge (Input 1, negate 2.0)
+ [ [ Edge (Input 0, negate 1.0)
+ , Edge (Input 1, 1.0)
]
, [ Edge (Input 2, negate 1.0)
, Edge (Input 3, 1.0)