brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1070 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import typing as tp
18
+ from dataclasses import dataclass
19
+ from functools import partial
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ from brainstate._state import LongTermState
26
+
27
+ __all__ = [
28
+ 'MetricState',
29
+ 'Metric',
30
+ 'AverageMetric',
31
+ 'WelfordMetric',
32
+ 'AccuracyMetric',
33
+ 'MultiMetric',
34
+ 'PrecisionMetric',
35
+ 'RecallMetric',
36
+ 'F1ScoreMetric',
37
+ 'ConfusionMatrix',
38
+ ]
39
+
40
+
41
+ class MetricState(LongTermState):
42
+ """
43
+ Wrapper class for Metric Variables.
44
+
45
+ This class extends ``State`` to provide a container for metric state variables
46
+ that need to be tracked and updated during training or evaluation.
47
+
48
+ Examples
49
+ --------
50
+ .. code-block:: python
51
+
52
+ >>> import jax.numpy as jnp
53
+ >>> import brainstate
54
+ >>> state = brainstate.nn.MetricState(jnp.array(0.0))
55
+ >>> state.value
56
+ Array(0., dtype=float32)
57
+ >>> state.value = jnp.array(1.5)
58
+ >>> state.value
59
+ Array(1.5, dtype=float32)
60
+ """
61
+ __module__ = "brainstate.nn"
62
+
63
+
64
+ class Metric(object):
65
+ """
66
+ Base class for metrics.
67
+
68
+ Any class that subclasses ``Metric`` should implement ``compute``, ``reset``,
69
+ and ``update`` methods to track and compute evaluation metrics.
70
+
71
+ Methods
72
+ -------
73
+ reset()
74
+ Reset the metric state to initial values.
75
+ update(**kwargs)
76
+ Update the metric state with new data.
77
+ compute()
78
+ Compute and return the current metric value.
79
+
80
+ Notes
81
+ -----
82
+ This is an abstract base class and should not be instantiated directly.
83
+ Subclasses must implement all three methods.
84
+ """
85
+ __module__ = "brainstate.nn"
86
+
87
+ def reset(self) -> None:
88
+ """
89
+ In-place reset the metric state to initial values.
90
+
91
+ This method should restore all internal state variables to their
92
+ initial values as if the metric was just constructed.
93
+
94
+ Raises
95
+ ------
96
+ NotImplementedError
97
+ If the subclass does not implement this method.
98
+ """
99
+ raise NotImplementedError('Must override `reset()` method.')
100
+
101
+ def update(self, **kwargs) -> None:
102
+ """
103
+ In-place update the metric with new data.
104
+
105
+ Parameters
106
+ ----------
107
+ **kwargs
108
+ Keyword arguments containing the data to update the metric.
109
+ The specific arguments depend on the metric implementation.
110
+
111
+ Raises
112
+ ------
113
+ NotImplementedError
114
+ If the subclass does not implement this method.
115
+ """
116
+ raise NotImplementedError('Must override `update()` method.')
117
+
118
+ def compute(self):
119
+ """
120
+ Compute and return the current value of the metric.
121
+
122
+ Returns
123
+ -------
124
+ metric_value
125
+ The computed metric value. The type depends on the specific metric.
126
+
127
+ Raises
128
+ ------
129
+ NotImplementedError
130
+ If the subclass does not implement this method.
131
+ """
132
+ raise NotImplementedError('Must override `compute()` method.')
133
+
134
+
135
+ class AverageMetric(Metric):
136
+ """
137
+ Average metric for computing running mean of values.
138
+
139
+ This metric maintains a running sum and count to compute the average
140
+ of all values passed to it via the ``update`` method.
141
+
142
+ Parameters
143
+ ----------
144
+ argname : str, optional
145
+ The keyword argument name that ``update`` will use to derive the new value.
146
+ Defaults to ``'values'``.
147
+
148
+ Attributes
149
+ ----------
150
+ argname : str
151
+ The keyword argument name for updates.
152
+ total : MetricState
153
+ Cumulative sum of all values.
154
+ count : MetricState
155
+ Total number of elements processed.
156
+
157
+ Examples
158
+ --------
159
+ .. code-block:: python
160
+
161
+ >>> import jax.numpy as jnp
162
+ >>> import brainstate
163
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
164
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
165
+ >>> metrics = brainstate.nn.AverageMetric()
166
+ >>> metrics.compute()
167
+ Array(nan, dtype=float32)
168
+ >>> metrics.update(values=batch_loss)
169
+ >>> metrics.compute()
170
+ Array(2.5, dtype=float32)
171
+ >>> metrics.update(values=batch_loss2)
172
+ >>> metrics.compute()
173
+ Array(2., dtype=float32)
174
+ >>> metrics.reset()
175
+ >>> metrics.compute()
176
+ Array(nan, dtype=float32)
177
+
178
+ Notes
179
+ -----
180
+ The metric returns NaN when no values have been added (count = 0).
181
+ This metric can handle scalar values, arrays, or tensors.
182
+ """
183
+ __module__ = "brainstate.nn"
184
+
185
+ def __init__(self, argname: str = 'values'):
186
+ self.argname = argname
187
+ self.total = MetricState(jnp.array(0, dtype=jnp.float32))
188
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
189
+
190
+ def reset(self) -> None:
191
+ """
192
+ Reset the metric state to zero.
193
+
194
+ This sets both the total sum and count to zero.
195
+ """
196
+ self.total.value = jnp.array(0, dtype=jnp.float32)
197
+ self.count.value = jnp.array(0, dtype=jnp.int32)
198
+
199
+ def update(self, **kwargs) -> None:
200
+ """
201
+ Update the metric with new values.
202
+
203
+ Parameters
204
+ ----------
205
+ **kwargs
206
+ Must contain ``self.argname`` as a key, mapping to the values
207
+ to be averaged. Values can be scalars, arrays, or tensors.
208
+
209
+ Raises
210
+ ------
211
+ TypeError
212
+ If the expected keyword argument is not provided.
213
+
214
+ Examples
215
+ --------
216
+ .. code-block:: python
217
+
218
+ >>> import jax.numpy as jnp
219
+ >>> import brainstate
220
+ >>> metric = brainstate.nn.AverageMetric('loss')
221
+ >>> metric.update(loss=jnp.array([1.0, 2.0, 3.0]))
222
+ >>> metric.compute()
223
+ Array(2., dtype=float32)
224
+ """
225
+ if self.argname not in kwargs:
226
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
227
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
228
+ self.total.value += (
229
+ values if isinstance(values, (int, float)) else values.sum()
230
+ )
231
+ self.count.value += 1 if isinstance(values, (int, float)) else values.size
232
+
233
+ def compute(self) -> jax.Array:
234
+ """
235
+ Compute and return the average.
236
+
237
+ Returns
238
+ -------
239
+ jax.Array
240
+ The average of all values provided to ``update``.
241
+ Returns NaN if no values have been added.
242
+ """
243
+ return self.total.value / self.count.value
244
+
245
+
246
+ @partial(
247
+ jax.tree_util.register_dataclass,
248
+ data_fields=['mean', 'standard_error_of_mean', 'standard_deviation'],
249
+ meta_fields=[]
250
+ )
251
+ @dataclass
252
+ class Statistics:
253
+ """
254
+ Dataclass for statistical measurements.
255
+
256
+ Attributes
257
+ ----------
258
+ mean : float32
259
+ The mean value.
260
+ standard_error_of_mean : float32
261
+ The standard error of the mean (SEM).
262
+ standard_deviation : float32
263
+ The standard deviation.
264
+
265
+ Examples
266
+ --------
267
+ .. code-block:: python
268
+
269
+ >>> import jax.numpy as jnp
270
+ >>> import brainstate
271
+ >>> stats = brainstate.nn.Statistics(
272
+ ... mean=jnp.float32(2.5),
273
+ ... standard_error_of_mean=jnp.float32(0.5),
274
+ ... standard_deviation=jnp.float32(1.0)
275
+ ... )
276
+ >>> stats.mean
277
+ Array(2.5, dtype=float32)
278
+ """
279
+ __module__ = "brainstate.nn"
280
+ mean: jnp.float32
281
+ standard_error_of_mean: jnp.float32
282
+ standard_deviation: jnp.float32
283
+
284
+
285
+ class WelfordMetric(Metric):
286
+ """
287
+ Welford's algorithm for computing mean and variance of streaming data.
288
+
289
+ This metric uses Welford's online algorithm to compute running statistics
290
+ (mean, variance, standard deviation) in a numerically stable way.
291
+
292
+ Parameters
293
+ ----------
294
+ argname : str, optional
295
+ The keyword argument name that ``update`` will use to derive the new value.
296
+ Defaults to ``'values'``.
297
+
298
+ Attributes
299
+ ----------
300
+ argname : str
301
+ The keyword argument name for updates.
302
+ count : MetricState
303
+ Total number of elements processed.
304
+ mean : MetricState
305
+ Running mean estimate.
306
+ m2 : MetricState
307
+ Running sum of squared deviations from the mean.
308
+
309
+ Examples
310
+ --------
311
+ .. code-block:: python
312
+
313
+ >>> import jax.numpy as jnp
314
+ >>> import brainstate
315
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
316
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
317
+ >>> metrics = brainstate.nn.WelfordMetric()
318
+ >>> metrics.compute()
319
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
320
+ >>> metrics.update(values=batch_loss)
321
+ >>> metrics.compute()
322
+ Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
323
+ >>> metrics.update(values=batch_loss2)
324
+ >>> metrics.compute()
325
+ Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
326
+ >>> metrics.reset()
327
+ >>> metrics.compute()
328
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
329
+
330
+ Notes
331
+ -----
332
+ Welford's algorithm is numerically stable and computes variance in a single pass.
333
+ The algorithm updates the mean and variance incrementally as new data arrives.
334
+
335
+ References
336
+ ----------
337
+ .. [1] Welford, B. P. (1962). "Note on a method for calculating corrected sums
338
+ of squares and products". Technometrics. 4 (3): 419-420.
339
+ """
340
+ __module__ = "brainstate.nn"
341
+
342
+ def __init__(self, argname: str = 'values'):
343
+ self.argname = argname
344
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
345
+ self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
346
+ self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
347
+
348
+ def reset(self) -> None:
349
+ """
350
+ Reset the metric state to zero.
351
+
352
+ This resets count, mean, and the sum of squared deviations (m2).
353
+ """
354
+ self.count.value = jnp.array(0, dtype=jnp.uint32)
355
+ self.mean.value = jnp.array(0, dtype=jnp.float32)
356
+ self.m2.value = jnp.array(0, dtype=jnp.float32)
357
+
358
+ def update(self, **kwargs) -> None:
359
+ """
360
+ Update the metric using Welford's algorithm.
361
+
362
+ Parameters
363
+ ----------
364
+ **kwargs
365
+ Must contain ``self.argname`` as a key, mapping to the values
366
+ to be processed. Values can be scalars, arrays, or tensors.
367
+
368
+ Raises
369
+ ------
370
+ TypeError
371
+ If the expected keyword argument is not provided.
372
+
373
+ Examples
374
+ --------
375
+ .. code-block:: python
376
+
377
+ >>> import jax.numpy as jnp
378
+ >>> import brainstate
379
+ >>> metric = brainstate.nn.WelfordMetric('data')
380
+ >>> metric.update(data=jnp.array([1.0, 2.0, 3.0]))
381
+ >>> stats = metric.compute()
382
+ >>> stats.mean
383
+ Array(2., dtype=float32)
384
+ """
385
+ if self.argname not in kwargs:
386
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
387
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
388
+ count = 1 if isinstance(values, (int, float)) else values.size
389
+ original_count = self.count.value
390
+ self.count.value += count
391
+ delta = (
392
+ values if isinstance(values, (int, float)) else values.mean()
393
+ ) - self.mean.value
394
+ self.mean.value += delta * count / self.count.value
395
+ m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
396
+ self.m2.value += (
397
+ m2 + delta * delta * count * original_count / self.count.value
398
+ )
399
+
400
+ def compute(self) -> Statistics:
401
+ """
402
+ Compute and return statistical measurements.
403
+
404
+ Returns
405
+ -------
406
+ Statistics
407
+ A dataclass containing mean, standard error of mean, and standard deviation.
408
+ Returns NaN for error metrics when count is 0.
409
+ """
410
+ variance = self.m2.value / self.count.value
411
+ standard_deviation = variance ** 0.5
412
+ sem = standard_deviation / (self.count.value ** 0.5)
413
+ return Statistics(
414
+ mean=self.mean.value,
415
+ standard_error_of_mean=sem,
416
+ standard_deviation=standard_deviation,
417
+ )
418
+
419
+
420
+ class AccuracyMetric(AverageMetric):
421
+ """
422
+ Accuracy metric for classification tasks.
423
+
424
+ This metric computes the accuracy by comparing predicted labels (derived from
425
+ logits using argmax) with ground truth labels. It inherits from ``AverageMetric``
426
+ and shares the same ``reset`` and ``compute`` implementations.
427
+
428
+ Examples
429
+ --------
430
+ .. code-block:: python
431
+
432
+ >>> import brainstate
433
+ >>> import jax, jax.numpy as jnp
434
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
435
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
436
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
437
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
438
+ >>> metrics = brainstate.nn.AccuracyMetric()
439
+ >>> metrics.compute()
440
+ Array(nan, dtype=float32)
441
+ >>> metrics.update(logits=logits, labels=labels)
442
+ >>> metrics.compute()
443
+ Array(0.6, dtype=float32)
444
+ >>> metrics.update(logits=logits2, labels=labels2)
445
+ >>> metrics.compute()
446
+ Array(0.7, dtype=float32)
447
+ >>> metrics.reset()
448
+ >>> metrics.compute()
449
+ Array(nan, dtype=float32)
450
+
451
+ Notes
452
+ -----
453
+ The accuracy is computed as the fraction of correct predictions:
454
+ accuracy = (number of correct predictions) / (total predictions)
455
+
456
+ Logits are converted to predictions using argmax along the last dimension.
457
+ """
458
+ __module__ = "brainstate.nn"
459
+
460
+ def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None:
461
+ """
462
+ Update the accuracy metric with predictions and labels.
463
+
464
+ Parameters
465
+ ----------
466
+ logits : jax.Array
467
+ Predicted activations/logits with shape (..., num_classes).
468
+ The last dimension represents class scores.
469
+ labels : jax.Array
470
+ Ground truth integer labels with shape (...,).
471
+ Must be one dimension less than logits.
472
+ **_
473
+ Additional keyword arguments are ignored.
474
+
475
+ Raises
476
+ ------
477
+ ValueError
478
+ If logits and labels have incompatible shapes, or if labels have
479
+ incorrect dtype.
480
+
481
+ Examples
482
+ --------
483
+ .. code-block:: python
484
+
485
+ >>> import jax.numpy as jnp
486
+ >>> import brainstate
487
+ >>> logits = jnp.array([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]])
488
+ >>> labels = jnp.array([1, 0, 1])
489
+ >>> metric = brainstate.nn.AccuracyMetric()
490
+ >>> metric.update(logits=logits, labels=labels)
491
+ >>> metric.compute()
492
+ Array(1., dtype=float32)
493
+ """
494
+ if logits.ndim != labels.ndim + 1:
495
+ raise ValueError(
496
+ f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
497
+ )
498
+ elif labels.dtype in (jnp.int64, np.int32, np.int64):
499
+ labels = jnp.astype(labels, jnp.int32)
500
+ elif labels.dtype != jnp.int32:
501
+ raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
502
+
503
+ super().update(values=(logits.argmax(axis=-1) == labels))
504
+
505
+
506
+ class PrecisionMetric(Metric):
507
+ """
508
+ Precision metric for binary and multi-class classification.
509
+
510
+ Precision is the ratio of true positives to all positive predictions:
511
+ precision = TP / (TP + FP)
512
+
513
+ Parameters
514
+ ----------
515
+ num_classes : int, optional
516
+ Number of classes. If None, assumes binary classification. Default is None.
517
+ average : str, optional
518
+ Type of averaging for multi-class: 'micro', 'macro', or 'weighted'.
519
+ Default is 'macro'. Ignored for binary classification.
520
+
521
+ Attributes
522
+ ----------
523
+ num_classes : int or None
524
+ Number of classes.
525
+ average : str
526
+ Averaging method for multi-class.
527
+ true_positives : MetricState
528
+ Count of true positive predictions.
529
+ false_positives : MetricState
530
+ Count of false positive predictions.
531
+
532
+ Examples
533
+ --------
534
+ .. code-block:: python
535
+
536
+ >>> import jax.numpy as jnp
537
+ >>> import brainstate
538
+ >>> predictions = jnp.array([1, 0, 1, 1, 0])
539
+ >>> labels = jnp.array([1, 0, 0, 1, 0])
540
+ >>> metric = brainstate.nn.PrecisionMetric()
541
+ >>> metric.update(predictions=predictions, labels=labels)
542
+ >>> metric.compute()
543
+ Array(0.6666667, dtype=float32)
544
+
545
+ Notes
546
+ -----
547
+ For multi-class classification, the metric supports different averaging strategies:
548
+ - 'micro': Calculate metrics globally by counting total TP and FP
549
+ - 'macro': Calculate metrics for each class and find their unweighted mean
550
+ - 'weighted': Calculate metrics for each class and find their weighted mean
551
+ """
552
+ __module__ = "brainstate.nn"
553
+
554
+ def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'):
555
+ self.num_classes = num_classes
556
+ self.average = average
557
+ if num_classes is None:
558
+ self.true_positives = MetricState(jnp.array(0, dtype=jnp.int32))
559
+ self.false_positives = MetricState(jnp.array(0, dtype=jnp.int32))
560
+ else:
561
+ self.true_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
562
+ self.false_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
563
+
564
+ def reset(self) -> None:
565
+ """Reset the metric state to zero."""
566
+ if self.num_classes is None:
567
+ self.true_positives.value = jnp.array(0, dtype=jnp.int32)
568
+ self.false_positives.value = jnp.array(0, dtype=jnp.int32)
569
+ else:
570
+ self.true_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
571
+ self.false_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
572
+
573
+ def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
574
+ """
575
+ Update the precision metric.
576
+
577
+ Parameters
578
+ ----------
579
+ predictions : jax.Array
580
+ Predicted class labels (integers).
581
+ labels : jax.Array
582
+ Ground truth class labels (integers).
583
+ **_
584
+ Additional keyword arguments are ignored.
585
+ """
586
+ if self.num_classes is None:
587
+ # Binary classification
588
+ self.true_positives.value += jnp.sum((predictions == 1) & (labels == 1))
589
+ self.false_positives.value += jnp.sum((predictions == 1) & (labels == 0))
590
+ else:
591
+ # Multi-class classification
592
+ for c in range(self.num_classes):
593
+ self.true_positives.value = self.true_positives.value.at[c].add(
594
+ jnp.sum((predictions == c) & (labels == c))
595
+ )
596
+ self.false_positives.value = self.false_positives.value.at[c].add(
597
+ jnp.sum((predictions == c) & (labels != c))
598
+ )
599
+
600
+ def compute(self) -> jax.Array:
601
+ """
602
+ Compute and return the precision.
603
+
604
+ Returns
605
+ -------
606
+ jax.Array
607
+ The precision value(s). For binary classification, returns a scalar.
608
+ For multi-class, returns per-class or averaged precision based on
609
+ the ``average`` parameter.
610
+ """
611
+ denominator = self.true_positives.value + self.false_positives.value
612
+ precision = jnp.where(
613
+ denominator > 0,
614
+ self.true_positives.value / denominator,
615
+ jnp.zeros_like(denominator, dtype=jnp.float32)
616
+ )
617
+
618
+ if self.num_classes is not None and self.average == 'macro':
619
+ return jnp.mean(precision)
620
+ elif self.num_classes is not None and self.average == 'micro':
621
+ total_tp = jnp.sum(self.true_positives.value)
622
+ total_fp = jnp.sum(self.false_positives.value)
623
+ return jnp.where(
624
+ total_tp + total_fp > 0,
625
+ total_tp / (total_tp + total_fp),
626
+ jnp.float32(0.0)
627
+ )
628
+ return precision
629
+
630
+
631
+ class RecallMetric(Metric):
632
+ """
633
+ Recall (sensitivity) metric for binary and multi-class classification.
634
+
635
+ Recall is the ratio of true positives to all actual positives:
636
+ recall = TP / (TP + FN)
637
+
638
+ Parameters
639
+ ----------
640
+ num_classes : int, optional
641
+ Number of classes. If None, assumes binary classification. Default is None.
642
+ average : str, optional
643
+ Type of averaging for multi-class: 'micro', 'macro', or 'weighted'.
644
+ Default is 'macro'. Ignored for binary classification.
645
+
646
+ Attributes
647
+ ----------
648
+ num_classes : int or None
649
+ Number of classes.
650
+ average : str
651
+ Averaging method for multi-class.
652
+ true_positives : MetricState
653
+ Count of true positive predictions.
654
+ false_negatives : MetricState
655
+ Count of false negative predictions.
656
+
657
+ Examples
658
+ --------
659
+ .. code-block:: python
660
+
661
+ >>> import jax.numpy as jnp
662
+ >>> import brainstate
663
+ >>> predictions = jnp.array([1, 0, 1, 1, 0])
664
+ >>> labels = jnp.array([1, 0, 0, 1, 0])
665
+ >>> metric = brainstate.nn.RecallMetric()
666
+ >>> metric.update(predictions=predictions, labels=labels)
667
+ >>> metric.compute()
668
+ Array(1., dtype=float32)
669
+
670
+ Notes
671
+ -----
672
+ Recall measures the fraction of actual positive cases that were correctly identified.
673
+ Also known as sensitivity or true positive rate (TPR).
674
+ """
675
+ __module__ = "brainstate.nn"
676
+
677
+ def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'):
678
+ self.num_classes = num_classes
679
+ self.average = average
680
+ if num_classes is None:
681
+ self.true_positives = MetricState(jnp.array(0, dtype=jnp.int32))
682
+ self.false_negatives = MetricState(jnp.array(0, dtype=jnp.int32))
683
+ else:
684
+ self.true_positives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
685
+ self.false_negatives = MetricState(jnp.zeros(num_classes, dtype=jnp.int32))
686
+
687
+ def reset(self) -> None:
688
+ """Reset the metric state to zero."""
689
+ if self.num_classes is None:
690
+ self.true_positives.value = jnp.array(0, dtype=jnp.int32)
691
+ self.false_negatives.value = jnp.array(0, dtype=jnp.int32)
692
+ else:
693
+ self.true_positives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
694
+ self.false_negatives.value = jnp.zeros(self.num_classes, dtype=jnp.int32)
695
+
696
+ def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
697
+ """
698
+ Update the recall metric.
699
+
700
+ Parameters
701
+ ----------
702
+ predictions : jax.Array
703
+ Predicted class labels (integers).
704
+ labels : jax.Array
705
+ Ground truth class labels (integers).
706
+ **_
707
+ Additional keyword arguments are ignored.
708
+ """
709
+ if self.num_classes is None:
710
+ # Binary classification
711
+ self.true_positives.value += jnp.sum((predictions == 1) & (labels == 1))
712
+ self.false_negatives.value += jnp.sum((predictions == 0) & (labels == 1))
713
+ else:
714
+ # Multi-class classification
715
+ for c in range(self.num_classes):
716
+ self.true_positives.value = self.true_positives.value.at[c].add(
717
+ jnp.sum((predictions == c) & (labels == c))
718
+ )
719
+ self.false_negatives.value = self.false_negatives.value.at[c].add(
720
+ jnp.sum((predictions != c) & (labels == c))
721
+ )
722
+
723
+ def compute(self) -> jax.Array:
724
+ """
725
+ Compute and return the recall.
726
+
727
+ Returns
728
+ -------
729
+ jax.Array
730
+ The recall value(s). For binary classification, returns a scalar.
731
+ For multi-class, returns per-class or averaged recall based on
732
+ the ``average`` parameter.
733
+ """
734
+ denominator = self.true_positives.value + self.false_negatives.value
735
+ recall = jnp.where(
736
+ denominator > 0,
737
+ self.true_positives.value / denominator,
738
+ jnp.zeros_like(denominator, dtype=jnp.float32)
739
+ )
740
+
741
+ if self.num_classes is not None and self.average == 'macro':
742
+ return jnp.mean(recall)
743
+ elif self.num_classes is not None and self.average == 'micro':
744
+ total_tp = jnp.sum(self.true_positives.value)
745
+ total_fn = jnp.sum(self.false_negatives.value)
746
+ return jnp.where(
747
+ total_tp + total_fn > 0,
748
+ total_tp / (total_tp + total_fn),
749
+ jnp.float32(0.0)
750
+ )
751
+ return recall
752
+
753
+
754
+ class F1ScoreMetric(Metric):
755
+ """
756
+ F1 score metric for binary and multi-class classification.
757
+
758
+ F1 score is the harmonic mean of precision and recall:
759
+ F1 = 2 * (precision * recall) / (precision + recall)
760
+
761
+ Parameters
762
+ ----------
763
+ num_classes : int, optional
764
+ Number of classes. If None, assumes binary classification. Default is None.
765
+ average : str, optional
766
+ Type of averaging for multi-class: 'micro', 'macro', or 'weighted'.
767
+ Default is 'macro'. Ignored for binary classification.
768
+
769
+ Attributes
770
+ ----------
771
+ precision_metric : PrecisionMetric
772
+ Internal precision metric.
773
+ recall_metric : RecallMetric
774
+ Internal recall metric.
775
+
776
+ Examples
777
+ --------
778
+ .. code-block:: python
779
+
780
+ >>> import jax.numpy as jnp
781
+ >>> import brainstate
782
+ >>> predictions = jnp.array([1, 0, 1, 1, 0])
783
+ >>> labels = jnp.array([1, 0, 0, 1, 0])
784
+ >>> metric = brainstate.nn.F1ScoreMetric()
785
+ >>> metric.update(predictions=predictions, labels=labels)
786
+ >>> metric.compute()
787
+ Array(0.8, dtype=float32)
788
+
789
+ Notes
790
+ -----
791
+ The F1 score balances precision and recall, providing a single metric that
792
+ considers both false positives and false negatives.
793
+ """
794
+ __module__ = "brainstate.nn"
795
+
796
+ def __init__(self, num_classes: tp.Optional[int] = None, average: str = 'macro'):
797
+ self.precision_metric = PrecisionMetric(num_classes, average)
798
+ self.recall_metric = RecallMetric(num_classes, average)
799
+
800
+ def reset(self) -> None:
801
+ """Reset the metric state to zero."""
802
+ self.precision_metric.reset()
803
+ self.recall_metric.reset()
804
+
805
+ def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
806
+ """
807
+ Update the F1 score metric.
808
+
809
+ Parameters
810
+ ----------
811
+ predictions : jax.Array
812
+ Predicted class labels (integers).
813
+ labels : jax.Array
814
+ Ground truth class labels (integers).
815
+ **_
816
+ Additional keyword arguments are ignored.
817
+ """
818
+ self.precision_metric.update(predictions=predictions, labels=labels)
819
+ self.recall_metric.update(predictions=predictions, labels=labels)
820
+
821
+ def compute(self) -> jax.Array:
822
+ """
823
+ Compute and return the F1 score.
824
+
825
+ Returns
826
+ -------
827
+ jax.Array
828
+ The F1 score value(s). Returns 0 when both precision and recall are 0.
829
+ """
830
+ precision = self.precision_metric.compute()
831
+ recall = self.recall_metric.compute()
832
+ denominator = precision + recall
833
+ return jnp.where(
834
+ denominator > 0,
835
+ 2 * precision * recall / denominator,
836
+ jnp.float32(0.0)
837
+ )
838
+
839
+
840
+ class ConfusionMatrix(Metric):
841
+ """
842
+ Confusion matrix metric for multi-class classification.
843
+
844
+ A confusion matrix shows the counts of predicted vs. actual class labels,
845
+ where rows represent true labels and columns represent predicted labels.
846
+
847
+ Parameters
848
+ ----------
849
+ num_classes : int
850
+ Number of classes in the classification task.
851
+
852
+ Attributes
853
+ ----------
854
+ num_classes : int
855
+ Number of classes.
856
+ matrix : MetricState
857
+ The confusion matrix of shape (num_classes, num_classes).
858
+
859
+ Examples
860
+ --------
861
+ .. code-block:: python
862
+
863
+ >>> import jax.numpy as jnp
864
+ >>> import brainstate
865
+ >>> predictions = jnp.array([0, 1, 2, 1, 0])
866
+ >>> labels = jnp.array([0, 1, 1, 1, 2])
867
+ >>> metric = brainstate.nn.ConfusionMatrix(num_classes=3)
868
+ >>> metric.update(predictions=predictions, labels=labels)
869
+ >>> metric.compute()
870
+ Array([[1, 0, 1],
871
+ [0, 2, 0],
872
+ [1, 0, 0]], dtype=int32)
873
+
874
+ Notes
875
+ -----
876
+ The confusion matrix is useful for understanding which classes are being
877
+ confused with each other and for computing class-specific metrics.
878
+ """
879
+ __module__ = "brainstate.nn"
880
+
881
+ def __init__(self, num_classes: int):
882
+ self.num_classes = num_classes
883
+ self.matrix = MetricState(jnp.zeros((num_classes, num_classes), dtype=jnp.int32))
884
+
885
+ def reset(self) -> None:
886
+ """Reset the confusion matrix to zeros."""
887
+ self.matrix.value = jnp.zeros((self.num_classes, self.num_classes), dtype=jnp.int32)
888
+
889
+ def update(self, *, predictions: jax.Array, labels: jax.Array, **_) -> None:
890
+ """
891
+ Update the confusion matrix.
892
+
893
+ Parameters
894
+ ----------
895
+ predictions : jax.Array
896
+ Predicted class labels (integers) with shape (batch_size,).
897
+ labels : jax.Array
898
+ Ground truth class labels (integers) with shape (batch_size,).
899
+ **_
900
+ Additional keyword arguments are ignored.
901
+
902
+ Raises
903
+ ------
904
+ ValueError
905
+ If predictions or labels contain values outside [0, num_classes).
906
+ """
907
+ predictions = jnp.asarray(predictions, dtype=jnp.int32).flatten()
908
+ labels = jnp.asarray(labels, dtype=jnp.int32).flatten()
909
+
910
+ if jnp.any((predictions < 0) | (predictions >= self.num_classes)):
911
+ raise ValueError(f"Predictions contain values outside [0, {self.num_classes})")
912
+ if jnp.any((labels < 0) | (labels >= self.num_classes)):
913
+ raise ValueError(f"Labels contain values outside [0, {self.num_classes})")
914
+
915
+ for true_label in range(self.num_classes):
916
+ for pred_label in range(self.num_classes):
917
+ count = jnp.sum((labels == true_label) & (predictions == pred_label))
918
+ self.matrix.value = self.matrix.value.at[true_label, pred_label].add(count)
919
+
920
+ def compute(self) -> jax.Array:
921
+ """
922
+ Compute and return the confusion matrix.
923
+
924
+ Returns
925
+ -------
926
+ jax.Array
927
+ The confusion matrix of shape (num_classes, num_classes).
928
+ Element [i, j] represents the count of samples with true label i
929
+ that were predicted as label j.
930
+ """
931
+ return self.matrix.value
932
+
933
+
934
+ class MultiMetric(Metric):
935
+ """
936
+ Container for multiple metrics updated simultaneously.
937
+
938
+ This class allows you to group multiple metrics together and update them
939
+ all with a single call. It's useful for tracking multiple evaluation metrics
940
+ (e.g., accuracy, loss, F1 score) during training or evaluation.
941
+
942
+ Parameters
943
+ ----------
944
+ **metrics
945
+ Keyword arguments where keys are metric names (strings) and values
946
+ are Metric instances.
947
+
948
+ Attributes
949
+ ----------
950
+ _metric_names : list of str
951
+ List of metric names in the order they were added.
952
+
953
+ Examples
954
+ --------
955
+ .. code-block:: python
956
+
957
+ >>> import brainstate
958
+ >>> import jax, jax.numpy as jnp
959
+ >>> metrics = brainstate.nn.MultiMetric(
960
+ ... accuracy=brainstate.nn.AccuracyMetric(),
961
+ ... loss=brainstate.nn.AverageMetric(),
962
+ ... )
963
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
964
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
965
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
966
+ >>> metrics.compute()
967
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
968
+ >>> metrics.update(logits=logits, labels=labels, values=batch_loss)
969
+ >>> metrics.compute()
970
+ {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
971
+ >>> metrics.reset()
972
+ >>> metrics.compute()
973
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
974
+
975
+ Notes
976
+ -----
977
+ All keyword arguments passed to ``update`` are forwarded to all underlying metrics.
978
+ Each metric will extract the arguments it needs based on its implementation.
979
+
980
+ Reserved method names ('reset', 'update', 'compute') cannot be used as metric names.
981
+ """
982
+ __module__ = "brainstate.nn"
983
+
984
+ def __init__(self, **metrics):
985
+ # Validate that no reserved names are used
986
+ reserved_names = {'reset', 'update', 'compute'}
987
+ for metric_name in metrics.keys():
988
+ if metric_name in reserved_names:
989
+ raise ValueError(
990
+ f"Metric name '{metric_name}' is reserved for class methods. "
991
+ f"Please use a different name. Reserved names: {reserved_names}"
992
+ )
993
+
994
+ self._metric_names = []
995
+ for metric_name, metric in metrics.items():
996
+ if not isinstance(metric, Metric):
997
+ raise TypeError(
998
+ f"All metrics must be instances of Metric, got {type(metric)} "
999
+ f"for metric '{metric_name}'"
1000
+ )
1001
+ self._metric_names.append(metric_name)
1002
+ vars(self)[metric_name] = metric
1003
+
1004
+ def reset(self) -> None:
1005
+ """
1006
+ Reset all underlying metrics.
1007
+
1008
+ This calls the ``reset`` method on each metric in the collection.
1009
+ """
1010
+ for metric_name in self._metric_names:
1011
+ getattr(self, metric_name).reset()
1012
+
1013
+ def update(self, **updates) -> None:
1014
+ """
1015
+ Update all underlying metrics.
1016
+
1017
+ All keyword arguments are passed to the ``update`` method of each metric.
1018
+ Individual metrics will extract the arguments they need.
1019
+
1020
+ Parameters
1021
+ ----------
1022
+ **updates
1023
+ Keyword arguments to be passed to all underlying metrics.
1024
+
1025
+ Examples
1026
+ --------
1027
+ .. code-block:: python
1028
+
1029
+ >>> import jax.numpy as jnp
1030
+ >>> import brainstate
1031
+ >>> metrics = brainstate.nn.MultiMetric(
1032
+ ... accuracy=brainstate.nn.AccuracyMetric(),
1033
+ ... loss=brainstate.nn.AverageMetric('loss_value'),
1034
+ ... )
1035
+ >>> logits = jnp.array([[0.2, 0.8], [0.9, 0.1]])
1036
+ >>> labels = jnp.array([1, 0])
1037
+ >>> loss = jnp.array([0.5, 0.3])
1038
+ >>> metrics.update(logits=logits, labels=labels, loss_value=loss)
1039
+ """
1040
+ for metric_name in self._metric_names:
1041
+ getattr(self, metric_name).update(**updates)
1042
+
1043
+ def compute(self) -> dict[str, tp.Any]:
1044
+ """
1045
+ Compute and return all metric values.
1046
+
1047
+ Returns
1048
+ -------
1049
+ dict[str, Any]
1050
+ Dictionary mapping metric names to their computed values.
1051
+ The value type depends on the specific metric implementation.
1052
+
1053
+ Examples
1054
+ --------
1055
+ .. code-block:: python
1056
+
1057
+ >>> import brainstate
1058
+ >>> metrics = brainstate.nn.MultiMetric(
1059
+ ... loss=brainstate.nn.AverageMetric(),
1060
+ ... stats=brainstate.nn.WelfordMetric(),
1061
+ ... )
1062
+ >>> # After updates...
1063
+ >>> results = metrics.compute()
1064
+ >>> results['loss'] # Returns a scalar
1065
+ >>> results['stats'] # Returns a Statistics object
1066
+ """
1067
+ return {
1068
+ metric_name: getattr(self, metric_name).compute()
1069
+ for metric_name in self._metric_names
1070
+ }