brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
16
17
 
17
18
  import unittest
18
19
 
@@ -22,28 +23,28 @@ import brainstate as bst
22
23
 
23
24
 
24
25
  class TestMultiStepLR(unittest.TestCase):
25
- def test1(self):
26
- lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
27
- for i in range(40):
28
- r = lr(i)
29
- if i < 10:
30
- self.assertEqual(r, 0.1)
31
- elif i < 20:
32
- self.assertTrue(jnp.allclose(r, 0.01))
33
- elif i < 30:
34
- self.assertTrue(jnp.allclose(r, 0.001))
35
- else:
36
- self.assertTrue(jnp.allclose(r, 0.0001))
26
+ def test1(self):
27
+ lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
28
+ for i in range(40):
29
+ r = lr(i)
30
+ if i < 10:
31
+ self.assertEqual(r, 0.1)
32
+ elif i < 20:
33
+ self.assertTrue(jnp.allclose(r, 0.01))
34
+ elif i < 30:
35
+ self.assertTrue(jnp.allclose(r, 0.001))
36
+ else:
37
+ self.assertTrue(jnp.allclose(r, 0.0001))
37
38
 
38
- def test2(self):
39
- lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
- for i in range(40):
41
- r = lr(i)
42
- if i < 10:
43
- self.assertEqual(r, 0.1)
44
- elif i < 20:
45
- self.assertTrue(jnp.allclose(r, 0.01))
46
- elif i < 30:
47
- self.assertTrue(jnp.allclose(r, 0.001))
48
- else:
49
- self.assertTrue(jnp.allclose(r, 0.0001))
39
+ def test2(self):
40
+ lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
+ for i in range(40):
42
+ r = lr(i)
43
+ if i < 10:
44
+ self.assertEqual(r, 0.1)
45
+ elif i < 20:
46
+ self.assertTrue(jnp.allclose(r, 0.01))
47
+ elif i < 30:
48
+ self.assertTrue(jnp.allclose(r, 0.001))
49
+ else:
50
+ self.assertTrue(jnp.allclose(r, 0.0001))
@@ -17,192 +17,137 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import importlib.util
20
- from typing import Any
20
+ from typing import Hashable, Dict, Optional
21
21
 
22
- import jax.numpy as jnp
22
+ from brainstate._state import ShortTermState, State, StateDictManager
23
+ from brainstate.typing import PyTree
24
+ from ._base import Optimizer
23
25
 
24
- from brainstate._module import Module
25
- from brainstate._state import ShortTermState, ParamState
26
+ optax_installed = importlib.util.find_spec('optax') is not None
26
27
 
27
28
  __all__ = [
28
- 'OptaxOptimizer',
29
+ 'OptaxOptimizer',
29
30
  ]
30
31
 
31
- optax_installed = importlib.util.find_spec('optax') is not None
32
32
 
33
+ class OptaxOptimizer(Optimizer):
34
+ """Simple train state for the common case with a single Optax optimizer.
33
35
 
34
- class OptaxState(ShortTermState):
35
- """Wrapper class for Optimizer Variables."""
36
- pass
37
-
38
-
39
- class OptaxOptimizer(Module):
40
- """Simple train state for the common case with a single Optax optimizer.
41
-
42
- Example usage::
43
-
44
- >>> import jax, jax.numpy as jnp
45
- >>> import brainstate as bst
46
- >>> from brainstate import nn
47
- >>> import optax
48
- ...
49
- >>> class Model(bst.Module):
50
- ... def __init__(self):
51
- ... super().__init__()
52
- ... self.linear1 = nn.Linear(2, 3)
53
- ... self.linear2 = nn.Linear(3, 4)
54
- ... def __call__(self, x):
55
- ... return self.linear2(self.linear1(x))
56
- ...
57
- >>> x = jax.random.normal(jax.random.key(0), (1, 2))
58
- >>> y = jnp.ones((1, 4))
59
- ...
60
- >>> model = Model()
61
- >>> tx = optax.adam(1e-3)
62
- >>> state = bst.optim.OptaxOptimizer(model, tx)
63
- ...
64
- >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
65
- >>> loss_fn(model)
66
- Array(1.7055722, dtype=float32)
67
- >>> grads = bst.transform.grad(loss_fn)(state.model)
68
- >>> state.update(grads)
69
- >>> loss_fn(model)
70
- Array(1.6925814, dtype=float32)
71
-
72
- Note that you can easily extend this class by subclassing it for storing
73
- additional data (e.g. adding metrics).
74
-
75
- Example usage::
76
-
77
- >>> class TrainState(nnx.Optimizer):
78
- ... def __init__(self, model, tx, metrics):
79
- ... self.metrics = metrics
80
- ... super().__init__(model, tx)
81
- ... def update(self, *, grads, **updates):
82
- ... self.metrics.update(**updates)
83
- ... super().update(grads)
84
- ...
85
- >>> metrics = nnx.metrics.Average()
86
- >>> state = TrainState(model, tx, metrics)
87
- ...
88
- >>> grads = nnx.grad(loss_fn)(state.model)
89
- >>> state.update(grads=grads, values=loss_fn(state.model))
90
- >>> state.metrics.compute()
91
- Array(1.6925814, dtype=float32)
92
- >>> state.update(grads=grads, values=loss_fn(state.model))
93
- >>> state.metrics.compute()
94
- Array(1.68612, dtype=float32)
95
-
96
- For more exotic usecases (e.g. multiple optimizers) it's probably best to
97
- fork the class and modify it.
98
-
99
- Attributes:
100
- step: An ``OptaxState`` :class:`Variable` that tracks the step count.
101
- model: The wrapped :class:`Module`.
102
- tx: An Optax gradient transformation.
103
- opt_state: The Optax optimizer state.
104
- """
105
-
106
- def __init__(
107
- self,
108
- model: Module,
109
- tx: 'optax.GradientTransformation',
110
- wrt: Any = ParamState,
111
- ):
112
- """
113
- Instantiate the class and wrap the :class:`Module` and Optax gradient
114
- transformation. Instantiate the optimizer state to keep track of
115
- :class:`Variable` types specified in ``wrt``. Set the step count to 0.
116
-
117
- Args:
118
- model: An NNX Module.
119
- tx: An Optax gradient transformation.
120
- wrt: optional argument to filter for which :class:`Variable`'s to keep
121
- track of in the optimizer state. These should be the :class:`Variable`'s
122
- that you plan on updating; i.e. this argument value should match the
123
- ``wrt`` argument passed to the ``nnx.grad`` call that will generate the
124
- gradients that will be passed into the ``grads`` argument of the
125
- :func:`update` method.
126
- """
36
+ Example usage::
127
37
 
128
- # tx must be an instance of optax.GradientTransformation
129
- import optax # type: ignore[import-not-found,import-untyped]
130
- if not isinstance(tx, optax.GradientTransformation):
131
- raise TypeError(f"tx must be an instance of optax.GradientTransformation, got {tx}")
132
- self.tx = tx
133
-
134
- # model
135
- if not callable(model):
136
- raise TypeError(f"model must be a callable, got {model}")
137
- self.model = model
138
-
139
- # wrt
140
- self.opt_state = tx.init(nnx.state(model, wrt))
141
- self.wrt = wrt
142
-
143
- def update(self, grads):
144
- """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
145
- The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
146
- gradients are with respect to the same :class:`Variable` types as defined in
147
- ``self.wrt`` during instantiation of this ``Optimizer``. For example::
148
-
149
- >>> from flax import nnx
150
- >>> import jax, jax.numpy as jnp
38
+ >>> import jax
39
+ >>> import jax.numpy as jnp
40
+ >>> import brainstate as bst
151
41
  >>> import optax
152
-
153
- >>> class CustomVariable(nnx.Variable):
154
- ... pass
155
-
156
- >>> class Model(nnx.Module):
157
- ... def __init__(self, rngs):
158
- ... self.linear = nnx.Linear(2, 3, rngs=rngs)
159
- ... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
42
+ ...
43
+ >>> class Model(bst.nn.Module):
44
+ ... def __init__(self):
45
+ ... super().__init__()
46
+ ... self.linear1 = bst.nn.Linear(2, 3)
47
+ ... self.linear2 = bst.nn.Linear(3, 4)
160
48
  ... def __call__(self, x):
161
- ... return self.linear(x) + self.custom_variable
162
- >>> model = Model(rngs=nnx.Rngs(0))
163
- >>> jax.tree.map(jnp.shape, nnx.state(model))
164
- State({
165
- 'custom_variable': VariableState(
166
- type=CustomVariable,
167
- value=(1, 3)
168
- ),
169
- 'linear': {
170
- 'bias': VariableState(
171
- type=Param,
172
- value=(3,)
173
- ),
174
- 'kernel': VariableState(
175
- type=Param,
176
- value=(2, 3)
177
- )
178
- }
179
- })
180
-
181
- >>> # update:
182
- >>> # - only Linear layer parameters
183
- >>> # - only CustomVariable parameters
184
- >>> # - both Linear layer and CustomVariable parameters
185
- >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
186
- >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)):
187
- ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad`
188
- ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable)
189
- ... grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))(
190
- ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3))
191
- ... )
192
- ... state.update(grads=grads)
193
-
194
- Note that internally this function calls ``.tx.update()`` followed by a call
195
- to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
196
-
197
- Args:
198
- grads: the gradients derived from ``nnx.grad``.
49
+ ... return self.linear2(self.linear1(x))
50
+ ...
51
+ >>> x = bst.random.randn(1, 2)
52
+ >>> y = jnp.ones((1, 4))
53
+ ...
54
+ >>> model = Model()
55
+ >>> tx = optax.adam(1e-3)
56
+ >>> optimizer = bst.optim.OptaxOptimizer(tx)
57
+ >>> optimizer.register_trainable_weights(model.states(bst.ParamState))
58
+ ...
59
+ >>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
60
+ >>> loss_fn()
61
+ Array(1.7055722, dtype=float32)
62
+ >>> grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
63
+ >>> optimizer.update(grads)
64
+ >>> loss_fn()
65
+ Array(1.6925814, dtype=float32)
66
+
67
+ For more exotic usecases (e.g. multiple optimizers) it's probably best to
68
+ fork the class and modify it.
69
+
70
+ Attributes:
71
+ param_states: The parameter states to update.
72
+ tx: An Optax gradient transformation.
199
73
  """
200
- import optax # type: ignore[import-not-found,import-untyped]
201
- state = nnx.state(self.model, self.wrt)
202
-
203
- updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
204
- new_params = optax.apply_updates(state, updates)
205
- assert isinstance(new_params, nnx.State)
206
74
 
207
- nnx.update(self.model, new_params)
208
- self.opt_state = new_opt_state
75
+ param_states: StateDictManager
76
+ opt_state: Optional[ShortTermState]
77
+
78
+ def __init__(
79
+ self,
80
+ tx: 'optax.GradientTransformation',
81
+ ):
82
+ """
83
+ Instantiate the class and wrap the :class:`FlattedDict` and Optax gradient
84
+ transformation. Instantiate the optimizer state to keep track of
85
+ :class:`State`.
86
+
87
+ Args:
88
+ tx: An Optax gradient transformation.
89
+ """
90
+ super().__init__()
91
+
92
+ # tx must be an instance of optax.GradientTransformation
93
+ import optax # type: ignore[import-not-found,import-untyped]
94
+ if not isinstance(tx, optax.GradientTransformation):
95
+ raise TypeError(f"tx must be an instance of optax.GradientTransformation, got {tx}")
96
+ self.tx = tx
97
+
98
+ # optimizer state
99
+ self.opt_state = None
100
+
101
+ def register_trainable_weights(self, param_states: Dict[Hashable, State]):
102
+ # model
103
+ if not isinstance(param_states, dict):
104
+ raise TypeError(f"states must be a dict, got {param_states}")
105
+ for k, v in param_states.items():
106
+ if not isinstance(v, State):
107
+ raise TypeError(f"states values must be ParamState, got {v}")
108
+ self.param_states.update(param_states)
109
+ self.param_states.unique_()
110
+
111
+ # wrt
112
+ self.opt_state = ShortTermState(self.tx.init({k: v.value for k, v in self.param_states.items()}))
113
+ return self
114
+
115
+ def update(self, grads: Dict[Hashable, PyTree]):
116
+ """Update the model states with the gradients.
117
+
118
+ Args:
119
+ grads: the gradients derived from ``brainstate.augment.grad``.
120
+ """
121
+ if self.opt_state is None:
122
+ raise ValueError("register_trainable_weights must be called before update.")
123
+
124
+ import optax # type: ignore[import-not-found,import-untyped]
125
+ grads = {k: grads[k] for k in self.param_states.keys()}
126
+ states = {k: v.value for k, v in self.param_states.items()}
127
+
128
+ # compute updates
129
+ updates, new_opt_state = self.tx.update(grads, self.opt_state.value, states)
130
+ new_params = optax.apply_updates(states, updates)
131
+
132
+ # update model states and optimizer states
133
+ for k, v in self.param_states.items():
134
+ v.value = new_params[k]
135
+ self.opt_state.value = new_opt_state
136
+
137
+
138
+ class LBFGS(OptaxOptimizer):
139
+ def __init__(
140
+ self,
141
+ lr: float,
142
+ memory_size: int = 10,
143
+ scale_init_precond: bool = True,
144
+ ):
145
+ import optax # type: ignore[import-not-found,import-untyped]
146
+ super().__init__(
147
+ optax.lbfgs(
148
+ lr,
149
+ memory_size=memory_size,
150
+ scale_init_precond=scale_init_precond,
151
+ linesearch=None,
152
+ )
153
+ )
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Flax Authors.
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,4 +11,44 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ # ==============================================================================
14
15
 
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax
21
+ import optax
22
+
23
+ import brainstate as bst
24
+
25
+
26
+ class TestOptaxOptimizer(unittest.TestCase):
27
+ def test1(self):
28
+ class Model(bst.nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.linear1 = bst.nn.Linear(2, 3)
32
+ self.linear2 = bst.nn.Linear(3, 4)
33
+
34
+ def __call__(self, x):
35
+ return self.linear2(self.linear1(x))
36
+
37
+ x = bst.random.randn(1, 2)
38
+ y = jax.numpy.ones((1, 4))
39
+
40
+ model = Model()
41
+ tx = optax.adam(1e-3)
42
+ optimizer = bst.optim.OptaxOptimizer(tx)
43
+ optimizer.register_trainable_weights(model.states(bst.ParamState))
44
+
45
+ loss_fn = lambda: ((model(x) - y) ** 2).mean()
46
+ prev_loss = loss_fn()
47
+
48
+ grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()
49
+ optimizer.update(grads)
50
+
51
+ new_loss = loss_fn()
52
+
53
+ print(new_loss, prev_loss)
54
+ self.assertLess(new_loss, prev_loss)