brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/nn/metrics.py CHANGED
@@ -27,100 +27,100 @@ import numpy as np
27
27
  from brainstate._state import State
28
28
 
29
29
  __all__ = [
30
- 'Average',
31
- 'Statistics',
32
- 'Welford',
33
- 'Accuracy',
34
- 'MultiMetric',
30
+ 'Average',
31
+ 'Statistics',
32
+ 'Welford',
33
+ 'Accuracy',
34
+ 'MultiMetric',
35
35
  ]
36
36
 
37
37
 
38
38
  class MetricState(State):
39
- """Wrapper class for Metric Variables."""
40
- pass
39
+ """Wrapper class for Metric Variables."""
40
+ pass
41
41
 
42
42
 
43
43
  class Metric(object):
44
- """Base class for metrics. Any class that subclasses ``Metric`` should
45
- implement a ``compute``, ``reset`` and ``update`` method."""
44
+ """Base class for metrics. Any class that subclasses ``Metric`` should
45
+ implement a ``compute``, ``reset`` and ``update`` method."""
46
46
 
47
- def reset(self) -> None:
48
- """In-place reset the ``Metric``."""
49
- raise NotImplementedError('Must override `reset()` method.')
47
+ def reset(self) -> None:
48
+ """In-place reset the ``Metric``."""
49
+ raise NotImplementedError('Must override `reset()` method.')
50
50
 
51
- def update(self, **kwargs) -> None:
52
- """In-place update the ``Metric``."""
53
- raise NotImplementedError('Must override `update()` method.')
51
+ def update(self, **kwargs) -> None:
52
+ """In-place update the ``Metric``."""
53
+ raise NotImplementedError('Must override `update()` method.')
54
54
 
55
- def compute(self):
56
- """Compute and return the value of the ``Metric``."""
57
- raise NotImplementedError('Must override `compute()` method.')
55
+ def compute(self):
56
+ """Compute and return the value of the ``Metric``."""
57
+ raise NotImplementedError('Must override `compute()` method.')
58
58
 
59
59
 
60
60
  class Average(Metric):
61
- """Average metric.
62
-
63
- Example usage::
64
-
65
- >>> import jax.numpy as jnp
66
- >>> import brainstate as bst
67
-
68
- >>> batch_loss = jnp.array([1, 2, 3, 4])
69
- >>> batch_loss2 = jnp.array([3, 2, 1, 0])
70
-
71
- >>> metrics = bst.nn.metrics.Average()
72
- >>> metrics.compute()
73
- Array(nan, dtype=float32)
74
- >>> metrics.update(values=batch_loss)
75
- >>> metrics.compute()
76
- Array(2.5, dtype=float32)
77
- >>> metrics.update(values=batch_loss2)
78
- >>> metrics.compute()
79
- Array(2., dtype=float32)
80
- >>> metrics.reset()
81
- >>> metrics.compute()
82
- Array(nan, dtype=float32)
83
- """
84
-
85
- def __init__(self, argname: str = 'values'):
86
- """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
87
- For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
88
- ``avg.update(test=new_value)``.
89
-
90
- Args:
91
- argname: an optional string denoting the key-word argument that
92
- :func:`update` will use to derive the new value. Defaults to
93
- ``'values'``.
61
+ """Average metric.
62
+
63
+ Example usage::
64
+
65
+ >>> import jax.numpy as jnp
66
+ >>> import brainstate as bst
67
+
68
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
69
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
70
+
71
+ >>> metrics = bst.nn.metrics.Average()
72
+ >>> metrics.compute()
73
+ Array(nan, dtype=float32)
74
+ >>> metrics.update(values=batch_loss)
75
+ >>> metrics.compute()
76
+ Array(2.5, dtype=float32)
77
+ >>> metrics.update(values=batch_loss2)
78
+ >>> metrics.compute()
79
+ Array(2., dtype=float32)
80
+ >>> metrics.reset()
81
+ >>> metrics.compute()
82
+ Array(nan, dtype=float32)
94
83
  """
95
- self.argname = argname
96
- self.total = MetricState(jnp.array(0, dtype=jnp.float32))
97
- self.count = MetricState(jnp.array(0, dtype=jnp.int32))
98
-
99
- def reset(self) -> None:
100
- """Reset this ``Metric``."""
101
- self.total.value = jnp.array(0, dtype=jnp.float32)
102
- self.count.value = jnp.array(0, dtype=jnp.int32)
103
-
104
- def update(self, **kwargs) -> None:
105
- """In-place update this ``Metric``. This method will use the value from
106
- ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
107
- defined on construction.
108
-
109
- Args:
110
- **kwargs: the key-word arguments that contains a ``self.argname``
111
- entry that maps to the value we want to use to update this metric.
112
- """
113
- if self.argname not in kwargs:
114
- raise TypeError(f"Expected keyword argument '{self.argname}'")
115
- values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
116
- self.total.value += (
117
- values if isinstance(values, (int, float)) else values.sum()
118
- )
119
- self.count.value += 1 if isinstance(values, (int, float)) else values.size
120
84
 
121
- def compute(self) -> jax.Array:
122
- """Compute and return the average."""
123
- return self.total.value / self.count.value
85
+ def __init__(self, argname: str = 'values'):
86
+ """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
87
+ For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
88
+ ``avg.update(test=new_value)``.
89
+
90
+ Args:
91
+ argname: an optional string denoting the key-word argument that
92
+ :func:`update` will use to derive the new value. Defaults to
93
+ ``'values'``.
94
+ """
95
+ self.argname = argname
96
+ self.total = MetricState(jnp.array(0, dtype=jnp.float32))
97
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
98
+
99
+ def reset(self) -> None:
100
+ """Reset this ``Metric``."""
101
+ self.total.value = jnp.array(0, dtype=jnp.float32)
102
+ self.count.value = jnp.array(0, dtype=jnp.int32)
103
+
104
+ def update(self, **kwargs) -> None:
105
+ """In-place update this ``Metric``. This method will use the value from
106
+ ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
107
+ defined on construction.
108
+
109
+ Args:
110
+ **kwargs: the key-word arguments that contains a ``self.argname``
111
+ entry that maps to the value we want to use to update this metric.
112
+ """
113
+ if self.argname not in kwargs:
114
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
115
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
116
+ self.total.value += (
117
+ values if isinstance(values, (int, float)) else values.sum()
118
+ )
119
+ self.count.value += 1 if isinstance(values, (int, float)) else values.size
120
+
121
+ def compute(self) -> jax.Array:
122
+ """Compute and return the average."""
123
+ return self.total.value / self.count.value
124
124
 
125
125
 
126
126
  @partial(jax.tree_util.register_dataclass,
@@ -128,162 +128,183 @@ class Average(Metric):
128
128
  meta_fields=[])
129
129
  @dataclass
130
130
  class Statistics:
131
- mean: jnp.float32
132
- standard_error_of_mean: jnp.float32
133
- standard_deviation: jnp.float32
131
+ mean: jnp.float32
132
+ standard_error_of_mean: jnp.float32
133
+ standard_deviation: jnp.float32
134
134
 
135
135
 
136
136
  class Welford(Metric):
137
- """Uses Welford's algorithm to compute the mean and variance of a stream of data.
138
-
139
- Example usage::
140
-
141
- >>> import jax.numpy as jnp
142
- >>> from brainstate import nn
143
-
144
- >>> batch_loss = jnp.array([1, 2, 3, 4])
145
- >>> batch_loss2 = jnp.array([3, 2, 1, 0])
146
-
147
- >>> metrics = nn.metrics.Welford()
148
- >>> metrics.compute()
149
- Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
150
- >>> metrics.update(values=batch_loss)
151
- >>> metrics.compute()
152
- Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
153
- >>> metrics.update(values=batch_loss2)
154
- >>> metrics.compute()
155
- Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
156
- >>> metrics.reset()
157
- >>> metrics.compute()
158
- Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
159
- """
160
-
161
- def __init__(self, argname: str = 'values'):
162
- """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
163
- For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
164
- ``wf.update(test=new_value)``.
165
-
166
- Args:
167
- argname: an optional string denoting the key-word argument that
168
- :func:`update` will use to derive the new value. Defaults to
169
- ``'values'``.
137
+ """Uses Welford's algorithm to compute the mean and variance of a stream of data.
138
+
139
+ Example usage::
140
+
141
+ >>> import jax.numpy as jnp
142
+ >>> from brainstate import nn
143
+
144
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
145
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
146
+
147
+ >>> metrics = nn.metrics.Welford()
148
+ >>> metrics.compute()
149
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
150
+ >>> metrics.update(values=batch_loss)
151
+ >>> metrics.compute()
152
+ Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
153
+ >>> metrics.update(values=batch_loss2)
154
+ >>> metrics.compute()
155
+ Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
156
+ >>> metrics.reset()
157
+ >>> metrics.compute()
158
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
170
159
  """
171
- self.argname = argname
172
- self.count = MetricState(jnp.array(0, dtype=jnp.int32))
173
- self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
174
- self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
175
-
176
- def reset(self) -> None:
177
- """Reset this ``Metric``."""
178
- self.count.value = jnp.array(0, dtype=jnp.uint32)
179
- self.mean.value = jnp.array(0, dtype=jnp.float32)
180
- self.m2.value = jnp.array(0, dtype=jnp.float32)
181
-
182
- def update(self, **kwargs) -> None:
183
- """In-place update this ``Metric``. This method will use the value from
184
- ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
185
- defined on construction.
186
-
187
- Args:
188
- **kwargs: the key-word arguments that contains a ``self.argname``
189
- entry that maps to the value we want to use to update this metric.
190
- """
191
- if self.argname not in kwargs:
192
- raise TypeError(f"Expected keyword argument '{self.argname}'")
193
- values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
194
- count = 1 if isinstance(values, (int, float)) else values.size
195
- original_count = self.count.value
196
- self.count.value += count
197
- delta = (
198
- values if isinstance(values, (int, float)) else values.mean()
199
- ) - self.mean.value
200
- self.mean.value += delta * count / self.count.value
201
- m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
202
- self.m2.value += (
203
- m2 + delta * delta * count * original_count / self.count
204
- )
205
-
206
- def compute(self) -> Statistics:
207
- """Compute and return the mean and variance statistics in a
208
- ``Statistics`` dataclass object.
209
- """
210
- variance = self.m2.value / self.count.value
211
- standard_deviation = variance ** 0.5
212
- sem = standard_deviation / (self.count.value ** 0.5)
213
- return Statistics(
214
- mean=self.mean.value,
215
- standard_error_of_mean=sem,
216
- standard_deviation=standard_deviation,
217
- )
160
+
161
+ def __init__(self, argname: str = 'values'):
162
+ """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
163
+ For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
164
+ ``wf.update(test=new_value)``.
165
+
166
+ Args:
167
+ argname: an optional string denoting the key-word argument that
168
+ :func:`update` will use to derive the new value. Defaults to
169
+ ``'values'``.
170
+ """
171
+ self.argname = argname
172
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
173
+ self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
174
+ self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
175
+
176
+ def reset(self) -> None:
177
+ """Reset this ``Metric``."""
178
+ self.count.value = jnp.array(0, dtype=jnp.uint32)
179
+ self.mean.value = jnp.array(0, dtype=jnp.float32)
180
+ self.m2.value = jnp.array(0, dtype=jnp.float32)
181
+
182
+ def update(self, **kwargs) -> None:
183
+ """In-place update this ``Metric``. This method will use the value from
184
+ ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
185
+ defined on construction.
186
+
187
+ Args:
188
+ **kwargs: the key-word arguments that contains a ``self.argname``
189
+ entry that maps to the value we want to use to update this metric.
190
+ """
191
+ if self.argname not in kwargs:
192
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
193
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
194
+ count = 1 if isinstance(values, (int, float)) else values.size
195
+ original_count = self.count.value
196
+ self.count.value += count
197
+ delta = (
198
+ values if isinstance(values, (int, float)) else values.mean()
199
+ ) - self.mean.value
200
+ self.mean.value += delta * count / self.count.value
201
+ m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
202
+ self.m2.value += (
203
+ m2 + delta * delta * count * original_count / self.count
204
+ )
205
+
206
+ def compute(self) -> Statistics:
207
+ """Compute and return the mean and variance statistics in a
208
+ ``Statistics`` dataclass object.
209
+ """
210
+ variance = self.m2.value / self.count.value
211
+ standard_deviation = variance ** 0.5
212
+ sem = standard_deviation / (self.count.value ** 0.5)
213
+ return Statistics(
214
+ mean=self.mean.value,
215
+ standard_error_of_mean=sem,
216
+ standard_deviation=standard_deviation,
217
+ )
218
218
 
219
219
 
220
220
  class Accuracy(Average):
221
- """Accuracy metric. This metric subclasses :class:`Average`,
222
- and so they share the same ``reset`` and ``compute`` method
223
- implementations. Unlike :class:`Average`, no string needs to
224
- be passed to ``Accuracy`` during construction.
225
-
226
- Example usage::
227
-
228
- >>> import brainstate as bst
229
- >>> import jax, jax.numpy as jnp
230
-
231
- >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
232
- >>> labels = jnp.array([1, 1, 0, 1, 0])
233
- >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
234
- >>> labels2 = jnp.array([0, 1, 1, 1, 1])
235
-
236
- >>> metrics = bst.nn.metrics.Accuracy()
237
- >>> metrics.compute()
238
- Array(nan, dtype=float32)
239
- >>> metrics.update(logits=logits, labels=labels)
240
- >>> metrics.compute()
241
- Array(0.6, dtype=float32)
242
- >>> metrics.update(logits=logits2, labels=labels2)
243
- >>> metrics.compute()
244
- Array(0.7, dtype=float32)
245
- >>> metrics.reset()
246
- >>> metrics.compute()
247
- Array(nan, dtype=float32)
248
- """
249
-
250
- def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
251
- """In-place update this ``Metric``.
252
-
253
- Args:
254
- logits: the outputted predicted activations. These values are
255
- argmax-ed (on the trailing dimension), before comparing them
256
- to the labels.
257
- labels: the ground truth integer labels.
221
+ """Accuracy metric. This metric subclasses :class:`Average`,
222
+ and so they share the same ``reset`` and ``compute`` method
223
+ implementations. Unlike :class:`Average`, no string needs to
224
+ be passed to ``Accuracy`` during construction.
225
+
226
+ Example usage::
227
+
228
+ >>> import brainstate as bst
229
+ >>> import jax, jax.numpy as jnp
230
+
231
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
232
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
233
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
234
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
235
+
236
+ >>> metrics = bst.nn.metrics.Accuracy()
237
+ >>> metrics.compute()
238
+ Array(nan, dtype=float32)
239
+ >>> metrics.update(logits=logits, labels=labels)
240
+ >>> metrics.compute()
241
+ Array(0.6, dtype=float32)
242
+ >>> metrics.update(logits=logits2, labels=labels2)
243
+ >>> metrics.compute()
244
+ Array(0.7, dtype=float32)
245
+ >>> metrics.reset()
246
+ >>> metrics.compute()
247
+ Array(nan, dtype=float32)
258
248
  """
259
- if logits.ndim != labels.ndim + 1:
260
- raise ValueError(
261
- f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
262
- )
263
- elif labels.dtype in (jnp.int64, np.int32, np.int64):
264
- labels = jnp.astype(labels, jnp.int32)
265
- elif labels.dtype != jnp.int32:
266
- raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
267
-
268
- super().update(values=(logits.argmax(axis=-1) == labels))
269
249
 
250
+ def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
251
+ """In-place update this ``Metric``.
270
252
 
271
- class MultiMetric(Metric):
272
- """MultiMetric class to store multiple metrics and update them in a single call.
253
+ Args:
254
+ logits: the outputted predicted activations. These values are
255
+ argmax-ed (on the trailing dimension), before comparing them
256
+ to the labels.
257
+ labels: the ground truth integer labels.
258
+ """
259
+ if logits.ndim != labels.ndim + 1:
260
+ raise ValueError(
261
+ f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
262
+ )
263
+ elif labels.dtype in (jnp.int64, np.int32, np.int64):
264
+ labels = jnp.astype(labels, jnp.int32)
265
+ elif labels.dtype != jnp.int32:
266
+ raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
273
267
 
274
- Example usage::
268
+ super().update(values=(logits.argmax(axis=-1) == labels))
275
269
 
276
- >>> from brainstate import nn
277
- >>> import jax, jax.numpy as jnp
278
270
 
279
- >>> metrics = nn.metrics.MultiMetric(
280
- ... accuracy=nn.metrics.Accuracy(),
281
- ... loss=nn.metrics.Average(),
282
- ... )
271
+ class MultiMetric(Metric):
272
+ """MultiMetric class to store multiple metrics and update them in a single call.
273
+
274
+ Example usage::
275
+
276
+ >>> from brainstate import nn
277
+ >>> import jax, jax.numpy as jnp
278
+
279
+ >>> metrics = nn.metrics.MultiMetric(
280
+ ... accuracy=nn.metrics.Accuracy(),
281
+ ... loss=nn.metrics.Average(),
282
+ ... )
283
+
284
+ >>> metrics
285
+ MultiMetric(
286
+ accuracy=Accuracy(
287
+ argname='values',
288
+ total=MetricState(
289
+ value=Array(0., dtype=float32)
290
+ ),
291
+ count=MetricState(
292
+ value=Array(0, dtype=int32)
293
+ )
294
+ ),
295
+ loss=Average(
296
+ argname='values',
297
+ total=MetricState(
298
+ value=Array(0., dtype=float32)
299
+ ),
300
+ count=MetricState(
301
+ value=Array(0, dtype=int32)
302
+ )
303
+ )
304
+ )
283
305
 
284
- >>> metrics
285
- MultiMetric(
286
- accuracy=Accuracy(
306
+ >>> metrics.accuracy
307
+ Accuracy(
287
308
  argname='values',
288
309
  total=MetricState(
289
310
  value=Array(0., dtype=float32)
@@ -291,8 +312,10 @@ class MultiMetric(Metric):
291
312
  count=MetricState(
292
313
  value=Array(0, dtype=int32)
293
314
  )
294
- ),
295
- loss=Average(
315
+ )
316
+
317
+ >>> metrics.loss
318
+ Average(
296
319
  argname='values',
297
320
  total=MetricState(
298
321
  value=Array(0., dtype=float32)
@@ -301,90 +324,67 @@ class MultiMetric(Metric):
301
324
  value=Array(0, dtype=int32)
302
325
  )
303
326
  )
304
- )
305
-
306
- >>> metrics.accuracy
307
- Accuracy(
308
- argname='values',
309
- total=MetricState(
310
- value=Array(0., dtype=float32)
311
- ),
312
- count=MetricState(
313
- value=Array(0, dtype=int32)
314
- )
315
- )
316
-
317
- >>> metrics.loss
318
- Average(
319
- argname='values',
320
- total=MetricState(
321
- value=Array(0., dtype=float32)
322
- ),
323
- count=MetricState(
324
- value=Array(0, dtype=int32)
325
- )
326
- )
327
-
328
- >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
329
- >>> labels = jnp.array([1, 1, 0, 1, 0])
330
- >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
331
- >>> labels2 = jnp.array([0, 1, 1, 1, 1])
332
-
333
- >>> batch_loss = jnp.array([1, 2, 3, 4])
334
- >>> batch_loss2 = jnp.array([3, 2, 1, 0])
335
-
336
- >>> metrics.compute()
337
- {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
338
- >>> metrics.update(logits=logits, labels=labels, values=batch_loss)
339
- >>> metrics.compute()
340
- {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
341
- >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
342
- >>> metrics.compute()
343
- {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
344
- >>> metrics.reset()
345
- >>> metrics.compute()
346
- {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
347
- """
348
-
349
- def __init__(self, **metrics):
350
- """Pass in key-word arguments to the constructor, e.g.
351
- ``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
352
-
353
- Args:
354
- **metrics: the key-word arguments that will be used to access
355
- the corresponding ``Metric``.
356
- """
357
- # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
358
- self._metric_names = []
359
- for metric_name, metric in metrics.items():
360
- self._metric_names.append(metric_name)
361
- vars(self)[metric_name] = metric
362
-
363
- def reset(self) -> None:
364
- """Reset all underlying ``Metric``'s."""
365
- for metric_name in self._metric_names:
366
- getattr(self, metric_name).reset()
367
-
368
- def update(self, **updates) -> None:
369
- """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
370
- ``**updates`` will be passed to the ``update`` method of all underlying
371
- ``Metric``'s.
372
-
373
- Args:
374
- **updates: the key-word arguments that will be passed to the underlying ``Metric``'s
375
- ``update`` method.
376
- """
377
- # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
378
- # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
379
- for metric_name in self._metric_names:
380
- getattr(self, metric_name).update(**updates)
381
-
382
- def compute(self) -> dict[str, Metric]:
383
- """Compute and return the value of all underlying ``Metric``'s. This method
384
- will return a dictionary, mapping strings (defined by the key-word arguments
385
- ``**metrics`` passed to the constructor) to the corresponding metric value.
327
+
328
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
329
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
330
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
331
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
332
+
333
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
334
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
335
+
336
+ >>> metrics.compute()
337
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
338
+ >>> metrics.update(logits=logits, labels=labels, values=batch_loss)
339
+ >>> metrics.compute()
340
+ {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
341
+ >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
342
+ >>> metrics.compute()
343
+ {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
344
+ >>> metrics.reset()
345
+ >>> metrics.compute()
346
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
386
347
  """
387
- return {
388
- f'{metric_name}': getattr(self, metric_name).compute()
389
- for metric_name in self._metric_names
390
- }
348
+
349
+ def __init__(self, **metrics):
350
+ """Pass in key-word arguments to the constructor, e.g.
351
+ ``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
352
+
353
+ Args:
354
+ **metrics: the key-word arguments that will be used to access
355
+ the corresponding ``Metric``.
356
+ """
357
+ # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
358
+ self._metric_names = []
359
+ for metric_name, metric in metrics.items():
360
+ self._metric_names.append(metric_name)
361
+ vars(self)[metric_name] = metric
362
+
363
+ def reset(self) -> None:
364
+ """Reset all underlying ``Metric``'s."""
365
+ for metric_name in self._metric_names:
366
+ getattr(self, metric_name).reset()
367
+
368
+ def update(self, **updates) -> None:
369
+ """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
370
+ ``**updates`` will be passed to the ``update`` method of all underlying
371
+ ``Metric``'s.
372
+
373
+ Args:
374
+ **updates: the key-word arguments that will be passed to the underlying ``Metric``'s
375
+ ``update`` method.
376
+ """
377
+ # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
378
+ # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
379
+ for metric_name in self._metric_names:
380
+ getattr(self, metric_name).update(**updates)
381
+
382
+ def compute(self) -> dict[str, Metric]:
383
+ """Compute and return the value of all underlying ``Metric``'s. This method
384
+ will return a dictionary, mapping strings (defined by the key-word arguments
385
+ ``**metrics`` passed to the constructor) to the corresponding metric value.
386
+ """
387
+ return {
388
+ f'{metric_name}': getattr(self, metric_name).compute()
389
+ for metric_name in self._metric_names
390
+ }