brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/metrics.py
CHANGED
@@ -1,388 +1,388 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import typing as tp
|
18
|
-
from dataclasses import dataclass
|
19
|
-
from functools import partial
|
20
|
-
|
21
|
-
import jax
|
22
|
-
import jax.numpy as jnp
|
23
|
-
import numpy as np
|
24
|
-
|
25
|
-
from brainstate._state import State
|
26
|
-
|
27
|
-
__all__ = [
|
28
|
-
'Average',
|
29
|
-
'Statistics',
|
30
|
-
'Welford',
|
31
|
-
'Accuracy',
|
32
|
-
'MultiMetric',
|
33
|
-
]
|
34
|
-
|
35
|
-
|
36
|
-
class MetricState(State):
|
37
|
-
"""Wrapper class for Metric Variables."""
|
38
|
-
pass
|
39
|
-
|
40
|
-
|
41
|
-
class Metric(object):
|
42
|
-
"""Base class for metrics. Any class that subclasses ``Metric`` should
|
43
|
-
implement a ``compute``, ``reset`` and ``update`` method."""
|
44
|
-
|
45
|
-
def reset(self) -> None:
|
46
|
-
"""In-place reset the ``Metric``."""
|
47
|
-
raise NotImplementedError('Must override `reset()` method.')
|
48
|
-
|
49
|
-
def update(self, **kwargs) -> None:
|
50
|
-
"""In-place update the ``Metric``."""
|
51
|
-
raise NotImplementedError('Must override `update()` method.')
|
52
|
-
|
53
|
-
def compute(self):
|
54
|
-
"""Compute and return the value of the ``Metric``."""
|
55
|
-
raise NotImplementedError('Must override `compute()` method.')
|
56
|
-
|
57
|
-
|
58
|
-
class Average(Metric):
|
59
|
-
"""Average metric.
|
60
|
-
|
61
|
-
Example usage::
|
62
|
-
|
63
|
-
>>> import jax.numpy as jnp
|
64
|
-
>>> import brainstate as brainstate
|
65
|
-
|
66
|
-
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
67
|
-
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
68
|
-
|
69
|
-
>>> metrics = brainstate.nn.metrics.Average()
|
70
|
-
>>> metrics.compute()
|
71
|
-
Array(nan, dtype=float32)
|
72
|
-
>>> metrics.update(values=batch_loss)
|
73
|
-
>>> metrics.compute()
|
74
|
-
Array(2.5, dtype=float32)
|
75
|
-
>>> metrics.update(values=batch_loss2)
|
76
|
-
>>> metrics.compute()
|
77
|
-
Array(2., dtype=float32)
|
78
|
-
>>> metrics.reset()
|
79
|
-
>>> metrics.compute()
|
80
|
-
Array(nan, dtype=float32)
|
81
|
-
"""
|
82
|
-
|
83
|
-
def __init__(self, argname: str = 'values'):
|
84
|
-
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
85
|
-
For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
|
86
|
-
``avg.update(test=new_value)``.
|
87
|
-
|
88
|
-
Args:
|
89
|
-
argname: an optional string denoting the key-word argument that
|
90
|
-
:func:`update` will use to derive the new value. Defaults to
|
91
|
-
``'values'``.
|
92
|
-
"""
|
93
|
-
self.argname = argname
|
94
|
-
self.total = MetricState(jnp.array(0, dtype=jnp.float32))
|
95
|
-
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
96
|
-
|
97
|
-
def reset(self) -> None:
|
98
|
-
"""Reset this ``Metric``."""
|
99
|
-
self.total.value = jnp.array(0, dtype=jnp.float32)
|
100
|
-
self.count.value = jnp.array(0, dtype=jnp.int32)
|
101
|
-
|
102
|
-
def update(self, **kwargs) -> None:
|
103
|
-
"""In-place update this ``Metric``. This method will use the value from
|
104
|
-
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
105
|
-
defined on construction.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
**kwargs: the key-word arguments that contains a ``self.argname``
|
109
|
-
entry that maps to the value we want to use to update this metric.
|
110
|
-
"""
|
111
|
-
if self.argname not in kwargs:
|
112
|
-
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
113
|
-
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
114
|
-
self.total.value += (
|
115
|
-
values if isinstance(values, (int, float)) else values.sum()
|
116
|
-
)
|
117
|
-
self.count.value += 1 if isinstance(values, (int, float)) else values.size
|
118
|
-
|
119
|
-
def compute(self) -> jax.Array:
|
120
|
-
"""Compute and return the average."""
|
121
|
-
return self.total.value / self.count.value
|
122
|
-
|
123
|
-
|
124
|
-
@partial(jax.tree_util.register_dataclass,
|
125
|
-
data_fields=['mean', 'standard_error_of_mean', 'standard_deviation'],
|
126
|
-
meta_fields=[])
|
127
|
-
@dataclass
|
128
|
-
class Statistics:
|
129
|
-
mean: jnp.float32
|
130
|
-
standard_error_of_mean: jnp.float32
|
131
|
-
standard_deviation: jnp.float32
|
132
|
-
|
133
|
-
|
134
|
-
class Welford(Metric):
|
135
|
-
"""Uses Welford's algorithm to compute the mean and variance of a stream of data.
|
136
|
-
|
137
|
-
Example usage::
|
138
|
-
|
139
|
-
>>> import jax.numpy as jnp
|
140
|
-
>>> from brainstate import nn
|
141
|
-
|
142
|
-
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
143
|
-
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
144
|
-
|
145
|
-
>>> metrics = nn.metrics.Welford()
|
146
|
-
>>> metrics.compute()
|
147
|
-
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
148
|
-
>>> metrics.update(values=batch_loss)
|
149
|
-
>>> metrics.compute()
|
150
|
-
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
|
151
|
-
>>> metrics.update(values=batch_loss2)
|
152
|
-
>>> metrics.compute()
|
153
|
-
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
|
154
|
-
>>> metrics.reset()
|
155
|
-
>>> metrics.compute()
|
156
|
-
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
157
|
-
"""
|
158
|
-
|
159
|
-
def __init__(self, argname: str = 'values'):
|
160
|
-
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
161
|
-
For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
|
162
|
-
``wf.update(test=new_value)``.
|
163
|
-
|
164
|
-
Args:
|
165
|
-
argname: an optional string denoting the key-word argument that
|
166
|
-
:func:`update` will use to derive the new value. Defaults to
|
167
|
-
``'values'``.
|
168
|
-
"""
|
169
|
-
self.argname = argname
|
170
|
-
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
171
|
-
self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
|
172
|
-
self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
|
173
|
-
|
174
|
-
def reset(self) -> None:
|
175
|
-
"""Reset this ``Metric``."""
|
176
|
-
self.count.value = jnp.array(0, dtype=jnp.uint32)
|
177
|
-
self.mean.value = jnp.array(0, dtype=jnp.float32)
|
178
|
-
self.m2.value = jnp.array(0, dtype=jnp.float32)
|
179
|
-
|
180
|
-
def update(self, **kwargs) -> None:
|
181
|
-
"""In-place update this ``Metric``. This method will use the value from
|
182
|
-
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
183
|
-
defined on construction.
|
184
|
-
|
185
|
-
Args:
|
186
|
-
**kwargs: the key-word arguments that contains a ``self.argname``
|
187
|
-
entry that maps to the value we want to use to update this metric.
|
188
|
-
"""
|
189
|
-
if self.argname not in kwargs:
|
190
|
-
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
191
|
-
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
192
|
-
count = 1 if isinstance(values, (int, float)) else values.size
|
193
|
-
original_count = self.count.value
|
194
|
-
self.count.value += count
|
195
|
-
delta = (
|
196
|
-
values if isinstance(values, (int, float)) else values.mean()
|
197
|
-
) - self.mean.value
|
198
|
-
self.mean.value += delta * count / self.count.value
|
199
|
-
m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
|
200
|
-
self.m2.value += (
|
201
|
-
m2 + delta * delta * count * original_count / self.count
|
202
|
-
)
|
203
|
-
|
204
|
-
def compute(self) -> Statistics:
|
205
|
-
"""Compute and return the mean and variance statistics in a
|
206
|
-
``Statistics`` dataclass object.
|
207
|
-
"""
|
208
|
-
variance = self.m2.value / self.count.value
|
209
|
-
standard_deviation = variance ** 0.5
|
210
|
-
sem = standard_deviation / (self.count.value ** 0.5)
|
211
|
-
return Statistics(
|
212
|
-
mean=self.mean.value,
|
213
|
-
standard_error_of_mean=sem,
|
214
|
-
standard_deviation=standard_deviation,
|
215
|
-
)
|
216
|
-
|
217
|
-
|
218
|
-
class Accuracy(Average):
|
219
|
-
"""Accuracy metric. This metric subclasses :class:`Average`,
|
220
|
-
and so they share the same ``reset`` and ``compute`` method
|
221
|
-
implementations. Unlike :class:`Average`, no string needs to
|
222
|
-
be passed to ``Accuracy`` during construction.
|
223
|
-
|
224
|
-
Example usage::
|
225
|
-
|
226
|
-
>>> import brainstate as brainstate
|
227
|
-
>>> import jax, jax.numpy as jnp
|
228
|
-
|
229
|
-
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
230
|
-
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
231
|
-
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
232
|
-
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
233
|
-
|
234
|
-
>>> metrics = brainstate.nn.metrics.Accuracy()
|
235
|
-
>>> metrics.compute()
|
236
|
-
Array(nan, dtype=float32)
|
237
|
-
>>> metrics.update(logits=logits, labels=labels)
|
238
|
-
>>> metrics.compute()
|
239
|
-
Array(0.6, dtype=float32)
|
240
|
-
>>> metrics.update(logits=logits2, labels=labels2)
|
241
|
-
>>> metrics.compute()
|
242
|
-
Array(0.7, dtype=float32)
|
243
|
-
>>> metrics.reset()
|
244
|
-
>>> metrics.compute()
|
245
|
-
Array(nan, dtype=float32)
|
246
|
-
"""
|
247
|
-
|
248
|
-
def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
|
249
|
-
"""In-place update this ``Metric``.
|
250
|
-
|
251
|
-
Args:
|
252
|
-
logits: the outputted predicted activations. These values are
|
253
|
-
argmax-ed (on the trailing dimension), before comparing them
|
254
|
-
to the labels.
|
255
|
-
labels: the ground truth integer labels.
|
256
|
-
"""
|
257
|
-
if logits.ndim != labels.ndim + 1:
|
258
|
-
raise ValueError(
|
259
|
-
f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
|
260
|
-
)
|
261
|
-
elif labels.dtype in (jnp.int64, np.int32, np.int64):
|
262
|
-
labels = jnp.astype(labels, jnp.int32)
|
263
|
-
elif labels.dtype != jnp.int32:
|
264
|
-
raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
|
265
|
-
|
266
|
-
super().update(values=(logits.argmax(axis=-1) == labels))
|
267
|
-
|
268
|
-
|
269
|
-
class MultiMetric(Metric):
|
270
|
-
"""MultiMetric class to store multiple metrics and update them in a single call.
|
271
|
-
|
272
|
-
Example usage::
|
273
|
-
|
274
|
-
>>> from brainstate import nn
|
275
|
-
>>> import jax, jax.numpy as jnp
|
276
|
-
|
277
|
-
>>> metrics = nn.metrics.MultiMetric(
|
278
|
-
... accuracy=nn.metrics.Accuracy(),
|
279
|
-
... loss=nn.metrics.Average(),
|
280
|
-
... )
|
281
|
-
|
282
|
-
>>> metrics
|
283
|
-
MultiMetric(
|
284
|
-
accuracy=Accuracy(
|
285
|
-
argname='values',
|
286
|
-
total=MetricState(
|
287
|
-
value=Array(0., dtype=float32)
|
288
|
-
),
|
289
|
-
count=MetricState(
|
290
|
-
value=Array(0, dtype=int32)
|
291
|
-
)
|
292
|
-
),
|
293
|
-
loss=Average(
|
294
|
-
argname='values',
|
295
|
-
total=MetricState(
|
296
|
-
value=Array(0., dtype=float32)
|
297
|
-
),
|
298
|
-
count=MetricState(
|
299
|
-
value=Array(0, dtype=int32)
|
300
|
-
)
|
301
|
-
)
|
302
|
-
)
|
303
|
-
|
304
|
-
>>> metrics.accuracy
|
305
|
-
Accuracy(
|
306
|
-
argname='values',
|
307
|
-
total=MetricState(
|
308
|
-
value=Array(0., dtype=float32)
|
309
|
-
),
|
310
|
-
count=MetricState(
|
311
|
-
value=Array(0, dtype=int32)
|
312
|
-
)
|
313
|
-
)
|
314
|
-
|
315
|
-
>>> metrics.loss
|
316
|
-
Average(
|
317
|
-
argname='values',
|
318
|
-
total=MetricState(
|
319
|
-
value=Array(0., dtype=float32)
|
320
|
-
),
|
321
|
-
count=MetricState(
|
322
|
-
value=Array(0, dtype=int32)
|
323
|
-
)
|
324
|
-
)
|
325
|
-
|
326
|
-
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
327
|
-
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
328
|
-
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
329
|
-
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
330
|
-
|
331
|
-
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
332
|
-
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
333
|
-
|
334
|
-
>>> metrics.compute()
|
335
|
-
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
336
|
-
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
|
337
|
-
>>> metrics.compute()
|
338
|
-
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
|
339
|
-
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
|
340
|
-
>>> metrics.compute()
|
341
|
-
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
|
342
|
-
>>> metrics.reset()
|
343
|
-
>>> metrics.compute()
|
344
|
-
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
345
|
-
"""
|
346
|
-
|
347
|
-
def __init__(self, **metrics):
|
348
|
-
"""Pass in key-word arguments to the constructor, e.g.
|
349
|
-
``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
|
350
|
-
|
351
|
-
Args:
|
352
|
-
**metrics: the key-word arguments that will be used to access
|
353
|
-
the corresponding ``Metric``.
|
354
|
-
"""
|
355
|
-
# TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
|
356
|
-
self._metric_names = []
|
357
|
-
for metric_name, metric in metrics.items():
|
358
|
-
self._metric_names.append(metric_name)
|
359
|
-
vars(self)[metric_name] = metric
|
360
|
-
|
361
|
-
def reset(self) -> None:
|
362
|
-
"""Reset all underlying ``Metric``'s."""
|
363
|
-
for metric_name in self._metric_names:
|
364
|
-
getattr(self, metric_name).reset()
|
365
|
-
|
366
|
-
def update(self, **updates) -> None:
|
367
|
-
"""In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
|
368
|
-
``**updates`` will be passed to the ``update`` method of all underlying
|
369
|
-
``Metric``'s.
|
370
|
-
|
371
|
-
Args:
|
372
|
-
**updates: the key-word arguments that will be passed to the underlying ``Metric``'s
|
373
|
-
``update`` method.
|
374
|
-
"""
|
375
|
-
# TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
|
376
|
-
# TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
|
377
|
-
for metric_name in self._metric_names:
|
378
|
-
getattr(self, metric_name).update(**updates)
|
379
|
-
|
380
|
-
def compute(self) -> dict[str, Metric]:
|
381
|
-
"""Compute and return the value of all underlying ``Metric``'s. This method
|
382
|
-
will return a dictionary, mapping strings (defined by the key-word arguments
|
383
|
-
``**metrics`` passed to the constructor) to the corresponding metric value.
|
384
|
-
"""
|
385
|
-
return {
|
386
|
-
f'{metric_name}': getattr(self, metric_name).compute()
|
387
|
-
for metric_name in self._metric_names
|
388
|
-
}
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import typing as tp
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from functools import partial
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from brainstate._state import State
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'Average',
|
29
|
+
'Statistics',
|
30
|
+
'Welford',
|
31
|
+
'Accuracy',
|
32
|
+
'MultiMetric',
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
class MetricState(State):
|
37
|
+
"""Wrapper class for Metric Variables."""
|
38
|
+
pass
|
39
|
+
|
40
|
+
|
41
|
+
class Metric(object):
|
42
|
+
"""Base class for metrics. Any class that subclasses ``Metric`` should
|
43
|
+
implement a ``compute``, ``reset`` and ``update`` method."""
|
44
|
+
|
45
|
+
def reset(self) -> None:
|
46
|
+
"""In-place reset the ``Metric``."""
|
47
|
+
raise NotImplementedError('Must override `reset()` method.')
|
48
|
+
|
49
|
+
def update(self, **kwargs) -> None:
|
50
|
+
"""In-place update the ``Metric``."""
|
51
|
+
raise NotImplementedError('Must override `update()` method.')
|
52
|
+
|
53
|
+
def compute(self):
|
54
|
+
"""Compute and return the value of the ``Metric``."""
|
55
|
+
raise NotImplementedError('Must override `compute()` method.')
|
56
|
+
|
57
|
+
|
58
|
+
class Average(Metric):
|
59
|
+
"""Average metric.
|
60
|
+
|
61
|
+
Example usage::
|
62
|
+
|
63
|
+
>>> import jax.numpy as jnp
|
64
|
+
>>> import brainstate as brainstate
|
65
|
+
|
66
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
67
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
68
|
+
|
69
|
+
>>> metrics = brainstate.nn.metrics.Average()
|
70
|
+
>>> metrics.compute()
|
71
|
+
Array(nan, dtype=float32)
|
72
|
+
>>> metrics.update(values=batch_loss)
|
73
|
+
>>> metrics.compute()
|
74
|
+
Array(2.5, dtype=float32)
|
75
|
+
>>> metrics.update(values=batch_loss2)
|
76
|
+
>>> metrics.compute()
|
77
|
+
Array(2., dtype=float32)
|
78
|
+
>>> metrics.reset()
|
79
|
+
>>> metrics.compute()
|
80
|
+
Array(nan, dtype=float32)
|
81
|
+
"""
|
82
|
+
|
83
|
+
def __init__(self, argname: str = 'values'):
|
84
|
+
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
85
|
+
For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
|
86
|
+
``avg.update(test=new_value)``.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
argname: an optional string denoting the key-word argument that
|
90
|
+
:func:`update` will use to derive the new value. Defaults to
|
91
|
+
``'values'``.
|
92
|
+
"""
|
93
|
+
self.argname = argname
|
94
|
+
self.total = MetricState(jnp.array(0, dtype=jnp.float32))
|
95
|
+
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
96
|
+
|
97
|
+
def reset(self) -> None:
|
98
|
+
"""Reset this ``Metric``."""
|
99
|
+
self.total.value = jnp.array(0, dtype=jnp.float32)
|
100
|
+
self.count.value = jnp.array(0, dtype=jnp.int32)
|
101
|
+
|
102
|
+
def update(self, **kwargs) -> None:
|
103
|
+
"""In-place update this ``Metric``. This method will use the value from
|
104
|
+
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
105
|
+
defined on construction.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
**kwargs: the key-word arguments that contains a ``self.argname``
|
109
|
+
entry that maps to the value we want to use to update this metric.
|
110
|
+
"""
|
111
|
+
if self.argname not in kwargs:
|
112
|
+
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
113
|
+
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
114
|
+
self.total.value += (
|
115
|
+
values if isinstance(values, (int, float)) else values.sum()
|
116
|
+
)
|
117
|
+
self.count.value += 1 if isinstance(values, (int, float)) else values.size
|
118
|
+
|
119
|
+
def compute(self) -> jax.Array:
|
120
|
+
"""Compute and return the average."""
|
121
|
+
return self.total.value / self.count.value
|
122
|
+
|
123
|
+
|
124
|
+
@partial(jax.tree_util.register_dataclass,
|
125
|
+
data_fields=['mean', 'standard_error_of_mean', 'standard_deviation'],
|
126
|
+
meta_fields=[])
|
127
|
+
@dataclass
|
128
|
+
class Statistics:
|
129
|
+
mean: jnp.float32
|
130
|
+
standard_error_of_mean: jnp.float32
|
131
|
+
standard_deviation: jnp.float32
|
132
|
+
|
133
|
+
|
134
|
+
class Welford(Metric):
|
135
|
+
"""Uses Welford's algorithm to compute the mean and variance of a stream of data.
|
136
|
+
|
137
|
+
Example usage::
|
138
|
+
|
139
|
+
>>> import jax.numpy as jnp
|
140
|
+
>>> from brainstate import nn
|
141
|
+
|
142
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
143
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
144
|
+
|
145
|
+
>>> metrics = nn.metrics.Welford()
|
146
|
+
>>> metrics.compute()
|
147
|
+
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
148
|
+
>>> metrics.update(values=batch_loss)
|
149
|
+
>>> metrics.compute()
|
150
|
+
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
|
151
|
+
>>> metrics.update(values=batch_loss2)
|
152
|
+
>>> metrics.compute()
|
153
|
+
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
|
154
|
+
>>> metrics.reset()
|
155
|
+
>>> metrics.compute()
|
156
|
+
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
157
|
+
"""
|
158
|
+
|
159
|
+
def __init__(self, argname: str = 'values'):
|
160
|
+
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
161
|
+
For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
|
162
|
+
``wf.update(test=new_value)``.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
argname: an optional string denoting the key-word argument that
|
166
|
+
:func:`update` will use to derive the new value. Defaults to
|
167
|
+
``'values'``.
|
168
|
+
"""
|
169
|
+
self.argname = argname
|
170
|
+
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
171
|
+
self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
|
172
|
+
self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
|
173
|
+
|
174
|
+
def reset(self) -> None:
|
175
|
+
"""Reset this ``Metric``."""
|
176
|
+
self.count.value = jnp.array(0, dtype=jnp.uint32)
|
177
|
+
self.mean.value = jnp.array(0, dtype=jnp.float32)
|
178
|
+
self.m2.value = jnp.array(0, dtype=jnp.float32)
|
179
|
+
|
180
|
+
def update(self, **kwargs) -> None:
|
181
|
+
"""In-place update this ``Metric``. This method will use the value from
|
182
|
+
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
183
|
+
defined on construction.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
**kwargs: the key-word arguments that contains a ``self.argname``
|
187
|
+
entry that maps to the value we want to use to update this metric.
|
188
|
+
"""
|
189
|
+
if self.argname not in kwargs:
|
190
|
+
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
191
|
+
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
192
|
+
count = 1 if isinstance(values, (int, float)) else values.size
|
193
|
+
original_count = self.count.value
|
194
|
+
self.count.value += count
|
195
|
+
delta = (
|
196
|
+
values if isinstance(values, (int, float)) else values.mean()
|
197
|
+
) - self.mean.value
|
198
|
+
self.mean.value += delta * count / self.count.value
|
199
|
+
m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
|
200
|
+
self.m2.value += (
|
201
|
+
m2 + delta * delta * count * original_count / self.count
|
202
|
+
)
|
203
|
+
|
204
|
+
def compute(self) -> Statistics:
|
205
|
+
"""Compute and return the mean and variance statistics in a
|
206
|
+
``Statistics`` dataclass object.
|
207
|
+
"""
|
208
|
+
variance = self.m2.value / self.count.value
|
209
|
+
standard_deviation = variance ** 0.5
|
210
|
+
sem = standard_deviation / (self.count.value ** 0.5)
|
211
|
+
return Statistics(
|
212
|
+
mean=self.mean.value,
|
213
|
+
standard_error_of_mean=sem,
|
214
|
+
standard_deviation=standard_deviation,
|
215
|
+
)
|
216
|
+
|
217
|
+
|
218
|
+
class Accuracy(Average):
|
219
|
+
"""Accuracy metric. This metric subclasses :class:`Average`,
|
220
|
+
and so they share the same ``reset`` and ``compute`` method
|
221
|
+
implementations. Unlike :class:`Average`, no string needs to
|
222
|
+
be passed to ``Accuracy`` during construction.
|
223
|
+
|
224
|
+
Example usage::
|
225
|
+
|
226
|
+
>>> import brainstate as brainstate
|
227
|
+
>>> import jax, jax.numpy as jnp
|
228
|
+
|
229
|
+
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
230
|
+
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
231
|
+
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
232
|
+
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
233
|
+
|
234
|
+
>>> metrics = brainstate.nn.metrics.Accuracy()
|
235
|
+
>>> metrics.compute()
|
236
|
+
Array(nan, dtype=float32)
|
237
|
+
>>> metrics.update(logits=logits, labels=labels)
|
238
|
+
>>> metrics.compute()
|
239
|
+
Array(0.6, dtype=float32)
|
240
|
+
>>> metrics.update(logits=logits2, labels=labels2)
|
241
|
+
>>> metrics.compute()
|
242
|
+
Array(0.7, dtype=float32)
|
243
|
+
>>> metrics.reset()
|
244
|
+
>>> metrics.compute()
|
245
|
+
Array(nan, dtype=float32)
|
246
|
+
"""
|
247
|
+
|
248
|
+
def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
|
249
|
+
"""In-place update this ``Metric``.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
logits: the outputted predicted activations. These values are
|
253
|
+
argmax-ed (on the trailing dimension), before comparing them
|
254
|
+
to the labels.
|
255
|
+
labels: the ground truth integer labels.
|
256
|
+
"""
|
257
|
+
if logits.ndim != labels.ndim + 1:
|
258
|
+
raise ValueError(
|
259
|
+
f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
|
260
|
+
)
|
261
|
+
elif labels.dtype in (jnp.int64, np.int32, np.int64):
|
262
|
+
labels = jnp.astype(labels, jnp.int32)
|
263
|
+
elif labels.dtype != jnp.int32:
|
264
|
+
raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
|
265
|
+
|
266
|
+
super().update(values=(logits.argmax(axis=-1) == labels))
|
267
|
+
|
268
|
+
|
269
|
+
class MultiMetric(Metric):
|
270
|
+
"""MultiMetric class to store multiple metrics and update them in a single call.
|
271
|
+
|
272
|
+
Example usage::
|
273
|
+
|
274
|
+
>>> from brainstate import nn
|
275
|
+
>>> import jax, jax.numpy as jnp
|
276
|
+
|
277
|
+
>>> metrics = nn.metrics.MultiMetric(
|
278
|
+
... accuracy=nn.metrics.Accuracy(),
|
279
|
+
... loss=nn.metrics.Average(),
|
280
|
+
... )
|
281
|
+
|
282
|
+
>>> metrics
|
283
|
+
MultiMetric(
|
284
|
+
accuracy=Accuracy(
|
285
|
+
argname='values',
|
286
|
+
total=MetricState(
|
287
|
+
value=Array(0., dtype=float32)
|
288
|
+
),
|
289
|
+
count=MetricState(
|
290
|
+
value=Array(0, dtype=int32)
|
291
|
+
)
|
292
|
+
),
|
293
|
+
loss=Average(
|
294
|
+
argname='values',
|
295
|
+
total=MetricState(
|
296
|
+
value=Array(0., dtype=float32)
|
297
|
+
),
|
298
|
+
count=MetricState(
|
299
|
+
value=Array(0, dtype=int32)
|
300
|
+
)
|
301
|
+
)
|
302
|
+
)
|
303
|
+
|
304
|
+
>>> metrics.accuracy
|
305
|
+
Accuracy(
|
306
|
+
argname='values',
|
307
|
+
total=MetricState(
|
308
|
+
value=Array(0., dtype=float32)
|
309
|
+
),
|
310
|
+
count=MetricState(
|
311
|
+
value=Array(0, dtype=int32)
|
312
|
+
)
|
313
|
+
)
|
314
|
+
|
315
|
+
>>> metrics.loss
|
316
|
+
Average(
|
317
|
+
argname='values',
|
318
|
+
total=MetricState(
|
319
|
+
value=Array(0., dtype=float32)
|
320
|
+
),
|
321
|
+
count=MetricState(
|
322
|
+
value=Array(0, dtype=int32)
|
323
|
+
)
|
324
|
+
)
|
325
|
+
|
326
|
+
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
327
|
+
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
328
|
+
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
329
|
+
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
330
|
+
|
331
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
332
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
333
|
+
|
334
|
+
>>> metrics.compute()
|
335
|
+
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
336
|
+
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
|
337
|
+
>>> metrics.compute()
|
338
|
+
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
|
339
|
+
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
|
340
|
+
>>> metrics.compute()
|
341
|
+
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
|
342
|
+
>>> metrics.reset()
|
343
|
+
>>> metrics.compute()
|
344
|
+
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
345
|
+
"""
|
346
|
+
|
347
|
+
def __init__(self, **metrics):
|
348
|
+
"""Pass in key-word arguments to the constructor, e.g.
|
349
|
+
``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
**metrics: the key-word arguments that will be used to access
|
353
|
+
the corresponding ``Metric``.
|
354
|
+
"""
|
355
|
+
# TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
|
356
|
+
self._metric_names = []
|
357
|
+
for metric_name, metric in metrics.items():
|
358
|
+
self._metric_names.append(metric_name)
|
359
|
+
vars(self)[metric_name] = metric
|
360
|
+
|
361
|
+
def reset(self) -> None:
|
362
|
+
"""Reset all underlying ``Metric``'s."""
|
363
|
+
for metric_name in self._metric_names:
|
364
|
+
getattr(self, metric_name).reset()
|
365
|
+
|
366
|
+
def update(self, **updates) -> None:
|
367
|
+
"""In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
|
368
|
+
``**updates`` will be passed to the ``update`` method of all underlying
|
369
|
+
``Metric``'s.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
**updates: the key-word arguments that will be passed to the underlying ``Metric``'s
|
373
|
+
``update`` method.
|
374
|
+
"""
|
375
|
+
# TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
|
376
|
+
# TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
|
377
|
+
for metric_name in self._metric_names:
|
378
|
+
getattr(self, metric_name).update(**updates)
|
379
|
+
|
380
|
+
def compute(self) -> dict[str, Metric]:
|
381
|
+
"""Compute and return the value of all underlying ``Metric``'s. This method
|
382
|
+
will return a dictionary, mapping strings (defined by the key-word arguments
|
383
|
+
``**metrics`` passed to the constructor) to the corresponding metric value.
|
384
|
+
"""
|
385
|
+
return {
|
386
|
+
f'{metric_name}': getattr(self, metric_name).compute()
|
387
|
+
for metric_name in self._metric_names
|
388
|
+
}
|