summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Mind.hs12
-rw-r--r--test/MindTest.hs12
2 files changed, 19 insertions, 5 deletions
diff --git a/src/Mind.hs b/src/Mind.hs
index cb78a1e..70b7494 100644
--- a/src/Mind.hs
+++ b/src/Mind.hs
@@ -121,21 +121,23 @@ getValue net inputState state' (Internal x) =
in
if isJust cached then (fromJust cached, state')
else let
- (value, ns) = foldEdges net inputState state' (internalNeurons net !! x)
+ (value, ns) = foldEdges net inputState state' (Internal x) (internalNeurons net !! x)
nss = updateValue ns x value
in (value, nss)
getValue net inputState state' (Output x) =
- foldEdges net inputState state' (outputNeurons net !! x)
+ foldEdges net inputState state' (Output x) (outputNeurons net !! x)
-foldEdges:: Network -> InputState -> NewState -> [Edge] -> (Float, NewState)
-foldEdges net (input, state) state' edges =
+foldEdges:: Network -> InputState -> NewState -> NeuronIndex -> [Edge] -> (Float, NewState)
+foldEdges net (input, state) state' sink edges =
let
(total, ns) = foldl
(\(total, ns) (Edge (source, w)) ->
let
- (value, ns') = getValue net (input, state) ns source
+ (value, ns') = if (sink == source)
+ then (state !! (getNeuronIndex source), ns)
+ else getValue net (input, state) ns source
total' = (w * value) + total
in (total', ns')
)
diff --git a/test/MindTest.hs b/test/MindTest.hs
index 7b208b4..efa79ef 100644
--- a/test/MindTest.hs
+++ b/test/MindTest.hs
@@ -131,4 +131,16 @@ networkTests = testGroup "network tests" $
, tanh (5-3)
]
)
+ , testCase "computing with self-connection" $
+ let
+ net = Network 0
+ -- hidden neurons
+ [ [ Edge (Internal 0, negate 0.5) ] ]
+ -- output neurons
+ [ [ Edge (Internal 0, 2) ] ]
+ Just result = compute net [] [1.0]
+ in result @?=
+ ( [ tanh $ 2 * (tanh (negate 0.5)) ]
+ , [ tanh (negate 0.5) ]
+ )
]