diff options
author | sanine <sanine.not@pm.me> | 2023-11-21 23:41:38 -0600 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-11-21 23:41:38 -0600 |
commit | 435c52c7330bcd49328a8facfc5a11b00e4a41bf (patch) | |
tree | 3431819acd8d6e2e2efd11c476510990abe739ef | |
parent | 26be9f9dc7d57aa76e6c5eb4dd95681df3bde309 (diff) |
implement self-connected computation
-rw-r--r-- | src/Mind.hs | 12 | ||||
-rw-r--r-- | test/MindTest.hs | 12 |
2 files changed, 19 insertions, 5 deletions
diff --git a/src/Mind.hs b/src/Mind.hs index cb78a1e..70b7494 100644 --- a/src/Mind.hs +++ b/src/Mind.hs @@ -121,21 +121,23 @@ getValue net inputState state' (Internal x) = in if isJust cached then (fromJust cached, state') else let - (value, ns) = foldEdges net inputState state' (internalNeurons net !! x) + (value, ns) = foldEdges net inputState state' (Internal x) (internalNeurons net !! x) nss = updateValue ns x value in (value, nss) getValue net inputState state' (Output x) = - foldEdges net inputState state' (outputNeurons net !! x) + foldEdges net inputState state' (Output x) (outputNeurons net !! x) -foldEdges:: Network -> InputState -> NewState -> [Edge] -> (Float, NewState) -foldEdges net (input, state) state' edges = +foldEdges:: Network -> InputState -> NewState -> NeuronIndex -> [Edge] -> (Float, NewState) +foldEdges net (input, state) state' sink edges = let (total, ns) = foldl (\(total, ns) (Edge (source, w)) -> let - (value, ns') = getValue net (input, state) ns source + (value, ns') = if (sink == source) + then (state !! (getNeuronIndex source), ns) + else getValue net (input, state) ns source total' = (w * value) + total in (total', ns') ) diff --git a/test/MindTest.hs b/test/MindTest.hs index 7b208b4..efa79ef 100644 --- a/test/MindTest.hs +++ b/test/MindTest.hs @@ -131,4 +131,16 @@ networkTests = testGroup "network tests" $ , tanh (5-3) ] ) + , testCase "computing with self-connection" $ + let + net = Network 0 + -- hidden neurons + [ [ Edge (Internal 0, negate 0.5) ] ] + -- output neurons + [ [ Edge (Internal 0, 2) ] ] + Just result = compute net [] [1.0] + in result @?= + ( [ tanh $ 2 * (tanh (negate 0.5)) ] + , [ tanh (negate 0.5) ] + ) ] |