summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsanine <sanine.not@pm.me>2023-11-22 10:39:33 -0600
committersanine <sanine.not@pm.me>2023-11-22 10:39:33 -0600
commite8d7a5237c666a40e14f3709289329fd8c2cb7d2 (patch)
tree673f843f98b00a8868de6705c491777088a86842
parent8fb358e84770f69606f7f27c40cfdf0ce57cd026 (diff)
show error messages on bad input/state
-rw-r--r--src/Mind.hs37
-rw-r--r--test/MindTest.hs21
2 files changed, 36 insertions, 22 deletions
diff --git a/src/Mind.hs b/src/Mind.hs
index 359293c..d686e24 100644
--- a/src/Mind.hs
+++ b/src/Mind.hs
@@ -76,12 +76,15 @@ insertEdge ns i e
-- network computation
-compute :: Network -> [Float] -> [Float] -> Maybe ([Float], [Float])
-compute net input state =
- let
- s = newState net input state
- state' = map (Just) s
- in Just $ (output net input state state', s)
+compute :: Network -> [Float] -> [Float] -> Either String ([Float], [Float])
+compute net input state
+ | (length input) /= (numInput net) = Left $ "Bad input length: " ++ (show $ length input)
+ | (length state) /= (length $ internalNeurons net) = Left $ "Bad state length: " ++ (show $ length state)
+ | otherwise =
+ let
+ s = newState net input state
+ state' = map (Just) s
+ in Right $ (output net input state state', s)
type InputState = ([Float], [Float])
@@ -89,7 +92,7 @@ type NewState = [Maybe Float]
output :: Network -> [Float] -> [Float] -> NewState -> [Float]
-output net input state state'=
+output net input state state' =
let
numOutput = length $ outputNeurons net
in
@@ -97,16 +100,16 @@ output net input state state'=
newState :: Network -> [Float] -> [Float] -> [Float]
newState net input state =
- let numInternal = length $ internalNeurons net
- in fst $
- foldr
- (\x (r, ns) ->
- let (value, ns') = getValue net (input, state) ns (Internal x)
- in (value:r, ns')
- )
- ([], replicate numInternal Nothing)
- [0..numInternal-1]
-
+ let numInternal = length $ internalNeurons net
+ in fst $
+ foldr
+ (\x (r, ns) ->
+ let (value, ns') = getValue net (input, state) ns (Internal x)
+ in (value:r, ns')
+ )
+ ([], replicate numInternal Nothing)
+ [0..numInternal-1]
+
updateValue :: NewState -> Int -> Float -> NewState
diff --git a/test/MindTest.hs b/test/MindTest.hs
index d8b63cf..01185f1 100644
--- a/test/MindTest.hs
+++ b/test/MindTest.hs
@@ -3,6 +3,7 @@ module MindTest (suite) where
import Test.Tasty
import Test.Tasty.HUnit
import Mind
+import Data.Either
suite :: TestTree
suite = testGroup "mind tests" $
@@ -64,7 +65,7 @@ networkTests = testGroup "network tests" $
, testCase "single input, single output" $
let
net = Network 1 [] [[Edge (Input 0, 2.0)]]
- Just (output, state) = compute net [negate 0.5] []
+ Right (output, state) = compute net [negate 0.5] []
in (output, state) @?=
( [tanh (2.0 * (negate 0.5))]
, []
@@ -77,7 +78,7 @@ networkTests = testGroup "network tests" $
, Edge (Input 2, 1.0)
, Edge (Input 3, 2.0)
]]
- Just (output, state) = compute net [1, 2, 3, 5] []
+ Right (output, state) = compute net [1, 2, 3, 5] []
in (output, state) @?=
(
[tanh
@@ -99,7 +100,7 @@ networkTests = testGroup "network tests" $
, Edge (Input 3, 1.0)
]
]
- Just (output, state) = compute net [1, 2, 3, 5] []
+ Right (output, state) = compute net [1, 2, 3, 5] []
in (output, state) @?=
(
[ tanh (2 - 1)
@@ -123,7 +124,7 @@ networkTests = testGroup "network tests" $
, Edge (Internal 1, 1.0)
]
]
- Just (output, state) = compute net [1, 2, 3, 5] [0, 0]
+ Right (output, state) = compute net [1, 2, 3, 5] [0, 0]
in (output, state) @?=
(
[ tanh ( (tanh (5-3)) - (tanh (2-1)) ) ]
@@ -138,9 +139,19 @@ networkTests = testGroup "network tests" $
[ [ Edge (Internal 0, negate 0.5) ] ]
-- output neurons
[ [ Edge (Internal 0, 2) ] ]
- Just result = compute net [] [1.0]
+ Right result = compute net [] [1.0]
in result @?=
( [ tanh $ 2 * (tanh (negate 0.5)) ]
, [ tanh (negate 0.5) ]
)
+ , testCase "computation fails for bad input length" $
+ let
+ net = Network 2 [[]] [[]]
+ result = compute net (replicate 3 1.0) [1]
+ in (isLeft result) @?= True
+ , testCase "computation fails for bad state length" $
+ let
+ net = Network 2 [[]] [[]]
+ result = compute net (replicate 2 1.0) [1, 1]
+ in (isLeft result) @?= True
]