brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1491 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from absl.testing import parameterized
20
+
21
+ import brainstate as bst
22
+ from brainstate.nn.event.linear import EventDense
23
+
24
+
25
+ class TestEventLinear(parameterized.TestCase):
26
+ @parameterized.product(
27
+ homo_w=[True, False],
28
+ bool_x=[True, False],
29
+ )
30
+ def test1(self, homo_w, bool_x):
31
+ x = bst.random.rand(20) < 0.1
32
+ if not bool_x:
33
+ x = jnp.asarray(x, dtype=float)
34
+ m = EventDense(20, 40, 1.5 if homo_w else bst.init.KaimingUniform())
35
+ y = m(x)
36
+ print(y)
37
+
38
+ self.assertTrue(jnp.allclose(y, (x.sum() * m.weight) if homo_w else (x @ m.weight)))
39
+
40
+ def test_grad_bool(self):
41
+ n_in = 20
42
+ n_out = 30
43
+ x = bst.random.rand(n_in) < 0.3
44
+ fn = EventDense(n_in, n_out, bst.init.KaimingUniform())
45
+
46
+ with self.assertRaises(TypeError):
47
+ print(jax.grad(lambda x: fn(x).sum())(x))
48
+
49
+ @parameterized.product(
50
+ bool_x=[True, False],
51
+ homo_w=[True, False]
52
+ )
53
+ def test_vjp(self, bool_x, homo_w):
54
+ n_in = 20
55
+ n_out = 30
56
+ if bool_x:
57
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
58
+ else:
59
+ x = bst.random.rand(n_in)
60
+
61
+ fn = EventDense(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform())
62
+ w = fn.weight
63
+
64
+ def f(x, w):
65
+ fn.weight = w
66
+ return fn(x).sum()
67
+
68
+ r1 = jax.grad(f, argnums=(0, 1))(x, w)
69
+
70
+ # -------------------
71
+ # TRUE gradients
72
+
73
+ def f2(x, w):
74
+ y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
75
+ return y.sum()
76
+
77
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
78
+ self.assertTrue(jnp.allclose(r1[0], r2[0]))
79
+ self.assertTrue(jnp.allclose(r1[1], r2[1]))
80
+
81
+ @parameterized.product(
82
+ bool_x=[True, False],
83
+ homo_w=[True, False]
84
+ )
85
+ def test_jvp(self, bool_x, homo_w):
86
+ n_in = 20
87
+ n_out = 30
88
+ if bool_x:
89
+ x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
90
+ else:
91
+ x = bst.random.rand(n_in)
92
+
93
+ fn = EventDense(n_in, n_out, 1.5 if homo_w else bst.init.KaimingUniform(), grad_mode='jvp')
94
+ w = fn.weight
95
+
96
+ def f(x, w):
97
+ fn.weight = w
98
+ return fn(x)
99
+
100
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
101
+
102
+ # -------------------
103
+ # TRUE gradients
104
+
105
+ def f2(x, w):
106
+ y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
107
+ return y
108
+
109
+ o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
110
+ self.assertTrue(jnp.allclose(o1, o2))
111
+ self.assertTrue(jnp.allclose(r1, r2))
@@ -0,0 +1,390 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ import typing as tp
20
+ from dataclasses import dataclass
21
+ from functools import partial
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ from brainstate._state import State
28
+
29
+ __all__ = [
30
+ 'Average',
31
+ 'Statistics',
32
+ 'Welford',
33
+ 'Accuracy',
34
+ 'MultiMetric',
35
+ ]
36
+
37
+
38
+ class MetricState(State):
39
+ """Wrapper class for Metric Variables."""
40
+ pass
41
+
42
+
43
+ class Metric(object):
44
+ """Base class for metrics. Any class that subclasses ``Metric`` should
45
+ implement a ``compute``, ``reset`` and ``update`` method."""
46
+
47
+ def reset(self) -> None:
48
+ """In-place reset the ``Metric``."""
49
+ raise NotImplementedError('Must override `reset()` method.')
50
+
51
+ def update(self, **kwargs) -> None:
52
+ """In-place update the ``Metric``."""
53
+ raise NotImplementedError('Must override `update()` method.')
54
+
55
+ def compute(self):
56
+ """Compute and return the value of the ``Metric``."""
57
+ raise NotImplementedError('Must override `compute()` method.')
58
+
59
+
60
+ class Average(Metric):
61
+ """Average metric.
62
+
63
+ Example usage::
64
+
65
+ >>> import jax.numpy as jnp
66
+ >>> import brainstate as bst
67
+
68
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
69
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
70
+
71
+ >>> metrics = bst.nn.metrics.Average()
72
+ >>> metrics.compute()
73
+ Array(nan, dtype=float32)
74
+ >>> metrics.update(values=batch_loss)
75
+ >>> metrics.compute()
76
+ Array(2.5, dtype=float32)
77
+ >>> metrics.update(values=batch_loss2)
78
+ >>> metrics.compute()
79
+ Array(2., dtype=float32)
80
+ >>> metrics.reset()
81
+ >>> metrics.compute()
82
+ Array(nan, dtype=float32)
83
+ """
84
+
85
+ def __init__(self, argname: str = 'values'):
86
+ """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
87
+ For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
88
+ ``avg.update(test=new_value)``.
89
+
90
+ Args:
91
+ argname: an optional string denoting the key-word argument that
92
+ :func:`update` will use to derive the new value. Defaults to
93
+ ``'values'``.
94
+ """
95
+ self.argname = argname
96
+ self.total = MetricState(jnp.array(0, dtype=jnp.float32))
97
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
98
+
99
+ def reset(self) -> None:
100
+ """Reset this ``Metric``."""
101
+ self.total.value = jnp.array(0, dtype=jnp.float32)
102
+ self.count.value = jnp.array(0, dtype=jnp.int32)
103
+
104
+ def update(self, **kwargs) -> None:
105
+ """In-place update this ``Metric``. This method will use the value from
106
+ ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
107
+ defined on construction.
108
+
109
+ Args:
110
+ **kwargs: the key-word arguments that contains a ``self.argname``
111
+ entry that maps to the value we want to use to update this metric.
112
+ """
113
+ if self.argname not in kwargs:
114
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
115
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
116
+ self.total.value += (
117
+ values if isinstance(values, (int, float)) else values.sum()
118
+ )
119
+ self.count.value += 1 if isinstance(values, (int, float)) else values.size
120
+
121
+ def compute(self) -> jax.Array:
122
+ """Compute and return the average."""
123
+ return self.total.value / self.count.value
124
+
125
+
126
+ @partial(jax.tree_util.register_dataclass,
127
+ data_fields=['mean', 'standard_error_of_mean', 'standard_deviation'],
128
+ meta_fields=[])
129
+ @dataclass
130
+ class Statistics:
131
+ mean: jnp.float32
132
+ standard_error_of_mean: jnp.float32
133
+ standard_deviation: jnp.float32
134
+
135
+
136
+ class Welford(Metric):
137
+ """Uses Welford's algorithm to compute the mean and variance of a stream of data.
138
+
139
+ Example usage::
140
+
141
+ >>> import jax.numpy as jnp
142
+ >>> from brainstate import nn
143
+
144
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
145
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
146
+
147
+ >>> metrics = nn.metrics.Welford()
148
+ >>> metrics.compute()
149
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
150
+ >>> metrics.update(values=batch_loss)
151
+ >>> metrics.compute()
152
+ Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
153
+ >>> metrics.update(values=batch_loss2)
154
+ >>> metrics.compute()
155
+ Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
156
+ >>> metrics.reset()
157
+ >>> metrics.compute()
158
+ Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
159
+ """
160
+
161
+ def __init__(self, argname: str = 'values'):
162
+ """Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
163
+ For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
164
+ ``wf.update(test=new_value)``.
165
+
166
+ Args:
167
+ argname: an optional string denoting the key-word argument that
168
+ :func:`update` will use to derive the new value. Defaults to
169
+ ``'values'``.
170
+ """
171
+ self.argname = argname
172
+ self.count = MetricState(jnp.array(0, dtype=jnp.int32))
173
+ self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
174
+ self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))
175
+
176
+ def reset(self) -> None:
177
+ """Reset this ``Metric``."""
178
+ self.count.value = jnp.array(0, dtype=jnp.uint32)
179
+ self.mean.value = jnp.array(0, dtype=jnp.float32)
180
+ self.m2.value = jnp.array(0, dtype=jnp.float32)
181
+
182
+ def update(self, **kwargs) -> None:
183
+ """In-place update this ``Metric``. This method will use the value from
184
+ ``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
185
+ defined on construction.
186
+
187
+ Args:
188
+ **kwargs: the key-word arguments that contains a ``self.argname``
189
+ entry that maps to the value we want to use to update this metric.
190
+ """
191
+ if self.argname not in kwargs:
192
+ raise TypeError(f"Expected keyword argument '{self.argname}'")
193
+ values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
194
+ count = 1 if isinstance(values, (int, float)) else values.size
195
+ original_count = self.count.value
196
+ self.count.value += count
197
+ delta = (
198
+ values if isinstance(values, (int, float)) else values.mean()
199
+ ) - self.mean.value
200
+ self.mean.value += delta * count / self.count.value
201
+ m2 = 0.0 if isinstance(values, (int, float)) else values.var() * count
202
+ self.m2.value += (
203
+ m2 + delta * delta * count * original_count / self.count
204
+ )
205
+
206
+ def compute(self) -> Statistics:
207
+ """Compute and return the mean and variance statistics in a
208
+ ``Statistics`` dataclass object.
209
+ """
210
+ variance = self.m2.value / self.count.value
211
+ standard_deviation = variance ** 0.5
212
+ sem = standard_deviation / (self.count.value ** 0.5)
213
+ return Statistics(
214
+ mean=self.mean.value,
215
+ standard_error_of_mean=sem,
216
+ standard_deviation=standard_deviation,
217
+ )
218
+
219
+
220
+ class Accuracy(Average):
221
+ """Accuracy metric. This metric subclasses :class:`Average`,
222
+ and so they share the same ``reset`` and ``compute`` method
223
+ implementations. Unlike :class:`Average`, no string needs to
224
+ be passed to ``Accuracy`` during construction.
225
+
226
+ Example usage::
227
+
228
+ >>> import brainstate as bst
229
+ >>> import jax, jax.numpy as jnp
230
+
231
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
232
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
233
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
234
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
235
+
236
+ >>> metrics = bst.nn.metrics.Accuracy()
237
+ >>> metrics.compute()
238
+ Array(nan, dtype=float32)
239
+ >>> metrics.update(logits=logits, labels=labels)
240
+ >>> metrics.compute()
241
+ Array(0.6, dtype=float32)
242
+ >>> metrics.update(logits=logits2, labels=labels2)
243
+ >>> metrics.compute()
244
+ Array(0.7, dtype=float32)
245
+ >>> metrics.reset()
246
+ >>> metrics.compute()
247
+ Array(nan, dtype=float32)
248
+ """
249
+
250
+ def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
251
+ """In-place update this ``Metric``.
252
+
253
+ Args:
254
+ logits: the outputted predicted activations. These values are
255
+ argmax-ed (on the trailing dimension), before comparing them
256
+ to the labels.
257
+ labels: the ground truth integer labels.
258
+ """
259
+ if logits.ndim != labels.ndim + 1:
260
+ raise ValueError(
261
+ f'Expected logits.ndim==labels.ndim+1, got {logits.ndim} and {labels.ndim}'
262
+ )
263
+ elif labels.dtype in (jnp.int64, np.int32, np.int64):
264
+ labels = jnp.astype(labels, jnp.int32)
265
+ elif labels.dtype != jnp.int32:
266
+ raise ValueError(f'Expected labels.dtype==jnp.int32, got {labels.dtype}')
267
+
268
+ super().update(values=(logits.argmax(axis=-1) == labels))
269
+
270
+
271
+ class MultiMetric(Metric):
272
+ """MultiMetric class to store multiple metrics and update them in a single call.
273
+
274
+ Example usage::
275
+
276
+ >>> from brainstate import nn
277
+ >>> import jax, jax.numpy as jnp
278
+
279
+ >>> metrics = nn.metrics.MultiMetric(
280
+ ... accuracy=nn.metrics.Accuracy(),
281
+ ... loss=nn.metrics.Average(),
282
+ ... )
283
+
284
+ >>> metrics
285
+ MultiMetric(
286
+ accuracy=Accuracy(
287
+ argname='values',
288
+ total=MetricState(
289
+ value=Array(0., dtype=float32)
290
+ ),
291
+ count=MetricState(
292
+ value=Array(0, dtype=int32)
293
+ )
294
+ ),
295
+ loss=Average(
296
+ argname='values',
297
+ total=MetricState(
298
+ value=Array(0., dtype=float32)
299
+ ),
300
+ count=MetricState(
301
+ value=Array(0, dtype=int32)
302
+ )
303
+ )
304
+ )
305
+
306
+ >>> metrics.accuracy
307
+ Accuracy(
308
+ argname='values',
309
+ total=MetricState(
310
+ value=Array(0., dtype=float32)
311
+ ),
312
+ count=MetricState(
313
+ value=Array(0, dtype=int32)
314
+ )
315
+ )
316
+
317
+ >>> metrics.loss
318
+ Average(
319
+ argname='values',
320
+ total=MetricState(
321
+ value=Array(0., dtype=float32)
322
+ ),
323
+ count=MetricState(
324
+ value=Array(0, dtype=int32)
325
+ )
326
+ )
327
+
328
+ >>> logits = jax.random.normal(jax.random.key(0), (5, 2))
329
+ >>> labels = jnp.array([1, 1, 0, 1, 0])
330
+ >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
331
+ >>> labels2 = jnp.array([0, 1, 1, 1, 1])
332
+
333
+ >>> batch_loss = jnp.array([1, 2, 3, 4])
334
+ >>> batch_loss2 = jnp.array([3, 2, 1, 0])
335
+
336
+ >>> metrics.compute()
337
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
338
+ >>> metrics.update(logits=logits, labels=labels, values=batch_loss)
339
+ >>> metrics.compute()
340
+ {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
341
+ >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
342
+ >>> metrics.compute()
343
+ {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
344
+ >>> metrics.reset()
345
+ >>> metrics.compute()
346
+ {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
347
+ """
348
+
349
+ def __init__(self, **metrics):
350
+ """Pass in key-word arguments to the constructor, e.g.
351
+ ``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
352
+
353
+ Args:
354
+ **metrics: the key-word arguments that will be used to access
355
+ the corresponding ``Metric``.
356
+ """
357
+ # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
358
+ self._metric_names = []
359
+ for metric_name, metric in metrics.items():
360
+ self._metric_names.append(metric_name)
361
+ vars(self)[metric_name] = metric
362
+
363
+ def reset(self) -> None:
364
+ """Reset all underlying ``Metric``'s."""
365
+ for metric_name in self._metric_names:
366
+ getattr(self, metric_name).reset()
367
+
368
+ def update(self, **updates) -> None:
369
+ """In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
370
+ ``**updates`` will be passed to the ``update`` method of all underlying
371
+ ``Metric``'s.
372
+
373
+ Args:
374
+ **updates: the key-word arguments that will be passed to the underlying ``Metric``'s
375
+ ``update`` method.
376
+ """
377
+ # TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
378
+ # TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
379
+ for metric_name in self._metric_names:
380
+ getattr(self, metric_name).update(**updates)
381
+
382
+ def compute(self) -> dict[str, Metric]:
383
+ """Compute and return the value of all underlying ``Metric``'s. This method
384
+ will return a dictionary, mapping strings (defined by the key-word arguments
385
+ ``**metrics`` passed to the constructor) to the corresponding metric value.
386
+ """
387
+ return {
388
+ f'{metric_name}': getattr(self, metric_name).compute()
389
+ for metric_name in self._metric_names
390
+ }
@@ -16,7 +16,11 @@
16
16
 
17
17
  from ._lr_scheduler import *
18
18
  from ._lr_scheduler import __all__ as scheduler_all
19
+ from ._optax_optimizer import *
20
+ from ._optax_optimizer import __all__ as optax_all
19
21
  from ._sgd_optimizer import *
20
22
  from ._sgd_optimizer import __all__ as optimizer_all
21
23
 
22
- __all__ = scheduler_all + optimizer_all
24
+ __all__ = scheduler_all + optimizer_all + optax_all
25
+
26
+ del optax_all, scheduler_all, optimizer_all