brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/nn/metrics.py
CHANGED
@@ -27,100 +27,100 @@ import numpy as np
|
|
27
27
|
from brainstate._state import State
|
28
28
|
|
29
29
|
__all__ = [
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
30
|
+
'Average',
|
31
|
+
'Statistics',
|
32
|
+
'Welford',
|
33
|
+
'Accuracy',
|
34
|
+
'MultiMetric',
|
35
35
|
]
|
36
36
|
|
37
37
|
|
38
38
|
class MetricState(State):
|
39
|
-
|
40
|
-
|
39
|
+
"""Wrapper class for Metric Variables."""
|
40
|
+
pass
|
41
41
|
|
42
42
|
|
43
43
|
class Metric(object):
|
44
|
-
|
45
|
-
|
44
|
+
"""Base class for metrics. Any class that subclasses ``Metric`` should
|
45
|
+
implement a ``compute``, ``reset`` and ``update`` method."""
|
46
46
|
|
47
|
-
|
48
|
-
|
49
|
-
|
47
|
+
def reset(self) -> None:
|
48
|
+
"""In-place reset the ``Metric``."""
|
49
|
+
raise NotImplementedError('Must override `reset()` method.')
|
50
50
|
|
51
|
-
|
52
|
-
|
53
|
-
|
51
|
+
def update(self, **kwargs) -> None:
|
52
|
+
"""In-place update the ``Metric``."""
|
53
|
+
raise NotImplementedError('Must override `update()` method.')
|
54
54
|
|
55
|
-
|
56
|
-
|
57
|
-
|
55
|
+
def compute(self):
|
56
|
+
"""Compute and return the value of the ``Metric``."""
|
57
|
+
raise NotImplementedError('Must override `compute()` method.')
|
58
58
|
|
59
59
|
|
60
60
|
class Average(Metric):
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
"""
|
84
|
-
|
85
|
-
def __init__(self, argname: str = 'values'):
|
86
|
-
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
87
|
-
For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
|
88
|
-
``avg.update(test=new_value)``.
|
89
|
-
|
90
|
-
Args:
|
91
|
-
argname: an optional string denoting the key-word argument that
|
92
|
-
:func:`update` will use to derive the new value. Defaults to
|
93
|
-
``'values'``.
|
61
|
+
"""Average metric.
|
62
|
+
|
63
|
+
Example usage::
|
64
|
+
|
65
|
+
>>> import jax.numpy as jnp
|
66
|
+
>>> import brainstate as bst
|
67
|
+
|
68
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
69
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
70
|
+
|
71
|
+
>>> metrics = bst.nn.metrics.Average()
|
72
|
+
>>> metrics.compute()
|
73
|
+
Array(nan, dtype=float32)
|
74
|
+
>>> metrics.update(values=batch_loss)
|
75
|
+
>>> metrics.compute()
|
76
|
+
Array(2.5, dtype=float32)
|
77
|
+
>>> metrics.update(values=batch_loss2)
|
78
|
+
>>> metrics.compute()
|
79
|
+
Array(2., dtype=float32)
|
80
|
+
>>> metrics.reset()
|
81
|
+
>>> metrics.compute()
|
82
|
+
Array(nan, dtype=float32)
|
94
83
|
"""
|
95
|
-
self.argname = argname
|
96
|
-
self.total = MetricState(jnp.array(0, dtype=jnp.float32))
|
97
|
-
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
98
|
-
|
99
|
-
def reset(self) -> None:
|
100
|
-
"""Reset this ``Metric``."""
|
101
|
-
self.total.value = jnp.array(0, dtype=jnp.float32)
|
102
|
-
self.count.value = jnp.array(0, dtype=jnp.int32)
|
103
|
-
|
104
|
-
def update(self, **kwargs) -> None:
|
105
|
-
"""In-place update this ``Metric``. This method will use the value from
|
106
|
-
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
107
|
-
defined on construction.
|
108
|
-
|
109
|
-
Args:
|
110
|
-
**kwargs: the key-word arguments that contains a ``self.argname``
|
111
|
-
entry that maps to the value we want to use to update this metric.
|
112
|
-
"""
|
113
|
-
if self.argname not in kwargs:
|
114
|
-
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
115
|
-
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
116
|
-
self.total.value += (
|
117
|
-
values if isinstance(values, (int, float)) else values.sum()
|
118
|
-
)
|
119
|
-
self.count.value += 1 if isinstance(values, (int, float)) else values.size
|
120
84
|
|
121
|
-
|
122
|
-
|
123
|
-
|
85
|
+
def __init__(self, argname: str = 'values'):
|
86
|
+
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
87
|
+
For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
|
88
|
+
``avg.update(test=new_value)``.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
argname: an optional string denoting the key-word argument that
|
92
|
+
:func:`update` will use to derive the new value. Defaults to
|
93
|
+
``'values'``.
|
94
|
+
"""
|
95
|
+
self.argname = argname
|
96
|
+
self.total = MetricState(jnp.array(0, dtype=jnp.float32))
|
97
|
+
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
98
|
+
|
99
|
+
def reset(self) -> None:
|
100
|
+
"""Reset this ``Metric``."""
|
101
|
+
self.total.value = jnp.array(0, dtype=jnp.float32)
|
102
|
+
self.count.value = jnp.array(0, dtype=jnp.int32)
|
103
|
+
|
104
|
+
def update(self, **kwargs) -> None:
|
105
|
+
"""In-place update this ``Metric``. This method will use the value from
|
106
|
+
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
107
|
+
defined on construction.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
**kwargs: the key-word arguments that contains a ``self.argname``
|
111
|
+
entry that maps to the value we want to use to update this metric.
|
112
|
+
"""
|
113
|
+
if self.argname not in kwargs:
|
114
|
+
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
115
|
+
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
116
|
+
self.total.value += (
|
117
|
+
values if isinstance(values, (int, float)) else values.sum()
|
118
|
+
)
|
119
|
+
self.count.value += 1 if isinstance(values, (int, float)) else values.size
|
120
|
+
|
121
|
+
def compute(self) -> jax.Array:
|
122
|
+
"""Compute and return the average."""
|
123
|
+
return self.total.value / self.count.value
|
124
124
|
|
125
125
|
|
126
126
|
@partial(jax.tree_util.register_dataclass,
|
@@ -128,162 +128,183 @@ class Average(Metric):
|
|
128
128
|
meta_fields=[])
|
129
129
|
@dataclass
|
130
130
|
class Statistics:
|
131
|
-
|
132
|
-
|
133
|
-
|
131
|
+
mean: jnp.float32
|
132
|
+
standard_error_of_mean: jnp.float32
|
133
|
+
standard_deviation: jnp.float32
|
134
134
|
|
135
135
|
|
136
136
|
class Welford(Metric):
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
"""
|
160
|
-
|
161
|
-
def __init__(self, argname: str = 'values'):
|
162
|
-
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
163
|
-
For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
|
164
|
-
``wf.update(test=new_value)``.
|
165
|
-
|
166
|
-
Args:
|
167
|
-
argname: an optional string denoting the key-word argument that
|
168
|
-
:func:`update` will use to derive the new value. Defaults to
|
169
|
-
``'values'``.
|
137
|
+
"""Uses Welford's algorithm to compute the mean and variance of a stream of data.
|
138
|
+
|
139
|
+
Example usage::
|
140
|
+
|
141
|
+
>>> import jax.numpy as jnp
|
142
|
+
>>> from brainstate import nn
|
143
|
+
|
144
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
145
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
146
|
+
|
147
|
+
>>> metrics = nn.metrics.Welford()
|
148
|
+
>>> metrics.compute()
|
149
|
+
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
150
|
+
>>> metrics.update(values=batch_loss)
|
151
|
+
>>> metrics.compute()
|
152
|
+
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
|
153
|
+
>>> metrics.update(values=batch_loss2)
|
154
|
+
>>> metrics.compute()
|
155
|
+
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
|
156
|
+
>>> metrics.reset()
|
157
|
+
>>> metrics.compute()
|
158
|
+
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
|
170
159
|
"""
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
)
|
160
|
+
|
161
|
+
def __init__(self, argname: str = 'values'):
|
162
|
+
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
|
163
|
+
For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
|
164
|
+
``wf.update(test=new_value)``.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
argname: an optional string denoting the key-word argument that
|
168
|
+
:func:`update` will use to derive the new value. Defaults to
|
169
|
+
``'values'``.
|
170
|
+
"""
|
171
|
+
self.argname = argname
|
172
|
+
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
|
173
|
+
self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
|
174
|
+
self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
|
175
|
+
|
176
|
+
def reset(self) -> None:
|
177
|
+
"""Reset this ``Metric``."""
|
178
|
+
self.count.value = jnp.array(0, dtype=jnp.uint32)
|
179
|
+
self.mean.value = jnp.array(0, dtype=jnp.float32)
|
180
|
+
self.m2.value = jnp.array(0, dtype=jnp.float32)
|
181
|
+
|
182
|
+
def update(self, **kwargs) -> None:
|
183
|
+
"""In-place update this ``Metric``. This method will use the value from
|
184
|
+
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
|
185
|
+
defined on construction.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
**kwargs: the key-word arguments that contains a ``self.argname``
|
189
|
+
entry that maps to the value we want to use to update this metric.
|
190
|
+
"""
|
191
|
+
if self.argname not in kwargs:
|
192
|
+
raise TypeError(f"Expected keyword argument '{self.argname}'")
|
193
|
+
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
|
194
|
+
count = 1 if isinstance(values, (int, float)) else values.size
|
195
|
+
original_count = self.count.value
|
196
|
+
self.count.value += count
|
197
|
+
delta = (
|
198
|
+
values if isinstance(values, (int, float)) else values.mean()
|
199
|
+
) - self.mean.value
|
200
|
+
self.mean.value += delta * count / self.count.value
|
201
|
+
m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
|
202
|
+
self.m2.value += (
|
203
|
+
m2 + delta * delta * count * original_count / self.count
|
204
|
+
)
|
205
|
+
|
206
|
+
def compute(self) -> Statistics:
|
207
|
+
"""Compute and return the mean and variance statistics in a
|
208
|
+
``Statistics`` dataclass object.
|
209
|
+
"""
|
210
|
+
variance = self.m2.value / self.count.value
|
211
|
+
standard_deviation = variance ** 0.5
|
212
|
+
sem = standard_deviation / (self.count.value ** 0.5)
|
213
|
+
return Statistics(
|
214
|
+
mean=self.mean.value,
|
215
|
+
standard_error_of_mean=sem,
|
216
|
+
standard_deviation=standard_deviation,
|
217
|
+
)
|
218
218
|
|
219
219
|
|
220
220
|
class Accuracy(Average):
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
"""
|
249
|
-
|
250
|
-
def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
|
251
|
-
"""In-place update this ``Metric``.
|
252
|
-
|
253
|
-
Args:
|
254
|
-
logits: the outputted predicted activations. These values are
|
255
|
-
argmax-ed (on the trailing dimension), before comparing them
|
256
|
-
to the labels.
|
257
|
-
labels: the ground truth integer labels.
|
221
|
+
"""Accuracy metric. This metric subclasses :class:`Average`,
|
222
|
+
and so they share the same ``reset`` and ``compute`` method
|
223
|
+
implementations. Unlike :class:`Average`, no string needs to
|
224
|
+
be passed to ``Accuracy`` during construction.
|
225
|
+
|
226
|
+
Example usage::
|
227
|
+
|
228
|
+
>>> import brainstate as bst
|
229
|
+
>>> import jax, jax.numpy as jnp
|
230
|
+
|
231
|
+
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
232
|
+
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
233
|
+
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
234
|
+
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
235
|
+
|
236
|
+
>>> metrics = bst.nn.metrics.Accuracy()
|
237
|
+
>>> metrics.compute()
|
238
|
+
Array(nan, dtype=float32)
|
239
|
+
>>> metrics.update(logits=logits, labels=labels)
|
240
|
+
>>> metrics.compute()
|
241
|
+
Array(0.6, dtype=float32)
|
242
|
+
>>> metrics.update(logits=logits2, labels=labels2)
|
243
|
+
>>> metrics.compute()
|
244
|
+
Array(0.7, dtype=float32)
|
245
|
+
>>> metrics.reset()
|
246
|
+
>>> metrics.compute()
|
247
|
+
Array(nan, dtype=float32)
|
258
248
|
"""
|
259
|
-
if logits.ndim != labels.ndim + 1:
|
260
|
-
raise ValueError(
|
261
|
-
f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
|
262
|
-
)
|
263
|
-
elif labels.dtype in (jnp.int64, np.int32, np.int64):
|
264
|
-
labels = jnp.astype(labels, jnp.int32)
|
265
|
-
elif labels.dtype != jnp.int32:
|
266
|
-
raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
|
267
|
-
|
268
|
-
super().update(values=(logits.argmax(axis=-1) == labels))
|
269
249
|
|
250
|
+
def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
|
251
|
+
"""In-place update this ``Metric``.
|
270
252
|
|
271
|
-
|
272
|
-
|
253
|
+
Args:
|
254
|
+
logits: the outputted predicted activations. These values are
|
255
|
+
argmax-ed (on the trailing dimension), before comparing them
|
256
|
+
to the labels.
|
257
|
+
labels: the ground truth integer labels.
|
258
|
+
"""
|
259
|
+
if logits.ndim != labels.ndim + 1:
|
260
|
+
raise ValueError(
|
261
|
+
f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
|
262
|
+
)
|
263
|
+
elif labels.dtype in (jnp.int64, np.int32, np.int64):
|
264
|
+
labels = jnp.astype(labels, jnp.int32)
|
265
|
+
elif labels.dtype != jnp.int32:
|
266
|
+
raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
|
273
267
|
|
274
|
-
|
268
|
+
super().update(values=(logits.argmax(axis=-1) == labels))
|
275
269
|
|
276
|
-
>>> from brainstate import nn
|
277
|
-
>>> import jax, jax.numpy as jnp
|
278
270
|
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
271
|
+
class MultiMetric(Metric):
|
272
|
+
"""MultiMetric class to store multiple metrics and update them in a single call.
|
273
|
+
|
274
|
+
Example usage::
|
275
|
+
|
276
|
+
>>> from brainstate import nn
|
277
|
+
>>> import jax, jax.numpy as jnp
|
278
|
+
|
279
|
+
>>> metrics = nn.metrics.MultiMetric(
|
280
|
+
... accuracy=nn.metrics.Accuracy(),
|
281
|
+
... loss=nn.metrics.Average(),
|
282
|
+
... )
|
283
|
+
|
284
|
+
>>> metrics
|
285
|
+
MultiMetric(
|
286
|
+
accuracy=Accuracy(
|
287
|
+
argname='values',
|
288
|
+
total=MetricState(
|
289
|
+
value=Array(0., dtype=float32)
|
290
|
+
),
|
291
|
+
count=MetricState(
|
292
|
+
value=Array(0, dtype=int32)
|
293
|
+
)
|
294
|
+
),
|
295
|
+
loss=Average(
|
296
|
+
argname='values',
|
297
|
+
total=MetricState(
|
298
|
+
value=Array(0., dtype=float32)
|
299
|
+
),
|
300
|
+
count=MetricState(
|
301
|
+
value=Array(0, dtype=int32)
|
302
|
+
)
|
303
|
+
)
|
304
|
+
)
|
283
305
|
|
284
|
-
|
285
|
-
|
286
|
-
accuracy=Accuracy(
|
306
|
+
>>> metrics.accuracy
|
307
|
+
Accuracy(
|
287
308
|
argname='values',
|
288
309
|
total=MetricState(
|
289
310
|
value=Array(0., dtype=float32)
|
@@ -291,8 +312,10 @@ class MultiMetric(Metric):
|
|
291
312
|
count=MetricState(
|
292
313
|
value=Array(0, dtype=int32)
|
293
314
|
)
|
294
|
-
)
|
295
|
-
|
315
|
+
)
|
316
|
+
|
317
|
+
>>> metrics.loss
|
318
|
+
Average(
|
296
319
|
argname='values',
|
297
320
|
total=MetricState(
|
298
321
|
value=Array(0., dtype=float32)
|
@@ -301,90 +324,67 @@ class MultiMetric(Metric):
|
|
301
324
|
value=Array(0, dtype=int32)
|
302
325
|
)
|
303
326
|
)
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
)
|
312
|
-
|
313
|
-
|
314
|
-
)
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
)
|
323
|
-
|
324
|
-
value=Array(0, dtype=int32)
|
325
|
-
)
|
326
|
-
)
|
327
|
-
|
328
|
-
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
329
|
-
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
330
|
-
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
331
|
-
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
332
|
-
|
333
|
-
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
334
|
-
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
335
|
-
|
336
|
-
>>> metrics.compute()
|
337
|
-
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
338
|
-
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
|
339
|
-
>>> metrics.compute()
|
340
|
-
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
|
341
|
-
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
|
342
|
-
>>> metrics.compute()
|
343
|
-
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
|
344
|
-
>>> metrics.reset()
|
345
|
-
>>> metrics.compute()
|
346
|
-
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
347
|
-
"""
|
348
|
-
|
349
|
-
def __init__(self, **metrics):
|
350
|
-
"""Pass in key-word arguments to the constructor, e.g.
|
351
|
-
``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
|
352
|
-
|
353
|
-
Args:
|
354
|
-
**metrics: the key-word arguments that will be used to access
|
355
|
-
the corresponding ``Metric``.
|
356
|
-
"""
|
357
|
-
# TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
|
358
|
-
self._metric_names = []
|
359
|
-
for metric_name, metric in metrics.items():
|
360
|
-
self._metric_names.append(metric_name)
|
361
|
-
vars(self)[metric_name] = metric
|
362
|
-
|
363
|
-
def reset(self) -> None:
|
364
|
-
"""Reset all underlying ``Metric``'s."""
|
365
|
-
for metric_name in self._metric_names:
|
366
|
-
getattr(self, metric_name).reset()
|
367
|
-
|
368
|
-
def update(self, **updates) -> None:
|
369
|
-
"""In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
|
370
|
-
``**updates`` will be passed to the ``update`` method of all underlying
|
371
|
-
``Metric``'s.
|
372
|
-
|
373
|
-
Args:
|
374
|
-
**updates: the key-word arguments that will be passed to the underlying ``Metric``'s
|
375
|
-
``update`` method.
|
376
|
-
"""
|
377
|
-
# TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
|
378
|
-
# TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
|
379
|
-
for metric_name in self._metric_names:
|
380
|
-
getattr(self, metric_name).update(**updates)
|
381
|
-
|
382
|
-
def compute(self) -> dict[str, Metric]:
|
383
|
-
"""Compute and return the value of all underlying ``Metric``'s. This method
|
384
|
-
will return a dictionary, mapping strings (defined by the key-word arguments
|
385
|
-
``**metrics`` passed to the constructor) to the corresponding metric value.
|
327
|
+
|
328
|
+
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
|
329
|
+
>>> labels = jnp.array([1, 1, 0, 1, 0])
|
330
|
+
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
|
331
|
+
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
|
332
|
+
|
333
|
+
>>> batch_loss = jnp.array([1, 2, 3, 4])
|
334
|
+
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
|
335
|
+
|
336
|
+
>>> metrics.compute()
|
337
|
+
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
338
|
+
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
|
339
|
+
>>> metrics.compute()
|
340
|
+
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
|
341
|
+
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
|
342
|
+
>>> metrics.compute()
|
343
|
+
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
|
344
|
+
>>> metrics.reset()
|
345
|
+
>>> metrics.compute()
|
346
|
+
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
|
386
347
|
"""
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
348
|
+
|
349
|
+
def __init__(self, **metrics):
|
350
|
+
"""Pass in key-word arguments to the constructor, e.g.
|
351
|
+
``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
**metrics: the key-word arguments that will be used to access
|
355
|
+
the corresponding ``Metric``.
|
356
|
+
"""
|
357
|
+
# TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
|
358
|
+
self._metric_names = []
|
359
|
+
for metric_name, metric in metrics.items():
|
360
|
+
self._metric_names.append(metric_name)
|
361
|
+
vars(self)[metric_name] = metric
|
362
|
+
|
363
|
+
def reset(self) -> None:
|
364
|
+
"""Reset all underlying ``Metric``'s."""
|
365
|
+
for metric_name in self._metric_names:
|
366
|
+
getattr(self, metric_name).reset()
|
367
|
+
|
368
|
+
def update(self, **updates) -> None:
|
369
|
+
"""In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
|
370
|
+
``**updates`` will be passed to the ``update`` method of all underlying
|
371
|
+
``Metric``'s.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
**updates: the key-word arguments that will be passed to the underlying ``Metric``'s
|
375
|
+
``update`` method.
|
376
|
+
"""
|
377
|
+
# TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
|
378
|
+
# TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
|
379
|
+
for metric_name in self._metric_names:
|
380
|
+
getattr(self, metric_name).update(**updates)
|
381
|
+
|
382
|
+
def compute(self) -> dict[str, Metric]:
|
383
|
+
"""Compute and return the value of all underlying ``Metric``'s. This method
|
384
|
+
will return a dictionary, mapping strings (defined by the key-word arguments
|
385
|
+
``**metrics`` passed to the constructor) to the corresponding metric value.
|
386
|
+
"""
|
387
|
+
return {
|
388
|
+
f'{metric_name}': getattr(self, metric_name).compute()
|
389
|
+
for metric_name in self._metric_names
|
390
|
+
}
|