brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,611 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Comprehensive tests for metrics module."""
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from absl.testing import absltest, parameterized
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class MetricStateTest(absltest.TestCase):
26
+ """Test cases for MetricState class."""
27
+
28
+ def test_metric_state_creation(self):
29
+ """Test creating a MetricState with different values."""
30
+ state = bst.nn.MetricState(jnp.array(0.0))
31
+ self.assertEqual(state.value, 0.0)
32
+
33
+ def test_metric_state_update(self):
34
+ """Test updating MetricState value."""
35
+ state = bst.nn.MetricState(jnp.array(0.0))
36
+ state.value = jnp.array(5.5)
37
+ self.assertAlmostEqual(float(state.value), 5.5)
38
+
39
+ def test_metric_state_module_attribute(self):
40
+ """Test that MetricState has correct __module__ attribute."""
41
+ state = bst.nn.MetricState(jnp.array(0.0))
42
+ self.assertEqual(bst.nn.MetricState.__module__, "brainstate.nn")
43
+
44
+
45
+ class AverageMetricTest(parameterized.TestCase):
46
+ """Test cases for AverageMetric class."""
47
+
48
+ def test_average_metric_initial_state(self):
49
+ """Test that initial average is NaN."""
50
+ metric = bst.nn.AverageMetric()
51
+ result = metric.compute()
52
+ self.assertTrue(jnp.isnan(result))
53
+
54
+ def test_average_metric_single_update(self):
55
+ """Test average after single update."""
56
+ metric = bst.nn.AverageMetric()
57
+ metric.update(values=jnp.array([1, 2, 3, 4]))
58
+ result = metric.compute()
59
+ self.assertAlmostEqual(float(result), 2.5)
60
+
61
+ def test_average_metric_multiple_updates(self):
62
+ """Test average after multiple updates."""
63
+ metric = bst.nn.AverageMetric()
64
+ metric.update(values=jnp.array([1, 2, 3, 4]))
65
+ metric.update(values=jnp.array([3, 2, 1, 0]))
66
+ result = metric.compute()
67
+ self.assertAlmostEqual(float(result), 2.0)
68
+
69
+ def test_average_metric_scalar_values(self):
70
+ """Test average with scalar values."""
71
+ metric = bst.nn.AverageMetric()
72
+ metric.update(values=5.0)
73
+ metric.update(values=3.0)
74
+ result = metric.compute()
75
+ self.assertAlmostEqual(float(result), 4.0)
76
+
77
+ def test_average_metric_reset(self):
78
+ """Test reset functionality."""
79
+ metric = bst.nn.AverageMetric()
80
+ metric.update(values=jnp.array([1, 2, 3, 4]))
81
+ metric.reset()
82
+ result = metric.compute()
83
+ self.assertTrue(jnp.isnan(result))
84
+
85
+ def test_average_metric_custom_argname(self):
86
+ """Test using custom argument name."""
87
+ metric = bst.nn.AverageMetric('loss')
88
+ metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
89
+ result = metric.compute()
90
+ self.assertAlmostEqual(float(result), 2.0)
91
+
92
+ def test_average_metric_missing_argname(self):
93
+ """Test error when expected argument is missing."""
94
+ metric = bst.nn.AverageMetric('loss')
95
+ with self.assertRaises(TypeError):
96
+ metric.update(values=jnp.array([1.0, 2.0]))
97
+
98
+ def test_average_metric_module_attribute(self):
99
+ """Test that AverageMetric has correct __module__ attribute."""
100
+ self.assertEqual(bst.nn.AverageMetric.__module__, "brainstate.nn")
101
+
102
+
103
+ class WelfordMetricTest(parameterized.TestCase):
104
+ """Test cases for WelfordMetric class."""
105
+
106
+ def test_welford_metric_initial_state(self):
107
+ """Test initial statistics."""
108
+ metric = bst.nn.WelfordMetric()
109
+ stats = metric.compute()
110
+ self.assertEqual(float(stats.mean), 0.0)
111
+ self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
112
+ self.assertTrue(jnp.isnan(stats.standard_deviation))
113
+
114
+ def test_welford_metric_single_update(self):
115
+ """Test statistics after single update."""
116
+ metric = bst.nn.WelfordMetric()
117
+ metric.update(values=jnp.array([1, 2, 3, 4]))
118
+ stats = metric.compute()
119
+ self.assertAlmostEqual(float(stats.mean), 2.5, places=5)
120
+ # Population std of [1,2,3,4] is sqrt(1.25) ≈ 1.118
121
+ self.assertAlmostEqual(float(stats.standard_deviation), 1.118, places=3)
122
+
123
+ def test_welford_metric_multiple_updates(self):
124
+ """Test statistics after multiple updates."""
125
+ metric = bst.nn.WelfordMetric()
126
+ metric.update(values=jnp.array([1, 2, 3, 4]))
127
+ metric.update(values=jnp.array([3, 2, 1, 0]))
128
+ stats = metric.compute()
129
+ # Mean of all 8 values: [1,2,3,4,3,2,1,0] is 2.0
130
+ self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
131
+
132
+ def test_welford_metric_reset(self):
133
+ """Test reset functionality."""
134
+ metric = bst.nn.WelfordMetric()
135
+ metric.update(values=jnp.array([1, 2, 3, 4]))
136
+ metric.reset()
137
+ stats = metric.compute()
138
+ self.assertEqual(float(stats.mean), 0.0)
139
+ self.assertTrue(jnp.isnan(stats.standard_deviation))
140
+
141
+ def test_welford_metric_custom_argname(self):
142
+ """Test using custom argument name."""
143
+ metric = bst.nn.WelfordMetric('data')
144
+ metric.update(data=jnp.array([1.0, 2.0, 3.0]))
145
+ stats = metric.compute()
146
+ self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
147
+
148
+ def test_welford_metric_module_attribute(self):
149
+ """Test that WelfordMetric has correct __module__ attribute."""
150
+ self.assertEqual(bst.nn.WelfordMetric.__module__, "brainstate.nn")
151
+
152
+
153
+ class AccuracyMetricTest(parameterized.TestCase):
154
+ """Test cases for AccuracyMetric class."""
155
+
156
+ def test_accuracy_metric_initial_state(self):
157
+ """Test that initial accuracy is NaN."""
158
+ metric = bst.nn.AccuracyMetric()
159
+ result = metric.compute()
160
+ self.assertTrue(jnp.isnan(result))
161
+
162
+ def test_accuracy_metric_perfect_accuracy(self):
163
+ """Test with perfect predictions."""
164
+ metric = bst.nn.AccuracyMetric()
165
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
166
+ labels = jnp.array([1, 0, 1])
167
+ metric.update(logits=logits, labels=labels)
168
+ result = metric.compute()
169
+ self.assertAlmostEqual(float(result), 1.0)
170
+
171
+ def test_accuracy_metric_partial_accuracy(self):
172
+ """Test with partial accuracy."""
173
+ metric = bst.nn.AccuracyMetric()
174
+ logits = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
175
+ labels = jnp.array([1, 0, 1]) # First one wrong
176
+ metric.update(logits=logits, labels=labels)
177
+ result = metric.compute()
178
+ self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
179
+
180
+ def test_accuracy_metric_multiple_updates(self):
181
+ """Test accuracy after multiple updates."""
182
+ metric = bst.nn.AccuracyMetric()
183
+ logits1 = jax.random.normal(jax.random.key(0), (5, 2))
184
+ labels1 = jnp.array([1, 1, 0, 1, 0])
185
+ logits2 = jax.random.normal(jax.random.key(1), (5, 2))
186
+ labels2 = jnp.array([0, 1, 1, 1, 1])
187
+
188
+ metric.update(logits=logits1, labels=labels1)
189
+ metric.update(logits=logits2, labels=labels2)
190
+ result = metric.compute()
191
+ # Result should be between 0 and 1
192
+ self.assertGreaterEqual(float(result), 0.0)
193
+ self.assertLessEqual(float(result), 1.0)
194
+
195
+ def test_accuracy_metric_reset(self):
196
+ """Test reset functionality."""
197
+ metric = bst.nn.AccuracyMetric()
198
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
199
+ labels = jnp.array([1, 0])
200
+ metric.update(logits=logits, labels=labels)
201
+ metric.reset()
202
+ result = metric.compute()
203
+ self.assertTrue(jnp.isnan(result))
204
+
205
+ def test_accuracy_metric_shape_validation(self):
206
+ """Test shape validation."""
207
+ metric = bst.nn.AccuracyMetric()
208
+ logits = jnp.array([[0.1, 0.9]])
209
+ labels = jnp.array([[1, 0]]) # Wrong shape
210
+ with self.assertRaises(ValueError):
211
+ metric.update(logits=logits, labels=labels)
212
+
213
+ def test_accuracy_metric_module_attribute(self):
214
+ """Test that AccuracyMetric has correct __module__ attribute."""
215
+ self.assertEqual(bst.nn.AccuracyMetric.__module__, "brainstate.nn")
216
+
217
+
218
+ class PrecisionMetricTest(parameterized.TestCase):
219
+ """Test cases for PrecisionMetric class."""
220
+
221
+ def test_precision_metric_binary_perfect(self):
222
+ """Test binary precision with perfect predictions."""
223
+ metric = bst.nn.PrecisionMetric()
224
+ predictions = jnp.array([1, 0, 1, 1, 0])
225
+ labels = jnp.array([1, 0, 1, 1, 0])
226
+ metric.update(predictions=predictions, labels=labels)
227
+ result = metric.compute()
228
+ self.assertAlmostEqual(float(result), 1.0)
229
+
230
+ def test_precision_metric_binary_partial(self):
231
+ """Test binary precision with partial accuracy."""
232
+ metric = bst.nn.PrecisionMetric()
233
+ predictions = jnp.array([1, 0, 1, 1, 0])
234
+ labels = jnp.array([1, 0, 0, 1, 0])
235
+ metric.update(predictions=predictions, labels=labels)
236
+ result = metric.compute()
237
+ # TP=2, FP=1, Precision=2/3
238
+ self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
239
+
240
+ def test_precision_metric_multiclass_macro(self):
241
+ """Test multi-class precision with macro averaging."""
242
+ metric = bst.nn.PrecisionMetric(num_classes=3, average='macro')
243
+ predictions = jnp.array([0, 1, 2, 1, 0])
244
+ labels = jnp.array([0, 1, 1, 1, 2])
245
+ metric.update(predictions=predictions, labels=labels)
246
+ result = metric.compute()
247
+ # Should compute precision for each class and average
248
+ self.assertGreaterEqual(float(result), 0.0)
249
+ self.assertLessEqual(float(result), 1.0)
250
+
251
+ def test_precision_metric_multiclass_micro(self):
252
+ """Test multi-class precision with micro averaging."""
253
+ metric = bst.nn.PrecisionMetric(num_classes=3, average='micro')
254
+ predictions = jnp.array([0, 1, 2, 1, 0])
255
+ labels = jnp.array([0, 1, 1, 1, 2])
256
+ metric.update(predictions=predictions, labels=labels)
257
+ result = metric.compute()
258
+ self.assertGreaterEqual(float(result), 0.0)
259
+ self.assertLessEqual(float(result), 1.0)
260
+
261
+ def test_precision_metric_reset(self):
262
+ """Test reset functionality."""
263
+ metric = bst.nn.PrecisionMetric()
264
+ predictions = jnp.array([1, 0, 1, 1, 0])
265
+ labels = jnp.array([1, 0, 0, 1, 0])
266
+ metric.update(predictions=predictions, labels=labels)
267
+ metric.reset()
268
+ result = metric.compute()
269
+ # After reset, with no data, precision should be 0
270
+ self.assertEqual(float(result), 0.0)
271
+
272
+ def test_precision_metric_module_attribute(self):
273
+ """Test that PrecisionMetric has correct __module__ attribute."""
274
+ self.assertEqual(bst.nn.PrecisionMetric.__module__, "brainstate.nn")
275
+
276
+
277
+ class RecallMetricTest(parameterized.TestCase):
278
+ """Test cases for RecallMetric class."""
279
+
280
+ def test_recall_metric_binary_perfect(self):
281
+ """Test binary recall with perfect predictions."""
282
+ metric = bst.nn.RecallMetric()
283
+ predictions = jnp.array([1, 0, 1, 1, 0])
284
+ labels = jnp.array([1, 0, 1, 1, 0])
285
+ metric.update(predictions=predictions, labels=labels)
286
+ result = metric.compute()
287
+ self.assertAlmostEqual(float(result), 1.0)
288
+
289
+ def test_recall_metric_binary_partial(self):
290
+ """Test binary recall with partial accuracy."""
291
+ metric = bst.nn.RecallMetric()
292
+ predictions = jnp.array([1, 0, 1, 1, 0])
293
+ labels = jnp.array([1, 0, 0, 1, 0])
294
+ metric.update(predictions=predictions, labels=labels)
295
+ result = metric.compute()
296
+ # TP=2, FN=0, Recall=2/2=1.0
297
+ self.assertAlmostEqual(float(result), 1.0, places=5)
298
+
299
+ def test_recall_metric_multiclass_macro(self):
300
+ """Test multi-class recall with macro averaging."""
301
+ metric = bst.nn.RecallMetric(num_classes=3, average='macro')
302
+ predictions = jnp.array([0, 1, 2, 1, 0])
303
+ labels = jnp.array([0, 1, 1, 1, 2])
304
+ metric.update(predictions=predictions, labels=labels)
305
+ result = metric.compute()
306
+ self.assertGreaterEqual(float(result), 0.0)
307
+ self.assertLessEqual(float(result), 1.0)
308
+
309
+ def test_recall_metric_multiclass_micro(self):
310
+ """Test multi-class recall with micro averaging."""
311
+ metric = bst.nn.RecallMetric(num_classes=3, average='micro')
312
+ predictions = jnp.array([0, 1, 2, 1, 0])
313
+ labels = jnp.array([0, 1, 1, 1, 2])
314
+ metric.update(predictions=predictions, labels=labels)
315
+ result = metric.compute()
316
+ self.assertGreaterEqual(float(result), 0.0)
317
+ self.assertLessEqual(float(result), 1.0)
318
+
319
+ def test_recall_metric_reset(self):
320
+ """Test reset functionality."""
321
+ metric = bst.nn.RecallMetric()
322
+ predictions = jnp.array([1, 0, 1, 1, 0])
323
+ labels = jnp.array([1, 0, 0, 1, 0])
324
+ metric.update(predictions=predictions, labels=labels)
325
+ metric.reset()
326
+ result = metric.compute()
327
+ self.assertEqual(float(result), 0.0)
328
+
329
+ def test_recall_metric_module_attribute(self):
330
+ """Test that RecallMetric has correct __module__ attribute."""
331
+ self.assertEqual(bst.nn.RecallMetric.__module__, "brainstate.nn")
332
+
333
+
334
+ class F1ScoreMetricTest(parameterized.TestCase):
335
+ """Test cases for F1ScoreMetric class."""
336
+
337
+ def test_f1_score_binary_perfect(self):
338
+ """Test binary F1 score with perfect predictions."""
339
+ metric = bst.nn.F1ScoreMetric()
340
+ predictions = jnp.array([1, 0, 1, 1, 0])
341
+ labels = jnp.array([1, 0, 1, 1, 0])
342
+ metric.update(predictions=predictions, labels=labels)
343
+ result = metric.compute()
344
+ self.assertAlmostEqual(float(result), 1.0)
345
+
346
+ def test_f1_score_binary_partial(self):
347
+ """Test binary F1 score with partial accuracy."""
348
+ metric = bst.nn.F1ScoreMetric()
349
+ predictions = jnp.array([1, 0, 1, 1, 0])
350
+ labels = jnp.array([1, 0, 0, 1, 0])
351
+ metric.update(predictions=predictions, labels=labels)
352
+ result = metric.compute()
353
+ # Precision=2/3, Recall=1.0, F1=2*(2/3*1)/(2/3+1)=0.8
354
+ self.assertAlmostEqual(float(result), 0.8, places=5)
355
+
356
+ def test_f1_score_multiclass(self):
357
+ """Test multi-class F1 score."""
358
+ metric = bst.nn.F1ScoreMetric(num_classes=3, average='macro')
359
+ predictions = jnp.array([0, 1, 2, 1, 0])
360
+ labels = jnp.array([0, 1, 1, 1, 2])
361
+ metric.update(predictions=predictions, labels=labels)
362
+ result = metric.compute()
363
+ self.assertGreaterEqual(float(result), 0.0)
364
+ self.assertLessEqual(float(result), 1.0)
365
+
366
+ def test_f1_score_reset(self):
367
+ """Test reset functionality."""
368
+ metric = bst.nn.F1ScoreMetric()
369
+ predictions = jnp.array([1, 0, 1, 1, 0])
370
+ labels = jnp.array([1, 0, 0, 1, 0])
371
+ metric.update(predictions=predictions, labels=labels)
372
+ metric.reset()
373
+ result = metric.compute()
374
+ self.assertEqual(float(result), 0.0)
375
+
376
+ def test_f1_score_module_attribute(self):
377
+ """Test that F1ScoreMetric has correct __module__ attribute."""
378
+ self.assertEqual(bst.nn.F1ScoreMetric.__module__, "brainstate.nn")
379
+
380
+
381
+ class ConfusionMatrixTest(parameterized.TestCase):
382
+ """Test cases for ConfusionMatrix class."""
383
+
384
+ def test_confusion_matrix_basic(self):
385
+ """Test basic confusion matrix computation."""
386
+ metric = bst.nn.ConfusionMatrix(num_classes=3)
387
+ predictions = jnp.array([0, 1, 2, 1, 0])
388
+ labels = jnp.array([0, 1, 1, 1, 2])
389
+ metric.update(predictions=predictions, labels=labels)
390
+ result = metric.compute()
391
+
392
+ # Check shape
393
+ self.assertEqual(result.shape, (3, 3))
394
+
395
+ # Check specific values
396
+ # True label 0, Predicted 0: 1
397
+ # True label 0, Predicted 2: 1
398
+ # True label 1, Predicted 1: 2
399
+ self.assertEqual(int(result[0, 0]), 1)
400
+ self.assertEqual(int(result[1, 1]), 2)
401
+
402
+ def test_confusion_matrix_perfect(self):
403
+ """Test confusion matrix with perfect predictions."""
404
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
405
+ predictions = jnp.array([0, 0, 1, 1])
406
+ labels = jnp.array([0, 0, 1, 1])
407
+ metric.update(predictions=predictions, labels=labels)
408
+ result = metric.compute()
409
+
410
+ # Should be diagonal matrix
411
+ self.assertEqual(int(result[0, 0]), 2)
412
+ self.assertEqual(int(result[0, 1]), 0)
413
+ self.assertEqual(int(result[1, 0]), 0)
414
+ self.assertEqual(int(result[1, 1]), 2)
415
+
416
+ def test_confusion_matrix_reset(self):
417
+ """Test reset functionality."""
418
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
419
+ predictions = jnp.array([0, 1])
420
+ labels = jnp.array([0, 1])
421
+ metric.update(predictions=predictions, labels=labels)
422
+ metric.reset()
423
+ result = metric.compute()
424
+
425
+ # Should be all zeros
426
+ self.assertTrue(jnp.all(result == 0))
427
+
428
+ def test_confusion_matrix_invalid_predictions(self):
429
+ """Test error handling for invalid predictions."""
430
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
431
+ predictions = jnp.array([0, 1, 2]) # 2 is out of range
432
+ labels = jnp.array([0, 1, 1])
433
+ with self.assertRaises(ValueError):
434
+ metric.update(predictions=predictions, labels=labels)
435
+
436
+ def test_confusion_matrix_invalid_labels(self):
437
+ """Test error handling for invalid labels."""
438
+ metric = bst.nn.ConfusionMatrix(num_classes=2)
439
+ predictions = jnp.array([0, 1, 1])
440
+ labels = jnp.array([0, 1, 3]) # 3 is out of range
441
+ with self.assertRaises(ValueError):
442
+ metric.update(predictions=predictions, labels=labels)
443
+
444
+ def test_confusion_matrix_module_attribute(self):
445
+ """Test that ConfusionMatrix has correct __module__ attribute."""
446
+ self.assertEqual(bst.nn.ConfusionMatrix.__module__, "brainstate.nn")
447
+
448
+
449
+ class MultiMetricTest(parameterized.TestCase):
450
+ """Test cases for MultiMetric class."""
451
+
452
+ def test_multimetric_creation(self):
453
+ """Test creating a MultiMetric with multiple metrics."""
454
+ metrics = bst.nn.MultiMetric(
455
+ accuracy=bst.nn.AccuracyMetric(),
456
+ loss=bst.nn.AverageMetric(),
457
+ )
458
+ self.assertIsNotNone(metrics.accuracy)
459
+ self.assertIsNotNone(metrics.loss)
460
+
461
+ def test_multimetric_compute(self):
462
+ """Test computing all metrics."""
463
+ metrics = bst.nn.MultiMetric(
464
+ accuracy=bst.nn.AccuracyMetric(),
465
+ loss=bst.nn.AverageMetric(),
466
+ )
467
+
468
+ logits = jax.random.normal(jax.random.key(0), (5, 2))
469
+ labels = jnp.array([1, 1, 0, 1, 0])
470
+ batch_loss = jnp.array([1, 2, 3, 4])
471
+
472
+ result = metrics.compute()
473
+ self.assertIn('accuracy', result)
474
+ self.assertIn('loss', result)
475
+ self.assertTrue(jnp.isnan(result['accuracy']))
476
+ self.assertTrue(jnp.isnan(result['loss']))
477
+
478
+ def test_multimetric_update(self):
479
+ """Test updating all metrics."""
480
+ metrics = bst.nn.MultiMetric(
481
+ accuracy=bst.nn.AccuracyMetric(),
482
+ loss=bst.nn.AverageMetric(),
483
+ )
484
+
485
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
486
+ labels = jnp.array([1, 0, 1])
487
+ batch_loss = jnp.array([1, 2, 3])
488
+
489
+ metrics.update(logits=logits, labels=labels, values=batch_loss)
490
+ result = metrics.compute()
491
+
492
+ self.assertGreaterEqual(float(result['accuracy']), 0.0)
493
+ self.assertLessEqual(float(result['accuracy']), 1.0)
494
+ self.assertAlmostEqual(float(result['loss']), 2.0)
495
+
496
+ def test_multimetric_reset(self):
497
+ """Test resetting all metrics."""
498
+ metrics = bst.nn.MultiMetric(
499
+ accuracy=bst.nn.AccuracyMetric(),
500
+ loss=bst.nn.AverageMetric(),
501
+ )
502
+
503
+ logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
504
+ labels = jnp.array([1, 0])
505
+ batch_loss = jnp.array([1, 2])
506
+
507
+ metrics.update(logits=logits, labels=labels, values=batch_loss)
508
+ metrics.reset()
509
+ result = metrics.compute()
510
+
511
+ self.assertTrue(jnp.isnan(result['accuracy']))
512
+ self.assertTrue(jnp.isnan(result['loss']))
513
+
514
+ def test_multimetric_reserved_name_validation(self):
515
+ """Test that reserved method names cannot be used."""
516
+ with self.assertRaises(ValueError):
517
+ bst.nn.MultiMetric(
518
+ reset=bst.nn.AverageMetric(),
519
+ )
520
+
521
+ with self.assertRaises(ValueError):
522
+ bst.nn.MultiMetric(
523
+ update=bst.nn.AverageMetric(),
524
+ )
525
+
526
+ with self.assertRaises(ValueError):
527
+ bst.nn.MultiMetric(
528
+ compute=bst.nn.AverageMetric(),
529
+ )
530
+
531
+ def test_multimetric_type_validation(self):
532
+ """Test that all metrics must be Metric instances."""
533
+ with self.assertRaises(TypeError):
534
+ bst.nn.MultiMetric(
535
+ accuracy=bst.nn.AccuracyMetric(),
536
+ invalid="not a metric",
537
+ )
538
+
539
+ def test_multimetric_module_attribute(self):
540
+ """Test that MultiMetric has correct __module__ attribute."""
541
+ self.assertEqual(bst.nn.MultiMetric.__module__, "brainstate.nn")
542
+
543
+
544
+ class MetricBaseClassTest(absltest.TestCase):
545
+ """Test cases for base Metric class."""
546
+
547
+ def test_metric_not_implemented_errors(self):
548
+ """Test that base Metric class raises NotImplementedError."""
549
+ metric = bst.nn.Metric()
550
+
551
+ with self.assertRaises(NotImplementedError):
552
+ metric.reset()
553
+
554
+ with self.assertRaises(NotImplementedError):
555
+ metric.update()
556
+
557
+ with self.assertRaises(NotImplementedError):
558
+ metric.compute()
559
+
560
+ def test_metric_module_attribute(self):
561
+ """Test that Metric has correct __module__ attribute."""
562
+ self.assertEqual(bst.nn.Metric.__module__, "brainstate.nn")
563
+
564
+
565
+ class EdgeCasesTest(parameterized.TestCase):
566
+ """Test edge cases and boundary conditions."""
567
+
568
+ def test_average_metric_large_numbers(self):
569
+ """Test AverageMetric with very large numbers."""
570
+ metric = bst.nn.AverageMetric()
571
+ metric.update(values=jnp.array([1e10, 2e10, 3e10]))
572
+ result = metric.compute()
573
+ self.assertAlmostEqual(float(result), 2e10, places=-5)
574
+
575
+ def test_average_metric_small_numbers(self):
576
+ """Test AverageMetric with very small numbers."""
577
+ metric = bst.nn.AverageMetric()
578
+ metric.update(values=jnp.array([1e-10, 2e-10, 3e-10]))
579
+ result = metric.compute()
580
+ self.assertAlmostEqual(float(result), 2e-10, places=15)
581
+
582
+ def test_confusion_matrix_single_class(self):
583
+ """Test ConfusionMatrix with single class."""
584
+ metric = bst.nn.ConfusionMatrix(num_classes=1)
585
+ predictions = jnp.array([0, 0, 0])
586
+ labels = jnp.array([0, 0, 0])
587
+ metric.update(predictions=predictions, labels=labels)
588
+ result = metric.compute()
589
+ self.assertEqual(int(result[0, 0]), 3)
590
+
591
+ def test_precision_recall_no_positives(self):
592
+ """Test precision/recall when there are no positive predictions."""
593
+ precision_metric = bst.nn.PrecisionMetric()
594
+ recall_metric = bst.nn.RecallMetric()
595
+
596
+ predictions = jnp.array([0, 0, 0, 0, 0])
597
+ labels = jnp.array([0, 0, 0, 0, 0])
598
+
599
+ precision_metric.update(predictions=predictions, labels=labels)
600
+ recall_metric.update(predictions=predictions, labels=labels)
601
+
602
+ # Should handle gracefully without division by zero
603
+ precision = precision_metric.compute()
604
+ recall = recall_metric.compute()
605
+
606
+ self.assertEqual(float(precision), 0.0)
607
+ self.assertEqual(float(recall), 0.0)
608
+
609
+
610
+ if __name__ == '__main__':
611
+ absltest.main()
brainstate/nn/_module.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -31,11 +31,12 @@ from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
31
31
 
32
32
  import numpy as np
33
33
 
34
+ from brainstate._error import BrainStateError
34
35
  from brainstate._state import State
35
36
  from brainstate.graph import Node, states, nodes, flatten
36
37
  from brainstate.mixin import ParamDescriber, ParamDesc
37
38
  from brainstate.typing import PathParts, Size
38
- from brainstate.util import FlattedDict, NestedDict, BrainStateError
39
+ from brainstate.util import FlattedDict, NestedDict
39
40
 
40
41
  # maximum integer
41
42
  max_int = np.iinfo(np.int32).max
@@ -94,7 +95,10 @@ class Module(Node, ParamDesc):
94
95
  def in_size(self, in_size: Sequence[int] | int):
95
96
  if isinstance(in_size, int):
96
97
  in_size = (in_size,)
97
- assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
98
+ elif isinstance(in_size, np.generic):
99
+ if np.issubdtype(in_size, np.integer) and in_size.ndim == 0:
100
+ in_size = (int(in_size),)
101
+ assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {in_size} {type(in_size)}"
98
102
  self._in_size = tuple(in_size)
99
103
 
100
104
  @property
@@ -105,6 +109,9 @@ class Module(Node, ParamDesc):
105
109
  def out_size(self, out_size: Sequence[int] | int):
106
110
  if isinstance(out_size, int):
107
111
  out_size = (out_size,)
112
+ elif isinstance(out_size, np.ndarray):
113
+ if np.issubdtype(out_size, np.integer) and out_size.ndim == 0:
114
+ out_size = (int(out_size),)
108
115
  assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
109
116
  self._out_size = tuple(out_size)
110
117
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.