brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,611 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Comprehensive tests for metrics module."""
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from absl.testing import absltest, parameterized
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class MetricStateTest(absltest.TestCase):
26
+ """Test cases for MetricState class."""
27
+
28
+ def test_metric_state_creation(self):
29
+ """Test creating a MetricState with different values."""
30
+ state = bst.nn.MetricState(jnp.array(0.0))
31
+ self.assertEqual(state.value, 0.0)
32
+
33
+ def test_metric_state_update(self):
34
+ """Test updating MetricState value."""
35
+ state = bst.nn.MetricState(jnp.array(0.0))
36
+ state.value = jnp.array(5.5)
37
+ self.assertAlmostEqual(float(state.value), 5.5)
38
+
39
+ def test_metric_state_module_attribute(self):
40
+ """Test that MetricState has correct __module__ attribute."""
41
+ state = bst.nn.MetricState(jnp.array(0.0))
42
+ self.assertEqual(bst.nn.MetricState.__module__, "brainstate.nn")
43
+
44
+
45
+ class AverageMetricTest(parameterized.TestCase):
46
+ """Test cases for AverageMetric class."""
47
+
48
+ def test_average_metric_initial_state(self):
49
+ """Test that initial average is NaN."""
50
+ metric = bst.nn.AverageMetric()
51
+ result = metric.compute()
52
+ self.assertTrue(jnp.isnan(result))
53
+
54
+ def test_average_metric_single_update(self):
55
+ """Test average after single update."""
56
+ metric = bst.nn.AverageMetric()
57
+ metric.update(values=jnp.array([1, 2, 3, 4]))
58
+ result = metric.compute()
59
+ self.assertAlmostEqual(float(result), 2.5)
60
+
61
+ def test_average_metric_multiple_updates(self):
62
+ """Test average after multiple updates."""
63
+ metric = bst.nn.AverageMetric()
64
+ metric.update(values=jnp.array([1, 2, 3, 4]))
65
+ metric.update(values=jnp.array([3, 2, 1, 0]))
66
+ result = metric.compute()
67
+ self.assertAlmostEqual(float(result), 2.0)
68
+
69
+ def test_average_metric_scalar_values(self):
70
+ """Test average with scalar values."""
71
+ metric = bst.nn.AverageMetric()
72
+ metric.update(values=5.0)
73
+ metric.update(values=3.0)
74
+ result = metric.compute()
75
+ self.assertAlmostEqual(float(result), 4.0)
76
+
77
+ def test_average_metric_reset(self):
78
+ """Test reset functionality."""
79
+ metric = bst.nn.AverageMetric()
80
+ metric.update(values=jnp.array([1, 2, 3, 4]))
81
+ metric.reset()
82
+ result = metric.compute()
83
+ self.assertTrue(jnp.isnan(result))
84
+
85
+ def test_average_metric_custom_argname(self):
86
+ """Test using custom argument name."""
87
+ metric = bst.nn.AverageMetric('loss')
88
+ metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
89
+ result = metric.compute()
90
+ self.assertAlmostEqual(float(result), 2.0)
91
+
92
+ def test_average_metric_missing_argname(self):
93
+ """Test error when expected argument is missing."""
94
+ metric = bst.nn.AverageMetric('loss')
95
+ with self.assertRaises(TypeError):
96
+ metric.update(values=jnp.array([1.0, 2.0]))
97
+
98
+ def test_average_metric_module_attribute(self):
99
+ """Test that AverageMetric has correct __module__ attribute."""
100
+ self.assertEqual(bst.nn.AverageMetric.__module__, "brainstate.nn")
101
+
102
+
103
+ class WelfordMetricTest(parameterized.TestCase):
104
+ """Test cases for WelfordMetric class."""
105
+
106
+ def test_welford_metric_initial_state(self):
107
+ """Test initial statistics."""
108
+ metric = bst.nn.WelfordMetric()
109
+ stats = metric.compute()
110
+ self.assertEqual(float(stats.mean), 0.0)
111
+ self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
112
+ self.assertTrue(jnp.isnan(stats.standard_deviation))
113
+
114
+ def test_welford_metric_single_update(self):
115
+ """Test statistics after single update."""
116
+ metric = bst.nn.WelfordMetric()
117
+ metric.update(values=jnp.array([1, 2, 3, 4]))
118
+ stats = metric.compute()
119
+ self.assertAlmostEqual(float(stats.mean), 2.5, places=5)
120
+ # Population std of [1,2,3,4] is sqrt(1.25) ≈ 1.118
121
+ self.assertAlmostEqual(float(stats.standard_deviation), 1.118, places=3)
122
+
123
+ def test_welford_metric_multiple_updates(self):
124
+ """Test statistics after multiple updates."""
125
+ metric = bst.nn.WelfordMetric()
126
+ metric.update(values=jnp.array([1, 2, 3, 4]))
127
+ metric.update(values=jnp.array([3, 2, 1, 0]))
128
+ stats = metric.compute()
129
+ # Mean of all 8 values: [1,2,3,4,3,2,1,0] is 2.0
130
+ self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
131
+
132
+ def test_welford_metric_reset(self):
133
+ """Test reset functionality."""
134
+ metric = bst.nn.WelfordMetric()
135
+ metric.update(values=jnp.array([1, 2, 3, 4]))
136
+ metric.reset()
137
+ stats = metric.compute()
138
+ self.assertEqual(float(stats.mean), 0.0)
139
+ self.assertTrue(jnp.isnan(stats.standard_deviation))
140
+
141
+ def test_welford_metric_custom_argname(self):
142
+ """Test using custom argument name."""
143
+ metric = bst.nn.WelfordMetric('data')
144
+ metric.update(data=jnp.array([1.0, 2.0, 3.0]))
145
+ stats = metric.compute()
146
+ self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
147
+
148
+ def test_welford_metric_module_attribute(self):
149
+ """Test that WelfordMetric has correct __module__ attribute."""
150
+ self.assertEqual(bst.nn.WelfordMetric.__module__, "brainstate.nn")
151
+
152
+
153
+ class AccuracyMetricTest(parameterized.TestCase):
154
+ """Test cases for AccuracyMetric class."""
155
+
156
+ def test_accuracy_metric_initial_state(self):
157
+ """Test that initial accuracy is NaN."""
158
+ metric = bst.nn.AccuracyMetric()
159
+ result = metric.compute()
160
+ self.assertTrue(jnp.isnan(result))
161
+
162
+ def test_accuracy_metric_perfect_accuracy(self):
163
+ """Test with perfect predictions."""
164
+ metric = bst.nn.AccuracyMetric()
165
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
166
+ labels = jnp.array([1, 0, 1])
167
+ metric.update(logits=logits, labels=labels)
168
+ result = metric.compute()
169
+ self.assertAlmostEqual(float(result), 1.0)
170
+
171
+ def test_accuracy_metric_partial_accuracy(self):
172
+ """Test with partial accuracy."""
173
+ metric = bst.nn.AccuracyMetric()
174
+ logits = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
175
+ labels = jnp.array([1, 0, 1]) # First one wrong
176
+ metric.update(logits=logits, labels=labels)
177
+ result = metric.compute()
178
+ self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
179
+
180
+ def test_accuracy_metric_multiple_updates(self):
181
+ """Test accuracy after multiple updates."""
182
+ metric = bst.nn.AccuracyMetric()
183
+ logits1 = jax.random.normal(jax.random.key(0), (5, 2))
184
+ labels1 = jnp.array([1, 1, 0, 1, 0])
185
+ logits2 = jax.random.normal(jax.random.key(1), (5, 2))
186
+ labels2 = jnp.array([0, 1, 1, 1, 1])
187
+
188
+ metric.update(logits=logits1, labels=labels1)
189
+ metric.update(logits=logits2, labels=labels2)
190
+ result = metric.compute()
191
+ # Result should be between 0 and 1
192
+ self.assertGreaterEqual(float(result), 0.0)
193
+ self.assertLessEqual(float(result), 1.0)
194
+
195
+ def test_accuracy_metric_reset(self):
196
+ """Test reset functionality."""
197
+ metric = bst.nn.AccuracyMetric()
198
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
199
+ labels = jnp.array([1, 0])
200
+ metric.update(logits=logits, labels=labels)
201
+ metric.reset()
202
+ result = metric.compute()
203
+ self.assertTrue(jnp.isnan(result))
204
+
205
+ def test_accuracy_metric_shape_validation(self):
206
+ """Test shape validation."""
207
+ metric = bst.nn.AccuracyMetric()
208
+ logits = jnp.array([[0.1, 0.9]])
209
+ labels = jnp.array([[1, 0]]) # Wrong shape
210
+ with self.assertRaises(ValueError):
211
+ metric.update(logits=logits, labels=labels)
212
+
213
+ def test_accuracy_metric_module_attribute(self):
214
+ """Test that AccuracyMetric has correct __module__ attribute."""
215
+ self.assertEqual(bst.nn.AccuracyMetric.__module__, "brainstate.nn")
216
+
217
+
218
+ class PrecisionMetricTest(parameterized.TestCase):
219
+ """Test cases for PrecisionMetric class."""
220
+
221
+ def test_precision_metric_binary_perfect(self):
222
+ """Test binary precision with perfect predictions."""
223
+ metric = bst.nn.PrecisionMetric()
224
+ predictions = jnp.array([1, 0, 1, 1, 0])
225
+ labels = jnp.array([1, 0, 1, 1, 0])
226
+ metric.update(predictions=predictions, labels=labels)
227
+ result = metric.compute()
228
+ self.assertAlmostEqual(float(result), 1.0)
229
+
230
+ def test_precision_metric_binary_partial(self):
231
+ """Test binary precision with partial accuracy."""
232
+ metric = bst.nn.PrecisionMetric()
233
+ predictions = jnp.array([1, 0, 1, 1, 0])
234
+ labels = jnp.array([1, 0, 0, 1, 0])
235
+ metric.update(predictions=predictions, labels=labels)
236
+ result = metric.compute()
237
+ # TP=2, FP=1, Precision=2/3
238
+ self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
239
+
240
+ def test_precision_metric_multiclass_macro(self):
241
+ """Test multi-class precision with macro averaging."""
242
+ metric = bst.nn.PrecisionMetric(num_classes=3, average='macro')
243
+ predictions = jnp.array([0, 1, 2, 1, 0])
244
+ labels = jnp.array([0, 1, 1, 1, 2])
245
+ metric.update(predictions=predictions, labels=labels)
246
+ result = metric.compute()
247
+ # Should compute precision for each class and average
248
+ self.assertGreaterEqual(float(result), 0.0)
249
+ self.assertLessEqual(float(result), 1.0)
250
+
251
+ def test_precision_metric_multiclass_micro(self):
252
+ """Test multi-class precision with micro averaging."""
253
+ metric = bst.nn.PrecisionMetric(num_classes=3, average='micro')
254
+ predictions = jnp.array([0, 1, 2, 1, 0])
255
+ labels = jnp.array([0, 1, 1, 1, 2])
256
+ metric.update(predictions=predictions, labels=labels)
257
+ result = metric.compute()
258
+ self.assertGreaterEqual(float(result), 0.0)
259
+ self.assertLessEqual(float(result), 1.0)
260
+
261
+ def test_precision_metric_reset(self):
262
+ """Test reset functionality."""
263
+ metric = bst.nn.PrecisionMetric()
264
+ predictions = jnp.array([1, 0, 1, 1, 0])
265
+ labels = jnp.array([1, 0, 0, 1, 0])
266
+ metric.update(predictions=predictions, labels=labels)
267
+ metric.reset()
268
+ result = metric.compute()
269
+ # After reset, with no data, precision should be 0
270
+ self.assertEqual(float(result), 0.0)
271
+
272
+ def test_precision_metric_module_attribute(self):
273
+ """Test that PrecisionMetric has correct __module__ attribute."""
274
+ self.assertEqual(bst.nn.PrecisionMetric.__module__, "brainstate.nn")
275
+
276
+
277
+ class RecallMetricTest(parameterized.TestCase):
278
+ """Test cases for RecallMetric class."""
279
+
280
+ def test_recall_metric_binary_perfect(self):
281
+ """Test binary recall with perfect predictions."""
282
+ metric = bst.nn.RecallMetric()
283
+ predictions = jnp.array([1, 0, 1, 1, 0])
284
+ labels = jnp.array([1, 0, 1, 1, 0])
285
+ metric.update(predictions=predictions, labels=labels)
286
+ result = metric.compute()
287
+ self.assertAlmostEqual(float(result), 1.0)
288
+
289
+ def test_recall_metric_binary_partial(self):
290
+ """Test binary recall with partial accuracy."""
291
+ metric = bst.nn.RecallMetric()
292
+ predictions = jnp.array([1, 0, 1, 1, 0])
293
+ labels = jnp.array([1, 0, 0, 1, 0])
294
+ metric.update(predictions=predictions, labels=labels)
295
+ result = metric.compute()
296
+ # TP=2, FN=0, Recall=2/2=1.0
297
+ self.assertAlmostEqual(float(result), 1.0, places=5)
298
+
299
+ def test_recall_metric_multiclass_macro(self):
300
+ """Test multi-class recall with macro averaging."""
301
+ metric = bst.nn.RecallMetric(num_classes=3, average='macro')
302
+ predictions = jnp.array([0, 1, 2, 1, 0])
303
+ labels = jnp.array([0, 1, 1, 1, 2])
304
+ metric.update(predictions=predictions, labels=labels)
305
+ result = metric.compute()
306
+ self.assertGreaterEqual(float(result), 0.0)
307
+ self.assertLessEqual(float(result), 1.0)
308
+
309
+ def test_recall_metric_multiclass_micro(self):
310
+ """Test multi-class recall with micro averaging."""
311
+ metric = bst.nn.RecallMetric(num_classes=3, average='micro')
312
+ predictions = jnp.array([0, 1, 2, 1, 0])
313
+ labels = jnp.array([0, 1, 1, 1, 2])
314
+ metric.update(predictions=predictions, labels=labels)
315
+ result = metric.compute()
316
+ self.assertGreaterEqual(float(result), 0.0)
317
+ self.assertLessEqual(float(result), 1.0)
318
+
319
+ def test_recall_metric_reset(self):
320
+ """Test reset functionality."""
321
+ metric = bst.nn.RecallMetric()
322
+ predictions = jnp.array([1, 0, 1, 1, 0])
323
+ labels = jnp.array([1, 0, 0, 1, 0])
324
+ metric.update(predictions=predictions, labels=labels)
325
+ metric.reset()
326
+ result = metric.compute()
327
+ self.assertEqual(float(result), 0.0)
328
+
329
+ def test_recall_metric_module_attribute(self):
330
+ """Test that RecallMetric has correct __module__ attribute."""
331
+ self.assertEqual(bst.nn.RecallMetric.__module__, "brainstate.nn")
332
+
333
+
334
+ class F1ScoreMetricTest(parameterized.TestCase):
335
+ """Test cases for F1ScoreMetric class."""
336
+
337
+ def test_f1_score_binary_perfect(self):
338
+ """Test binary F1 score with perfect predictions."""
339
+ metric = bst.nn.F1ScoreMetric()
340
+ predictions = jnp.array([1, 0, 1, 1, 0])
341
+ labels = jnp.array([1, 0, 1, 1, 0])
342
+ metric.update(predictions=predictions, labels=labels)
343
+ result = metric.compute()
344
+ self.assertAlmostEqual(float(result), 1.0)
345
+
346
+ def test_f1_score_binary_partial(self):
347
+ """Test binary F1 score with partial accuracy."""
348
+ metric = bst.nn.F1ScoreMetric()
349
+ predictions = jnp.array([1, 0, 1, 1, 0])
350
+ labels = jnp.array([1, 0, 0, 1, 0])
351
+ metric.update(predictions=predictions, labels=labels)
352
+ result = metric.compute()
353
+ # Precision=2/3, Recall=1.0, F1=2*(2/3*1)/(2/3+1)=0.8
354
+ self.assertAlmostEqual(float(result), 0.8, places=5)
355
+
356
+ def test_f1_score_multiclass(self):
357
+ """Test multi-class F1 score."""
358
+ metric = bst.nn.F1ScoreMetric(num_classes=3, average='macro')
359
+ predictions = jnp.array([0, 1, 2, 1, 0])
360
+ labels = jnp.array([0, 1, 1, 1, 2])
361
+ metric.update(predictions=predictions, labels=labels)
362
+ result = metric.compute()
363
+ self.assertGreaterEqual(float(result), 0.0)
364
+ self.assertLessEqual(float(result), 1.0)
365
+
366
+ def test_f1_score_reset(self):
367
+ """Test reset functionality."""
368
+ metric = bst.nn.F1ScoreMetric()
369
+ predictions = jnp.array([1, 0, 1, 1, 0])
370
+ labels = jnp.array([1, 0, 0, 1, 0])
371
+ metric.update(predictions=predictions, labels=labels)
372
+ metric.reset()
373
+ result = metric.compute()
374
+ self.assertEqual(float(result), 0.0)
375
+
376
+ def test_f1_score_module_attribute(self):
377
+ """Test that F1ScoreMetric has correct __module__ attribute."""
378
+ self.assertEqual(bst.nn.F1ScoreMetric.__module__, "brainstate.nn")
379
+
380
+
381
+ class ConfusionMatrixTest(parameterized.TestCase):
382
+ """Test cases for ConfusionMatrix class."""
383
+
384
+ def test_confusion_matrix_basic(self):
385
+ """Test basic confusion matrix computation."""
386
+ metric = bst.nn.ConfusionMatrix(num_classes=3)
387
+ predictions = jnp.array([0, 1, 2, 1, 0])
388
+ labels = jnp.array([0, 1, 1, 1, 2])
389
+ metric.update(predictions=predictions, labels=labels)
390
+ result = metric.compute()
391
+
392
+ # Check shape
393
+ self.assertEqual(result.shape, (3, 3))
394
+
395
+ # Check specific values
396
+ # True label 0, Predicted 0: 1
397
+ # True label 0, Predicted 2: 1
398
+ # True label 1, Predicted 1: 2
399
+ self.assertEqual(int(result[0, 0]), 1)
400
+ self.assertEqual(int(result[1, 1]), 2)
401
+
402
+ def test_confusion_matrix_perfect(self):
403
+ """Test confusion matrix with perfect predictions."""
404
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
405
+ predictions = jnp.array([0, 0, 1, 1])
406
+ labels = jnp.array([0, 0, 1, 1])
407
+ metric.update(predictions=predictions, labels=labels)
408
+ result = metric.compute()
409
+
410
+ # Should be diagonal matrix
411
+ self.assertEqual(int(result[0, 0]), 2)
412
+ self.assertEqual(int(result[0, 1]), 0)
413
+ self.assertEqual(int(result[1, 0]), 0)
414
+ self.assertEqual(int(result[1, 1]), 2)
415
+
416
+ def test_confusion_matrix_reset(self):
417
+ """Test reset functionality."""
418
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
419
+ predictions = jnp.array([0, 1])
420
+ labels = jnp.array([0, 1])
421
+ metric.update(predictions=predictions, labels=labels)
422
+ metric.reset()
423
+ result = metric.compute()
424
+
425
+ # Should be all zeros
426
+ self.assertTrue(jnp.all(result == 0))
427
+
428
+ def test_confusion_matrix_invalid_predictions(self):
429
+ """Test error handling for invalid predictions."""
430
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
431
+ predictions = jnp.array([0, 1, 2]) # 2 is out of range
432
+ labels = jnp.array([0, 1, 1])
433
+ with self.assertRaises(ValueError):
434
+ metric.update(predictions=predictions, labels=labels)
435
+
436
+ def test_confusion_matrix_invalid_labels(self):
437
+ """Test error handling for invalid labels."""
438
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
439
+ predictions = jnp.array([0, 1, 1])
440
+ labels = jnp.array([0, 1, 3]) # 3 is out of range
441
+ with self.assertRaises(ValueError):
442
+ metric.update(predictions=predictions, labels=labels)
443
+
444
+ def test_confusion_matrix_module_attribute(self):
445
+ """Test that ConfusionMatrix has correct __module__ attribute."""
446
+ self.assertEqual(bst.nn.ConfusionMatrix.__module__, "brainstate.nn")
447
+
448
+
449
+ class MultiMetricTest(parameterized.TestCase):
450
+ """Test cases for MultiMetric class."""
451
+
452
+ def test_multimetric_creation(self):
453
+ """Test creating a MultiMetric with multiple metrics."""
454
+ metrics = bst.nn.MultiMetric(
455
+ accuracy=bst.nn.AccuracyMetric(),
456
+ loss=bst.nn.AverageMetric(),
457
+ )
458
+ self.assertIsNotNone(metrics.accuracy)
459
+ self.assertIsNotNone(metrics.loss)
460
+
461
+ def test_multimetric_compute(self):
462
+ """Test computing all metrics."""
463
+ metrics = bst.nn.MultiMetric(
464
+ accuracy=bst.nn.AccuracyMetric(),
465
+ loss=bst.nn.AverageMetric(),
466
+ )
467
+
468
+ logits = jax.random.normal(jax.random.key(0), (5, 2))
469
+ labels = jnp.array([1, 1, 0, 1, 0])
470
+ batch_loss = jnp.array([1, 2, 3, 4])
471
+
472
+ result = metrics.compute()
473
+ self.assertIn('accuracy', result)
474
+ self.assertIn('loss', result)
475
+ self.assertTrue(jnp.isnan(result['accuracy']))
476
+ self.assertTrue(jnp.isnan(result['loss']))
477
+
478
+ def test_multimetric_update(self):
479
+ """Test updating all metrics."""
480
+ metrics = bst.nn.MultiMetric(
481
+ accuracy=bst.nn.AccuracyMetric(),
482
+ loss=bst.nn.AverageMetric(),
483
+ )
484
+
485
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
486
+ labels = jnp.array([1, 0, 1])
487
+ batch_loss = jnp.array([1, 2, 3])
488
+
489
+ metrics.update(logits=logits, labels=labels, values=batch_loss)
490
+ result = metrics.compute()
491
+
492
+ self.assertGreaterEqual(float(result['accuracy']), 0.0)
493
+ self.assertLessEqual(float(result['accuracy']), 1.0)
494
+ self.assertAlmostEqual(float(result['loss']), 2.0)
495
+
496
+ def test_multimetric_reset(self):
497
+ """Test resetting all metrics."""
498
+ metrics = bst.nn.MultiMetric(
499
+ accuracy=bst.nn.AccuracyMetric(),
500
+ loss=bst.nn.AverageMetric(),
501
+ )
502
+
503
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
504
+ labels = jnp.array([1, 0])
505
+ batch_loss = jnp.array([1, 2])
506
+
507
+ metrics.update(logits=logits, labels=labels, values=batch_loss)
508
+ metrics.reset()
509
+ result = metrics.compute()
510
+
511
+ self.assertTrue(jnp.isnan(result['accuracy']))
512
+ self.assertTrue(jnp.isnan(result['loss']))
513
+
514
+ def test_multimetric_reserved_name_validation(self):
515
+ """Test that reserved method names cannot be used."""
516
+ with self.assertRaises(ValueError):
517
+ bst.nn.MultiMetric(
518
+ reset=bst.nn.AverageMetric(),
519
+ )
520
+
521
+ with self.assertRaises(ValueError):
522
+ bst.nn.MultiMetric(
523
+ update=bst.nn.AverageMetric(),
524
+ )
525
+
526
+ with self.assertRaises(ValueError):
527
+ bst.nn.MultiMetric(
528
+ compute=bst.nn.AverageMetric(),
529
+ )
530
+
531
+ def test_multimetric_type_validation(self):
532
+ """Test that all metrics must be Metric instances."""
533
+ with self.assertRaises(TypeError):
534
+ bst.nn.MultiMetric(
535
+ accuracy=bst.nn.AccuracyMetric(),
536
+ invalid="not a metric",
537
+ )
538
+
539
+ def test_multimetric_module_attribute(self):
540
+ """Test that MultiMetric has correct __module__ attribute."""
541
+ self.assertEqual(bst.nn.MultiMetric.__module__, "brainstate.nn")
542
+
543
+
544
+ class MetricBaseClassTest(absltest.TestCase):
545
+ """Test cases for base Metric class."""
546
+
547
+ def test_metric_not_implemented_errors(self):
548
+ """Test that base Metric class raises NotImplementedError."""
549
+ metric = bst.nn.Metric()
550
+
551
+ with self.assertRaises(NotImplementedError):
552
+ metric.reset()
553
+
554
+ with self.assertRaises(NotImplementedError):
555
+ metric.update()
556
+
557
+ with self.assertRaises(NotImplementedError):
558
+ metric.compute()
559
+
560
+ def test_metric_module_attribute(self):
561
+ """Test that Metric has correct __module__ attribute."""
562
+ self.assertEqual(bst.nn.Metric.__module__, "brainstate.nn")
563
+
564
+
565
+ class EdgeCasesTest(parameterized.TestCase):
566
+ """Test edge cases and boundary conditions."""
567
+
568
+ def test_average_metric_large_numbers(self):
569
+ """Test AverageMetric with very large numbers."""
570
+ metric = bst.nn.AverageMetric()
571
+ metric.update(values=jnp.array([1e10, 2e10, 3e10]))
572
+ result = metric.compute()
573
+ self.assertAlmostEqual(float(result), 2e10, places=-5)
574
+
575
+ def test_average_metric_small_numbers(self):
576
+ """Test AverageMetric with very small numbers."""
577
+ metric = bst.nn.AverageMetric()
578
+ metric.update(values=jnp.array([1e-10, 2e-10, 3e-10]))
579
+ result = metric.compute()
580
+ self.assertAlmostEqual(float(result), 2e-10, places=15)
581
+
582
+ def test_confusion_matrix_single_class(self):
583
+ """Test ConfusionMatrix with single class."""
584
+ metric = bst.nn.ConfusionMatrix(num_classes=1)
585
+ predictions = jnp.array([0, 0, 0])
586
+ labels = jnp.array([0, 0, 0])
587
+ metric.update(predictions=predictions, labels=labels)
588
+ result = metric.compute()
589
+ self.assertEqual(int(result[0, 0]), 3)
590
+
591
+ def test_precision_recall_no_positives(self):
592
+ """Test precision/recall when there are no positive predictions."""
593
+ precision_metric = bst.nn.PrecisionMetric()
594
+ recall_metric = bst.nn.RecallMetric()
595
+
596
+ predictions = jnp.array([0, 0, 0, 0, 0])
597
+ labels = jnp.array([0, 0, 0, 0, 0])
598
+
599
+ precision_metric.update(predictions=predictions, labels=labels)
600
+ recall_metric.update(predictions=predictions, labels=labels)
601
+
602
+ # Should handle gracefully without division by zero
603
+ precision = precision_metric.compute()
604
+ recall = recall_metric.compute()
605
+
606
+ self.assertEqual(float(precision), 0.0)
607
+ self.assertEqual(float(recall), 0.0)
608
+
609
+
610
+ if __name__ == '__main__':
611
+ absltest.main()