brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1070 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import typing as tp
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from functools import partial
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from brainstate._state import LongTermState
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'MetricState',
|
29
|
+
'Metric',
|
30
|
+
'AverageMetric',
|
31
|
+
'WelfordMetric',
|
32
|
+
'AccuracyMetric',
|
33
|
+
'MultiMetric',
|
34
|
+
'PrecisionMetric',
|
35
|
+
'RecallMetric',
|
36
|
+
'F1ScoreMetric',
|
37
|
+
'ConfusionMatrix',
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
class MetricState(LongTermState):
|
42
|
+
"""
|
43
|
+
Wrapper class for Metric Variables.
|
44
|
+
|
45
|
+
This class extends ``State`` to provide a container for metric state variables
|
46
|
+
that need to be tracked and updated during training or evaluation.
|
47
|
+
|
48
|
+
Examples
|
49
|
+
--------
|
50
|
+
.. code-block:: python
|
51
|
+
|
52
|
+
>>> import jax.numpy as jnp
|
53
|
+
>>> import brainstate
|
54
|
+
>>> state = brainstate.nn.MetricState(jnp.array(0.0))
|
55
|
+
>>> state.value
|
56
|
+
Array(0., dtype=float32)
|
57
|
+
>>> state.value = jnp.array(1.5)
|
58
|
+
>>> state.value
|
59
|
+
Array(1.5, dtype=float32)
|
60
|
+
"""
|
61
|
+
__module__ = "brainstate.nn"
|
62
|
+
|
63
|
+
|
64
|
+
class Metric(object):
|
65
|
+
"""
|
66
|
+
Base class for metrics.
|
67
|
+
|
68
|
+
Any class that subclasses ``Metric`` should implement ``compute``, ``reset``,
|
69
|
+
and ``update`` methods to track and compute evaluation metrics.
|
70
|
+
|
71
|
+
Methods
|
72
|
+
-------
|
73
|
+
reset()
|
74
|
+
Reset the metric state to initial values.
|
75
|
+
update(**kwargs)
|
76
|
+
Update the metric state with new data.
|
77
|
+
compute()
|
78
|
+
Compute and return the current metric value.
|
79
|
+
|
80
|
+
Notes
|
81
|
+
-----
|
82
|
+
This is an abstract base class and should not be instantiated directly.
|
83
|
+
Subclasses must implement all three methods.
|
84
|
+
"""
|
85
|
+
__module__ = "brainstate.nn"
|
86
|
+
|
87
|
+
def reset(self) -> None:
|
88
|
+
"""
|
89
|
+
In-place reset the metric state to initial values.
|
90
|
+
|
91
|
+
This method should restore all internal state variables to their
|
92
|
+
initial values as if the metric was just constructed.
|
93
|
+
|
94
|
+
Raises
|
95
|
+
------
|
96
|
+
NotImplementedError
|
97
|
+
If the subclass does not implement this method.
|
98
|
+
"""
|
99
|
+
raise NotImplementedError('Must override `reset()` method.')
|
100
|
+
|
101
|
+
def update(self, **kwargs) -> None:
|
102
|
+
"""
|
103
|
+
In-place update the metric with new data.
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
**kwargs
|
108
|
+
Keyword arguments containing the data to update the metric.
|
109
|
+
The specific arguments depend on the metric implementation.
|
110
|
+
|
111
|
+
Raises
|
112
|
+
------
|
113
|
+
NotImplementedError
|
114
|
+
If the subclass does not implement this method.
|
115
|
+
"""
|
116
|
+
raise NotImplementedError('Must override `update()` method.')
|
117
|
+
|
118
|
+
def compute(self):
|
119
|
+
"""
|
120
|
+
Compute and return the current value of the metric.
|
121
|
+
|
122
|
+
Returns
|
123
|
+
-------
|
124
|
+
metric_value
|
125
|
+
The computed metric value. The type depends on the specific metric.
|
126
|
+
|
127
|
+
Raises
|
128
|
+
------
|
129
|
+
NotImplementedError
|
130
|
+
If the subclass does not implement this method.
|
131
|
+
"""
|
132
|
+
raise NotImplementedError('Must override `compute()` method.')
|
133
|
+
|
134
|
+
|
135
|
+
class AverageMetric(Metric):
|
136
|
+
"""
|
137
|
+
Average metric for computing running mean of values.
|
138
|
+
|
139
|
+
This metric maintains a running sum and count to compute the average
|
140
|
+
of all values passed to it via the ``update`` method.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
argname : str, optional
|
145
|
+
The keyword argument name that ``update`` will use to derive the new value.
|
146
|
+
Defaults to ``'values'``.
|
147
|
+
|
148
|
+
Attributes
|
149
|
+
----------
|
150
|
+
argname : str
|
151
|
+
The keyword argument name for updates.
|
152
|
+
total : MetricState
|
153
|
+
Cumulative sum of all values.
|
154
|
+
count : MetricState
|
155
|
+
Total number of elements processed.
|
156
|
+
|
157
|
+
Examples
|
158
|
+
--------
|
159
|
+
.. code-block:: python
|
160
|
+
|
161
|
+
>>> import jax.numpy as jnp
|
162
|
+
>>> import brainstate
|
163
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
164
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
165
|
+
>>> metrics = brainstate.nn.AverageMetric()
|
166
|
+
>>> metrics.compute()
|
167
|
+
Array(nan, dtype=float32)
|
168
|
+
>>> metrics.update(values=batch_loss)
|
169
|
+
>>> metrics.compute()
|
170
|
+
Array(2.5, dtype=float32)
|
171
|
+
>>> metrics.update(values=batch_loss2)
|
172
|
+
>>> metrics.compute()
|
173
|
+
Array(2., dtype=float32)
|
174
|
+
>>> metrics.reset()
|
175
|
+
>>> metrics.compute()
|
176
|
+
Array(nan, dtype=float32)
|
177
|
+
|
178
|
+
Notes
|
179
|
+
-----
|
180
|
+
The metric returns NaN when no values have been added (count = 0).
|
181
|
+
This metric can handle scalar values, arrays, or tensors.
|
182
|
+
"""
|
183
|
+
__module__ = "brainstate.nn"
|
184
|
+
|
185
|
+
def __init__(self, argname: str = 'values'):
|
186
|
+
self.argname = argname
|
187
|
+
self.total = MetricState(jnp.array(0, dtype=jnp.float32))
|
188
|
+
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
189
|
+
|
190
|
+
def reset(self) -> None:
|
191
|
+
"""
|
192
|
+
Reset the metric state to zero.
|
193
|
+
|
194
|
+
This sets both the total sum and count to zero.
|
195
|
+
"""
|
196
|
+
self.total.value = jnp.array(0, dtype=jnp.float32)
|
197
|
+
self.count.value = jnp.array(0, dtype=jnp.int32)
|
198
|
+
|
199
|
+
def update(self, **kwargs) -> None:
|
200
|
+
"""
|
201
|
+
Update the metric with new values.
|
202
|
+
|
203
|
+
Parameters
|
204
|
+
----------
|
205
|
+
**kwargs
|
206
|
+
Must contain ``self.argname`` as a key, mapping to the values
|
207
|
+
to be averaged. Values can be scalars, arrays, or tensors.
|
208
|
+
|
209
|
+
Raises
|
210
|
+
------
|
211
|
+
TypeError
|
212
|
+
If the expected keyword argument is not provided.
|
213
|
+
|
214
|
+
Examples
|
215
|
+
--------
|
216
|
+
.. code-block:: python
|
217
|
+
|
218
|
+
>>> import jax.numpy as jnp
|
219
|
+
>>> import brainstate
|
220
|
+
>>> metric = brainstate.nn.AverageMetric('loss')
|
221
|
+
>>> metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
|
222
|
+
>>> metric.compute()
|
223
|
+
Array(2., dtype=float32)
|
224
|
+
"""
|
225
|
+
if self.argname not in kwargs:
|
226
|
+
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
227
|
+
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
228
|
+
self.total.value += (
|
229
|
+
values if isinstance(values, (int, float)) else values.sum()
|
230
|
+
)
|
231
|
+
self.count.value += 1 if isinstance(values, (int, float)) else values.size
|
232
|
+
|
233
|
+
def compute(self) -> jax.Array:
|
234
|
+
"""
|
235
|
+
Compute and return the average.
|
236
|
+
|
237
|
+
Returns
|
238
|
+
-------
|
239
|
+
jax.Array
|
240
|
+
The average of all values provided to ``update``.
|
241
|
+
Returns NaN if no values have been added.
|
242
|
+
"""
|
243
|
+
return self.total.value / self.count.value
|
244
|
+
|
245
|
+
|
246
|
+
@partial(
|
247
|
+
jax.tree_util.register_dataclass,
|
248
|
+
data_fields=['mean', 'standard_error_of_mean', 'standard_deviation'],
|
249
|
+
meta_fields=[]
|
250
|
+
)
|
251
|
+
@dataclass
|
252
|
+
class Statistics:
|
253
|
+
"""
|
254
|
+
Dataclass for statistical measurements.
|
255
|
+
|
256
|
+
Attributes
|
257
|
+
----------
|
258
|
+
mean : float32
|
259
|
+
The mean value.
|
260
|
+
standard_error_of_mean : float32
|
261
|
+
The standard error of the mean (SEM).
|
262
|
+
standard_deviation : float32
|
263
|
+
The standard deviation.
|
264
|
+
|
265
|
+
Examples
|
266
|
+
--------
|
267
|
+
.. code-block:: python
|
268
|
+
|
269
|
+
>>> import jax.numpy as jnp
|
270
|
+
>>> import brainstate
|
271
|
+
>>> stats = brainstate.nn.Statistics(
|
272
|
+
... mean=jnp.float32(2.5),
|
273
|
+
... standard_error_of_mean=jnp.float32(0.5),
|
274
|
+
... standard_deviation=jnp.float32(1.0)
|
275
|
+
... )
|
276
|
+
>>> stats.mean
|
277
|
+
Array(2.5, dtype=float32)
|
278
|
+
"""
|
279
|
+
__module__ = "brainstate.nn"
|
280
|
+
mean: jnp.float32
|
281
|
+
standard_error_of_mean: jnp.float32
|
282
|
+
standard_deviation: jnp.float32
|
283
|
+
|
284
|
+
|
285
|
+
class WelfordMetric(Metric):
|
286
|
+
"""
|
287
|
+
Welford's algorithm for computing mean and variance of streaming data.
|
288
|
+
|
289
|
+
This metric uses Welford's online algorithm to compute running statistics
|
290
|
+
(mean, variance, standard deviation) in a numerically stable way.
|
291
|
+
|
292
|
+
Parameters
|
293
|
+
----------
|
294
|
+
argname : str, optional
|
295
|
+
The keyword argument name that ``update`` will use to derive the new value.
|
296
|
+
Defaults to ``'values'``.
|
297
|
+
|
298
|
+
Attributes
|
299
|
+
----------
|
300
|
+
argname : str
|
301
|
+
The keyword argument name for updates.
|
302
|
+
count : MetricState
|
303
|
+
Total number of elements processed.
|
304
|
+
mean : MetricState
|
305
|
+
Running mean estimate.
|
306
|
+
m2 : MetricState
|
307
|
+
Running sum of squared deviations from the mean.
|
308
|
+
|
309
|
+
Examples
|
310
|
+
--------
|
311
|
+
.. code-block:: python
|
312
|
+
|
313
|
+
>>> import jax.numpy as jnp
|
314
|
+
>>> import brainstate
|
315
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
316
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
317
|
+
>>> metrics = brainstate.nn.WelfordMetric()
|
318
|
+
>>> metrics.compute()
|
319
|
+
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
320
|
+
>>> metrics.update(values=batch_loss)
|
321
|
+
>>> metrics.compute()
|
322
|
+
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
|
323
|
+
>>> metrics.update(values=batch_loss2)
|
324
|
+
>>> metrics.compute()
|
325
|
+
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
|
326
|
+
>>> metrics.reset()
|
327
|
+
>>> metrics.compute()
|
328
|
+
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
329
|
+
|
330
|
+
Notes
|
331
|
+
-----
|
332
|
+
Welford's algorithm is numerically stable and computes variance in a single pass.
|
333
|
+
The algorithm updates the mean and variance incrementally as new data arrives.
|
334
|
+
|
335
|
+
References
|
336
|
+
----------
|
337
|
+
.. [1] Welford, B. P. (1962). "Note on a method for calculating corrected sums
|
338
|
+
of squares and products". Technometrics. 4 (3): 419-420.
|
339
|
+
"""
|
340
|
+
__module__ = "brainstate.nn"
|
341
|
+
|
342
|
+
def __init__(self, argname: str = 'values'):
|
343
|
+
self.argname = argname
|
344
|
+
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
345
|
+
self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
|
346
|
+
self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
|
347
|
+
|
348
|
+
def reset(self) -> None:
|
349
|
+
"""
|
350
|
+
Reset the metric state to zero.
|
351
|
+
|
352
|
+
This resets count, mean, and the sum of squared deviations (m2).
|
353
|
+
"""
|
354
|
+
self.count.value = jnp.array(0, dtype=jnp.uint32)
|
355
|
+
self.mean.value = jnp.array(0, dtype=jnp.float32)
|
356
|
+
self.m2.value = jnp.array(0, dtype=jnp.float32)
|
357
|
+
|
358
|
+
def update(self, **kwargs) -> None:
|
359
|
+
"""
|
360
|
+
Update the metric using Welford's algorithm.
|
361
|
+
|
362
|
+
Parameters
|
363
|
+
----------
|
364
|
+
**kwargs
|
365
|
+
Must contain ``self.argname`` as a key, mapping to the values
|
366
|
+
to be processed. Values can be scalars, arrays, or tensors.
|
367
|
+
|
368
|
+
Raises
|
369
|
+
------
|
370
|
+
TypeError
|
371
|
+
If the expected keyword argument is not provided.
|
372
|
+
|
373
|
+
Examples
|
374
|
+
--------
|
375
|
+
.. code-block:: python
|
376
|
+
|
377
|
+
>>> import jax.numpy as jnp
|
378
|
+
>>> import brainstate
|
379
|
+
>>> metric = brainstate.nn.WelfordMetric('data')
|
380
|
+
>>> metric.update(data=jnp.array([1.0, 2.0, 3.0]))
|
381
|
+
>>> stats = metric.compute()
|
382
|
+
>>> stats.mean
|
383
|
+
Array(2., dtype=float32)
|
384
|
+
"""
|
385
|
+
if self.argname not in kwargs:
|
386
|
+
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
387
|
+
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
388
|
+
count = 1 if isinstance(values, (int, float)) else values.size
|
389
|
+
original_count = self.count.value
|
390
|
+
self.count.value += count
|
391
|
+
delta = (
|
392
|
+
values if isinstance(values, (int, float)) else values.mean()
|
393
|
+
) - self.mean.value
|
394
|
+
self.mean.value += delta * count / self.count.value
|
395
|
+
m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
|
396
|
+
self.m2.value += (
|
397
|
+
m2 + delta * delta * count * original_count / self.count.value
|
398
|
+
)
|
399
|
+
|
400
|
+
def compute(self) -> Statistics:
|
401
|
+
"""
|
402
|
+
Compute and return statistical measurements.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
Statistics
|
407
|
+
A dataclass containing mean, standard error of mean, and standard deviation.
|
408
|
+
Returns NaN for error metrics when count is 0.
|
409
|
+
"""
|
410
|
+
variance = self.m2.value / self.count.value
|
411
|
+
standard_deviation = variance ** 0.5
|
412
|
+
sem = standard_deviation / (self.count.value ** 0.5)
|
413
|
+
return Statistics(
|
414
|
+
mean=self.mean.value,
|
415
|
+
standard_error_of_mean=sem,
|
416
|
+
standard_deviation=standard_deviation,
|
417
|
+
)
|
418
|
+
|
419
|
+
|
420
|
+
class AccuracyMetric(AverageMetric):
|
421
|
+
"""
|
422
|
+
Accuracy metric for classification tasks.
|
423
|
+
|
424
|
+
This metric computes the accuracy by comparing predicted labels (derived from
|
425
|
+
logits using argmax) with ground truth labels. It inherits from ``AverageMetric``
|
426
|
+
and shares the same ``reset`` and ``compute`` implementations.
|
427
|
+
|
428
|
+
Examples
|
429
|
+
--------
|
430
|
+
.. code-block:: python
|
431
|
+
|
432
|
+
>>> import brainstate
|
433
|
+
>>> import jax, jax.numpy as jnp
|
434
|
+
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
435
|
+
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
436
|
+
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
437
|
+
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
438
|
+
>>> metrics = brainstate.nn.AccuracyMetric()
|
439
|
+
>>> metrics.compute()
|
440
|
+
Array(nan, dtype=float32)
|
441
|
+
>>> metrics.update(logits=logits, labels=labels)
|
442
|
+
>>> metrics.compute()
|
443
|
+
Array(0.6, dtype=float32)
|
444
|
+
>>> metrics.update(logits=logits2, labels=labels2)
|
445
|
+
>>> metrics.compute()
|
446
|
+
Array(0.7, dtype=float32)
|
447
|
+
>>> metrics.reset()
|
448
|
+
>>> metrics.compute()
|
449
|
+
Array(nan, dtype=float32)
|
450
|
+
|
451
|
+
Notes
|
452
|
+
-----
|
453
|
+
The accuracy is computed as the fraction of correct predictions:
|
454
|
+
accuracy = (number of correct predictions) / (total predictions)
|
455
|
+
|
456
|
+
Logits are converted to predictions using argmax along the last dimension.
|
457
|
+
"""
|
458
|
+
__module__ = "brainstate.nn"
|
459
|
+
|
460
|
+
def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None:
|
461
|
+
"""
|
462
|
+
Update the accuracy metric with predictions and labels.
|
463
|
+
|
464
|
+
Parameters
|
465
|
+
----------
|
466
|
+
logits : jax.Array
|
467
|
+
Predicted activations/logits with shape (..., num_classes).
|
468
|
+
The last dimension represents class scores.
|
469
|
+
labels : jax.Array
|
470
|
+
Ground truth integer labels with shape (...,).
|
471
|
+
Must be one dimension less than logits.
|
472
|
+
**_
|
473
|
+
Additional keyword arguments are ignored.
|
474
|
+
|
475
|
+
Raises
|
476
|
+
------
|
477
|
+
ValueError
|
478
|
+
If logits and labels have incompatible shapes, or if labels have
|
479
|
+
incorrect dtype.
|
480
|
+
|
481
|
+
Examples
|
482
|
+
--------
|
483
|
+
.. code-block:: python
|
484
|
+
|
485
|
+
>>> import jax.numpy as jnp
|
486
|
+
>>> import brainstate
|
487
|
+
>>> logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
|
488
|
+
>>> labels = jnp.array([1, 0, 1])
|
489
|
+
>>> metric = brainstate.nn.AccuracyMetric()
|
490
|
+
>>> metric.update(logits=logits, labels=labels)
|
491
|
+
>>> metric.compute()
|
492
|
+
Array(1., dtype=float32)
|
493
|
+
"""
|
494
|
+
if logits.ndim != labels.ndim + 1:
|
495
|
+
raise ValueError(
|
496
|
+
f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
|
497
|
+
)
|
498
|
+
elif labels.dtype in (jnp.int64, np.int32, np.int64):
|
499
|
+
labels = jnp.astype(labels, jnp.int32)
|
500
|
+
elif labels.dtype != jnp.int32:
|
501
|
+
raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
|
502
|
+
|
503
|
+
super().update(values=(logits.argmax(axis=-1) == labels))
|
504
|
+
|
505
|
+
|
506
|
+
class PrecisionMetric(Metric):
|
507
|
+
"""
|
508
|
+
Precision metric for binary and multi-class classification.
|
509
|
+
|
510
|
+
Precision is the ratio of true positives to all positive predictions:
|
511
|
+
precision = TP / (TP + FP)
|
512
|
+
|
513
|
+
Parameters
|
514
|
+
----------
|
515
|
+
num_classes : int, optional
|
516
|
+
Number of classes. If None, assumes binary classification. Default is None.
|
517
|
+
average : str, optional
|
518
|
+
Type of averaging for multi-class: 'micro', 'macro', or 'weighted'.
|
519
|
+
Default is 'macro'. Ignored for binary classification.
|
520
|
+
|
521
|
+
Attributes
|
522
|
+
----------
|
523
|
+
num_classes : int or None
|
524
|
+
Number of classes.
|
525
|
+
average : str
|
526
|
+
Averaging method for multi-class.
|
527
|
+
true_positives : MetricState
|
528
|
+
Count of true positive predictions.
|
529
|
+
false_positives : MetricState
|
530
|
+
Count of false positive predictions.
|
531
|
+
|
532
|
+
Examples
|
533
|
+
--------
|
534
|
+
.. code-block:: python
|
535
|
+
|
536
|
+
>>> import jax.numpy as jnp
|
537
|
+
>>> import brainstate
|
538
|
+
>>> predictions = jnp.array([1, 0, 1, 1, 0])
|
539
|
+
>>> labels = jnp.array([1, 0, 0, 1, 0])
|
540
|
+
>>> metric = brainstate.nn.PrecisionMetric()
|
541
|
+
>>> metric.update(predictions=predictions, labels=labels)
|
542
|
+
>>> metric.compute()
|
543
|
+
Array(0.6666667, dtype=float32)
|
544
|
+
|
545
|
+
Notes
|
546
|
+
-----
|
547
|
+
For multi-class classification, the metric supports different averaging strategies:
|
548
|
+
- 'micro': Calculate metrics globally by counting total TP and FP
|
549
|
+
- 'macro': Calculate metrics for each class and find their unweighted mean
|
550
|
+
- 'weighted': Calculate metrics for each class and find their weighted mean
|
551
|
+
"""
|
552
|
+
__module__ = "brainstate.nn"
|
553
|
+
|
554
|
+
def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'):
|
555
|
+
self.num_classes = num_classes
|
556
|
+
self.average = average
|
557
|
+
if num_classes is None:
|
558
|
+
self.true_positives = MetricState(jnp.array(0, dtype=jnp.int32))
|
559
|
+
self.false_positives = MetricState(jnp.array(0, dtype=jnp.int32))
|
560
|
+
else:
|
561
|
+
self.true_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
|
562
|
+
self.false_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
|
563
|
+
|
564
|
+
def reset(self) -> None:
|
565
|
+
"""Reset the metric state to zero."""
|
566
|
+
if self.num_classes is None:
|
567
|
+
self.true_positives.value = jnp.array(0, dtype=jnp.int32)
|
568
|
+
self.false_positives.value = jnp.array(0, dtype=jnp.int32)
|
569
|
+
else:
|
570
|
+
self.true_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
|
571
|
+
self.false_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
|
572
|
+
|
573
|
+
def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
|
574
|
+
"""
|
575
|
+
Update the precision metric.
|
576
|
+
|
577
|
+
Parameters
|
578
|
+
----------
|
579
|
+
predictions : jax.Array
|
580
|
+
Predicted class labels (integers).
|
581
|
+
labels : jax.Array
|
582
|
+
Ground truth class labels (integers).
|
583
|
+
**_
|
584
|
+
Additional keyword arguments are ignored.
|
585
|
+
"""
|
586
|
+
if self.num_classes is None:
|
587
|
+
# Binary classification
|
588
|
+
self.true_positives.value += jnp.sum((predictions == 1) & (labels == 1))
|
589
|
+
self.false_positives.value += jnp.sum((predictions == 1) & (labels == 0))
|
590
|
+
else:
|
591
|
+
# Multi-class classification
|
592
|
+
for c in range(self.num_classes):
|
593
|
+
self.true_positives.value = self.true_positives.value.at[c].add(
|
594
|
+
jnp.sum((predictions == c) & (labels == c))
|
595
|
+
)
|
596
|
+
self.false_positives.value = self.false_positives.value.at[c].add(
|
597
|
+
jnp.sum((predictions == c) & (labels != c))
|
598
|
+
)
|
599
|
+
|
600
|
+
def compute(self) -> jax.Array:
|
601
|
+
"""
|
602
|
+
Compute and return the precision.
|
603
|
+
|
604
|
+
Returns
|
605
|
+
-------
|
606
|
+
jax.Array
|
607
|
+
The precision value(s). For binary classification, returns a scalar.
|
608
|
+
For multi-class, returns per-class or averaged precision based on
|
609
|
+
the ``average`` parameter.
|
610
|
+
"""
|
611
|
+
denominator = self.true_positives.value + self.false_positives.value
|
612
|
+
precision = jnp.where(
|
613
|
+
denominator > 0,
|
614
|
+
self.true_positives.value / denominator,
|
615
|
+
jnp.zeros_like(denominator, dtype=jnp.float32)
|
616
|
+
)
|
617
|
+
|
618
|
+
if self.num_classes is not None and self.average == 'macro':
|
619
|
+
return jnp.mean(precision)
|
620
|
+
elif self.num_classes is not None and self.average == 'micro':
|
621
|
+
total_tp = jnp.sum(self.true_positives.value)
|
622
|
+
total_fp = jnp.sum(self.false_positives.value)
|
623
|
+
return jnp.where(
|
624
|
+
total_tp + total_fp > 0,
|
625
|
+
total_tp / (total_tp + total_fp),
|
626
|
+
jnp.float32(0.0)
|
627
|
+
)
|
628
|
+
return precision
|
629
|
+
|
630
|
+
|
631
|
+
class RecallMetric(Metric):
|
632
|
+
"""
|
633
|
+
Recall (sensitivity) metric for binary and multi-class classification.
|
634
|
+
|
635
|
+
Recall is the ratio of true positives to all actual positives:
|
636
|
+
recall = TP / (TP + FN)
|
637
|
+
|
638
|
+
Parameters
|
639
|
+
----------
|
640
|
+
num_classes : int, optional
|
641
|
+
Number of classes. If None, assumes binary classification. Default is None.
|
642
|
+
average : str, optional
|
643
|
+
Type of averaging for multi-class: 'micro', 'macro', or 'weighted'.
|
644
|
+
Default is 'macro'. Ignored for binary classification.
|
645
|
+
|
646
|
+
Attributes
|
647
|
+
----------
|
648
|
+
num_classes : int or None
|
649
|
+
Number of classes.
|
650
|
+
average : str
|
651
|
+
Averaging method for multi-class.
|
652
|
+
true_positives : MetricState
|
653
|
+
Count of true positive predictions.
|
654
|
+
false_negatives : MetricState
|
655
|
+
Count of false negative predictions.
|
656
|
+
|
657
|
+
Examples
|
658
|
+
--------
|
659
|
+
.. code-block:: python
|
660
|
+
|
661
|
+
>>> import jax.numpy as jnp
|
662
|
+
>>> import brainstate
|
663
|
+
>>> predictions = jnp.array([1, 0, 1, 1, 0])
|
664
|
+
>>> labels = jnp.array([1, 0, 0, 1, 0])
|
665
|
+
>>> metric = brainstate.nn.RecallMetric()
|
666
|
+
>>> metric.update(predictions=predictions, labels=labels)
|
667
|
+
>>> metric.compute()
|
668
|
+
Array(1., dtype=float32)
|
669
|
+
|
670
|
+
Notes
|
671
|
+
-----
|
672
|
+
Recall measures the fraction of actual positive cases that were correctly identified.
|
673
|
+
Also known as sensitivity or true positive rate (TPR).
|
674
|
+
"""
|
675
|
+
__module__ = "brainstate.nn"
|
676
|
+
|
677
|
+
def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'):
|
678
|
+
self.num_classes = num_classes
|
679
|
+
self.average = average
|
680
|
+
if num_classes is None:
|
681
|
+
self.true_positives = MetricState(jnp.array(0, dtype=jnp.int32))
|
682
|
+
self.false_negatives = MetricState(jnp.array(0, dtype=jnp.int32))
|
683
|
+
else:
|
684
|
+
self.true_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
|
685
|
+
self.false_negatives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
|
686
|
+
|
687
|
+
def reset(self) -> None:
|
688
|
+
"""Reset the metric state to zero."""
|
689
|
+
if self.num_classes is None:
|
690
|
+
self.true_positives.value = jnp.array(0, dtype=jnp.int32)
|
691
|
+
self.false_negatives.value = jnp.array(0, dtype=jnp.int32)
|
692
|
+
else:
|
693
|
+
self.true_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
|
694
|
+
self.false_negatives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
|
695
|
+
|
696
|
+
def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
|
697
|
+
"""
|
698
|
+
Update the recall metric.
|
699
|
+
|
700
|
+
Parameters
|
701
|
+
----------
|
702
|
+
predictions : jax.Array
|
703
|
+
Predicted class labels (integers).
|
704
|
+
labels : jax.Array
|
705
|
+
Ground truth class labels (integers).
|
706
|
+
**_
|
707
|
+
Additional keyword arguments are ignored.
|
708
|
+
"""
|
709
|
+
if self.num_classes is None:
|
710
|
+
# Binary classification
|
711
|
+
self.true_positives.value += jnp.sum((predictions == 1) & (labels == 1))
|
712
|
+
self.false_negatives.value += jnp.sum((predictions == 0) & (labels == 1))
|
713
|
+
else:
|
714
|
+
# Multi-class classification
|
715
|
+
for c in range(self.num_classes):
|
716
|
+
self.true_positives.value = self.true_positives.value.at[c].add(
|
717
|
+
jnp.sum((predictions == c) & (labels == c))
|
718
|
+
)
|
719
|
+
self.false_negatives.value = self.false_negatives.value.at[c].add(
|
720
|
+
jnp.sum((predictions != c) & (labels == c))
|
721
|
+
)
|
722
|
+
|
723
|
+
def compute(self) -> jax.Array:
|
724
|
+
"""
|
725
|
+
Compute and return the recall.
|
726
|
+
|
727
|
+
Returns
|
728
|
+
-------
|
729
|
+
jax.Array
|
730
|
+
The recall value(s). For binary classification, returns a scalar.
|
731
|
+
For multi-class, returns per-class or averaged recall based on
|
732
|
+
the ``average`` parameter.
|
733
|
+
"""
|
734
|
+
denominator = self.true_positives.value + self.false_negatives.value
|
735
|
+
recall = jnp.where(
|
736
|
+
denominator > 0,
|
737
|
+
self.true_positives.value / denominator,
|
738
|
+
jnp.zeros_like(denominator, dtype=jnp.float32)
|
739
|
+
)
|
740
|
+
|
741
|
+
if self.num_classes is not None and self.average == 'macro':
|
742
|
+
return jnp.mean(recall)
|
743
|
+
elif self.num_classes is not None and self.average == 'micro':
|
744
|
+
total_tp = jnp.sum(self.true_positives.value)
|
745
|
+
total_fn = jnp.sum(self.false_negatives.value)
|
746
|
+
return jnp.where(
|
747
|
+
total_tp + total_fn > 0,
|
748
|
+
total_tp / (total_tp + total_fn),
|
749
|
+
jnp.float32(0.0)
|
750
|
+
)
|
751
|
+
return recall
|
752
|
+
|
753
|
+
|
754
|
+
class F1ScoreMetric(Metric):
|
755
|
+
"""
|
756
|
+
F1 score metric for binary and multi-class classification.
|
757
|
+
|
758
|
+
F1 score is the harmonic mean of precision and recall:
|
759
|
+
F1 = 2 * (precision * recall) / (precision + recall)
|
760
|
+
|
761
|
+
Parameters
|
762
|
+
----------
|
763
|
+
num_classes : int, optional
|
764
|
+
Number of classes. If None, assumes binary classification. Default is None.
|
765
|
+
average : str, optional
|
766
|
+
Type of averaging for multi-class: 'micro', 'macro', or 'weighted'.
|
767
|
+
Default is 'macro'. Ignored for binary classification.
|
768
|
+
|
769
|
+
Attributes
|
770
|
+
----------
|
771
|
+
precision_metric : PrecisionMetric
|
772
|
+
Internal precision metric.
|
773
|
+
recall_metric : RecallMetric
|
774
|
+
Internal recall metric.
|
775
|
+
|
776
|
+
Examples
|
777
|
+
--------
|
778
|
+
.. code-block:: python
|
779
|
+
|
780
|
+
>>> import jax.numpy as jnp
|
781
|
+
>>> import brainstate
|
782
|
+
>>> predictions = jnp.array([1, 0, 1, 1, 0])
|
783
|
+
>>> labels = jnp.array([1, 0, 0, 1, 0])
|
784
|
+
>>> metric = brainstate.nn.F1ScoreMetric()
|
785
|
+
>>> metric.update(predictions=predictions, labels=labels)
|
786
|
+
>>> metric.compute()
|
787
|
+
Array(0.8, dtype=float32)
|
788
|
+
|
789
|
+
Notes
|
790
|
+
-----
|
791
|
+
The F1 score balances precision and recall, providing a single metric that
|
792
|
+
considers both false positives and false negatives.
|
793
|
+
"""
|
794
|
+
__module__ = "brainstate.nn"
|
795
|
+
|
796
|
+
def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'):
|
797
|
+
self.precision_metric = PrecisionMetric(num_classes, average)
|
798
|
+
self.recall_metric = RecallMetric(num_classes, average)
|
799
|
+
|
800
|
+
def reset(self) -> None:
|
801
|
+
"""Reset the metric state to zero."""
|
802
|
+
self.precision_metric.reset()
|
803
|
+
self.recall_metric.reset()
|
804
|
+
|
805
|
+
def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
|
806
|
+
"""
|
807
|
+
Update the F1 score metric.
|
808
|
+
|
809
|
+
Parameters
|
810
|
+
----------
|
811
|
+
predictions : jax.Array
|
812
|
+
Predicted class labels (integers).
|
813
|
+
labels : jax.Array
|
814
|
+
Ground truth class labels (integers).
|
815
|
+
**_
|
816
|
+
Additional keyword arguments are ignored.
|
817
|
+
"""
|
818
|
+
self.precision_metric.update(predictions=predictions, labels=labels)
|
819
|
+
self.recall_metric.update(predictions=predictions, labels=labels)
|
820
|
+
|
821
|
+
def compute(self) -> jax.Array:
|
822
|
+
"""
|
823
|
+
Compute and return the F1 score.
|
824
|
+
|
825
|
+
Returns
|
826
|
+
-------
|
827
|
+
jax.Array
|
828
|
+
The F1 score value(s). Returns 0 when both precision and recall are 0.
|
829
|
+
"""
|
830
|
+
precision = self.precision_metric.compute()
|
831
|
+
recall = self.recall_metric.compute()
|
832
|
+
denominator = precision + recall
|
833
|
+
return jnp.where(
|
834
|
+
denominator > 0,
|
835
|
+
2 * precision * recall / denominator,
|
836
|
+
jnp.float32(0.0)
|
837
|
+
)
|
838
|
+
|
839
|
+
|
840
|
+
class ConfusionMatrix(Metric):
|
841
|
+
"""
|
842
|
+
Confusion matrix metric for multi-class classification.
|
843
|
+
|
844
|
+
A confusion matrix shows the counts of predicted vs. actual class labels,
|
845
|
+
where rows represent true labels and columns represent predicted labels.
|
846
|
+
|
847
|
+
Parameters
|
848
|
+
----------
|
849
|
+
num_classes : int
|
850
|
+
Number of classes in the classification task.
|
851
|
+
|
852
|
+
Attributes
|
853
|
+
----------
|
854
|
+
num_classes : int
|
855
|
+
Number of classes.
|
856
|
+
matrix : MetricState
|
857
|
+
The confusion matrix of shape (num_classes, num_classes).
|
858
|
+
|
859
|
+
Examples
|
860
|
+
--------
|
861
|
+
.. code-block:: python
|
862
|
+
|
863
|
+
>>> import jax.numpy as jnp
|
864
|
+
>>> import brainstate
|
865
|
+
>>> predictions = jnp.array([0, 1, 2, 1, 0])
|
866
|
+
>>> labels = jnp.array([0, 1, 1, 1, 2])
|
867
|
+
>>> metric = brainstate.nn.ConfusionMatrix(num_classes=3)
|
868
|
+
>>> metric.update(predictions=predictions, labels=labels)
|
869
|
+
>>> metric.compute()
|
870
|
+
Array([[1, 0, 1],
|
871
|
+
[0, 2, 0],
|
872
|
+
[1, 0, 0]], dtype=int32)
|
873
|
+
|
874
|
+
Notes
|
875
|
+
-----
|
876
|
+
The confusion matrix is useful for understanding which classes are being
|
877
|
+
confused with each other and for computing class-specific metrics.
|
878
|
+
"""
|
879
|
+
__module__ = "brainstate.nn"
|
880
|
+
|
881
|
+
def __init__(self, num_classes: int):
|
882
|
+
self.num_classes = num_classes
|
883
|
+
self.matrix = MetricState(jnp.zeros((num_classes, num_classes), dtype=jnp.int32))
|
884
|
+
|
885
|
+
def reset(self) -> None:
|
886
|
+
"""Reset the confusion matrix to zeros."""
|
887
|
+
self.matrix.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
|
888
|
+
|
889
|
+
def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
|
890
|
+
"""
|
891
|
+
Update the confusion matrix.
|
892
|
+
|
893
|
+
Parameters
|
894
|
+
----------
|
895
|
+
predictions : jax.Array
|
896
|
+
Predicted class labels (integers) with shape (batch_size,).
|
897
|
+
labels : jax.Array
|
898
|
+
Ground truth class labels (integers) with shape (batch_size,).
|
899
|
+
**_
|
900
|
+
Additional keyword arguments are ignored.
|
901
|
+
|
902
|
+
Raises
|
903
|
+
------
|
904
|
+
ValueError
|
905
|
+
If predictions or labels contain values outside [0, num_classes).
|
906
|
+
"""
|
907
|
+
predictions = jnp.asarray(predictions, dtype=jnp.int32).flatten()
|
908
|
+
labels = jnp.asarray(labels, dtype=jnp.int32).flatten()
|
909
|
+
|
910
|
+
if jnp.any((predictions < 0) | (predictions >= self.num_classes)):
|
911
|
+
raise ValueError(f"Predictions contain values outside [0, {self.num_classes})")
|
912
|
+
if jnp.any((labels < 0) | (labels >= self.num_classes)):
|
913
|
+
raise ValueError(f"Labels contain values outside [0, {self.num_classes})")
|
914
|
+
|
915
|
+
for true_label in range(self.num_classes):
|
916
|
+
for pred_label in range(self.num_classes):
|
917
|
+
count = jnp.sum((labels == true_label) & (predictions == pred_label))
|
918
|
+
self.matrix.value = self.matrix.value.at[true_label, pred_label].add(count)
|
919
|
+
|
920
|
+
def compute(self) -> jax.Array:
|
921
|
+
"""
|
922
|
+
Compute and return the confusion matrix.
|
923
|
+
|
924
|
+
Returns
|
925
|
+
-------
|
926
|
+
jax.Array
|
927
|
+
The confusion matrix of shape (num_classes, num_classes).
|
928
|
+
Element [i, j] represents the count of samples with true label i
|
929
|
+
that were predicted as label j.
|
930
|
+
"""
|
931
|
+
return self.matrix.value
|
932
|
+
|
933
|
+
|
934
|
+
class MultiMetric(Metric):
|
935
|
+
"""
|
936
|
+
Container for multiple metrics updated simultaneously.
|
937
|
+
|
938
|
+
This class allows you to group multiple metrics together and update them
|
939
|
+
all with a single call. It's useful for tracking multiple evaluation metrics
|
940
|
+
(e.g., accuracy, loss, F1 score) during training or evaluation.
|
941
|
+
|
942
|
+
Parameters
|
943
|
+
----------
|
944
|
+
**metrics
|
945
|
+
Keyword arguments where keys are metric names (strings) and values
|
946
|
+
are Metric instances.
|
947
|
+
|
948
|
+
Attributes
|
949
|
+
----------
|
950
|
+
_metric_names : list of str
|
951
|
+
List of metric names in the order they were added.
|
952
|
+
|
953
|
+
Examples
|
954
|
+
--------
|
955
|
+
.. code-block:: python
|
956
|
+
|
957
|
+
>>> import brainstate
|
958
|
+
>>> import jax, jax.numpy as jnp
|
959
|
+
>>> metrics = brainstate.nn.MultiMetric(
|
960
|
+
... accuracy=brainstate.nn.AccuracyMetric(),
|
961
|
+
... loss=brainstate.nn.AverageMetric(),
|
962
|
+
... )
|
963
|
+
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
964
|
+
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
965
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
966
|
+
>>> metrics.compute()
|
967
|
+
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
968
|
+
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
|
969
|
+
>>> metrics.compute()
|
970
|
+
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
|
971
|
+
>>> metrics.reset()
|
972
|
+
>>> metrics.compute()
|
973
|
+
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
974
|
+
|
975
|
+
Notes
|
976
|
+
-----
|
977
|
+
All keyword arguments passed to ``update`` are forwarded to all underlying metrics.
|
978
|
+
Each metric will extract the arguments it needs based on its implementation.
|
979
|
+
|
980
|
+
Reserved method names ('reset', 'update', 'compute') cannot be used as metric names.
|
981
|
+
"""
|
982
|
+
__module__ = "brainstate.nn"
|
983
|
+
|
984
|
+
def __init__(self, **metrics):
|
985
|
+
# Validate that no reserved names are used
|
986
|
+
reserved_names = {'reset', 'update', 'compute'}
|
987
|
+
for metric_name in metrics.keys():
|
988
|
+
if metric_name in reserved_names:
|
989
|
+
raise ValueError(
|
990
|
+
f"Metric name '{metric_name}' is reserved for class methods. "
|
991
|
+
f"Please use a different name. Reserved names: {reserved_names}"
|
992
|
+
)
|
993
|
+
|
994
|
+
self._metric_names = []
|
995
|
+
for metric_name, metric in metrics.items():
|
996
|
+
if not isinstance(metric, Metric):
|
997
|
+
raise TypeError(
|
998
|
+
f"All metrics must be instances of Metric, got {type(metric)} "
|
999
|
+
f"for metric '{metric_name}'"
|
1000
|
+
)
|
1001
|
+
self._metric_names.append(metric_name)
|
1002
|
+
vars(self)[metric_name] = metric
|
1003
|
+
|
1004
|
+
def reset(self) -> None:
|
1005
|
+
"""
|
1006
|
+
Reset all underlying metrics.
|
1007
|
+
|
1008
|
+
This calls the ``reset`` method on each metric in the collection.
|
1009
|
+
"""
|
1010
|
+
for metric_name in self._metric_names:
|
1011
|
+
getattr(self, metric_name).reset()
|
1012
|
+
|
1013
|
+
def update(self, **updates) -> None:
|
1014
|
+
"""
|
1015
|
+
Update all underlying metrics.
|
1016
|
+
|
1017
|
+
All keyword arguments are passed to the ``update`` method of each metric.
|
1018
|
+
Individual metrics will extract the arguments they need.
|
1019
|
+
|
1020
|
+
Parameters
|
1021
|
+
----------
|
1022
|
+
**updates
|
1023
|
+
Keyword arguments to be passed to all underlying metrics.
|
1024
|
+
|
1025
|
+
Examples
|
1026
|
+
--------
|
1027
|
+
.. code-block:: python
|
1028
|
+
|
1029
|
+
>>> import jax.numpy as jnp
|
1030
|
+
>>> import brainstate
|
1031
|
+
>>> metrics = brainstate.nn.MultiMetric(
|
1032
|
+
... accuracy=brainstate.nn.AccuracyMetric(),
|
1033
|
+
... loss=brainstate.nn.AverageMetric('loss_value'),
|
1034
|
+
... )
|
1035
|
+
>>> logits = jnp.array([[0.2, 0.8], [0.9, 0.1]])
|
1036
|
+
>>> labels = jnp.array([1, 0])
|
1037
|
+
>>> loss = jnp.array([0.5, 0.3])
|
1038
|
+
>>> metrics.update(logits=logits, labels=labels, loss_value=loss)
|
1039
|
+
"""
|
1040
|
+
for metric_name in self._metric_names:
|
1041
|
+
getattr(self, metric_name).update(**updates)
|
1042
|
+
|
1043
|
+
def compute(self) -> dict[str, tp.Any]:
|
1044
|
+
"""
|
1045
|
+
Compute and return all metric values.
|
1046
|
+
|
1047
|
+
Returns
|
1048
|
+
-------
|
1049
|
+
dict[str, Any]
|
1050
|
+
Dictionary mapping metric names to their computed values.
|
1051
|
+
The value type depends on the specific metric implementation.
|
1052
|
+
|
1053
|
+
Examples
|
1054
|
+
--------
|
1055
|
+
.. code-block:: python
|
1056
|
+
|
1057
|
+
>>> import brainstate
|
1058
|
+
>>> metrics = brainstate.nn.MultiMetric(
|
1059
|
+
... loss=brainstate.nn.AverageMetric(),
|
1060
|
+
... stats=brainstate.nn.WelfordMetric(),
|
1061
|
+
... )
|
1062
|
+
>>> # After updates...
|
1063
|
+
>>> results = metrics.compute()
|
1064
|
+
>>> results['loss'] # Returns a scalar
|
1065
|
+
>>> results['stats'] # Returns a Statistics object
|
1066
|
+
"""
|
1067
|
+
return {
|
1068
|
+
metric_name: getattr(self, metric_name).compute()
|
1069
|
+
for metric_name in self._metric_names
|
1070
|
+
}
|