diff options
author | sanine <sanine.not@pm.me> | 2023-11-22 10:39:33 -0600 |
---|---|---|
committer | sanine <sanine.not@pm.me> | 2023-11-22 10:39:33 -0600 |
commit | e8d7a5237c666a40e14f3709289329fd8c2cb7d2 (patch) | |
tree | 673f843f98b00a8868de6705c491777088a86842 | |
parent | 8fb358e84770f69606f7f27c40cfdf0ce57cd026 (diff) |
show error messages on bad input/state
-rw-r--r-- | src/Mind.hs | 37 | ||||
-rw-r--r-- | test/MindTest.hs | 21 |
2 files changed, 36 insertions, 22 deletions
diff --git a/src/Mind.hs b/src/Mind.hs index 359293c..d686e24 100644 --- a/src/Mind.hs +++ b/src/Mind.hs @@ -76,12 +76,15 @@ insertEdge ns i e -- network computation -compute :: Network -> [Float] -> [Float] -> Maybe ([Float], [Float]) -compute net input state = - let - s = newState net input state - state' = map (Just) s - in Just $ (output net input state state', s) +compute :: Network -> [Float] -> [Float] -> Either String ([Float], [Float]) +compute net input state + | (length input) /= (numInput net) = Left $ "Bad input length: " ++ (show $ length input) + | (length state) /= (length $ internalNeurons net) = Left $ "Bad state length: " ++ (show $ length state) + | otherwise = + let + s = newState net input state + state' = map (Just) s + in Right $ (output net input state state', s) type InputState = ([Float], [Float]) @@ -89,7 +92,7 @@ type NewState = [Maybe Float] output :: Network -> [Float] -> [Float] -> NewState -> [Float] -output net input state state'= +output net input state state' = let numOutput = length $ outputNeurons net in @@ -97,16 +100,16 @@ output net input state state'= newState :: Network -> [Float] -> [Float] -> [Float] newState net input state = - let numInternal = length $ internalNeurons net - in fst $ - foldr - (\x (r, ns) -> - let (value, ns') = getValue net (input, state) ns (Internal x) - in (value:r, ns') - ) - ([], replicate numInternal Nothing) - [0..numInternal-1] - + let numInternal = length $ internalNeurons net + in fst $ + foldr + (\x (r, ns) -> + let (value, ns') = getValue net (input, state) ns (Internal x) + in (value:r, ns') + ) + ([], replicate numInternal Nothing) + [0..numInternal-1] + updateValue :: NewState -> Int -> Float -> NewState diff --git a/test/MindTest.hs b/test/MindTest.hs index d8b63cf..01185f1 100644 --- a/test/MindTest.hs +++ b/test/MindTest.hs @@ -3,6 +3,7 @@ module MindTest (suite) where import Test.Tasty import Test.Tasty.HUnit import Mind +import Data.Either suite :: TestTree suite = testGroup "mind tests" $ @@ -64,7 +65,7 @@ networkTests = testGroup "network tests" $ , testCase "single input, single output" $ let net = Network 1 [] [[Edge (Input 0, 2.0)]] - Just (output, state) = compute net [negate 0.5] [] + Right (output, state) = compute net [negate 0.5] [] in (output, state) @?= ( [tanh (2.0 * (negate 0.5))] , [] @@ -77,7 +78,7 @@ networkTests = testGroup "network tests" $ , Edge (Input 2, 1.0) , Edge (Input 3, 2.0) ]] - Just (output, state) = compute net [1, 2, 3, 5] [] + Right (output, state) = compute net [1, 2, 3, 5] [] in (output, state) @?= ( [tanh @@ -99,7 +100,7 @@ networkTests = testGroup "network tests" $ , Edge (Input 3, 1.0) ] ] - Just (output, state) = compute net [1, 2, 3, 5] [] + Right (output, state) = compute net [1, 2, 3, 5] [] in (output, state) @?= ( [ tanh (2 - 1) @@ -123,7 +124,7 @@ networkTests = testGroup "network tests" $ , Edge (Internal 1, 1.0) ] ] - Just (output, state) = compute net [1, 2, 3, 5] [0, 0] + Right (output, state) = compute net [1, 2, 3, 5] [0, 0] in (output, state) @?= ( [ tanh ( (tanh (5-3)) - (tanh (2-1)) ) ] @@ -138,9 +139,19 @@ networkTests = testGroup "network tests" $ [ [ Edge (Internal 0, negate 0.5) ] ] -- output neurons [ [ Edge (Internal 0, 2) ] ] - Just result = compute net [] [1.0] + Right result = compute net [] [1.0] in result @?= ( [ tanh $ 2 * (tanh (negate 0.5)) ] , [ tanh (negate 0.5) ] ) + , testCase "computation fails for bad input length" $ + let + net = Network 2 [[]] [[]] + result = compute net (replicate 3 1.0) [1] + in (isLeft result) @?= True + , testCase "computation fails for bad state length" $ + let + net = Network 2 [[]] [[]] + result = compute net (replicate 2 1.0) [1, 1] + in (isLeft result) @?= True ] |