summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/MindTest.hs21
1 files changed, 16 insertions, 5 deletions
diff --git a/test/MindTest.hs b/test/MindTest.hs
index d8b63cf..01185f1 100644
--- a/test/MindTest.hs
+++ b/test/MindTest.hs
@@ -3,6 +3,7 @@ module MindTest (suite) where
import Test.Tasty
import Test.Tasty.HUnit
import Mind
+import Data.Either
suite :: TestTree
suite = testGroup "mind tests" $
@@ -64,7 +65,7 @@ networkTests = testGroup "network tests" $
, testCase "single input, single output" $
let
net = Network 1 [] [[Edge (Input 0, 2.0)]]
- Just (output, state) = compute net [negate 0.5] []
+ Right (output, state) = compute net [negate 0.5] []
in (output, state) @?=
( [tanh (2.0 * (negate 0.5))]
, []
@@ -77,7 +78,7 @@ networkTests = testGroup "network tests" $
, Edge (Input 2, 1.0)
, Edge (Input 3, 2.0)
]]
- Just (output, state) = compute net [1, 2, 3, 5] []
+ Right (output, state) = compute net [1, 2, 3, 5] []
in (output, state) @?=
(
[tanh
@@ -99,7 +100,7 @@ networkTests = testGroup "network tests" $
, Edge (Input 3, 1.0)
]
]
- Just (output, state) = compute net [1, 2, 3, 5] []
+ Right (output, state) = compute net [1, 2, 3, 5] []
in (output, state) @?=
(
[ tanh (2 - 1)
@@ -123,7 +124,7 @@ networkTests = testGroup "network tests" $
, Edge (Internal 1, 1.0)
]
]
- Just (output, state) = compute net [1, 2, 3, 5] [0, 0]
+ Right (output, state) = compute net [1, 2, 3, 5] [0, 0]
in (output, state) @?=
(
[ tanh ( (tanh (5-3)) - (tanh (2-1)) ) ]
@@ -138,9 +139,19 @@ networkTests = testGroup "network tests" $
[ [ Edge (Internal 0, negate 0.5) ] ]
-- output neurons
[ [ Edge (Internal 0, 2) ] ]
- Just result = compute net [] [1.0]
+ Right result = compute net [] [1.0]
in result @?=
( [ tanh $ 2 * (tanh (negate 0.5)) ]
, [ tanh (negate 0.5) ]
)
+ , testCase "computation fails for bad input length" $
+ let
+ net = Network 2 [[]] [[]]
+ result = compute net (replicate 3 1.0) [1]
+ in (isLeft result) @?= True
+ , testCase "computation fails for bad state length" $
+ let
+ net = Network 2 [[]] [[]]
+ result = compute net (replicate 2 1.0) [1, 1]
+ in (isLeft result) @?= True
]