brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/nn/metrics.py CHANGED
@@ -27,100 +27,100 @@ import numpy as np
27
27
  from brainstate._state import State
28
28
 
29
29
  __all__ = [
30
- 'Average',
31
- 'Statistics',
32
- 'Welford',
33
- 'Accuracy',
34
- 'MultiMetric',
30
+ 'Average',
31
+ 'Statistics',
32
+ 'Welford',
33
+ 'Accuracy',
34
+ 'MultiMetric',
35
35
  ]
36
36
 
37
37
 
38
38
  class MetricState(State):
39
- """Wrapper class for Metric Variables."""
40
- pass
39
+ """Wrapper class for Metric Variables."""
40
+ pass
41
41
 
42
42
 
43
43
  class Metric(object):
44
- """Base class for metrics. Any class that subclasses ``Metric`` should
45
- implement a ``compute``, ``reset`` and ``update`` method."""
44
+ """Base class for metrics. Any class that subclasses ``Metric`` should
45
+ implement a ``compute``, ``reset`` and ``update`` method."""
46
46
 
47
- def reset(self) -> None:
48
- """In-place reset the ``Metric``."""
49
- raise NotImplementedError('Must override `reset()` method.')
47
+ def reset(self) -> None:
48
+ """In-place reset the ``Metric``."""
49
+ raise NotImplementedError('Must override `reset()` method.')
50
50
 
51
- def update(self, **kwargs) -> None:
52
- """In-place update the ``Metric``."""
53
- raise NotImplementedError('Must override `update()` method.')
51
+ def update(self, **kwargs) -> None:
52
+ """In-place update the ``Metric``."""
53
+ raise NotImplementedError('Must override `update()` method.')
54
54
 
55
- def compute(self):
56
- """Compute and return the value of the ``Metric``."""
57
- raise NotImplementedError('Must override `compute()` method.')
55
+ def compute(self):
56
+ """Compute and return the value of the ``Metric``."""
57
+ raise NotImplementedError('Must override `compute()` method.')
58
58
 
59
59
 
60
60
  class Average(Metric):
61
- """Average metric.
62
-
63
- Example usage::
64
-
65
- >>> import jax.numpy as jnp
66
- >>> import brainstate as bst
67
-
68
- >>> batch_loss = jnp.array([1, 2, 3, 4])
69
- >>> batch_loss2 = jnp.array([3, 2, 1, 0])
70
-
71
- >>> metrics = bst.nn.metrics.Average()
72
- >>> metrics.compute()
73
- Array(nan, dtype=float32)
74
- >>> metrics.update(values=batch_loss)
75
- >>> metrics.compute()
76
- Array(2.5, dtype=float32)
77
- >>> metrics.update(values=batch_loss2)
78
- >>> metrics.compute()
79
- Array(2., dtype=float32)
80
- >>> metrics.reset()
81
- >>> metrics.compute()
82
- Array(nan, dtype=float32)
83
- """
84
-
85
- def __init__(self, argname: str = 'values'):
86
- """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
87
- For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
88
- ``avg.update(test=new_value)``.
89
-
90
- Args:
91
- argname: an optional string denoting the key-word argument that
92
- :func:`update` will use to derive the new value. Defaults to
93
- ``'values'``.
61
+ """Average metric.
62
+
63
+ Example usage::
64
+
65
+ >>> import jax.numpy as jnp
66
+ >>> import brainstate as bst
67
+
68
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
69
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
70
+
71
+ >>> metrics = bst.nn.metrics.Average()
72
+ >>> metrics.compute()
73
+ Array(nan, dtype=float32)
74
+ >>> metrics.update(values=batch_loss)
75
+ >>> metrics.compute()
76
+ Array(2.5, dtype=float32)
77
+ >>> metrics.update(values=batch_loss2)
78
+ >>> metrics.compute()
79
+ Array(2., dtype=float32)
80
+ >>> metrics.reset()
81
+ >>> metrics.compute()
82
+ Array(nan, dtype=float32)
94
83
  """
95
- self.argname = argname
96
- self.total = MetricState(jnp.array(0, dtype=jnp.float32))
97
- self.count = MetricState(jnp.array(0, dtype=jnp.int32))
98
-
99
- def reset(self) -> None:
100
- """Reset this ``Metric``."""
101
- self.total.value = jnp.array(0, dtype=jnp.float32)
102
- self.count.value = jnp.array(0, dtype=jnp.int32)
103
-
104
- def update(self, **kwargs) -> None:
105
- """In-place update this ``Metric``. This method will use the value from
106
- ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
107
- defined on construction.
108
-
109
- Args:
110
- **kwargs: the key-word arguments that contains a ``self.argname``
111
- entry that maps to the value we want to use to update this metric.
112
- """
113
- if self.argname not in kwargs:
114
- raise TypeError(f"Expected keyword argument '{self.argname}'")
115
- values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
116
- self.total.value += (
117
- values if isinstance(values, (int, float)) else values.sum()
118
- )
119
- self.count.value += 1 if isinstance(values, (int, float)) else values.size
120
84
 
121
- def compute(self) -> jax.Array:
122
- """Compute and return the average."""
123
- return self.total.value / self.count.value
85
+ def __init__(self, argname: str = 'values'):
86
+ """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
87
+ For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
88
+ ``avg.update(test=new_value)``.
89
+
90
+ Args:
91
+ argname: an optional string denoting the key-word argument that
92
+ :func:`update` will use to derive the new value. Defaults to
93
+ ``'values'``.
94
+ """
95
+ self.argname = argname
96
+ self.total = MetricState(jnp.array(0, dtype=jnp.float32))
97
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
98
+
99
+ def reset(self) -> None:
100
+ """Reset this ``Metric``."""
101
+ self.total.value = jnp.array(0, dtype=jnp.float32)
102
+ self.count.value = jnp.array(0, dtype=jnp.int32)
103
+
104
+ def update(self, **kwargs) -> None:
105
+ """In-place update this ``Metric``. This method will use the value from
106
+ ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
107
+ defined on construction.
108
+
109
+ Args:
110
+ **kwargs: the key-word arguments that contains a ``self.argname``
111
+ entry that maps to the value we want to use to update this metric.
112
+ """
113
+ if self.argname not in kwargs:
114
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
115
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
116
+ self.total.value += (
117
+ values if isinstance(values, (int, float)) else values.sum()
118
+ )
119
+ self.count.value += 1 if isinstance(values, (int, float)) else values.size
120
+
121
+ def compute(self) -> jax.Array:
122
+ """Compute and return the average."""
123
+ return self.total.value / self.count.value
124
124
 
125
125
 
126
126
  @partial(jax.tree_util.register_dataclass,
@@ -128,162 +128,183 @@ class Average(Metric):
128
128
  meta_fields=[])
129
129
  @dataclass
130
130
  class Statistics:
131
- mean: jnp.float32
132
- standard_error_of_mean: jnp.float32
133
- standard_deviation: jnp.float32
131
+ mean: jnp.float32
132
+ standard_error_of_mean: jnp.float32
133
+ standard_deviation: jnp.float32
134
134
 
135
135
 
136
136
  class Welford(Metric):
137
- """Uses Welford's algorithm to compute the mean and variance of a stream of data.
138
-
139
- Example usage::
140
-
141
- >>> import jax.numpy as jnp
142
- >>> from brainstate import nn
143
-
144
- >>> batch_loss = jnp.array([1, 2, 3, 4])
145
- >>> batch_loss2 = jnp.array([3, 2, 1, 0])
146
-
147
- >>> metrics = nn.metrics.Welford()
148
- >>> metrics.compute()
149
- Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
150
- >>> metrics.update(values=batch_loss)
151
- >>> metrics.compute()
152
- Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
153
- >>> metrics.update(values=batch_loss2)
154
- >>> metrics.compute()
155
- Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
156
- >>> metrics.reset()
157
- >>> metrics.compute()
158
- Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
159
- """
160
-
161
- def __init__(self, argname: str = 'values'):
162
- """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
163
- For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
164
- ``wf.update(test=new_value)``.
165
-
166
- Args:
167
- argname: an optional string denoting the key-word argument that
168
- :func:`update` will use to derive the new value. Defaults to
169
- ``'values'``.
137
+ """Uses Welford's algorithm to compute the mean and variance of a stream of data.
138
+
139
+ Example usage::
140
+
141
+ >>> import jax.numpy as jnp
142
+ >>> from brainstate import nn
143
+
144
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
145
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
146
+
147
+ >>> metrics = nn.metrics.Welford()
148
+ >>> metrics.compute()
149
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
150
+ >>> metrics.update(values=batch_loss)
151
+ >>> metrics.compute()
152
+ Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
153
+ >>> metrics.update(values=batch_loss2)
154
+ >>> metrics.compute()
155
+ Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
156
+ >>> metrics.reset()
157
+ >>> metrics.compute()
158
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
170
159
  """
171
- self.argname = argname
172
- self.count = MetricState(jnp.array(0, dtype=jnp.int32))
173
- self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
174
- self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
175
-
176
- def reset(self) -> None:
177
- """Reset this ``Metric``."""
178
- self.count.value = jnp.array(0, dtype=jnp.uint32)
179
- self.mean.value = jnp.array(0, dtype=jnp.float32)
180
- self.m2.value = jnp.array(0, dtype=jnp.float32)
181
-
182
- def update(self, **kwargs) -> None:
183
- """In-place update this ``Metric``. This method will use the value from
184
- ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
185
- defined on construction.
186
-
187
- Args:
188
- **kwargs: the key-word arguments that contains a ``self.argname``
189
- entry that maps to the value we want to use to update this metric.
190
- """
191
- if self.argname not in kwargs:
192
- raise TypeError(f"Expected keyword argument '{self.argname}'")
193
- values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
194
- count = 1 if isinstance(values, (int, float)) else values.size
195
- original_count = self.count.value
196
- self.count.value += count
197
- delta = (
198
- values if isinstance(values, (int, float)) else values.mean()
199
- ) - self.mean.value
200
- self.mean.value += delta * count / self.count.value
201
- m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
202
- self.m2.value += (
203
- m2 + delta * delta * count * original_count / self.count
204
- )
205
-
206
- def compute(self) -> Statistics:
207
- """Compute and return the mean and variance statistics in a
208
- ``Statistics`` dataclass object.
209
- """
210
- variance = self.m2.value / self.count.value
211
- standard_deviation = variance ** 0.5
212
- sem = standard_deviation / (self.count.value ** 0.5)
213
- return Statistics(
214
- mean=self.mean.value,
215
- standard_error_of_mean=sem,
216
- standard_deviation=standard_deviation,
217
- )
160
+
161
+ def __init__(self, argname: str = 'values'):
162
+ """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
163
+ For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
164
+ ``wf.update(test=new_value)``.
165
+
166
+ Args:
167
+ argname: an optional string denoting the key-word argument that
168
+ :func:`update` will use to derive the new value. Defaults to
169
+ ``'values'``.
170
+ """
171
+ self.argname = argname
172
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
173
+ self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
174
+ self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
175
+
176
+ def reset(self) -> None:
177
+ """Reset this ``Metric``."""
178
+ self.count.value = jnp.array(0, dtype=jnp.uint32)
179
+ self.mean.value = jnp.array(0, dtype=jnp.float32)
180
+ self.m2.value = jnp.array(0, dtype=jnp.float32)
181
+
182
+ def update(self, **kwargs) -> None:
183
+ """In-place update this ``Metric``. This method will use the value from
184
+ ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
185
+ defined on construction.
186
+
187
+ Args:
188
+ **kwargs: the key-word arguments that contains a ``self.argname``
189
+ entry that maps to the value we want to use to update this metric.
190
+ """
191
+ if self.argname not in kwargs:
192
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
193
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
194
+ count = 1 if isinstance(values, (int, float)) else values.size
195
+ original_count = self.count.value
196
+ self.count.value += count
197
+ delta = (
198
+ values if isinstance(values, (int, float)) else values.mean()
199
+ ) - self.mean.value
200
+ self.mean.value += delta * count / self.count.value
201
+ m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
202
+ self.m2.value += (
203
+ m2 + delta * delta * count * original_count / self.count
204
+ )
205
+
206
+ def compute(self) -> Statistics:
207
+ """Compute and return the mean and variance statistics in a
208
+ ``Statistics`` dataclass object.
209
+ """
210
+ variance = self.m2.value / self.count.value
211
+ standard_deviation = variance ** 0.5
212
+ sem = standard_deviation / (self.count.value ** 0.5)
213
+ return Statistics(
214
+ mean=self.mean.value,
215
+ standard_error_of_mean=sem,
216
+ standard_deviation=standard_deviation,
217
+ )
218
218
 
219
219
 
220
220
  class Accuracy(Average):
221
- """Accuracy metric. This metric subclasses :class:`Average`,
222
- and so they share the same ``reset`` and ``compute`` method
223
- implementations. Unlike :class:`Average`, no string needs to
224
- be passed to ``Accuracy`` during construction.
225
-
226
- Example usage::
227
-
228
- >>> import brainstate as bst
229
- >>> import jax, jax.numpy as jnp
230
-
231
- >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
232
- >>> labels = jnp.array([1, 1, 0, 1, 0])
233
- >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
234
- >>> labels2 = jnp.array([0, 1, 1, 1, 1])
235
-
236
- >>> metrics = bst.nn.metrics.Accuracy()
237
- >>> metrics.compute()
238
- Array(nan, dtype=float32)
239
- >>> metrics.update(logits=logits, labels=labels)
240
- >>> metrics.compute()
241
- Array(0.6, dtype=float32)
242
- >>> metrics.update(logits=logits2, labels=labels2)
243
- >>> metrics.compute()
244
- Array(0.7, dtype=float32)
245
- >>> metrics.reset()
246
- >>> metrics.compute()
247
- Array(nan, dtype=float32)
248
- """
249
-
250
- def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
251
- """In-place update this ``Metric``.
252
-
253
- Args:
254
- logits: the outputted predicted activations. These values are
255
- argmax-ed (on the trailing dimension), before comparing them
256
- to the labels.
257
- labels: the ground truth integer labels.
221
+ """Accuracy metric. This metric subclasses :class:`Average`,
222
+ and so they share the same ``reset`` and ``compute`` method
223
+ implementations. Unlike :class:`Average`, no string needs to
224
+ be passed to ``Accuracy`` during construction.
225
+
226
+ Example usage::
227
+
228
+ >>> import brainstate as bst
229
+ >>> import jax, jax.numpy as jnp
230
+
231
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
232
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
233
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
234
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
235
+
236
+ >>> metrics = bst.nn.metrics.Accuracy()
237
+ >>> metrics.compute()
238
+ Array(nan, dtype=float32)
239
+ >>> metrics.update(logits=logits, labels=labels)
240
+ >>> metrics.compute()
241
+ Array(0.6, dtype=float32)
242
+ >>> metrics.update(logits=logits2, labels=labels2)
243
+ >>> metrics.compute()
244
+ Array(0.7, dtype=float32)
245
+ >>> metrics.reset()
246
+ >>> metrics.compute()
247
+ Array(nan, dtype=float32)
258
248
  """
259
- if logits.ndim != labels.ndim + 1:
260
- raise ValueError(
261
- f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
262
- )
263
- elif labels.dtype in (jnp.int64, np.int32, np.int64):
264
- labels = jnp.astype(labels, jnp.int32)
265
- elif labels.dtype != jnp.int32:
266
- raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
267
-
268
- super().update(values=(logits.argmax(axis=-1) == labels))
269
249
 
250
+ def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
251
+ """In-place update this ``Metric``.
270
252
 
271
- class MultiMetric(Metric):
272
- """MultiMetric class to store multiple metrics and update them in a single call.
253
+ Args:
254
+ logits: the outputted predicted activations. These values are
255
+ argmax-ed (on the trailing dimension), before comparing them
256
+ to the labels.
257
+ labels: the ground truth integer labels.
258
+ """
259
+ if logits.ndim != labels.ndim + 1:
260
+ raise ValueError(
261
+ f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
262
+ )
263
+ elif labels.dtype in (jnp.int64, np.int32, np.int64):
264
+ labels = jnp.astype(labels, jnp.int32)
265
+ elif labels.dtype != jnp.int32:
266
+ raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
273
267
 
274
- Example usage::
268
+ super().update(values=(logits.argmax(axis=-1) == labels))
275
269
 
276
- >>> from brainstate import nn
277
- >>> import jax, jax.numpy as jnp
278
270
 
279
- >>> metrics = nn.metrics.MultiMetric(
280
- ... accuracy=nn.metrics.Accuracy(),
281
- ... loss=nn.metrics.Average(),
282
- ... )
271
+ class MultiMetric(Metric):
272
+ """MultiMetric class to store multiple metrics and update them in a single call.
273
+
274
+ Example usage::
275
+
276
+ >>> from brainstate import nn
277
+ >>> import jax, jax.numpy as jnp
278
+
279
+ >>> metrics = nn.metrics.MultiMetric(
280
+ ... accuracy=nn.metrics.Accuracy(),
281
+ ... loss=nn.metrics.Average(),
282
+ ... )
283
+
284
+ >>> metrics
285
+ MultiMetric(
286
+ accuracy=Accuracy(
287
+ argname='values',
288
+ total=MetricState(
289
+ value=Array(0., dtype=float32)
290
+ ),
291
+ count=MetricState(
292
+ value=Array(0, dtype=int32)
293
+ )
294
+ ),
295
+ loss=Average(
296
+ argname='values',
297
+ total=MetricState(
298
+ value=Array(0., dtype=float32)
299
+ ),
300
+ count=MetricState(
301
+ value=Array(0, dtype=int32)
302
+ )
303
+ )
304
+ )
283
305
 
284
- >>> metrics
285
- MultiMetric(
286
- accuracy=Accuracy(
306
+ >>> metrics.accuracy
307
+ Accuracy(
287
308
  argname='values',
288
309
  total=MetricState(
289
310
  value=Array(0., dtype=float32)
@@ -291,8 +312,10 @@ class MultiMetric(Metric):
291
312
  count=MetricState(
292
313
  value=Array(0, dtype=int32)
293
314
  )
294
- ),
295
- loss=Average(
315
+ )
316
+
317
+ >>> metrics.loss
318
+ Average(
296
319
  argname='values',
297
320
  total=MetricState(
298
321
  value=Array(0., dtype=float32)
@@ -301,90 +324,67 @@ class MultiMetric(Metric):
301
324
  value=Array(0, dtype=int32)
302
325
  )
303
326
  )
304
- )
305
-
306
- >>> metrics.accuracy
307
- Accuracy(
308
- argname='values',
309
- total=MetricState(
310
- value=Array(0., dtype=float32)
311
- ),
312
- count=MetricState(
313
- value=Array(0, dtype=int32)
314
- )
315
- )
316
-
317
- >>> metrics.loss
318
- Average(
319
- argname='values',
320
- total=MetricState(
321
- value=Array(0., dtype=float32)
322
- ),
323
- count=MetricState(
324
- value=Array(0, dtype=int32)
325
- )
326
- )
327
-
328
- >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
329
- >>> labels = jnp.array([1, 1, 0, 1, 0])
330
- >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
331
- >>> labels2 = jnp.array([0, 1, 1, 1, 1])
332
-
333
- >>> batch_loss = jnp.array([1, 2, 3, 4])
334
- >>> batch_loss2 = jnp.array([3, 2, 1, 0])
335
-
336
- >>> metrics.compute()
337
- {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
338
- >>> metrics.update(logits=logits, labels=labels, values=batch_loss)
339
- >>> metrics.compute()
340
- {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
341
- >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
342
- >>> metrics.compute()
343
- {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
344
- >>> metrics.reset()
345
- >>> metrics.compute()
346
- {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
347
- """
348
-
349
- def __init__(self, **metrics):
350
- """Pass in key-word arguments to the constructor, e.g.
351
- ``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
352
-
353
- Args:
354
- **metrics: the key-word arguments that will be used to access
355
- the corresponding ``Metric``.
356
- """
357
- # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
358
- self._metric_names = []
359
- for metric_name, metric in metrics.items():
360
- self._metric_names.append(metric_name)
361
- vars(self)[metric_name] = metric
362
-
363
- def reset(self) -> None:
364
- """Reset all underlying ``Metric``'s."""
365
- for metric_name in self._metric_names:
366
- getattr(self, metric_name).reset()
367
-
368
- def update(self, **updates) -> None:
369
- """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
370
- ``**updates`` will be passed to the ``update`` method of all underlying
371
- ``Metric``'s.
372
-
373
- Args:
374
- **updates: the key-word arguments that will be passed to the underlying ``Metric``'s
375
- ``update`` method.
376
- """
377
- # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
378
- # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
379
- for metric_name in self._metric_names:
380
- getattr(self, metric_name).update(**updates)
381
-
382
- def compute(self) -> dict[str, Metric]:
383
- """Compute and return the value of all underlying ``Metric``'s. This method
384
- will return a dictionary, mapping strings (defined by the key-word arguments
385
- ``**metrics`` passed to the constructor) to the corresponding metric value.
327
+
328
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
329
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
330
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
331
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
332
+
333
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
334
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
335
+
336
+ >>> metrics.compute()
337
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
338
+ >>> metrics.update(logits=logits, labels=labels, values=batch_loss)
339
+ >>> metrics.compute()
340
+ {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
341
+ >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
342
+ >>> metrics.compute()
343
+ {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
344
+ >>> metrics.reset()
345
+ >>> metrics.compute()
346
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
386
347
  """
387
- return {
388
- f'{metric_name}': getattr(self, metric_name).compute()
389
- for metric_name in self._metric_names
390
- }
348
+
349
+ def __init__(self, **metrics):
350
+ """Pass in key-word arguments to the constructor, e.g.
351
+ ``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
352
+
353
+ Args:
354
+ **metrics: the key-word arguments that will be used to access
355
+ the corresponding ``Metric``.
356
+ """
357
+ # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
358
+ self._metric_names = []
359
+ for metric_name, metric in metrics.items():
360
+ self._metric_names.append(metric_name)
361
+ vars(self)[metric_name] = metric
362
+
363
+ def reset(self) -> None:
364
+ """Reset all underlying ``Metric``'s."""
365
+ for metric_name in self._metric_names:
366
+ getattr(self, metric_name).reset()
367
+
368
+ def update(self, **updates) -> None:
369
+ """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
370
+ ``**updates`` will be passed to the ``update`` method of all underlying
371
+ ``Metric``'s.
372
+
373
+ Args:
374
+ **updates: the key-word arguments that will be passed to the underlying ``Metric``'s
375
+ ``update`` method.
376
+ """
377
+ # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
378
+ # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
379
+ for metric_name in self._metric_names:
380
+ getattr(self, metric_name).update(**updates)
381
+
382
+ def compute(self) -> dict[str, Metric]:
383
+ """Compute and return the value of all underlying ``Metric``'s. This method
384
+ will return a dictionary, mapping strings (defined by the key-word arguments
385
+ ``**metrics`` passed to the constructor) to the corresponding metric value.
386
+ """
387
+ return {
388
+ f'{metric_name}': getattr(self, metric_name).compute()
389
+ for metric_name in self._metric_names
390
+ }