brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,611 +1,611 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Comprehensive tests for metrics module."""
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- from absl.testing import absltest, parameterized
21
-
22
- import brainstate as bst
23
-
24
-
25
- class MetricStateTest(absltest.TestCase):
26
- """Test cases for MetricState class."""
27
-
28
- def test_metric_state_creation(self):
29
- """Test creating a MetricState with different values."""
30
- state = bst.nn.MetricState(jnp.array(0.0))
31
- self.assertEqual(state.value, 0.0)
32
-
33
- def test_metric_state_update(self):
34
- """Test updating MetricState value."""
35
- state = bst.nn.MetricState(jnp.array(0.0))
36
- state.value = jnp.array(5.5)
37
- self.assertAlmostEqual(float(state.value), 5.5)
38
-
39
- def test_metric_state_module_attribute(self):
40
- """Test that MetricState has correct __module__ attribute."""
41
- state = bst.nn.MetricState(jnp.array(0.0))
42
- self.assertEqual(bst.nn.MetricState.__module__, "brainstate.nn")
43
-
44
-
45
- class AverageMetricTest(parameterized.TestCase):
46
- """Test cases for AverageMetric class."""
47
-
48
- def test_average_metric_initial_state(self):
49
- """Test that initial average is NaN."""
50
- metric = bst.nn.AverageMetric()
51
- result = metric.compute()
52
- self.assertTrue(jnp.isnan(result))
53
-
54
- def test_average_metric_single_update(self):
55
- """Test average after single update."""
56
- metric = bst.nn.AverageMetric()
57
- metric.update(values=jnp.array([1, 2, 3, 4]))
58
- result = metric.compute()
59
- self.assertAlmostEqual(float(result), 2.5)
60
-
61
- def test_average_metric_multiple_updates(self):
62
- """Test average after multiple updates."""
63
- metric = bst.nn.AverageMetric()
64
- metric.update(values=jnp.array([1, 2, 3, 4]))
65
- metric.update(values=jnp.array([3, 2, 1, 0]))
66
- result = metric.compute()
67
- self.assertAlmostEqual(float(result), 2.0)
68
-
69
- def test_average_metric_scalar_values(self):
70
- """Test average with scalar values."""
71
- metric = bst.nn.AverageMetric()
72
- metric.update(values=5.0)
73
- metric.update(values=3.0)
74
- result = metric.compute()
75
- self.assertAlmostEqual(float(result), 4.0)
76
-
77
- def test_average_metric_reset(self):
78
- """Test reset functionality."""
79
- metric = bst.nn.AverageMetric()
80
- metric.update(values=jnp.array([1, 2, 3, 4]))
81
- metric.reset()
82
- result = metric.compute()
83
- self.assertTrue(jnp.isnan(result))
84
-
85
- def test_average_metric_custom_argname(self):
86
- """Test using custom argument name."""
87
- metric = bst.nn.AverageMetric('loss')
88
- metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
89
- result = metric.compute()
90
- self.assertAlmostEqual(float(result), 2.0)
91
-
92
- def test_average_metric_missing_argname(self):
93
- """Test error when expected argument is missing."""
94
- metric = bst.nn.AverageMetric('loss')
95
- with self.assertRaises(TypeError):
96
- metric.update(values=jnp.array([1.0, 2.0]))
97
-
98
- def test_average_metric_module_attribute(self):
99
- """Test that AverageMetric has correct __module__ attribute."""
100
- self.assertEqual(bst.nn.AverageMetric.__module__, "brainstate.nn")
101
-
102
-
103
- class WelfordMetricTest(parameterized.TestCase):
104
- """Test cases for WelfordMetric class."""
105
-
106
- def test_welford_metric_initial_state(self):
107
- """Test initial statistics."""
108
- metric = bst.nn.WelfordMetric()
109
- stats = metric.compute()
110
- self.assertEqual(float(stats.mean), 0.0)
111
- self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
112
- self.assertTrue(jnp.isnan(stats.standard_deviation))
113
-
114
- def test_welford_metric_single_update(self):
115
- """Test statistics after single update."""
116
- metric = bst.nn.WelfordMetric()
117
- metric.update(values=jnp.array([1, 2, 3, 4]))
118
- stats = metric.compute()
119
- self.assertAlmostEqual(float(stats.mean), 2.5, places=5)
120
- # Population std of [1,2,3,4] is sqrt(1.25) ≈ 1.118
121
- self.assertAlmostEqual(float(stats.standard_deviation), 1.118, places=3)
122
-
123
- def test_welford_metric_multiple_updates(self):
124
- """Test statistics after multiple updates."""
125
- metric = bst.nn.WelfordMetric()
126
- metric.update(values=jnp.array([1, 2, 3, 4]))
127
- metric.update(values=jnp.array([3, 2, 1, 0]))
128
- stats = metric.compute()
129
- # Mean of all 8 values: [1,2,3,4,3,2,1,0] is 2.0
130
- self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
131
-
132
- def test_welford_metric_reset(self):
133
- """Test reset functionality."""
134
- metric = bst.nn.WelfordMetric()
135
- metric.update(values=jnp.array([1, 2, 3, 4]))
136
- metric.reset()
137
- stats = metric.compute()
138
- self.assertEqual(float(stats.mean), 0.0)
139
- self.assertTrue(jnp.isnan(stats.standard_deviation))
140
-
141
- def test_welford_metric_custom_argname(self):
142
- """Test using custom argument name."""
143
- metric = bst.nn.WelfordMetric('data')
144
- metric.update(data=jnp.array([1.0, 2.0, 3.0]))
145
- stats = metric.compute()
146
- self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
147
-
148
- def test_welford_metric_module_attribute(self):
149
- """Test that WelfordMetric has correct __module__ attribute."""
150
- self.assertEqual(bst.nn.WelfordMetric.__module__, "brainstate.nn")
151
-
152
-
153
- class AccuracyMetricTest(parameterized.TestCase):
154
- """Test cases for AccuracyMetric class."""
155
-
156
- def test_accuracy_metric_initial_state(self):
157
- """Test that initial accuracy is NaN."""
158
- metric = bst.nn.AccuracyMetric()
159
- result = metric.compute()
160
- self.assertTrue(jnp.isnan(result))
161
-
162
- def test_accuracy_metric_perfect_accuracy(self):
163
- """Test with perfect predictions."""
164
- metric = bst.nn.AccuracyMetric()
165
- logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
166
- labels = jnp.array([1, 0, 1])
167
- metric.update(logits=logits, labels=labels)
168
- result = metric.compute()
169
- self.assertAlmostEqual(float(result), 1.0)
170
-
171
- def test_accuracy_metric_partial_accuracy(self):
172
- """Test with partial accuracy."""
173
- metric = bst.nn.AccuracyMetric()
174
- logits = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
175
- labels = jnp.array([1, 0, 1]) # First one wrong
176
- metric.update(logits=logits, labels=labels)
177
- result = metric.compute()
178
- self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
179
-
180
- def test_accuracy_metric_multiple_updates(self):
181
- """Test accuracy after multiple updates."""
182
- metric = bst.nn.AccuracyMetric()
183
- logits1 = jax.random.normal(jax.random.key(0), (5, 2))
184
- labels1 = jnp.array([1, 1, 0, 1, 0])
185
- logits2 = jax.random.normal(jax.random.key(1), (5, 2))
186
- labels2 = jnp.array([0, 1, 1, 1, 1])
187
-
188
- metric.update(logits=logits1, labels=labels1)
189
- metric.update(logits=logits2, labels=labels2)
190
- result = metric.compute()
191
- # Result should be between 0 and 1
192
- self.assertGreaterEqual(float(result), 0.0)
193
- self.assertLessEqual(float(result), 1.0)
194
-
195
- def test_accuracy_metric_reset(self):
196
- """Test reset functionality."""
197
- metric = bst.nn.AccuracyMetric()
198
- logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
199
- labels = jnp.array([1, 0])
200
- metric.update(logits=logits, labels=labels)
201
- metric.reset()
202
- result = metric.compute()
203
- self.assertTrue(jnp.isnan(result))
204
-
205
- def test_accuracy_metric_shape_validation(self):
206
- """Test shape validation."""
207
- metric = bst.nn.AccuracyMetric()
208
- logits = jnp.array([[0.1, 0.9]])
209
- labels = jnp.array([[1, 0]]) # Wrong shape
210
- with self.assertRaises(ValueError):
211
- metric.update(logits=logits, labels=labels)
212
-
213
- def test_accuracy_metric_module_attribute(self):
214
- """Test that AccuracyMetric has correct __module__ attribute."""
215
- self.assertEqual(bst.nn.AccuracyMetric.__module__, "brainstate.nn")
216
-
217
-
218
- class PrecisionMetricTest(parameterized.TestCase):
219
- """Test cases for PrecisionMetric class."""
220
-
221
- def test_precision_metric_binary_perfect(self):
222
- """Test binary precision with perfect predictions."""
223
- metric = bst.nn.PrecisionMetric()
224
- predictions = jnp.array([1, 0, 1, 1, 0])
225
- labels = jnp.array([1, 0, 1, 1, 0])
226
- metric.update(predictions=predictions, labels=labels)
227
- result = metric.compute()
228
- self.assertAlmostEqual(float(result), 1.0)
229
-
230
- def test_precision_metric_binary_partial(self):
231
- """Test binary precision with partial accuracy."""
232
- metric = bst.nn.PrecisionMetric()
233
- predictions = jnp.array([1, 0, 1, 1, 0])
234
- labels = jnp.array([1, 0, 0, 1, 0])
235
- metric.update(predictions=predictions, labels=labels)
236
- result = metric.compute()
237
- # TP=2, FP=1, Precision=2/3
238
- self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
239
-
240
- def test_precision_metric_multiclass_macro(self):
241
- """Test multi-class precision with macro averaging."""
242
- metric = bst.nn.PrecisionMetric(num_classes=3, average='macro')
243
- predictions = jnp.array([0, 1, 2, 1, 0])
244
- labels = jnp.array([0, 1, 1, 1, 2])
245
- metric.update(predictions=predictions, labels=labels)
246
- result = metric.compute()
247
- # Should compute precision for each class and average
248
- self.assertGreaterEqual(float(result), 0.0)
249
- self.assertLessEqual(float(result), 1.0)
250
-
251
- def test_precision_metric_multiclass_micro(self):
252
- """Test multi-class precision with micro averaging."""
253
- metric = bst.nn.PrecisionMetric(num_classes=3, average='micro')
254
- predictions = jnp.array([0, 1, 2, 1, 0])
255
- labels = jnp.array([0, 1, 1, 1, 2])
256
- metric.update(predictions=predictions, labels=labels)
257
- result = metric.compute()
258
- self.assertGreaterEqual(float(result), 0.0)
259
- self.assertLessEqual(float(result), 1.0)
260
-
261
- def test_precision_metric_reset(self):
262
- """Test reset functionality."""
263
- metric = bst.nn.PrecisionMetric()
264
- predictions = jnp.array([1, 0, 1, 1, 0])
265
- labels = jnp.array([1, 0, 0, 1, 0])
266
- metric.update(predictions=predictions, labels=labels)
267
- metric.reset()
268
- result = metric.compute()
269
- # After reset, with no data, precision should be 0
270
- self.assertEqual(float(result), 0.0)
271
-
272
- def test_precision_metric_module_attribute(self):
273
- """Test that PrecisionMetric has correct __module__ attribute."""
274
- self.assertEqual(bst.nn.PrecisionMetric.__module__, "brainstate.nn")
275
-
276
-
277
- class RecallMetricTest(parameterized.TestCase):
278
- """Test cases for RecallMetric class."""
279
-
280
- def test_recall_metric_binary_perfect(self):
281
- """Test binary recall with perfect predictions."""
282
- metric = bst.nn.RecallMetric()
283
- predictions = jnp.array([1, 0, 1, 1, 0])
284
- labels = jnp.array([1, 0, 1, 1, 0])
285
- metric.update(predictions=predictions, labels=labels)
286
- result = metric.compute()
287
- self.assertAlmostEqual(float(result), 1.0)
288
-
289
- def test_recall_metric_binary_partial(self):
290
- """Test binary recall with partial accuracy."""
291
- metric = bst.nn.RecallMetric()
292
- predictions = jnp.array([1, 0, 1, 1, 0])
293
- labels = jnp.array([1, 0, 0, 1, 0])
294
- metric.update(predictions=predictions, labels=labels)
295
- result = metric.compute()
296
- # TP=2, FN=0, Recall=2/2=1.0
297
- self.assertAlmostEqual(float(result), 1.0, places=5)
298
-
299
- def test_recall_metric_multiclass_macro(self):
300
- """Test multi-class recall with macro averaging."""
301
- metric = bst.nn.RecallMetric(num_classes=3, average='macro')
302
- predictions = jnp.array([0, 1, 2, 1, 0])
303
- labels = jnp.array([0, 1, 1, 1, 2])
304
- metric.update(predictions=predictions, labels=labels)
305
- result = metric.compute()
306
- self.assertGreaterEqual(float(result), 0.0)
307
- self.assertLessEqual(float(result), 1.0)
308
-
309
- def test_recall_metric_multiclass_micro(self):
310
- """Test multi-class recall with micro averaging."""
311
- metric = bst.nn.RecallMetric(num_classes=3, average='micro')
312
- predictions = jnp.array([0, 1, 2, 1, 0])
313
- labels = jnp.array([0, 1, 1, 1, 2])
314
- metric.update(predictions=predictions, labels=labels)
315
- result = metric.compute()
316
- self.assertGreaterEqual(float(result), 0.0)
317
- self.assertLessEqual(float(result), 1.0)
318
-
319
- def test_recall_metric_reset(self):
320
- """Test reset functionality."""
321
- metric = bst.nn.RecallMetric()
322
- predictions = jnp.array([1, 0, 1, 1, 0])
323
- labels = jnp.array([1, 0, 0, 1, 0])
324
- metric.update(predictions=predictions, labels=labels)
325
- metric.reset()
326
- result = metric.compute()
327
- self.assertEqual(float(result), 0.0)
328
-
329
- def test_recall_metric_module_attribute(self):
330
- """Test that RecallMetric has correct __module__ attribute."""
331
- self.assertEqual(bst.nn.RecallMetric.__module__, "brainstate.nn")
332
-
333
-
334
- class F1ScoreMetricTest(parameterized.TestCase):
335
- """Test cases for F1ScoreMetric class."""
336
-
337
- def test_f1_score_binary_perfect(self):
338
- """Test binary F1 score with perfect predictions."""
339
- metric = bst.nn.F1ScoreMetric()
340
- predictions = jnp.array([1, 0, 1, 1, 0])
341
- labels = jnp.array([1, 0, 1, 1, 0])
342
- metric.update(predictions=predictions, labels=labels)
343
- result = metric.compute()
344
- self.assertAlmostEqual(float(result), 1.0)
345
-
346
- def test_f1_score_binary_partial(self):
347
- """Test binary F1 score with partial accuracy."""
348
- metric = bst.nn.F1ScoreMetric()
349
- predictions = jnp.array([1, 0, 1, 1, 0])
350
- labels = jnp.array([1, 0, 0, 1, 0])
351
- metric.update(predictions=predictions, labels=labels)
352
- result = metric.compute()
353
- # Precision=2/3, Recall=1.0, F1=2*(2/3*1)/(2/3+1)=0.8
354
- self.assertAlmostEqual(float(result), 0.8, places=5)
355
-
356
- def test_f1_score_multiclass(self):
357
- """Test multi-class F1 score."""
358
- metric = bst.nn.F1ScoreMetric(num_classes=3, average='macro')
359
- predictions = jnp.array([0, 1, 2, 1, 0])
360
- labels = jnp.array([0, 1, 1, 1, 2])
361
- metric.update(predictions=predictions, labels=labels)
362
- result = metric.compute()
363
- self.assertGreaterEqual(float(result), 0.0)
364
- self.assertLessEqual(float(result), 1.0)
365
-
366
- def test_f1_score_reset(self):
367
- """Test reset functionality."""
368
- metric = bst.nn.F1ScoreMetric()
369
- predictions = jnp.array([1, 0, 1, 1, 0])
370
- labels = jnp.array([1, 0, 0, 1, 0])
371
- metric.update(predictions=predictions, labels=labels)
372
- metric.reset()
373
- result = metric.compute()
374
- self.assertEqual(float(result), 0.0)
375
-
376
- def test_f1_score_module_attribute(self):
377
- """Test that F1ScoreMetric has correct __module__ attribute."""
378
- self.assertEqual(bst.nn.F1ScoreMetric.__module__, "brainstate.nn")
379
-
380
-
381
- class ConfusionMatrixTest(parameterized.TestCase):
382
- """Test cases for ConfusionMatrix class."""
383
-
384
- def test_confusion_matrix_basic(self):
385
- """Test basic confusion matrix computation."""
386
- metric = bst.nn.ConfusionMatrix(num_classes=3)
387
- predictions = jnp.array([0, 1, 2, 1, 0])
388
- labels = jnp.array([0, 1, 1, 1, 2])
389
- metric.update(predictions=predictions, labels=labels)
390
- result = metric.compute()
391
-
392
- # Check shape
393
- self.assertEqual(result.shape, (3, 3))
394
-
395
- # Check specific values
396
- # True label 0, Predicted 0: 1
397
- # True label 0, Predicted 2: 1
398
- # True label 1, Predicted 1: 2
399
- self.assertEqual(int(result[0, 0]), 1)
400
- self.assertEqual(int(result[1, 1]), 2)
401
-
402
- def test_confusion_matrix_perfect(self):
403
- """Test confusion matrix with perfect predictions."""
404
- metric = bst.nn.ConfusionMatrix(num_classes=2)
405
- predictions = jnp.array([0, 0, 1, 1])
406
- labels = jnp.array([0, 0, 1, 1])
407
- metric.update(predictions=predictions, labels=labels)
408
- result = metric.compute()
409
-
410
- # Should be diagonal matrix
411
- self.assertEqual(int(result[0, 0]), 2)
412
- self.assertEqual(int(result[0, 1]), 0)
413
- self.assertEqual(int(result[1, 0]), 0)
414
- self.assertEqual(int(result[1, 1]), 2)
415
-
416
- def test_confusion_matrix_reset(self):
417
- """Test reset functionality."""
418
- metric = bst.nn.ConfusionMatrix(num_classes=2)
419
- predictions = jnp.array([0, 1])
420
- labels = jnp.array([0, 1])
421
- metric.update(predictions=predictions, labels=labels)
422
- metric.reset()
423
- result = metric.compute()
424
-
425
- # Should be all zeros
426
- self.assertTrue(jnp.all(result == 0))
427
-
428
- def test_confusion_matrix_invalid_predictions(self):
429
- """Test error handling for invalid predictions."""
430
- metric = bst.nn.ConfusionMatrix(num_classes=2)
431
- predictions = jnp.array([0, 1, 2]) # 2 is out of range
432
- labels = jnp.array([0, 1, 1])
433
- with self.assertRaises(ValueError):
434
- metric.update(predictions=predictions, labels=labels)
435
-
436
- def test_confusion_matrix_invalid_labels(self):
437
- """Test error handling for invalid labels."""
438
- metric = bst.nn.ConfusionMatrix(num_classes=2)
439
- predictions = jnp.array([0, 1, 1])
440
- labels = jnp.array([0, 1, 3]) # 3 is out of range
441
- with self.assertRaises(ValueError):
442
- metric.update(predictions=predictions, labels=labels)
443
-
444
- def test_confusion_matrix_module_attribute(self):
445
- """Test that ConfusionMatrix has correct __module__ attribute."""
446
- self.assertEqual(bst.nn.ConfusionMatrix.__module__, "brainstate.nn")
447
-
448
-
449
- class MultiMetricTest(parameterized.TestCase):
450
- """Test cases for MultiMetric class."""
451
-
452
- def test_multimetric_creation(self):
453
- """Test creating a MultiMetric with multiple metrics."""
454
- metrics = bst.nn.MultiMetric(
455
- accuracy=bst.nn.AccuracyMetric(),
456
- loss=bst.nn.AverageMetric(),
457
- )
458
- self.assertIsNotNone(metrics.accuracy)
459
- self.assertIsNotNone(metrics.loss)
460
-
461
- def test_multimetric_compute(self):
462
- """Test computing all metrics."""
463
- metrics = bst.nn.MultiMetric(
464
- accuracy=bst.nn.AccuracyMetric(),
465
- loss=bst.nn.AverageMetric(),
466
- )
467
-
468
- logits = jax.random.normal(jax.random.key(0), (5, 2))
469
- labels = jnp.array([1, 1, 0, 1, 0])
470
- batch_loss = jnp.array([1, 2, 3, 4])
471
-
472
- result = metrics.compute()
473
- self.assertIn('accuracy', result)
474
- self.assertIn('loss', result)
475
- self.assertTrue(jnp.isnan(result['accuracy']))
476
- self.assertTrue(jnp.isnan(result['loss']))
477
-
478
- def test_multimetric_update(self):
479
- """Test updating all metrics."""
480
- metrics = bst.nn.MultiMetric(
481
- accuracy=bst.nn.AccuracyMetric(),
482
- loss=bst.nn.AverageMetric(),
483
- )
484
-
485
- logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
486
- labels = jnp.array([1, 0, 1])
487
- batch_loss = jnp.array([1, 2, 3])
488
-
489
- metrics.update(logits=logits, labels=labels, values=batch_loss)
490
- result = metrics.compute()
491
-
492
- self.assertGreaterEqual(float(result['accuracy']), 0.0)
493
- self.assertLessEqual(float(result['accuracy']), 1.0)
494
- self.assertAlmostEqual(float(result['loss']), 2.0)
495
-
496
- def test_multimetric_reset(self):
497
- """Test resetting all metrics."""
498
- metrics = bst.nn.MultiMetric(
499
- accuracy=bst.nn.AccuracyMetric(),
500
- loss=bst.nn.AverageMetric(),
501
- )
502
-
503
- logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
504
- labels = jnp.array([1, 0])
505
- batch_loss = jnp.array([1, 2])
506
-
507
- metrics.update(logits=logits, labels=labels, values=batch_loss)
508
- metrics.reset()
509
- result = metrics.compute()
510
-
511
- self.assertTrue(jnp.isnan(result['accuracy']))
512
- self.assertTrue(jnp.isnan(result['loss']))
513
-
514
- def test_multimetric_reserved_name_validation(self):
515
- """Test that reserved method names cannot be used."""
516
- with self.assertRaises(ValueError):
517
- bst.nn.MultiMetric(
518
- reset=bst.nn.AverageMetric(),
519
- )
520
-
521
- with self.assertRaises(ValueError):
522
- bst.nn.MultiMetric(
523
- update=bst.nn.AverageMetric(),
524
- )
525
-
526
- with self.assertRaises(ValueError):
527
- bst.nn.MultiMetric(
528
- compute=bst.nn.AverageMetric(),
529
- )
530
-
531
- def test_multimetric_type_validation(self):
532
- """Test that all metrics must be Metric instances."""
533
- with self.assertRaises(TypeError):
534
- bst.nn.MultiMetric(
535
- accuracy=bst.nn.AccuracyMetric(),
536
- invalid="not a metric",
537
- )
538
-
539
- def test_multimetric_module_attribute(self):
540
- """Test that MultiMetric has correct __module__ attribute."""
541
- self.assertEqual(bst.nn.MultiMetric.__module__, "brainstate.nn")
542
-
543
-
544
- class MetricBaseClassTest(absltest.TestCase):
545
- """Test cases for base Metric class."""
546
-
547
- def test_metric_not_implemented_errors(self):
548
- """Test that base Metric class raises NotImplementedError."""
549
- metric = bst.nn.Metric()
550
-
551
- with self.assertRaises(NotImplementedError):
552
- metric.reset()
553
-
554
- with self.assertRaises(NotImplementedError):
555
- metric.update()
556
-
557
- with self.assertRaises(NotImplementedError):
558
- metric.compute()
559
-
560
- def test_metric_module_attribute(self):
561
- """Test that Metric has correct __module__ attribute."""
562
- self.assertEqual(bst.nn.Metric.__module__, "brainstate.nn")
563
-
564
-
565
- class EdgeCasesTest(parameterized.TestCase):
566
- """Test edge cases and boundary conditions."""
567
-
568
- def test_average_metric_large_numbers(self):
569
- """Test AverageMetric with very large numbers."""
570
- metric = bst.nn.AverageMetric()
571
- metric.update(values=jnp.array([1e10, 2e10, 3e10]))
572
- result = metric.compute()
573
- self.assertAlmostEqual(float(result), 2e10, places=-5)
574
-
575
- def test_average_metric_small_numbers(self):
576
- """Test AverageMetric with very small numbers."""
577
- metric = bst.nn.AverageMetric()
578
- metric.update(values=jnp.array([1e-10, 2e-10, 3e-10]))
579
- result = metric.compute()
580
- self.assertAlmostEqual(float(result), 2e-10, places=15)
581
-
582
- def test_confusion_matrix_single_class(self):
583
- """Test ConfusionMatrix with single class."""
584
- metric = bst.nn.ConfusionMatrix(num_classes=1)
585
- predictions = jnp.array([0, 0, 0])
586
- labels = jnp.array([0, 0, 0])
587
- metric.update(predictions=predictions, labels=labels)
588
- result = metric.compute()
589
- self.assertEqual(int(result[0, 0]), 3)
590
-
591
- def test_precision_recall_no_positives(self):
592
- """Test precision/recall when there are no positive predictions."""
593
- precision_metric = bst.nn.PrecisionMetric()
594
- recall_metric = bst.nn.RecallMetric()
595
-
596
- predictions = jnp.array([0, 0, 0, 0, 0])
597
- labels = jnp.array([0, 0, 0, 0, 0])
598
-
599
- precision_metric.update(predictions=predictions, labels=labels)
600
- recall_metric.update(predictions=predictions, labels=labels)
601
-
602
- # Should handle gracefully without division by zero
603
- precision = precision_metric.compute()
604
- recall = recall_metric.compute()
605
-
606
- self.assertEqual(float(precision), 0.0)
607
- self.assertEqual(float(recall), 0.0)
608
-
609
-
610
- if __name__ == '__main__':
611
- absltest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Comprehensive tests for metrics module."""
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from absl.testing import absltest, parameterized
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class MetricStateTest(absltest.TestCase):
26
+ """Test cases for MetricState class."""
27
+
28
+ def test_metric_state_creation(self):
29
+ """Test creating a MetricState with different values."""
30
+ state = bst.nn.MetricState(jnp.array(0.0))
31
+ self.assertEqual(state.value, 0.0)
32
+
33
+ def test_metric_state_update(self):
34
+ """Test updating MetricState value."""
35
+ state = bst.nn.MetricState(jnp.array(0.0))
36
+ state.value = jnp.array(5.5)
37
+ self.assertAlmostEqual(float(state.value), 5.5)
38
+
39
+ def test_metric_state_module_attribute(self):
40
+ """Test that MetricState has correct __module__ attribute."""
41
+ state = bst.nn.MetricState(jnp.array(0.0))
42
+ self.assertEqual(bst.nn.MetricState.__module__, "brainstate.nn")
43
+
44
+
45
+ class AverageMetricTest(parameterized.TestCase):
46
+ """Test cases for AverageMetric class."""
47
+
48
+ def test_average_metric_initial_state(self):
49
+ """Test that initial average is NaN."""
50
+ metric = bst.nn.AverageMetric()
51
+ result = metric.compute()
52
+ self.assertTrue(jnp.isnan(result))
53
+
54
+ def test_average_metric_single_update(self):
55
+ """Test average after single update."""
56
+ metric = bst.nn.AverageMetric()
57
+ metric.update(values=jnp.array([1, 2, 3, 4]))
58
+ result = metric.compute()
59
+ self.assertAlmostEqual(float(result), 2.5)
60
+
61
+ def test_average_metric_multiple_updates(self):
62
+ """Test average after multiple updates."""
63
+ metric = bst.nn.AverageMetric()
64
+ metric.update(values=jnp.array([1, 2, 3, 4]))
65
+ metric.update(values=jnp.array([3, 2, 1, 0]))
66
+ result = metric.compute()
67
+ self.assertAlmostEqual(float(result), 2.0)
68
+
69
+ def test_average_metric_scalar_values(self):
70
+ """Test average with scalar values."""
71
+ metric = bst.nn.AverageMetric()
72
+ metric.update(values=5.0)
73
+ metric.update(values=3.0)
74
+ result = metric.compute()
75
+ self.assertAlmostEqual(float(result), 4.0)
76
+
77
+ def test_average_metric_reset(self):
78
+ """Test reset functionality."""
79
+ metric = bst.nn.AverageMetric()
80
+ metric.update(values=jnp.array([1, 2, 3, 4]))
81
+ metric.reset()
82
+ result = metric.compute()
83
+ self.assertTrue(jnp.isnan(result))
84
+
85
+ def test_average_metric_custom_argname(self):
86
+ """Test using custom argument name."""
87
+ metric = bst.nn.AverageMetric('loss')
88
+ metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
89
+ result = metric.compute()
90
+ self.assertAlmostEqual(float(result), 2.0)
91
+
92
+ def test_average_metric_missing_argname(self):
93
+ """Test error when expected argument is missing."""
94
+ metric = bst.nn.AverageMetric('loss')
95
+ with self.assertRaises(TypeError):
96
+ metric.update(values=jnp.array([1.0, 2.0]))
97
+
98
+ def test_average_metric_module_attribute(self):
99
+ """Test that AverageMetric has correct __module__ attribute."""
100
+ self.assertEqual(bst.nn.AverageMetric.__module__, "brainstate.nn")
101
+
102
+
103
+ class WelfordMetricTest(parameterized.TestCase):
104
+ """Test cases for WelfordMetric class."""
105
+
106
+ def test_welford_metric_initial_state(self):
107
+ """Test initial statistics."""
108
+ metric = bst.nn.WelfordMetric()
109
+ stats = metric.compute()
110
+ self.assertEqual(float(stats.mean), 0.0)
111
+ self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
112
+ self.assertTrue(jnp.isnan(stats.standard_deviation))
113
+
114
+ def test_welford_metric_single_update(self):
115
+ """Test statistics after single update."""
116
+ metric = bst.nn.WelfordMetric()
117
+ metric.update(values=jnp.array([1, 2, 3, 4]))
118
+ stats = metric.compute()
119
+ self.assertAlmostEqual(float(stats.mean), 2.5, places=5)
120
+ # Population std of [1,2,3,4] is sqrt(1.25) ≈ 1.118
121
+ self.assertAlmostEqual(float(stats.standard_deviation), 1.118, places=3)
122
+
123
+ def test_welford_metric_multiple_updates(self):
124
+ """Test statistics after multiple updates."""
125
+ metric = bst.nn.WelfordMetric()
126
+ metric.update(values=jnp.array([1, 2, 3, 4]))
127
+ metric.update(values=jnp.array([3, 2, 1, 0]))
128
+ stats = metric.compute()
129
+ # Mean of all 8 values: [1,2,3,4,3,2,1,0] is 2.0
130
+ self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
131
+
132
+ def test_welford_metric_reset(self):
133
+ """Test reset functionality."""
134
+ metric = bst.nn.WelfordMetric()
135
+ metric.update(values=jnp.array([1, 2, 3, 4]))
136
+ metric.reset()
137
+ stats = metric.compute()
138
+ self.assertEqual(float(stats.mean), 0.0)
139
+ self.assertTrue(jnp.isnan(stats.standard_deviation))
140
+
141
+ def test_welford_metric_custom_argname(self):
142
+ """Test using custom argument name."""
143
+ metric = bst.nn.WelfordMetric('data')
144
+ metric.update(data=jnp.array([1.0, 2.0, 3.0]))
145
+ stats = metric.compute()
146
+ self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
147
+
148
+ def test_welford_metric_module_attribute(self):
149
+ """Test that WelfordMetric has correct __module__ attribute."""
150
+ self.assertEqual(bst.nn.WelfordMetric.__module__, "brainstate.nn")
151
+
152
+
153
+ class AccuracyMetricTest(parameterized.TestCase):
154
+ """Test cases for AccuracyMetric class."""
155
+
156
+ def test_accuracy_metric_initial_state(self):
157
+ """Test that initial accuracy is NaN."""
158
+ metric = bst.nn.AccuracyMetric()
159
+ result = metric.compute()
160
+ self.assertTrue(jnp.isnan(result))
161
+
162
+ def test_accuracy_metric_perfect_accuracy(self):
163
+ """Test with perfect predictions."""
164
+ metric = bst.nn.AccuracyMetric()
165
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
166
+ labels = jnp.array([1, 0, 1])
167
+ metric.update(logits=logits, labels=labels)
168
+ result = metric.compute()
169
+ self.assertAlmostEqual(float(result), 1.0)
170
+
171
+ def test_accuracy_metric_partial_accuracy(self):
172
+ """Test with partial accuracy."""
173
+ metric = bst.nn.AccuracyMetric()
174
+ logits = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
175
+ labels = jnp.array([1, 0, 1]) # First one wrong
176
+ metric.update(logits=logits, labels=labels)
177
+ result = metric.compute()
178
+ self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
179
+
180
+ def test_accuracy_metric_multiple_updates(self):
181
+ """Test accuracy after multiple updates."""
182
+ metric = bst.nn.AccuracyMetric()
183
+ logits1 = jax.random.normal(jax.random.key(0), (5, 2))
184
+ labels1 = jnp.array([1, 1, 0, 1, 0])
185
+ logits2 = jax.random.normal(jax.random.key(1), (5, 2))
186
+ labels2 = jnp.array([0, 1, 1, 1, 1])
187
+
188
+ metric.update(logits=logits1, labels=labels1)
189
+ metric.update(logits=logits2, labels=labels2)
190
+ result = metric.compute()
191
+ # Result should be between 0 and 1
192
+ self.assertGreaterEqual(float(result), 0.0)
193
+ self.assertLessEqual(float(result), 1.0)
194
+
195
+ def test_accuracy_metric_reset(self):
196
+ """Test reset functionality."""
197
+ metric = bst.nn.AccuracyMetric()
198
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
199
+ labels = jnp.array([1, 0])
200
+ metric.update(logits=logits, labels=labels)
201
+ metric.reset()
202
+ result = metric.compute()
203
+ self.assertTrue(jnp.isnan(result))
204
+
205
+ def test_accuracy_metric_shape_validation(self):
206
+ """Test shape validation."""
207
+ metric = bst.nn.AccuracyMetric()
208
+ logits = jnp.array([[0.1, 0.9]])
209
+ labels = jnp.array([[1, 0]]) # Wrong shape
210
+ with self.assertRaises(ValueError):
211
+ metric.update(logits=logits, labels=labels)
212
+
213
+ def test_accuracy_metric_module_attribute(self):
214
+ """Test that AccuracyMetric has correct __module__ attribute."""
215
+ self.assertEqual(bst.nn.AccuracyMetric.__module__, "brainstate.nn")
216
+
217
+
218
+ class PrecisionMetricTest(parameterized.TestCase):
219
+ """Test cases for PrecisionMetric class."""
220
+
221
+ def test_precision_metric_binary_perfect(self):
222
+ """Test binary precision with perfect predictions."""
223
+ metric = bst.nn.PrecisionMetric()
224
+ predictions = jnp.array([1, 0, 1, 1, 0])
225
+ labels = jnp.array([1, 0, 1, 1, 0])
226
+ metric.update(predictions=predictions, labels=labels)
227
+ result = metric.compute()
228
+ self.assertAlmostEqual(float(result), 1.0)
229
+
230
+ def test_precision_metric_binary_partial(self):
231
+ """Test binary precision with partial accuracy."""
232
+ metric = bst.nn.PrecisionMetric()
233
+ predictions = jnp.array([1, 0, 1, 1, 0])
234
+ labels = jnp.array([1, 0, 0, 1, 0])
235
+ metric.update(predictions=predictions, labels=labels)
236
+ result = metric.compute()
237
+ # TP=2, FP=1, Precision=2/3
238
+ self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
239
+
240
+ def test_precision_metric_multiclass_macro(self):
241
+ """Test multi-class precision with macro averaging."""
242
+ metric = bst.nn.PrecisionMetric(num_classes=3, average='macro')
243
+ predictions = jnp.array([0, 1, 2, 1, 0])
244
+ labels = jnp.array([0, 1, 1, 1, 2])
245
+ metric.update(predictions=predictions, labels=labels)
246
+ result = metric.compute()
247
+ # Should compute precision for each class and average
248
+ self.assertGreaterEqual(float(result), 0.0)
249
+ self.assertLessEqual(float(result), 1.0)
250
+
251
+ def test_precision_metric_multiclass_micro(self):
252
+ """Test multi-class precision with micro averaging."""
253
+ metric = bst.nn.PrecisionMetric(num_classes=3, average='micro')
254
+ predictions = jnp.array([0, 1, 2, 1, 0])
255
+ labels = jnp.array([0, 1, 1, 1, 2])
256
+ metric.update(predictions=predictions, labels=labels)
257
+ result = metric.compute()
258
+ self.assertGreaterEqual(float(result), 0.0)
259
+ self.assertLessEqual(float(result), 1.0)
260
+
261
+ def test_precision_metric_reset(self):
262
+ """Test reset functionality."""
263
+ metric = bst.nn.PrecisionMetric()
264
+ predictions = jnp.array([1, 0, 1, 1, 0])
265
+ labels = jnp.array([1, 0, 0, 1, 0])
266
+ metric.update(predictions=predictions, labels=labels)
267
+ metric.reset()
268
+ result = metric.compute()
269
+ # After reset, with no data, precision should be 0
270
+ self.assertEqual(float(result), 0.0)
271
+
272
+ def test_precision_metric_module_attribute(self):
273
+ """Test that PrecisionMetric has correct __module__ attribute."""
274
+ self.assertEqual(bst.nn.PrecisionMetric.__module__, "brainstate.nn")
275
+
276
+
277
+ class RecallMetricTest(parameterized.TestCase):
278
+ """Test cases for RecallMetric class."""
279
+
280
+ def test_recall_metric_binary_perfect(self):
281
+ """Test binary recall with perfect predictions."""
282
+ metric = bst.nn.RecallMetric()
283
+ predictions = jnp.array([1, 0, 1, 1, 0])
284
+ labels = jnp.array([1, 0, 1, 1, 0])
285
+ metric.update(predictions=predictions, labels=labels)
286
+ result = metric.compute()
287
+ self.assertAlmostEqual(float(result), 1.0)
288
+
289
+ def test_recall_metric_binary_partial(self):
290
+ """Test binary recall with partial accuracy."""
291
+ metric = bst.nn.RecallMetric()
292
+ predictions = jnp.array([1, 0, 1, 1, 0])
293
+ labels = jnp.array([1, 0, 0, 1, 0])
294
+ metric.update(predictions=predictions, labels=labels)
295
+ result = metric.compute()
296
+ # TP=2, FN=0, Recall=2/2=1.0
297
+ self.assertAlmostEqual(float(result), 1.0, places=5)
298
+
299
+ def test_recall_metric_multiclass_macro(self):
300
+ """Test multi-class recall with macro averaging."""
301
+ metric = bst.nn.RecallMetric(num_classes=3, average='macro')
302
+ predictions = jnp.array([0, 1, 2, 1, 0])
303
+ labels = jnp.array([0, 1, 1, 1, 2])
304
+ metric.update(predictions=predictions, labels=labels)
305
+ result = metric.compute()
306
+ self.assertGreaterEqual(float(result), 0.0)
307
+ self.assertLessEqual(float(result), 1.0)
308
+
309
+ def test_recall_metric_multiclass_micro(self):
310
+ """Test multi-class recall with micro averaging."""
311
+ metric = bst.nn.RecallMetric(num_classes=3, average='micro')
312
+ predictions = jnp.array([0, 1, 2, 1, 0])
313
+ labels = jnp.array([0, 1, 1, 1, 2])
314
+ metric.update(predictions=predictions, labels=labels)
315
+ result = metric.compute()
316
+ self.assertGreaterEqual(float(result), 0.0)
317
+ self.assertLessEqual(float(result), 1.0)
318
+
319
+ def test_recall_metric_reset(self):
320
+ """Test reset functionality."""
321
+ metric = bst.nn.RecallMetric()
322
+ predictions = jnp.array([1, 0, 1, 1, 0])
323
+ labels = jnp.array([1, 0, 0, 1, 0])
324
+ metric.update(predictions=predictions, labels=labels)
325
+ metric.reset()
326
+ result = metric.compute()
327
+ self.assertEqual(float(result), 0.0)
328
+
329
+ def test_recall_metric_module_attribute(self):
330
+ """Test that RecallMetric has correct __module__ attribute."""
331
+ self.assertEqual(bst.nn.RecallMetric.__module__, "brainstate.nn")
332
+
333
+
334
+ class F1ScoreMetricTest(parameterized.TestCase):
335
+ """Test cases for F1ScoreMetric class."""
336
+
337
+ def test_f1_score_binary_perfect(self):
338
+ """Test binary F1 score with perfect predictions."""
339
+ metric = bst.nn.F1ScoreMetric()
340
+ predictions = jnp.array([1, 0, 1, 1, 0])
341
+ labels = jnp.array([1, 0, 1, 1, 0])
342
+ metric.update(predictions=predictions, labels=labels)
343
+ result = metric.compute()
344
+ self.assertAlmostEqual(float(result), 1.0)
345
+
346
+ def test_f1_score_binary_partial(self):
347
+ """Test binary F1 score with partial accuracy."""
348
+ metric = bst.nn.F1ScoreMetric()
349
+ predictions = jnp.array([1, 0, 1, 1, 0])
350
+ labels = jnp.array([1, 0, 0, 1, 0])
351
+ metric.update(predictions=predictions, labels=labels)
352
+ result = metric.compute()
353
+ # Precision=2/3, Recall=1.0, F1=2*(2/3*1)/(2/3+1)=0.8
354
+ self.assertAlmostEqual(float(result), 0.8, places=5)
355
+
356
+ def test_f1_score_multiclass(self):
357
+ """Test multi-class F1 score."""
358
+ metric = bst.nn.F1ScoreMetric(num_classes=3, average='macro')
359
+ predictions = jnp.array([0, 1, 2, 1, 0])
360
+ labels = jnp.array([0, 1, 1, 1, 2])
361
+ metric.update(predictions=predictions, labels=labels)
362
+ result = metric.compute()
363
+ self.assertGreaterEqual(float(result), 0.0)
364
+ self.assertLessEqual(float(result), 1.0)
365
+
366
+ def test_f1_score_reset(self):
367
+ """Test reset functionality."""
368
+ metric = bst.nn.F1ScoreMetric()
369
+ predictions = jnp.array([1, 0, 1, 1, 0])
370
+ labels = jnp.array([1, 0, 0, 1, 0])
371
+ metric.update(predictions=predictions, labels=labels)
372
+ metric.reset()
373
+ result = metric.compute()
374
+ self.assertEqual(float(result), 0.0)
375
+
376
+ def test_f1_score_module_attribute(self):
377
+ """Test that F1ScoreMetric has correct __module__ attribute."""
378
+ self.assertEqual(bst.nn.F1ScoreMetric.__module__, "brainstate.nn")
379
+
380
+
381
+ class ConfusionMatrixTest(parameterized.TestCase):
382
+ """Test cases for ConfusionMatrix class."""
383
+
384
+ def test_confusion_matrix_basic(self):
385
+ """Test basic confusion matrix computation."""
386
+ metric = bst.nn.ConfusionMatrix(num_classes=3)
387
+ predictions = jnp.array([0, 1, 2, 1, 0])
388
+ labels = jnp.array([0, 1, 1, 1, 2])
389
+ metric.update(predictions=predictions, labels=labels)
390
+ result = metric.compute()
391
+
392
+ # Check shape
393
+ self.assertEqual(result.shape, (3, 3))
394
+
395
+ # Check specific values
396
+ # True label 0, Predicted 0: 1
397
+ # True label 0, Predicted 2: 1
398
+ # True label 1, Predicted 1: 2
399
+ self.assertEqual(int(result[0, 0]), 1)
400
+ self.assertEqual(int(result[1, 1]), 2)
401
+
402
+ def test_confusion_matrix_perfect(self):
403
+ """Test confusion matrix with perfect predictions."""
404
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
405
+ predictions = jnp.array([0, 0, 1, 1])
406
+ labels = jnp.array([0, 0, 1, 1])
407
+ metric.update(predictions=predictions, labels=labels)
408
+ result = metric.compute()
409
+
410
+ # Should be diagonal matrix
411
+ self.assertEqual(int(result[0, 0]), 2)
412
+ self.assertEqual(int(result[0, 1]), 0)
413
+ self.assertEqual(int(result[1, 0]), 0)
414
+ self.assertEqual(int(result[1, 1]), 2)
415
+
416
+ def test_confusion_matrix_reset(self):
417
+ """Test reset functionality."""
418
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
419
+ predictions = jnp.array([0, 1])
420
+ labels = jnp.array([0, 1])
421
+ metric.update(predictions=predictions, labels=labels)
422
+ metric.reset()
423
+ result = metric.compute()
424
+
425
+ # Should be all zeros
426
+ self.assertTrue(jnp.all(result == 0))
427
+
428
+ def test_confusion_matrix_invalid_predictions(self):
429
+ """Test error handling for invalid predictions."""
430
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
431
+ predictions = jnp.array([0, 1, 2]) # 2 is out of range
432
+ labels = jnp.array([0, 1, 1])
433
+ with self.assertRaises(ValueError):
434
+ metric.update(predictions=predictions, labels=labels)
435
+
436
+ def test_confusion_matrix_invalid_labels(self):
437
+ """Test error handling for invalid labels."""
438
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
439
+ predictions = jnp.array([0, 1, 1])
440
+ labels = jnp.array([0, 1, 3]) # 3 is out of range
441
+ with self.assertRaises(ValueError):
442
+ metric.update(predictions=predictions, labels=labels)
443
+
444
+ def test_confusion_matrix_module_attribute(self):
445
+ """Test that ConfusionMatrix has correct __module__ attribute."""
446
+ self.assertEqual(bst.nn.ConfusionMatrix.__module__, "brainstate.nn")
447
+
448
+
449
+ class MultiMetricTest(parameterized.TestCase):
450
+ """Test cases for MultiMetric class."""
451
+
452
+ def test_multimetric_creation(self):
453
+ """Test creating a MultiMetric with multiple metrics."""
454
+ metrics = bst.nn.MultiMetric(
455
+ accuracy=bst.nn.AccuracyMetric(),
456
+ loss=bst.nn.AverageMetric(),
457
+ )
458
+ self.assertIsNotNone(metrics.accuracy)
459
+ self.assertIsNotNone(metrics.loss)
460
+
461
+ def test_multimetric_compute(self):
462
+ """Test computing all metrics."""
463
+ metrics = bst.nn.MultiMetric(
464
+ accuracy=bst.nn.AccuracyMetric(),
465
+ loss=bst.nn.AverageMetric(),
466
+ )
467
+
468
+ logits = jax.random.normal(jax.random.key(0), (5, 2))
469
+ labels = jnp.array([1, 1, 0, 1, 0])
470
+ batch_loss = jnp.array([1, 2, 3, 4])
471
+
472
+ result = metrics.compute()
473
+ self.assertIn('accuracy', result)
474
+ self.assertIn('loss', result)
475
+ self.assertTrue(jnp.isnan(result['accuracy']))
476
+ self.assertTrue(jnp.isnan(result['loss']))
477
+
478
+ def test_multimetric_update(self):
479
+ """Test updating all metrics."""
480
+ metrics = bst.nn.MultiMetric(
481
+ accuracy=bst.nn.AccuracyMetric(),
482
+ loss=bst.nn.AverageMetric(),
483
+ )
484
+
485
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
486
+ labels = jnp.array([1, 0, 1])
487
+ batch_loss = jnp.array([1, 2, 3])
488
+
489
+ metrics.update(logits=logits, labels=labels, values=batch_loss)
490
+ result = metrics.compute()
491
+
492
+ self.assertGreaterEqual(float(result['accuracy']), 0.0)
493
+ self.assertLessEqual(float(result['accuracy']), 1.0)
494
+ self.assertAlmostEqual(float(result['loss']), 2.0)
495
+
496
+ def test_multimetric_reset(self):
497
+ """Test resetting all metrics."""
498
+ metrics = bst.nn.MultiMetric(
499
+ accuracy=bst.nn.AccuracyMetric(),
500
+ loss=bst.nn.AverageMetric(),
501
+ )
502
+
503
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
504
+ labels = jnp.array([1, 0])
505
+ batch_loss = jnp.array([1, 2])
506
+
507
+ metrics.update(logits=logits, labels=labels, values=batch_loss)
508
+ metrics.reset()
509
+ result = metrics.compute()
510
+
511
+ self.assertTrue(jnp.isnan(result['accuracy']))
512
+ self.assertTrue(jnp.isnan(result['loss']))
513
+
514
+ def test_multimetric_reserved_name_validation(self):
515
+ """Test that reserved method names cannot be used."""
516
+ with self.assertRaises(ValueError):
517
+ bst.nn.MultiMetric(
518
+ reset=bst.nn.AverageMetric(),
519
+ )
520
+
521
+ with self.assertRaises(ValueError):
522
+ bst.nn.MultiMetric(
523
+ update=bst.nn.AverageMetric(),
524
+ )
525
+
526
+ with self.assertRaises(ValueError):
527
+ bst.nn.MultiMetric(
528
+ compute=bst.nn.AverageMetric(),
529
+ )
530
+
531
+ def test_multimetric_type_validation(self):
532
+ """Test that all metrics must be Metric instances."""
533
+ with self.assertRaises(TypeError):
534
+ bst.nn.MultiMetric(
535
+ accuracy=bst.nn.AccuracyMetric(),
536
+ invalid="not a metric",
537
+ )
538
+
539
+ def test_multimetric_module_attribute(self):
540
+ """Test that MultiMetric has correct __module__ attribute."""
541
+ self.assertEqual(bst.nn.MultiMetric.__module__, "brainstate.nn")
542
+
543
+
544
+ class MetricBaseClassTest(absltest.TestCase):
545
+ """Test cases for base Metric class."""
546
+
547
+ def test_metric_not_implemented_errors(self):
548
+ """Test that base Metric class raises NotImplementedError."""
549
+ metric = bst.nn.Metric()
550
+
551
+ with self.assertRaises(NotImplementedError):
552
+ metric.reset()
553
+
554
+ with self.assertRaises(NotImplementedError):
555
+ metric.update()
556
+
557
+ with self.assertRaises(NotImplementedError):
558
+ metric.compute()
559
+
560
+ def test_metric_module_attribute(self):
561
+ """Test that Metric has correct __module__ attribute."""
562
+ self.assertEqual(bst.nn.Metric.__module__, "brainstate.nn")
563
+
564
+
565
+ class EdgeCasesTest(parameterized.TestCase):
566
+ """Test edge cases and boundary conditions."""
567
+
568
+ def test_average_metric_large_numbers(self):
569
+ """Test AverageMetric with very large numbers."""
570
+ metric = bst.nn.AverageMetric()
571
+ metric.update(values=jnp.array([1e10, 2e10, 3e10]))
572
+ result = metric.compute()
573
+ self.assertAlmostEqual(float(result), 2e10, places=-5)
574
+
575
+ def test_average_metric_small_numbers(self):
576
+ """Test AverageMetric with very small numbers."""
577
+ metric = bst.nn.AverageMetric()
578
+ metric.update(values=jnp.array([1e-10, 2e-10, 3e-10]))
579
+ result = metric.compute()
580
+ self.assertAlmostEqual(float(result), 2e-10, places=15)
581
+
582
+ def test_confusion_matrix_single_class(self):
583
+ """Test ConfusionMatrix with single class."""
584
+ metric = bst.nn.ConfusionMatrix(num_classes=1)
585
+ predictions = jnp.array([0, 0, 0])
586
+ labels = jnp.array([0, 0, 0])
587
+ metric.update(predictions=predictions, labels=labels)
588
+ result = metric.compute()
589
+ self.assertEqual(int(result[0, 0]), 3)
590
+
591
+ def test_precision_recall_no_positives(self):
592
+ """Test precision/recall when there are no positive predictions."""
593
+ precision_metric = bst.nn.PrecisionMetric()
594
+ recall_metric = bst.nn.RecallMetric()
595
+
596
+ predictions = jnp.array([0, 0, 0, 0, 0])
597
+ labels = jnp.array([0, 0, 0, 0, 0])
598
+
599
+ precision_metric.update(predictions=predictions, labels=labels)
600
+ recall_metric.update(predictions=predictions, labels=labels)
601
+
602
+ # Should handle gracefully without division by zero
603
+ precision = precision_metric.compute()
604
+ recall = recall_metric.compute()
605
+
606
+ self.assertEqual(float(precision), 0.0)
607
+ self.assertEqual(float(recall), 0.0)
608
+
609
+
610
+ if __name__ == '__main__':
611
+ absltest.main()