brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
17
18
  import unittest
18
19
 
@@ -22,28 +23,28 @@ import brainstate as bst
22
23
 
23
24
 
24
25
  class TestMultiStepLR(unittest.TestCase):
25
- def test1(self):
26
- lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
27
- for i in range(40):
28
- r = lr(i)
29
- if i < 10:
30
- self.assertEqual(r, 0.1)
31
- elif i < 20:
32
- self.assertTrue(jnp.allclose(r, 0.01))
33
- elif i < 30:
34
- self.assertTrue(jnp.allclose(r, 0.001))
35
- else:
36
- self.assertTrue(jnp.allclose(r, 0.0001))
26
+ def test1(self):
27
+ lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
28
+ for i in range(40):
29
+ r = lr(i)
30
+ if i < 10:
31
+ self.assertEqual(r, 0.1)
32
+ elif i < 20:
33
+ self.assertTrue(jnp.allclose(r, 0.01))
34
+ elif i < 30:
35
+ self.assertTrue(jnp.allclose(r, 0.001))
36
+ else:
37
+ self.assertTrue(jnp.allclose(r, 0.0001))
37
38
 
38
- def test2(self):
39
- lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
- for i in range(40):
41
- r = lr(i)
42
- if i < 10:
43
- self.assertEqual(r, 0.1)
44
- elif i < 20:
45
- self.assertTrue(jnp.allclose(r, 0.01))
46
- elif i < 30:
47
- self.assertTrue(jnp.allclose(r, 0.001))
48
- else:
49
- self.assertTrue(jnp.allclose(r, 0.0001))
39
+ def test2(self):
40
+ lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
+ for i in range(40):
42
+ r = lr(i)
43
+ if i < 10:
44
+ self.assertEqual(r, 0.1)
45
+ elif i < 20:
46
+ self.assertTrue(jnp.allclose(r, 0.01))
47
+ elif i < 30:
48
+ self.assertTrue(jnp.allclose(r, 0.001))
49
+ else:
50
+ self.assertTrue(jnp.allclose(r, 0.0001))
@@ -17,192 +17,119 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import importlib.util
20
- from typing import Any
20
+ from typing import Hashable, Dict, Optional
21
21
 
22
- import jax.numpy as jnp
22
+ from brainstate._state import ShortTermState, State, StateDictManager
23
+ from brainstate.typing import PyTree
24
+ from ._base import Optimizer
23
25
 
24
- from brainstate._module import Module
25
- from brainstate._state import ShortTermState, ParamState
26
+ optax_installed = importlib.util.find_spec('optax') is not None
26
27
 
27
28
  __all__ = [
28
- 'OptaxOptimizer',
29
+ 'OptaxOptimizer',
29
30
  ]
30
31
 
31
- optax_installed = importlib.util.find_spec('optax') is not None
32
32
 
33
+ class OptaxOptimizer(Optimizer):
34
+ """Simple train state for the common case with a single Optax optimizer.
33
35
 
34
- class OptaxState(ShortTermState):
35
- """Wrapper class for Optimizer Variables."""
36
- pass
37
-
38
-
39
- class OptaxOptimizer(Module):
40
- """Simple train state for the common case with a single Optax optimizer.
41
-
42
- Example usage::
43
-
44
- >>> import jax, jax.numpy as jnp
45
- >>> import brainstate as bst
46
- >>> from brainstate import nn
47
- >>> import optax
48
- ...
49
- >>> class Model(bst.Module):
50
- ... def __init__(self):
51
- ... super().__init__()
52
- ... self.linear1 = nn.Linear(2, 3)
53
- ... self.linear2 = nn.Linear(3, 4)
54
- ... def __call__(self, x):
55
- ... return self.linear2(self.linear1(x))
56
- ...
57
- >>> x = jax.random.normal(jax.random.key(0), (1, 2))
58
- >>> y = jnp.ones((1, 4))
59
- ...
60
- >>> model = Model()
61
- >>> tx = optax.adam(1e-3)
62
- >>> state = bst.optim.OptaxOptimizer(model, tx)
63
- ...
64
- >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
65
- >>> loss_fn(model)
66
- Array(1.7055722, dtype=float32)
67
- >>> grads = bst.transform.grad(loss_fn)(state.model)
68
- >>> state.update(grads)
69
- >>> loss_fn(model)
70
- Array(1.6925814, dtype=float32)
71
-
72
- Note that you can easily extend this class by subclassing it for storing
73
- additional data (e.g. adding metrics).
74
-
75
- Example usage::
76
-
77
- >>> class TrainState(nnx.Optimizer):
78
- ... def __init__(self, model, tx, metrics):
79
- ... self.metrics = metrics
80
- ... super().__init__(model, tx)
81
- ... def update(self, *, grads, **updates):
82
- ... self.metrics.update(**updates)
83
- ... super().update(grads)
84
- ...
85
- >>> metrics = nnx.metrics.Average()
86
- >>> state = TrainState(model, tx, metrics)
87
- ...
88
- >>> grads = nnx.grad(loss_fn)(state.model)
89
- >>> state.update(grads=grads, values=loss_fn(state.model))
90
- >>> state.metrics.compute()
91
- Array(1.6925814, dtype=float32)
92
- >>> state.update(grads=grads, values=loss_fn(state.model))
93
- >>> state.metrics.compute()
94
- Array(1.68612, dtype=float32)
95
-
96
- For more exotic usecases (e.g. multiple optimizers) it's probably best to
97
- fork the class and modify it.
98
-
99
- Attributes:
100
- step: An ``OptaxState`` :class:`Variable` that tracks the step count.
101
- model: The wrapped :class:`Module`.
102
- tx: An Optax gradient transformation.
103
- opt_state: The Optax optimizer state.
104
- """
105
-
106
- def __init__(
107
- self,
108
- model: Module,
109
- tx: 'optax.GradientTransformation',
110
- wrt: Any = ParamState,
111
- ):
112
- """
113
- Instantiate the class and wrap the :class:`Module` and Optax gradient
114
- transformation. Instantiate the optimizer state to keep track of
115
- :class:`Variable` types specified in ``wrt``. Set the step count to 0.
116
-
117
- Args:
118
- model: An NNX Module.
119
- tx: An Optax gradient transformation.
120
- wrt: optional argument to filter for which :class:`Variable`'s to keep
121
- track of in the optimizer state. These should be the :class:`Variable`'s
122
- that you plan on updating; i.e. this argument value should match the
123
- ``wrt`` argument passed to the ``nnx.grad`` call that will generate the
124
- gradients that will be passed into the ``grads`` argument of the
125
- :func:`update` method.
126
- """
36
+ Example usage::
127
37
 
128
- # tx must be an instance of optax.GradientTransformation
129
- import optax # type: ignore[import-not-found,import-untyped]
130
- if not isinstance(tx, optax.GradientTransformation):
131
- raise TypeError(f"tx must be an instance of optax.GradientTransformation, got {tx}")
132
- self.tx = tx
133
-
134
- # model
135
- if not callable(model):
136
- raise TypeError(f"model must be a callable, got {model}")
137
- self.model = model
138
-
139
- # wrt
140
- self.opt_state = tx.init(nnx.state(model, wrt))
141
- self.wrt = wrt
142
-
143
- def update(self, grads):
144
- """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
145
- The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
146
- gradients are with respect to the same :class:`Variable` types as defined in
147
- ``self.wrt`` during instantiation of this ``Optimizer``. For example::
148
-
149
- >>> from flax import nnx
150
- >>> import jax, jax.numpy as jnp
38
+ >>> import jax
39
+ >>> import jax.numpy as jnp
40
+ >>> import brainstate as bst
151
41
  >>> import optax
152
-
153
- >>> class CustomVariable(nnx.Variable):
154
- ... pass
155
-
156
- >>> class Model(nnx.Module):
157
- ... def __init__(self, rngs):
158
- ... self.linear = nnx.Linear(2, 3, rngs=rngs)
159
- ... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
42
+ ...
43
+ >>> class Model(bst.nn.Module):
44
+ ... def __init__(self):
45
+ ... super().__init__()
46
+ ... self.linear1 = bst.nn.Linear(2, 3)
47
+ ... self.linear2 = bst.nn.Linear(3, 4)
160
48
  ... def __call__(self, x):
161
- ... return self.linear(x) + self.custom_variable
162
- >>> model = Model(rngs=nnx.Rngs(0))
163
- >>> jax.tree.map(jnp.shape, nnx.state(model))
164
- State({
165
- 'custom_variable': VariableState(
166
- type=CustomVariable,
167
- value=(1, 3)
168
- ),
169
- 'linear': {
170
- 'bias': VariableState(
171
- type=Param,
172
- value=(3,)
173
- ),
174
- 'kernel': VariableState(
175
- type=Param,
176
- value=(2, 3)
177
- )
178
- }
179
- })
180
-
181
- >>> # update:
182
- >>> # - only Linear layer parameters
183
- >>> # - only CustomVariable parameters
184
- >>> # - both Linear layer and CustomVariable parameters
185
- >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
186
- >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)):
187
- ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad`
188
- ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable)
189
- ... grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))(
190
- ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3))
191
- ... )
192
- ... state.update(grads=grads)
193
-
194
- Note that internally this function calls ``.tx.update()`` followed by a call
195
- to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
196
-
197
- Args:
198
- grads: the gradients derived from ``nnx.grad``.
49
+ ... return self.linear2(self.linear1(x))
50
+ ...
51
+ >>> x = bst.random.randn(1, 2)
52
+ >>> y = jnp.ones((1, 4))
53
+ ...
54
+ >>> model = Model()
55
+ >>> tx = optax.adam(1e-3)
56
+ >>> optimizer = bst.optim.OptaxOptimizer(tx)
57
+ >>> optimizer.register_trainable_weights(model.states(bst.ParamState))
58
+ ...
59
+ >>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
60
+ >>> loss_fn()
61
+ Array(1.7055722, dtype=float32)
62
+ >>> grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
63
+ >>> optimizer.update(grads)
64
+ >>> loss_fn()
65
+ Array(1.6925814, dtype=float32)
66
+
67
+ For more exotic usecases (e.g. multiple optimizers) it's probably best to
68
+ fork the class and modify it.
69
+
70
+ Attributes:
71
+ param_states: The parameter states to update.
72
+ tx: An Optax gradient transformation.
199
73
  """
200
- import optax # type: ignore[import-not-found,import-untyped]
201
- state = nnx.state(self.model, self.wrt)
202
-
203
- updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
204
- new_params = optax.apply_updates(state, updates)
205
- assert isinstance(new_params, nnx.State)
206
74
 
207
- nnx.update(self.model, new_params)
208
- self.opt_state = new_opt_state
75
+ param_states: StateDictManager
76
+ opt_state: Optional[ShortTermState]
77
+
78
+ def __init__(
79
+ self,
80
+ tx: 'optax.GradientTransformation',
81
+ ):
82
+ """
83
+ Instantiate the class and wrap the :class:`FlattedDict` and Optax gradient
84
+ transformation. Instantiate the optimizer state to keep track of
85
+ :class:`State`.
86
+
87
+ Args:
88
+ tx: An Optax gradient transformation.
89
+ """
90
+ super().__init__()
91
+
92
+ # tx must be an instance of optax.GradientTransformation
93
+ import optax # type: ignore[import-not-found,import-untyped]
94
+ if not isinstance(tx, optax.GradientTransformation):
95
+ raise TypeError(f"tx must be an instance of optax.GradientTransformation, got {tx}")
96
+ self.tx = tx
97
+
98
+ # optimizer state
99
+ self.opt_state = None
100
+
101
+ def register_trainable_weights(self, param_states: Dict[Hashable, State]):
102
+ # model
103
+ if not isinstance(param_states, dict):
104
+ raise TypeError(f"states must be a dict, got {param_states}")
105
+ for k, v in param_states.items():
106
+ if not isinstance(v, State):
107
+ raise TypeError(f"states values must be ParamState, got {v}")
108
+ self.param_states.update(param_states)
109
+ self.param_states.unique_()
110
+
111
+ # wrt
112
+ self.opt_state = ShortTermState(self.tx.init({k: v.value for k, v in self.param_states.items()}))
113
+ return self
114
+
115
+ def update(self, grads: Dict[Hashable, PyTree]):
116
+ """Update the model states with the gradients.
117
+
118
+ Args:
119
+ grads: the gradients derived from ``brainstate.augment.grad``.
120
+ """
121
+ if self.opt_state is None:
122
+ raise ValueError("register_trainable_weights must be called before update.")
123
+
124
+ import optax # type: ignore[import-not-found,import-untyped]
125
+ grads = {k: grads[k] for k in self.param_states.keys()}
126
+ states = {k: v.value for k, v in self.param_states.items()}
127
+
128
+ # compute updates
129
+ updates, new_opt_state = self.tx.update(grads, self.opt_state.value, states)
130
+ new_params = optax.apply_updates(states, updates)
131
+
132
+ # update model states and optimizer states
133
+ for k, v in self.param_states.items():
134
+ v.value = new_params[k]
135
+ self.opt_state.value = new_opt_state
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Flax Authors.
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,4 +11,44 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ # ==============================================================================
14
15
 
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax
21
+ import optax
22
+
23
+ import brainstate as bst
24
+
25
+
26
+ class TestOptaxOptimizer(unittest.TestCase):
27
+ def test1(self):
28
+ class Model(bst.nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.linear1 = bst.nn.Linear(2, 3)
32
+ self.linear2 = bst.nn.Linear(3, 4)
33
+
34
+ def __call__(self, x):
35
+ return self.linear2(self.linear1(x))
36
+
37
+ x = bst.random.randn(1, 2)
38
+ y = jax.numpy.ones((1, 4))
39
+
40
+ model = Model()
41
+ tx = optax.adam(1e-3)
42
+ optimizer = bst.optim.OptaxOptimizer(tx)
43
+ optimizer.register_trainable_weights(model.states(bst.ParamState))
44
+
45
+ loss_fn = lambda: ((model(x) - y) ** 2).mean()
46
+ prev_loss = loss_fn()
47
+
48
+ grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
49
+ optimizer.update(grads)
50
+
51
+ new_loss = loss_fn()
52
+
53
+ print(new_loss, prev_loss)
54
+ self.assertLess(new_loss, prev_loss)