brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_metrics_test.py
CHANGED
@@ -1,611 +1,611 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Comprehensive tests for metrics module."""
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
from absl.testing import absltest, parameterized
|
21
|
-
|
22
|
-
import brainstate as bst
|
23
|
-
|
24
|
-
|
25
|
-
class MetricStateTest(absltest.TestCase):
|
26
|
-
"""Test cases for MetricState class."""
|
27
|
-
|
28
|
-
def test_metric_state_creation(self):
|
29
|
-
"""Test creating a MetricState with different values."""
|
30
|
-
state = bst.nn.MetricState(jnp.array(0.0))
|
31
|
-
self.assertEqual(state.value, 0.0)
|
32
|
-
|
33
|
-
def test_metric_state_update(self):
|
34
|
-
"""Test updating MetricState value."""
|
35
|
-
state = bst.nn.MetricState(jnp.array(0.0))
|
36
|
-
state.value = jnp.array(5.5)
|
37
|
-
self.assertAlmostEqual(float(state.value), 5.5)
|
38
|
-
|
39
|
-
def test_metric_state_module_attribute(self):
|
40
|
-
"""Test that MetricState has correct __module__ attribute."""
|
41
|
-
state = bst.nn.MetricState(jnp.array(0.0))
|
42
|
-
self.assertEqual(bst.nn.MetricState.__module__, "brainstate.nn")
|
43
|
-
|
44
|
-
|
45
|
-
class AverageMetricTest(parameterized.TestCase):
|
46
|
-
"""Test cases for AverageMetric class."""
|
47
|
-
|
48
|
-
def test_average_metric_initial_state(self):
|
49
|
-
"""Test that initial average is NaN."""
|
50
|
-
metric = bst.nn.AverageMetric()
|
51
|
-
result = metric.compute()
|
52
|
-
self.assertTrue(jnp.isnan(result))
|
53
|
-
|
54
|
-
def test_average_metric_single_update(self):
|
55
|
-
"""Test average after single update."""
|
56
|
-
metric = bst.nn.AverageMetric()
|
57
|
-
metric.update(values=jnp.array([1, 2, 3, 4]))
|
58
|
-
result = metric.compute()
|
59
|
-
self.assertAlmostEqual(float(result), 2.5)
|
60
|
-
|
61
|
-
def test_average_metric_multiple_updates(self):
|
62
|
-
"""Test average after multiple updates."""
|
63
|
-
metric = bst.nn.AverageMetric()
|
64
|
-
metric.update(values=jnp.array([1, 2, 3, 4]))
|
65
|
-
metric.update(values=jnp.array([3, 2, 1, 0]))
|
66
|
-
result = metric.compute()
|
67
|
-
self.assertAlmostEqual(float(result), 2.0)
|
68
|
-
|
69
|
-
def test_average_metric_scalar_values(self):
|
70
|
-
"""Test average with scalar values."""
|
71
|
-
metric = bst.nn.AverageMetric()
|
72
|
-
metric.update(values=5.0)
|
73
|
-
metric.update(values=3.0)
|
74
|
-
result = metric.compute()
|
75
|
-
self.assertAlmostEqual(float(result), 4.0)
|
76
|
-
|
77
|
-
def test_average_metric_reset(self):
|
78
|
-
"""Test reset functionality."""
|
79
|
-
metric = bst.nn.AverageMetric()
|
80
|
-
metric.update(values=jnp.array([1, 2, 3, 4]))
|
81
|
-
metric.reset()
|
82
|
-
result = metric.compute()
|
83
|
-
self.assertTrue(jnp.isnan(result))
|
84
|
-
|
85
|
-
def test_average_metric_custom_argname(self):
|
86
|
-
"""Test using custom argument name."""
|
87
|
-
metric = bst.nn.AverageMetric('loss')
|
88
|
-
metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
|
89
|
-
result = metric.compute()
|
90
|
-
self.assertAlmostEqual(float(result), 2.0)
|
91
|
-
|
92
|
-
def test_average_metric_missing_argname(self):
|
93
|
-
"""Test error when expected argument is missing."""
|
94
|
-
metric = bst.nn.AverageMetric('loss')
|
95
|
-
with self.assertRaises(TypeError):
|
96
|
-
metric.update(values=jnp.array([1.0, 2.0]))
|
97
|
-
|
98
|
-
def test_average_metric_module_attribute(self):
|
99
|
-
"""Test that AverageMetric has correct __module__ attribute."""
|
100
|
-
self.assertEqual(bst.nn.AverageMetric.__module__, "brainstate.nn")
|
101
|
-
|
102
|
-
|
103
|
-
class WelfordMetricTest(parameterized.TestCase):
|
104
|
-
"""Test cases for WelfordMetric class."""
|
105
|
-
|
106
|
-
def test_welford_metric_initial_state(self):
|
107
|
-
"""Test initial statistics."""
|
108
|
-
metric = bst.nn.WelfordMetric()
|
109
|
-
stats = metric.compute()
|
110
|
-
self.assertEqual(float(stats.mean), 0.0)
|
111
|
-
self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
|
112
|
-
self.assertTrue(jnp.isnan(stats.standard_deviation))
|
113
|
-
|
114
|
-
def test_welford_metric_single_update(self):
|
115
|
-
"""Test statistics after single update."""
|
116
|
-
metric = bst.nn.WelfordMetric()
|
117
|
-
metric.update(values=jnp.array([1, 2, 3, 4]))
|
118
|
-
stats = metric.compute()
|
119
|
-
self.assertAlmostEqual(float(stats.mean), 2.5, places=5)
|
120
|
-
# Population std of [1,2,3,4] is sqrt(1.25) ≈ 1.118
|
121
|
-
self.assertAlmostEqual(float(stats.standard_deviation), 1.118, places=3)
|
122
|
-
|
123
|
-
def test_welford_metric_multiple_updates(self):
|
124
|
-
"""Test statistics after multiple updates."""
|
125
|
-
metric = bst.nn.WelfordMetric()
|
126
|
-
metric.update(values=jnp.array([1, 2, 3, 4]))
|
127
|
-
metric.update(values=jnp.array([3, 2, 1, 0]))
|
128
|
-
stats = metric.compute()
|
129
|
-
# Mean of all 8 values: [1,2,3,4,3,2,1,0] is 2.0
|
130
|
-
self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
|
131
|
-
|
132
|
-
def test_welford_metric_reset(self):
|
133
|
-
"""Test reset functionality."""
|
134
|
-
metric = bst.nn.WelfordMetric()
|
135
|
-
metric.update(values=jnp.array([1, 2, 3, 4]))
|
136
|
-
metric.reset()
|
137
|
-
stats = metric.compute()
|
138
|
-
self.assertEqual(float(stats.mean), 0.0)
|
139
|
-
self.assertTrue(jnp.isnan(stats.standard_deviation))
|
140
|
-
|
141
|
-
def test_welford_metric_custom_argname(self):
|
142
|
-
"""Test using custom argument name."""
|
143
|
-
metric = bst.nn.WelfordMetric('data')
|
144
|
-
metric.update(data=jnp.array([1.0, 2.0, 3.0]))
|
145
|
-
stats = metric.compute()
|
146
|
-
self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
|
147
|
-
|
148
|
-
def test_welford_metric_module_attribute(self):
|
149
|
-
"""Test that WelfordMetric has correct __module__ attribute."""
|
150
|
-
self.assertEqual(bst.nn.WelfordMetric.__module__, "brainstate.nn")
|
151
|
-
|
152
|
-
|
153
|
-
class AccuracyMetricTest(parameterized.TestCase):
|
154
|
-
"""Test cases for AccuracyMetric class."""
|
155
|
-
|
156
|
-
def test_accuracy_metric_initial_state(self):
|
157
|
-
"""Test that initial accuracy is NaN."""
|
158
|
-
metric = bst.nn.AccuracyMetric()
|
159
|
-
result = metric.compute()
|
160
|
-
self.assertTrue(jnp.isnan(result))
|
161
|
-
|
162
|
-
def test_accuracy_metric_perfect_accuracy(self):
|
163
|
-
"""Test with perfect predictions."""
|
164
|
-
metric = bst.nn.AccuracyMetric()
|
165
|
-
logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
|
166
|
-
labels = jnp.array([1, 0, 1])
|
167
|
-
metric.update(logits=logits, labels=labels)
|
168
|
-
result = metric.compute()
|
169
|
-
self.assertAlmostEqual(float(result), 1.0)
|
170
|
-
|
171
|
-
def test_accuracy_metric_partial_accuracy(self):
|
172
|
-
"""Test with partial accuracy."""
|
173
|
-
metric = bst.nn.AccuracyMetric()
|
174
|
-
logits = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
|
175
|
-
labels = jnp.array([1, 0, 1]) # First one wrong
|
176
|
-
metric.update(logits=logits, labels=labels)
|
177
|
-
result = metric.compute()
|
178
|
-
self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
|
179
|
-
|
180
|
-
def test_accuracy_metric_multiple_updates(self):
|
181
|
-
"""Test accuracy after multiple updates."""
|
182
|
-
metric = bst.nn.AccuracyMetric()
|
183
|
-
logits1 = jax.random.normal(jax.random.key(0), (5, 2))
|
184
|
-
labels1 = jnp.array([1, 1, 0, 1, 0])
|
185
|
-
logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
186
|
-
labels2 = jnp.array([0, 1, 1, 1, 1])
|
187
|
-
|
188
|
-
metric.update(logits=logits1, labels=labels1)
|
189
|
-
metric.update(logits=logits2, labels=labels2)
|
190
|
-
result = metric.compute()
|
191
|
-
# Result should be between 0 and 1
|
192
|
-
self.assertGreaterEqual(float(result), 0.0)
|
193
|
-
self.assertLessEqual(float(result), 1.0)
|
194
|
-
|
195
|
-
def test_accuracy_metric_reset(self):
|
196
|
-
"""Test reset functionality."""
|
197
|
-
metric = bst.nn.AccuracyMetric()
|
198
|
-
logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
|
199
|
-
labels = jnp.array([1, 0])
|
200
|
-
metric.update(logits=logits, labels=labels)
|
201
|
-
metric.reset()
|
202
|
-
result = metric.compute()
|
203
|
-
self.assertTrue(jnp.isnan(result))
|
204
|
-
|
205
|
-
def test_accuracy_metric_shape_validation(self):
|
206
|
-
"""Test shape validation."""
|
207
|
-
metric = bst.nn.AccuracyMetric()
|
208
|
-
logits = jnp.array([[0.1, 0.9]])
|
209
|
-
labels = jnp.array([[1, 0]]) # Wrong shape
|
210
|
-
with self.assertRaises(ValueError):
|
211
|
-
metric.update(logits=logits, labels=labels)
|
212
|
-
|
213
|
-
def test_accuracy_metric_module_attribute(self):
|
214
|
-
"""Test that AccuracyMetric has correct __module__ attribute."""
|
215
|
-
self.assertEqual(bst.nn.AccuracyMetric.__module__, "brainstate.nn")
|
216
|
-
|
217
|
-
|
218
|
-
class PrecisionMetricTest(parameterized.TestCase):
|
219
|
-
"""Test cases for PrecisionMetric class."""
|
220
|
-
|
221
|
-
def test_precision_metric_binary_perfect(self):
|
222
|
-
"""Test binary precision with perfect predictions."""
|
223
|
-
metric = bst.nn.PrecisionMetric()
|
224
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
225
|
-
labels = jnp.array([1, 0, 1, 1, 0])
|
226
|
-
metric.update(predictions=predictions, labels=labels)
|
227
|
-
result = metric.compute()
|
228
|
-
self.assertAlmostEqual(float(result), 1.0)
|
229
|
-
|
230
|
-
def test_precision_metric_binary_partial(self):
|
231
|
-
"""Test binary precision with partial accuracy."""
|
232
|
-
metric = bst.nn.PrecisionMetric()
|
233
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
234
|
-
labels = jnp.array([1, 0, 0, 1, 0])
|
235
|
-
metric.update(predictions=predictions, labels=labels)
|
236
|
-
result = metric.compute()
|
237
|
-
# TP=2, FP=1, Precision=2/3
|
238
|
-
self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
|
239
|
-
|
240
|
-
def test_precision_metric_multiclass_macro(self):
|
241
|
-
"""Test multi-class precision with macro averaging."""
|
242
|
-
metric = bst.nn.PrecisionMetric(num_classes=3, average='macro')
|
243
|
-
predictions = jnp.array([0, 1, 2, 1, 0])
|
244
|
-
labels = jnp.array([0, 1, 1, 1, 2])
|
245
|
-
metric.update(predictions=predictions, labels=labels)
|
246
|
-
result = metric.compute()
|
247
|
-
# Should compute precision for each class and average
|
248
|
-
self.assertGreaterEqual(float(result), 0.0)
|
249
|
-
self.assertLessEqual(float(result), 1.0)
|
250
|
-
|
251
|
-
def test_precision_metric_multiclass_micro(self):
|
252
|
-
"""Test multi-class precision with micro averaging."""
|
253
|
-
metric = bst.nn.PrecisionMetric(num_classes=3, average='micro')
|
254
|
-
predictions = jnp.array([0, 1, 2, 1, 0])
|
255
|
-
labels = jnp.array([0, 1, 1, 1, 2])
|
256
|
-
metric.update(predictions=predictions, labels=labels)
|
257
|
-
result = metric.compute()
|
258
|
-
self.assertGreaterEqual(float(result), 0.0)
|
259
|
-
self.assertLessEqual(float(result), 1.0)
|
260
|
-
|
261
|
-
def test_precision_metric_reset(self):
|
262
|
-
"""Test reset functionality."""
|
263
|
-
metric = bst.nn.PrecisionMetric()
|
264
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
265
|
-
labels = jnp.array([1, 0, 0, 1, 0])
|
266
|
-
metric.update(predictions=predictions, labels=labels)
|
267
|
-
metric.reset()
|
268
|
-
result = metric.compute()
|
269
|
-
# After reset, with no data, precision should be 0
|
270
|
-
self.assertEqual(float(result), 0.0)
|
271
|
-
|
272
|
-
def test_precision_metric_module_attribute(self):
|
273
|
-
"""Test that PrecisionMetric has correct __module__ attribute."""
|
274
|
-
self.assertEqual(bst.nn.PrecisionMetric.__module__, "brainstate.nn")
|
275
|
-
|
276
|
-
|
277
|
-
class RecallMetricTest(parameterized.TestCase):
|
278
|
-
"""Test cases for RecallMetric class."""
|
279
|
-
|
280
|
-
def test_recall_metric_binary_perfect(self):
|
281
|
-
"""Test binary recall with perfect predictions."""
|
282
|
-
metric = bst.nn.RecallMetric()
|
283
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
284
|
-
labels = jnp.array([1, 0, 1, 1, 0])
|
285
|
-
metric.update(predictions=predictions, labels=labels)
|
286
|
-
result = metric.compute()
|
287
|
-
self.assertAlmostEqual(float(result), 1.0)
|
288
|
-
|
289
|
-
def test_recall_metric_binary_partial(self):
|
290
|
-
"""Test binary recall with partial accuracy."""
|
291
|
-
metric = bst.nn.RecallMetric()
|
292
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
293
|
-
labels = jnp.array([1, 0, 0, 1, 0])
|
294
|
-
metric.update(predictions=predictions, labels=labels)
|
295
|
-
result = metric.compute()
|
296
|
-
# TP=2, FN=0, Recall=2/2=1.0
|
297
|
-
self.assertAlmostEqual(float(result), 1.0, places=5)
|
298
|
-
|
299
|
-
def test_recall_metric_multiclass_macro(self):
|
300
|
-
"""Test multi-class recall with macro averaging."""
|
301
|
-
metric = bst.nn.RecallMetric(num_classes=3, average='macro')
|
302
|
-
predictions = jnp.array([0, 1, 2, 1, 0])
|
303
|
-
labels = jnp.array([0, 1, 1, 1, 2])
|
304
|
-
metric.update(predictions=predictions, labels=labels)
|
305
|
-
result = metric.compute()
|
306
|
-
self.assertGreaterEqual(float(result), 0.0)
|
307
|
-
self.assertLessEqual(float(result), 1.0)
|
308
|
-
|
309
|
-
def test_recall_metric_multiclass_micro(self):
|
310
|
-
"""Test multi-class recall with micro averaging."""
|
311
|
-
metric = bst.nn.RecallMetric(num_classes=3, average='micro')
|
312
|
-
predictions = jnp.array([0, 1, 2, 1, 0])
|
313
|
-
labels = jnp.array([0, 1, 1, 1, 2])
|
314
|
-
metric.update(predictions=predictions, labels=labels)
|
315
|
-
result = metric.compute()
|
316
|
-
self.assertGreaterEqual(float(result), 0.0)
|
317
|
-
self.assertLessEqual(float(result), 1.0)
|
318
|
-
|
319
|
-
def test_recall_metric_reset(self):
|
320
|
-
"""Test reset functionality."""
|
321
|
-
metric = bst.nn.RecallMetric()
|
322
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
323
|
-
labels = jnp.array([1, 0, 0, 1, 0])
|
324
|
-
metric.update(predictions=predictions, labels=labels)
|
325
|
-
metric.reset()
|
326
|
-
result = metric.compute()
|
327
|
-
self.assertEqual(float(result), 0.0)
|
328
|
-
|
329
|
-
def test_recall_metric_module_attribute(self):
|
330
|
-
"""Test that RecallMetric has correct __module__ attribute."""
|
331
|
-
self.assertEqual(bst.nn.RecallMetric.__module__, "brainstate.nn")
|
332
|
-
|
333
|
-
|
334
|
-
class F1ScoreMetricTest(parameterized.TestCase):
|
335
|
-
"""Test cases for F1ScoreMetric class."""
|
336
|
-
|
337
|
-
def test_f1_score_binary_perfect(self):
|
338
|
-
"""Test binary F1 score with perfect predictions."""
|
339
|
-
metric = bst.nn.F1ScoreMetric()
|
340
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
341
|
-
labels = jnp.array([1, 0, 1, 1, 0])
|
342
|
-
metric.update(predictions=predictions, labels=labels)
|
343
|
-
result = metric.compute()
|
344
|
-
self.assertAlmostEqual(float(result), 1.0)
|
345
|
-
|
346
|
-
def test_f1_score_binary_partial(self):
|
347
|
-
"""Test binary F1 score with partial accuracy."""
|
348
|
-
metric = bst.nn.F1ScoreMetric()
|
349
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
350
|
-
labels = jnp.array([1, 0, 0, 1, 0])
|
351
|
-
metric.update(predictions=predictions, labels=labels)
|
352
|
-
result = metric.compute()
|
353
|
-
# Precision=2/3, Recall=1.0, F1=2*(2/3*1)/(2/3+1)=0.8
|
354
|
-
self.assertAlmostEqual(float(result), 0.8, places=5)
|
355
|
-
|
356
|
-
def test_f1_score_multiclass(self):
|
357
|
-
"""Test multi-class F1 score."""
|
358
|
-
metric = bst.nn.F1ScoreMetric(num_classes=3, average='macro')
|
359
|
-
predictions = jnp.array([0, 1, 2, 1, 0])
|
360
|
-
labels = jnp.array([0, 1, 1, 1, 2])
|
361
|
-
metric.update(predictions=predictions, labels=labels)
|
362
|
-
result = metric.compute()
|
363
|
-
self.assertGreaterEqual(float(result), 0.0)
|
364
|
-
self.assertLessEqual(float(result), 1.0)
|
365
|
-
|
366
|
-
def test_f1_score_reset(self):
|
367
|
-
"""Test reset functionality."""
|
368
|
-
metric = bst.nn.F1ScoreMetric()
|
369
|
-
predictions = jnp.array([1, 0, 1, 1, 0])
|
370
|
-
labels = jnp.array([1, 0, 0, 1, 0])
|
371
|
-
metric.update(predictions=predictions, labels=labels)
|
372
|
-
metric.reset()
|
373
|
-
result = metric.compute()
|
374
|
-
self.assertEqual(float(result), 0.0)
|
375
|
-
|
376
|
-
def test_f1_score_module_attribute(self):
|
377
|
-
"""Test that F1ScoreMetric has correct __module__ attribute."""
|
378
|
-
self.assertEqual(bst.nn.F1ScoreMetric.__module__, "brainstate.nn")
|
379
|
-
|
380
|
-
|
381
|
-
class ConfusionMatrixTest(parameterized.TestCase):
|
382
|
-
"""Test cases for ConfusionMatrix class."""
|
383
|
-
|
384
|
-
def test_confusion_matrix_basic(self):
|
385
|
-
"""Test basic confusion matrix computation."""
|
386
|
-
metric = bst.nn.ConfusionMatrix(num_classes=3)
|
387
|
-
predictions = jnp.array([0, 1, 2, 1, 0])
|
388
|
-
labels = jnp.array([0, 1, 1, 1, 2])
|
389
|
-
metric.update(predictions=predictions, labels=labels)
|
390
|
-
result = metric.compute()
|
391
|
-
|
392
|
-
# Check shape
|
393
|
-
self.assertEqual(result.shape, (3, 3))
|
394
|
-
|
395
|
-
# Check specific values
|
396
|
-
# True label 0, Predicted 0: 1
|
397
|
-
# True label 0, Predicted 2: 1
|
398
|
-
# True label 1, Predicted 1: 2
|
399
|
-
self.assertEqual(int(result[0, 0]), 1)
|
400
|
-
self.assertEqual(int(result[1, 1]), 2)
|
401
|
-
|
402
|
-
def test_confusion_matrix_perfect(self):
|
403
|
-
"""Test confusion matrix with perfect predictions."""
|
404
|
-
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
405
|
-
predictions = jnp.array([0, 0, 1, 1])
|
406
|
-
labels = jnp.array([0, 0, 1, 1])
|
407
|
-
metric.update(predictions=predictions, labels=labels)
|
408
|
-
result = metric.compute()
|
409
|
-
|
410
|
-
# Should be diagonal matrix
|
411
|
-
self.assertEqual(int(result[0, 0]), 2)
|
412
|
-
self.assertEqual(int(result[0, 1]), 0)
|
413
|
-
self.assertEqual(int(result[1, 0]), 0)
|
414
|
-
self.assertEqual(int(result[1, 1]), 2)
|
415
|
-
|
416
|
-
def test_confusion_matrix_reset(self):
|
417
|
-
"""Test reset functionality."""
|
418
|
-
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
419
|
-
predictions = jnp.array([0, 1])
|
420
|
-
labels = jnp.array([0, 1])
|
421
|
-
metric.update(predictions=predictions, labels=labels)
|
422
|
-
metric.reset()
|
423
|
-
result = metric.compute()
|
424
|
-
|
425
|
-
# Should be all zeros
|
426
|
-
self.assertTrue(jnp.all(result == 0))
|
427
|
-
|
428
|
-
def test_confusion_matrix_invalid_predictions(self):
|
429
|
-
"""Test error handling for invalid predictions."""
|
430
|
-
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
431
|
-
predictions = jnp.array([0, 1, 2]) # 2 is out of range
|
432
|
-
labels = jnp.array([0, 1, 1])
|
433
|
-
with self.assertRaises(ValueError):
|
434
|
-
metric.update(predictions=predictions, labels=labels)
|
435
|
-
|
436
|
-
def test_confusion_matrix_invalid_labels(self):
|
437
|
-
"""Test error handling for invalid labels."""
|
438
|
-
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
439
|
-
predictions = jnp.array([0, 1, 1])
|
440
|
-
labels = jnp.array([0, 1, 3]) # 3 is out of range
|
441
|
-
with self.assertRaises(ValueError):
|
442
|
-
metric.update(predictions=predictions, labels=labels)
|
443
|
-
|
444
|
-
def test_confusion_matrix_module_attribute(self):
|
445
|
-
"""Test that ConfusionMatrix has correct __module__ attribute."""
|
446
|
-
self.assertEqual(bst.nn.ConfusionMatrix.__module__, "brainstate.nn")
|
447
|
-
|
448
|
-
|
449
|
-
class MultiMetricTest(parameterized.TestCase):
|
450
|
-
"""Test cases for MultiMetric class."""
|
451
|
-
|
452
|
-
def test_multimetric_creation(self):
|
453
|
-
"""Test creating a MultiMetric with multiple metrics."""
|
454
|
-
metrics = bst.nn.MultiMetric(
|
455
|
-
accuracy=bst.nn.AccuracyMetric(),
|
456
|
-
loss=bst.nn.AverageMetric(),
|
457
|
-
)
|
458
|
-
self.assertIsNotNone(metrics.accuracy)
|
459
|
-
self.assertIsNotNone(metrics.loss)
|
460
|
-
|
461
|
-
def test_multimetric_compute(self):
|
462
|
-
"""Test computing all metrics."""
|
463
|
-
metrics = bst.nn.MultiMetric(
|
464
|
-
accuracy=bst.nn.AccuracyMetric(),
|
465
|
-
loss=bst.nn.AverageMetric(),
|
466
|
-
)
|
467
|
-
|
468
|
-
logits = jax.random.normal(jax.random.key(0), (5, 2))
|
469
|
-
labels = jnp.array([1, 1, 0, 1, 0])
|
470
|
-
batch_loss = jnp.array([1, 2, 3, 4])
|
471
|
-
|
472
|
-
result = metrics.compute()
|
473
|
-
self.assertIn('accuracy', result)
|
474
|
-
self.assertIn('loss', result)
|
475
|
-
self.assertTrue(jnp.isnan(result['accuracy']))
|
476
|
-
self.assertTrue(jnp.isnan(result['loss']))
|
477
|
-
|
478
|
-
def test_multimetric_update(self):
|
479
|
-
"""Test updating all metrics."""
|
480
|
-
metrics = bst.nn.MultiMetric(
|
481
|
-
accuracy=bst.nn.AccuracyMetric(),
|
482
|
-
loss=bst.nn.AverageMetric(),
|
483
|
-
)
|
484
|
-
|
485
|
-
logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
|
486
|
-
labels = jnp.array([1, 0, 1])
|
487
|
-
batch_loss = jnp.array([1, 2, 3])
|
488
|
-
|
489
|
-
metrics.update(logits=logits, labels=labels, values=batch_loss)
|
490
|
-
result = metrics.compute()
|
491
|
-
|
492
|
-
self.assertGreaterEqual(float(result['accuracy']), 0.0)
|
493
|
-
self.assertLessEqual(float(result['accuracy']), 1.0)
|
494
|
-
self.assertAlmostEqual(float(result['loss']), 2.0)
|
495
|
-
|
496
|
-
def test_multimetric_reset(self):
|
497
|
-
"""Test resetting all metrics."""
|
498
|
-
metrics = bst.nn.MultiMetric(
|
499
|
-
accuracy=bst.nn.AccuracyMetric(),
|
500
|
-
loss=bst.nn.AverageMetric(),
|
501
|
-
)
|
502
|
-
|
503
|
-
logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
|
504
|
-
labels = jnp.array([1, 0])
|
505
|
-
batch_loss = jnp.array([1, 2])
|
506
|
-
|
507
|
-
metrics.update(logits=logits, labels=labels, values=batch_loss)
|
508
|
-
metrics.reset()
|
509
|
-
result = metrics.compute()
|
510
|
-
|
511
|
-
self.assertTrue(jnp.isnan(result['accuracy']))
|
512
|
-
self.assertTrue(jnp.isnan(result['loss']))
|
513
|
-
|
514
|
-
def test_multimetric_reserved_name_validation(self):
|
515
|
-
"""Test that reserved method names cannot be used."""
|
516
|
-
with self.assertRaises(ValueError):
|
517
|
-
bst.nn.MultiMetric(
|
518
|
-
reset=bst.nn.AverageMetric(),
|
519
|
-
)
|
520
|
-
|
521
|
-
with self.assertRaises(ValueError):
|
522
|
-
bst.nn.MultiMetric(
|
523
|
-
update=bst.nn.AverageMetric(),
|
524
|
-
)
|
525
|
-
|
526
|
-
with self.assertRaises(ValueError):
|
527
|
-
bst.nn.MultiMetric(
|
528
|
-
compute=bst.nn.AverageMetric(),
|
529
|
-
)
|
530
|
-
|
531
|
-
def test_multimetric_type_validation(self):
|
532
|
-
"""Test that all metrics must be Metric instances."""
|
533
|
-
with self.assertRaises(TypeError):
|
534
|
-
bst.nn.MultiMetric(
|
535
|
-
accuracy=bst.nn.AccuracyMetric(),
|
536
|
-
invalid="not a metric",
|
537
|
-
)
|
538
|
-
|
539
|
-
def test_multimetric_module_attribute(self):
|
540
|
-
"""Test that MultiMetric has correct __module__ attribute."""
|
541
|
-
self.assertEqual(bst.nn.MultiMetric.__module__, "brainstate.nn")
|
542
|
-
|
543
|
-
|
544
|
-
class MetricBaseClassTest(absltest.TestCase):
|
545
|
-
"""Test cases for base Metric class."""
|
546
|
-
|
547
|
-
def test_metric_not_implemented_errors(self):
|
548
|
-
"""Test that base Metric class raises NotImplementedError."""
|
549
|
-
metric = bst.nn.Metric()
|
550
|
-
|
551
|
-
with self.assertRaises(NotImplementedError):
|
552
|
-
metric.reset()
|
553
|
-
|
554
|
-
with self.assertRaises(NotImplementedError):
|
555
|
-
metric.update()
|
556
|
-
|
557
|
-
with self.assertRaises(NotImplementedError):
|
558
|
-
metric.compute()
|
559
|
-
|
560
|
-
def test_metric_module_attribute(self):
|
561
|
-
"""Test that Metric has correct __module__ attribute."""
|
562
|
-
self.assertEqual(bst.nn.Metric.__module__, "brainstate.nn")
|
563
|
-
|
564
|
-
|
565
|
-
class EdgeCasesTest(parameterized.TestCase):
|
566
|
-
"""Test edge cases and boundary conditions."""
|
567
|
-
|
568
|
-
def test_average_metric_large_numbers(self):
|
569
|
-
"""Test AverageMetric with very large numbers."""
|
570
|
-
metric = bst.nn.AverageMetric()
|
571
|
-
metric.update(values=jnp.array([1e10, 2e10, 3e10]))
|
572
|
-
result = metric.compute()
|
573
|
-
self.assertAlmostEqual(float(result), 2e10, places=-5)
|
574
|
-
|
575
|
-
def test_average_metric_small_numbers(self):
|
576
|
-
"""Test AverageMetric with very small numbers."""
|
577
|
-
metric = bst.nn.AverageMetric()
|
578
|
-
metric.update(values=jnp.array([1e-10, 2e-10, 3e-10]))
|
579
|
-
result = metric.compute()
|
580
|
-
self.assertAlmostEqual(float(result), 2e-10, places=15)
|
581
|
-
|
582
|
-
def test_confusion_matrix_single_class(self):
|
583
|
-
"""Test ConfusionMatrix with single class."""
|
584
|
-
metric = bst.nn.ConfusionMatrix(num_classes=1)
|
585
|
-
predictions = jnp.array([0, 0, 0])
|
586
|
-
labels = jnp.array([0, 0, 0])
|
587
|
-
metric.update(predictions=predictions, labels=labels)
|
588
|
-
result = metric.compute()
|
589
|
-
self.assertEqual(int(result[0, 0]), 3)
|
590
|
-
|
591
|
-
def test_precision_recall_no_positives(self):
|
592
|
-
"""Test precision/recall when there are no positive predictions."""
|
593
|
-
precision_metric = bst.nn.PrecisionMetric()
|
594
|
-
recall_metric = bst.nn.RecallMetric()
|
595
|
-
|
596
|
-
predictions = jnp.array([0, 0, 0, 0, 0])
|
597
|
-
labels = jnp.array([0, 0, 0, 0, 0])
|
598
|
-
|
599
|
-
precision_metric.update(predictions=predictions, labels=labels)
|
600
|
-
recall_metric.update(predictions=predictions, labels=labels)
|
601
|
-
|
602
|
-
# Should handle gracefully without division by zero
|
603
|
-
precision = precision_metric.compute()
|
604
|
-
recall = recall_metric.compute()
|
605
|
-
|
606
|
-
self.assertEqual(float(precision), 0.0)
|
607
|
-
self.assertEqual(float(recall), 0.0)
|
608
|
-
|
609
|
-
|
610
|
-
if __name__ == '__main__':
|
611
|
-
absltest.main()
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Comprehensive tests for metrics module."""
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
from absl.testing import absltest, parameterized
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class MetricStateTest(absltest.TestCase):
|
26
|
+
"""Test cases for MetricState class."""
|
27
|
+
|
28
|
+
def test_metric_state_creation(self):
|
29
|
+
"""Test creating a MetricState with different values."""
|
30
|
+
state = bst.nn.MetricState(jnp.array(0.0))
|
31
|
+
self.assertEqual(state.value, 0.0)
|
32
|
+
|
33
|
+
def test_metric_state_update(self):
|
34
|
+
"""Test updating MetricState value."""
|
35
|
+
state = bst.nn.MetricState(jnp.array(0.0))
|
36
|
+
state.value = jnp.array(5.5)
|
37
|
+
self.assertAlmostEqual(float(state.value), 5.5)
|
38
|
+
|
39
|
+
def test_metric_state_module_attribute(self):
|
40
|
+
"""Test that MetricState has correct __module__ attribute."""
|
41
|
+
state = bst.nn.MetricState(jnp.array(0.0))
|
42
|
+
self.assertEqual(bst.nn.MetricState.__module__, "brainstate.nn")
|
43
|
+
|
44
|
+
|
45
|
+
class AverageMetricTest(parameterized.TestCase):
|
46
|
+
"""Test cases for AverageMetric class."""
|
47
|
+
|
48
|
+
def test_average_metric_initial_state(self):
|
49
|
+
"""Test that initial average is NaN."""
|
50
|
+
metric = bst.nn.AverageMetric()
|
51
|
+
result = metric.compute()
|
52
|
+
self.assertTrue(jnp.isnan(result))
|
53
|
+
|
54
|
+
def test_average_metric_single_update(self):
|
55
|
+
"""Test average after single update."""
|
56
|
+
metric = bst.nn.AverageMetric()
|
57
|
+
metric.update(values=jnp.array([1, 2, 3, 4]))
|
58
|
+
result = metric.compute()
|
59
|
+
self.assertAlmostEqual(float(result), 2.5)
|
60
|
+
|
61
|
+
def test_average_metric_multiple_updates(self):
|
62
|
+
"""Test average after multiple updates."""
|
63
|
+
metric = bst.nn.AverageMetric()
|
64
|
+
metric.update(values=jnp.array([1, 2, 3, 4]))
|
65
|
+
metric.update(values=jnp.array([3, 2, 1, 0]))
|
66
|
+
result = metric.compute()
|
67
|
+
self.assertAlmostEqual(float(result), 2.0)
|
68
|
+
|
69
|
+
def test_average_metric_scalar_values(self):
|
70
|
+
"""Test average with scalar values."""
|
71
|
+
metric = bst.nn.AverageMetric()
|
72
|
+
metric.update(values=5.0)
|
73
|
+
metric.update(values=3.0)
|
74
|
+
result = metric.compute()
|
75
|
+
self.assertAlmostEqual(float(result), 4.0)
|
76
|
+
|
77
|
+
def test_average_metric_reset(self):
|
78
|
+
"""Test reset functionality."""
|
79
|
+
metric = bst.nn.AverageMetric()
|
80
|
+
metric.update(values=jnp.array([1, 2, 3, 4]))
|
81
|
+
metric.reset()
|
82
|
+
result = metric.compute()
|
83
|
+
self.assertTrue(jnp.isnan(result))
|
84
|
+
|
85
|
+
def test_average_metric_custom_argname(self):
|
86
|
+
"""Test using custom argument name."""
|
87
|
+
metric = bst.nn.AverageMetric('loss')
|
88
|
+
metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
|
89
|
+
result = metric.compute()
|
90
|
+
self.assertAlmostEqual(float(result), 2.0)
|
91
|
+
|
92
|
+
def test_average_metric_missing_argname(self):
|
93
|
+
"""Test error when expected argument is missing."""
|
94
|
+
metric = bst.nn.AverageMetric('loss')
|
95
|
+
with self.assertRaises(TypeError):
|
96
|
+
metric.update(values=jnp.array([1.0, 2.0]))
|
97
|
+
|
98
|
+
def test_average_metric_module_attribute(self):
|
99
|
+
"""Test that AverageMetric has correct __module__ attribute."""
|
100
|
+
self.assertEqual(bst.nn.AverageMetric.__module__, "brainstate.nn")
|
101
|
+
|
102
|
+
|
103
|
+
class WelfordMetricTest(parameterized.TestCase):
|
104
|
+
"""Test cases for WelfordMetric class."""
|
105
|
+
|
106
|
+
def test_welford_metric_initial_state(self):
|
107
|
+
"""Test initial statistics."""
|
108
|
+
metric = bst.nn.WelfordMetric()
|
109
|
+
stats = metric.compute()
|
110
|
+
self.assertEqual(float(stats.mean), 0.0)
|
111
|
+
self.assertTrue(jnp.isnan(stats.standard_error_of_mean))
|
112
|
+
self.assertTrue(jnp.isnan(stats.standard_deviation))
|
113
|
+
|
114
|
+
def test_welford_metric_single_update(self):
|
115
|
+
"""Test statistics after single update."""
|
116
|
+
metric = bst.nn.WelfordMetric()
|
117
|
+
metric.update(values=jnp.array([1, 2, 3, 4]))
|
118
|
+
stats = metric.compute()
|
119
|
+
self.assertAlmostEqual(float(stats.mean), 2.5, places=5)
|
120
|
+
# Population std of [1,2,3,4] is sqrt(1.25) ≈ 1.118
|
121
|
+
self.assertAlmostEqual(float(stats.standard_deviation), 1.118, places=3)
|
122
|
+
|
123
|
+
def test_welford_metric_multiple_updates(self):
|
124
|
+
"""Test statistics after multiple updates."""
|
125
|
+
metric = bst.nn.WelfordMetric()
|
126
|
+
metric.update(values=jnp.array([1, 2, 3, 4]))
|
127
|
+
metric.update(values=jnp.array([3, 2, 1, 0]))
|
128
|
+
stats = metric.compute()
|
129
|
+
# Mean of all 8 values: [1,2,3,4,3,2,1,0] is 2.0
|
130
|
+
self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
|
131
|
+
|
132
|
+
def test_welford_metric_reset(self):
|
133
|
+
"""Test reset functionality."""
|
134
|
+
metric = bst.nn.WelfordMetric()
|
135
|
+
metric.update(values=jnp.array([1, 2, 3, 4]))
|
136
|
+
metric.reset()
|
137
|
+
stats = metric.compute()
|
138
|
+
self.assertEqual(float(stats.mean), 0.0)
|
139
|
+
self.assertTrue(jnp.isnan(stats.standard_deviation))
|
140
|
+
|
141
|
+
def test_welford_metric_custom_argname(self):
|
142
|
+
"""Test using custom argument name."""
|
143
|
+
metric = bst.nn.WelfordMetric('data')
|
144
|
+
metric.update(data=jnp.array([1.0, 2.0, 3.0]))
|
145
|
+
stats = metric.compute()
|
146
|
+
self.assertAlmostEqual(float(stats.mean), 2.0, places=5)
|
147
|
+
|
148
|
+
def test_welford_metric_module_attribute(self):
|
149
|
+
"""Test that WelfordMetric has correct __module__ attribute."""
|
150
|
+
self.assertEqual(bst.nn.WelfordMetric.__module__, "brainstate.nn")
|
151
|
+
|
152
|
+
|
153
|
+
class AccuracyMetricTest(parameterized.TestCase):
|
154
|
+
"""Test cases for AccuracyMetric class."""
|
155
|
+
|
156
|
+
def test_accuracy_metric_initial_state(self):
|
157
|
+
"""Test that initial accuracy is NaN."""
|
158
|
+
metric = bst.nn.AccuracyMetric()
|
159
|
+
result = metric.compute()
|
160
|
+
self.assertTrue(jnp.isnan(result))
|
161
|
+
|
162
|
+
def test_accuracy_metric_perfect_accuracy(self):
|
163
|
+
"""Test with perfect predictions."""
|
164
|
+
metric = bst.nn.AccuracyMetric()
|
165
|
+
logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
|
166
|
+
labels = jnp.array([1, 0, 1])
|
167
|
+
metric.update(logits=logits, labels=labels)
|
168
|
+
result = metric.compute()
|
169
|
+
self.assertAlmostEqual(float(result), 1.0)
|
170
|
+
|
171
|
+
def test_accuracy_metric_partial_accuracy(self):
|
172
|
+
"""Test with partial accuracy."""
|
173
|
+
metric = bst.nn.AccuracyMetric()
|
174
|
+
logits = jnp.array([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
|
175
|
+
labels = jnp.array([1, 0, 1]) # First one wrong
|
176
|
+
metric.update(logits=logits, labels=labels)
|
177
|
+
result = metric.compute()
|
178
|
+
self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
|
179
|
+
|
180
|
+
def test_accuracy_metric_multiple_updates(self):
|
181
|
+
"""Test accuracy after multiple updates."""
|
182
|
+
metric = bst.nn.AccuracyMetric()
|
183
|
+
logits1 = jax.random.normal(jax.random.key(0), (5, 2))
|
184
|
+
labels1 = jnp.array([1, 1, 0, 1, 0])
|
185
|
+
logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
186
|
+
labels2 = jnp.array([0, 1, 1, 1, 1])
|
187
|
+
|
188
|
+
metric.update(logits=logits1, labels=labels1)
|
189
|
+
metric.update(logits=logits2, labels=labels2)
|
190
|
+
result = metric.compute()
|
191
|
+
# Result should be between 0 and 1
|
192
|
+
self.assertGreaterEqual(float(result), 0.0)
|
193
|
+
self.assertLessEqual(float(result), 1.0)
|
194
|
+
|
195
|
+
def test_accuracy_metric_reset(self):
|
196
|
+
"""Test reset functionality."""
|
197
|
+
metric = bst.nn.AccuracyMetric()
|
198
|
+
logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
|
199
|
+
labels = jnp.array([1, 0])
|
200
|
+
metric.update(logits=logits, labels=labels)
|
201
|
+
metric.reset()
|
202
|
+
result = metric.compute()
|
203
|
+
self.assertTrue(jnp.isnan(result))
|
204
|
+
|
205
|
+
def test_accuracy_metric_shape_validation(self):
|
206
|
+
"""Test shape validation."""
|
207
|
+
metric = bst.nn.AccuracyMetric()
|
208
|
+
logits = jnp.array([[0.1, 0.9]])
|
209
|
+
labels = jnp.array([[1, 0]]) # Wrong shape
|
210
|
+
with self.assertRaises(ValueError):
|
211
|
+
metric.update(logits=logits, labels=labels)
|
212
|
+
|
213
|
+
def test_accuracy_metric_module_attribute(self):
|
214
|
+
"""Test that AccuracyMetric has correct __module__ attribute."""
|
215
|
+
self.assertEqual(bst.nn.AccuracyMetric.__module__, "brainstate.nn")
|
216
|
+
|
217
|
+
|
218
|
+
class PrecisionMetricTest(parameterized.TestCase):
|
219
|
+
"""Test cases for PrecisionMetric class."""
|
220
|
+
|
221
|
+
def test_precision_metric_binary_perfect(self):
|
222
|
+
"""Test binary precision with perfect predictions."""
|
223
|
+
metric = bst.nn.PrecisionMetric()
|
224
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
225
|
+
labels = jnp.array([1, 0, 1, 1, 0])
|
226
|
+
metric.update(predictions=predictions, labels=labels)
|
227
|
+
result = metric.compute()
|
228
|
+
self.assertAlmostEqual(float(result), 1.0)
|
229
|
+
|
230
|
+
def test_precision_metric_binary_partial(self):
|
231
|
+
"""Test binary precision with partial accuracy."""
|
232
|
+
metric = bst.nn.PrecisionMetric()
|
233
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
234
|
+
labels = jnp.array([1, 0, 0, 1, 0])
|
235
|
+
metric.update(predictions=predictions, labels=labels)
|
236
|
+
result = metric.compute()
|
237
|
+
# TP=2, FP=1, Precision=2/3
|
238
|
+
self.assertAlmostEqual(float(result), 2.0 / 3.0, places=5)
|
239
|
+
|
240
|
+
def test_precision_metric_multiclass_macro(self):
|
241
|
+
"""Test multi-class precision with macro averaging."""
|
242
|
+
metric = bst.nn.PrecisionMetric(num_classes=3, average='macro')
|
243
|
+
predictions = jnp.array([0, 1, 2, 1, 0])
|
244
|
+
labels = jnp.array([0, 1, 1, 1, 2])
|
245
|
+
metric.update(predictions=predictions, labels=labels)
|
246
|
+
result = metric.compute()
|
247
|
+
# Should compute precision for each class and average
|
248
|
+
self.assertGreaterEqual(float(result), 0.0)
|
249
|
+
self.assertLessEqual(float(result), 1.0)
|
250
|
+
|
251
|
+
def test_precision_metric_multiclass_micro(self):
|
252
|
+
"""Test multi-class precision with micro averaging."""
|
253
|
+
metric = bst.nn.PrecisionMetric(num_classes=3, average='micro')
|
254
|
+
predictions = jnp.array([0, 1, 2, 1, 0])
|
255
|
+
labels = jnp.array([0, 1, 1, 1, 2])
|
256
|
+
metric.update(predictions=predictions, labels=labels)
|
257
|
+
result = metric.compute()
|
258
|
+
self.assertGreaterEqual(float(result), 0.0)
|
259
|
+
self.assertLessEqual(float(result), 1.0)
|
260
|
+
|
261
|
+
def test_precision_metric_reset(self):
|
262
|
+
"""Test reset functionality."""
|
263
|
+
metric = bst.nn.PrecisionMetric()
|
264
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
265
|
+
labels = jnp.array([1, 0, 0, 1, 0])
|
266
|
+
metric.update(predictions=predictions, labels=labels)
|
267
|
+
metric.reset()
|
268
|
+
result = metric.compute()
|
269
|
+
# After reset, with no data, precision should be 0
|
270
|
+
self.assertEqual(float(result), 0.0)
|
271
|
+
|
272
|
+
def test_precision_metric_module_attribute(self):
|
273
|
+
"""Test that PrecisionMetric has correct __module__ attribute."""
|
274
|
+
self.assertEqual(bst.nn.PrecisionMetric.__module__, "brainstate.nn")
|
275
|
+
|
276
|
+
|
277
|
+
class RecallMetricTest(parameterized.TestCase):
|
278
|
+
"""Test cases for RecallMetric class."""
|
279
|
+
|
280
|
+
def test_recall_metric_binary_perfect(self):
|
281
|
+
"""Test binary recall with perfect predictions."""
|
282
|
+
metric = bst.nn.RecallMetric()
|
283
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
284
|
+
labels = jnp.array([1, 0, 1, 1, 0])
|
285
|
+
metric.update(predictions=predictions, labels=labels)
|
286
|
+
result = metric.compute()
|
287
|
+
self.assertAlmostEqual(float(result), 1.0)
|
288
|
+
|
289
|
+
def test_recall_metric_binary_partial(self):
|
290
|
+
"""Test binary recall with partial accuracy."""
|
291
|
+
metric = bst.nn.RecallMetric()
|
292
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
293
|
+
labels = jnp.array([1, 0, 0, 1, 0])
|
294
|
+
metric.update(predictions=predictions, labels=labels)
|
295
|
+
result = metric.compute()
|
296
|
+
# TP=2, FN=0, Recall=2/2=1.0
|
297
|
+
self.assertAlmostEqual(float(result), 1.0, places=5)
|
298
|
+
|
299
|
+
def test_recall_metric_multiclass_macro(self):
|
300
|
+
"""Test multi-class recall with macro averaging."""
|
301
|
+
metric = bst.nn.RecallMetric(num_classes=3, average='macro')
|
302
|
+
predictions = jnp.array([0, 1, 2, 1, 0])
|
303
|
+
labels = jnp.array([0, 1, 1, 1, 2])
|
304
|
+
metric.update(predictions=predictions, labels=labels)
|
305
|
+
result = metric.compute()
|
306
|
+
self.assertGreaterEqual(float(result), 0.0)
|
307
|
+
self.assertLessEqual(float(result), 1.0)
|
308
|
+
|
309
|
+
def test_recall_metric_multiclass_micro(self):
|
310
|
+
"""Test multi-class recall with micro averaging."""
|
311
|
+
metric = bst.nn.RecallMetric(num_classes=3, average='micro')
|
312
|
+
predictions = jnp.array([0, 1, 2, 1, 0])
|
313
|
+
labels = jnp.array([0, 1, 1, 1, 2])
|
314
|
+
metric.update(predictions=predictions, labels=labels)
|
315
|
+
result = metric.compute()
|
316
|
+
self.assertGreaterEqual(float(result), 0.0)
|
317
|
+
self.assertLessEqual(float(result), 1.0)
|
318
|
+
|
319
|
+
def test_recall_metric_reset(self):
|
320
|
+
"""Test reset functionality."""
|
321
|
+
metric = bst.nn.RecallMetric()
|
322
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
323
|
+
labels = jnp.array([1, 0, 0, 1, 0])
|
324
|
+
metric.update(predictions=predictions, labels=labels)
|
325
|
+
metric.reset()
|
326
|
+
result = metric.compute()
|
327
|
+
self.assertEqual(float(result), 0.0)
|
328
|
+
|
329
|
+
def test_recall_metric_module_attribute(self):
|
330
|
+
"""Test that RecallMetric has correct __module__ attribute."""
|
331
|
+
self.assertEqual(bst.nn.RecallMetric.__module__, "brainstate.nn")
|
332
|
+
|
333
|
+
|
334
|
+
class F1ScoreMetricTest(parameterized.TestCase):
|
335
|
+
"""Test cases for F1ScoreMetric class."""
|
336
|
+
|
337
|
+
def test_f1_score_binary_perfect(self):
|
338
|
+
"""Test binary F1 score with perfect predictions."""
|
339
|
+
metric = bst.nn.F1ScoreMetric()
|
340
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
341
|
+
labels = jnp.array([1, 0, 1, 1, 0])
|
342
|
+
metric.update(predictions=predictions, labels=labels)
|
343
|
+
result = metric.compute()
|
344
|
+
self.assertAlmostEqual(float(result), 1.0)
|
345
|
+
|
346
|
+
def test_f1_score_binary_partial(self):
|
347
|
+
"""Test binary F1 score with partial accuracy."""
|
348
|
+
metric = bst.nn.F1ScoreMetric()
|
349
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
350
|
+
labels = jnp.array([1, 0, 0, 1, 0])
|
351
|
+
metric.update(predictions=predictions, labels=labels)
|
352
|
+
result = metric.compute()
|
353
|
+
# Precision=2/3, Recall=1.0, F1=2*(2/3*1)/(2/3+1)=0.8
|
354
|
+
self.assertAlmostEqual(float(result), 0.8, places=5)
|
355
|
+
|
356
|
+
def test_f1_score_multiclass(self):
|
357
|
+
"""Test multi-class F1 score."""
|
358
|
+
metric = bst.nn.F1ScoreMetric(num_classes=3, average='macro')
|
359
|
+
predictions = jnp.array([0, 1, 2, 1, 0])
|
360
|
+
labels = jnp.array([0, 1, 1, 1, 2])
|
361
|
+
metric.update(predictions=predictions, labels=labels)
|
362
|
+
result = metric.compute()
|
363
|
+
self.assertGreaterEqual(float(result), 0.0)
|
364
|
+
self.assertLessEqual(float(result), 1.0)
|
365
|
+
|
366
|
+
def test_f1_score_reset(self):
|
367
|
+
"""Test reset functionality."""
|
368
|
+
metric = bst.nn.F1ScoreMetric()
|
369
|
+
predictions = jnp.array([1, 0, 1, 1, 0])
|
370
|
+
labels = jnp.array([1, 0, 0, 1, 0])
|
371
|
+
metric.update(predictions=predictions, labels=labels)
|
372
|
+
metric.reset()
|
373
|
+
result = metric.compute()
|
374
|
+
self.assertEqual(float(result), 0.0)
|
375
|
+
|
376
|
+
def test_f1_score_module_attribute(self):
|
377
|
+
"""Test that F1ScoreMetric has correct __module__ attribute."""
|
378
|
+
self.assertEqual(bst.nn.F1ScoreMetric.__module__, "brainstate.nn")
|
379
|
+
|
380
|
+
|
381
|
+
class ConfusionMatrixTest(parameterized.TestCase):
|
382
|
+
"""Test cases for ConfusionMatrix class."""
|
383
|
+
|
384
|
+
def test_confusion_matrix_basic(self):
|
385
|
+
"""Test basic confusion matrix computation."""
|
386
|
+
metric = bst.nn.ConfusionMatrix(num_classes=3)
|
387
|
+
predictions = jnp.array([0, 1, 2, 1, 0])
|
388
|
+
labels = jnp.array([0, 1, 1, 1, 2])
|
389
|
+
metric.update(predictions=predictions, labels=labels)
|
390
|
+
result = metric.compute()
|
391
|
+
|
392
|
+
# Check shape
|
393
|
+
self.assertEqual(result.shape, (3, 3))
|
394
|
+
|
395
|
+
# Check specific values
|
396
|
+
# True label 0, Predicted 0: 1
|
397
|
+
# True label 0, Predicted 2: 1
|
398
|
+
# True label 1, Predicted 1: 2
|
399
|
+
self.assertEqual(int(result[0, 0]), 1)
|
400
|
+
self.assertEqual(int(result[1, 1]), 2)
|
401
|
+
|
402
|
+
def test_confusion_matrix_perfect(self):
|
403
|
+
"""Test confusion matrix with perfect predictions."""
|
404
|
+
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
405
|
+
predictions = jnp.array([0, 0, 1, 1])
|
406
|
+
labels = jnp.array([0, 0, 1, 1])
|
407
|
+
metric.update(predictions=predictions, labels=labels)
|
408
|
+
result = metric.compute()
|
409
|
+
|
410
|
+
# Should be diagonal matrix
|
411
|
+
self.assertEqual(int(result[0, 0]), 2)
|
412
|
+
self.assertEqual(int(result[0, 1]), 0)
|
413
|
+
self.assertEqual(int(result[1, 0]), 0)
|
414
|
+
self.assertEqual(int(result[1, 1]), 2)
|
415
|
+
|
416
|
+
def test_confusion_matrix_reset(self):
|
417
|
+
"""Test reset functionality."""
|
418
|
+
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
419
|
+
predictions = jnp.array([0, 1])
|
420
|
+
labels = jnp.array([0, 1])
|
421
|
+
metric.update(predictions=predictions, labels=labels)
|
422
|
+
metric.reset()
|
423
|
+
result = metric.compute()
|
424
|
+
|
425
|
+
# Should be all zeros
|
426
|
+
self.assertTrue(jnp.all(result == 0))
|
427
|
+
|
428
|
+
def test_confusion_matrix_invalid_predictions(self):
|
429
|
+
"""Test error handling for invalid predictions."""
|
430
|
+
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
431
|
+
predictions = jnp.array([0, 1, 2]) # 2 is out of range
|
432
|
+
labels = jnp.array([0, 1, 1])
|
433
|
+
with self.assertRaises(ValueError):
|
434
|
+
metric.update(predictions=predictions, labels=labels)
|
435
|
+
|
436
|
+
def test_confusion_matrix_invalid_labels(self):
|
437
|
+
"""Test error handling for invalid labels."""
|
438
|
+
metric = bst.nn.ConfusionMatrix(num_classes=2)
|
439
|
+
predictions = jnp.array([0, 1, 1])
|
440
|
+
labels = jnp.array([0, 1, 3]) # 3 is out of range
|
441
|
+
with self.assertRaises(ValueError):
|
442
|
+
metric.update(predictions=predictions, labels=labels)
|
443
|
+
|
444
|
+
def test_confusion_matrix_module_attribute(self):
|
445
|
+
"""Test that ConfusionMatrix has correct __module__ attribute."""
|
446
|
+
self.assertEqual(bst.nn.ConfusionMatrix.__module__, "brainstate.nn")
|
447
|
+
|
448
|
+
|
449
|
+
class MultiMetricTest(parameterized.TestCase):
|
450
|
+
"""Test cases for MultiMetric class."""
|
451
|
+
|
452
|
+
def test_multimetric_creation(self):
|
453
|
+
"""Test creating a MultiMetric with multiple metrics."""
|
454
|
+
metrics = bst.nn.MultiMetric(
|
455
|
+
accuracy=bst.nn.AccuracyMetric(),
|
456
|
+
loss=bst.nn.AverageMetric(),
|
457
|
+
)
|
458
|
+
self.assertIsNotNone(metrics.accuracy)
|
459
|
+
self.assertIsNotNone(metrics.loss)
|
460
|
+
|
461
|
+
def test_multimetric_compute(self):
|
462
|
+
"""Test computing all metrics."""
|
463
|
+
metrics = bst.nn.MultiMetric(
|
464
|
+
accuracy=bst.nn.AccuracyMetric(),
|
465
|
+
loss=bst.nn.AverageMetric(),
|
466
|
+
)
|
467
|
+
|
468
|
+
logits = jax.random.normal(jax.random.key(0), (5, 2))
|
469
|
+
labels = jnp.array([1, 1, 0, 1, 0])
|
470
|
+
batch_loss = jnp.array([1, 2, 3, 4])
|
471
|
+
|
472
|
+
result = metrics.compute()
|
473
|
+
self.assertIn('accuracy', result)
|
474
|
+
self.assertIn('loss', result)
|
475
|
+
self.assertTrue(jnp.isnan(result['accuracy']))
|
476
|
+
self.assertTrue(jnp.isnan(result['loss']))
|
477
|
+
|
478
|
+
def test_multimetric_update(self):
|
479
|
+
"""Test updating all metrics."""
|
480
|
+
metrics = bst.nn.MultiMetric(
|
481
|
+
accuracy=bst.nn.AccuracyMetric(),
|
482
|
+
loss=bst.nn.AverageMetric(),
|
483
|
+
)
|
484
|
+
|
485
|
+
logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
|
486
|
+
labels = jnp.array([1, 0, 1])
|
487
|
+
batch_loss = jnp.array([1, 2, 3])
|
488
|
+
|
489
|
+
metrics.update(logits=logits, labels=labels, values=batch_loss)
|
490
|
+
result = metrics.compute()
|
491
|
+
|
492
|
+
self.assertGreaterEqual(float(result['accuracy']), 0.0)
|
493
|
+
self.assertLessEqual(float(result['accuracy']), 1.0)
|
494
|
+
self.assertAlmostEqual(float(result['loss']), 2.0)
|
495
|
+
|
496
|
+
def test_multimetric_reset(self):
|
497
|
+
"""Test resetting all metrics."""
|
498
|
+
metrics = bst.nn.MultiMetric(
|
499
|
+
accuracy=bst.nn.AccuracyMetric(),
|
500
|
+
loss=bst.nn.AverageMetric(),
|
501
|
+
)
|
502
|
+
|
503
|
+
logits = jnp.array([[0.1, 0.9], [0.8, 0.2]])
|
504
|
+
labels = jnp.array([1, 0])
|
505
|
+
batch_loss = jnp.array([1, 2])
|
506
|
+
|
507
|
+
metrics.update(logits=logits, labels=labels, values=batch_loss)
|
508
|
+
metrics.reset()
|
509
|
+
result = metrics.compute()
|
510
|
+
|
511
|
+
self.assertTrue(jnp.isnan(result['accuracy']))
|
512
|
+
self.assertTrue(jnp.isnan(result['loss']))
|
513
|
+
|
514
|
+
def test_multimetric_reserved_name_validation(self):
|
515
|
+
"""Test that reserved method names cannot be used."""
|
516
|
+
with self.assertRaises(ValueError):
|
517
|
+
bst.nn.MultiMetric(
|
518
|
+
reset=bst.nn.AverageMetric(),
|
519
|
+
)
|
520
|
+
|
521
|
+
with self.assertRaises(ValueError):
|
522
|
+
bst.nn.MultiMetric(
|
523
|
+
update=bst.nn.AverageMetric(),
|
524
|
+
)
|
525
|
+
|
526
|
+
with self.assertRaises(ValueError):
|
527
|
+
bst.nn.MultiMetric(
|
528
|
+
compute=bst.nn.AverageMetric(),
|
529
|
+
)
|
530
|
+
|
531
|
+
def test_multimetric_type_validation(self):
|
532
|
+
"""Test that all metrics must be Metric instances."""
|
533
|
+
with self.assertRaises(TypeError):
|
534
|
+
bst.nn.MultiMetric(
|
535
|
+
accuracy=bst.nn.AccuracyMetric(),
|
536
|
+
invalid="not a metric",
|
537
|
+
)
|
538
|
+
|
539
|
+
def test_multimetric_module_attribute(self):
|
540
|
+
"""Test that MultiMetric has correct __module__ attribute."""
|
541
|
+
self.assertEqual(bst.nn.MultiMetric.__module__, "brainstate.nn")
|
542
|
+
|
543
|
+
|
544
|
+
class MetricBaseClassTest(absltest.TestCase):
|
545
|
+
"""Test cases for base Metric class."""
|
546
|
+
|
547
|
+
def test_metric_not_implemented_errors(self):
|
548
|
+
"""Test that base Metric class raises NotImplementedError."""
|
549
|
+
metric = bst.nn.Metric()
|
550
|
+
|
551
|
+
with self.assertRaises(NotImplementedError):
|
552
|
+
metric.reset()
|
553
|
+
|
554
|
+
with self.assertRaises(NotImplementedError):
|
555
|
+
metric.update()
|
556
|
+
|
557
|
+
with self.assertRaises(NotImplementedError):
|
558
|
+
metric.compute()
|
559
|
+
|
560
|
+
def test_metric_module_attribute(self):
|
561
|
+
"""Test that Metric has correct __module__ attribute."""
|
562
|
+
self.assertEqual(bst.nn.Metric.__module__, "brainstate.nn")
|
563
|
+
|
564
|
+
|
565
|
+
class EdgeCasesTest(parameterized.TestCase):
|
566
|
+
"""Test edge cases and boundary conditions."""
|
567
|
+
|
568
|
+
def test_average_metric_large_numbers(self):
|
569
|
+
"""Test AverageMetric with very large numbers."""
|
570
|
+
metric = bst.nn.AverageMetric()
|
571
|
+
metric.update(values=jnp.array([1e10, 2e10, 3e10]))
|
572
|
+
result = metric.compute()
|
573
|
+
self.assertAlmostEqual(float(result), 2e10, places=-5)
|
574
|
+
|
575
|
+
def test_average_metric_small_numbers(self):
|
576
|
+
"""Test AverageMetric with very small numbers."""
|
577
|
+
metric = bst.nn.AverageMetric()
|
578
|
+
metric.update(values=jnp.array([1e-10, 2e-10, 3e-10]))
|
579
|
+
result = metric.compute()
|
580
|
+
self.assertAlmostEqual(float(result), 2e-10, places=15)
|
581
|
+
|
582
|
+
def test_confusion_matrix_single_class(self):
|
583
|
+
"""Test ConfusionMatrix with single class."""
|
584
|
+
metric = bst.nn.ConfusionMatrix(num_classes=1)
|
585
|
+
predictions = jnp.array([0, 0, 0])
|
586
|
+
labels = jnp.array([0, 0, 0])
|
587
|
+
metric.update(predictions=predictions, labels=labels)
|
588
|
+
result = metric.compute()
|
589
|
+
self.assertEqual(int(result[0, 0]), 3)
|
590
|
+
|
591
|
+
def test_precision_recall_no_positives(self):
|
592
|
+
"""Test precision/recall when there are no positive predictions."""
|
593
|
+
precision_metric = bst.nn.PrecisionMetric()
|
594
|
+
recall_metric = bst.nn.RecallMetric()
|
595
|
+
|
596
|
+
predictions = jnp.array([0, 0, 0, 0, 0])
|
597
|
+
labels = jnp.array([0, 0, 0, 0, 0])
|
598
|
+
|
599
|
+
precision_metric.update(predictions=predictions, labels=labels)
|
600
|
+
recall_metric.update(predictions=predictions, labels=labels)
|
601
|
+
|
602
|
+
# Should handle gracefully without division by zero
|
603
|
+
precision = precision_metric.compute()
|
604
|
+
recall = recall_metric.compute()
|
605
|
+
|
606
|
+
self.assertEqual(float(precision), 0.0)
|
607
|
+
self.assertEqual(float(recall), 0.0)
|
608
|
+
|
609
|
+
|
610
|
+
if __name__ == '__main__':
|
611
|
+
absltest.main()
|