brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,448 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Sequence, Union
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
-
24
- from brainstate import environ
25
- from brainstate._state import State, LongTermState
26
- from brainstate.graph import Node
27
-
28
- __all__ = [
29
- 'LearningRateScheduler',
30
- 'ConstantLR',
31
- 'StepLR',
32
- 'MultiStepLR',
33
- 'CosineAnnealingLR',
34
- 'CosineAnnealingWarmRestarts',
35
- 'ExponentialLR',
36
- 'ExponentialDecayLR',
37
- 'InverseTimeDecayLR',
38
- 'PolynomialDecayLR',
39
- 'PiecewiseConstantLR',
40
- ]
41
-
42
-
43
- # learning rate schedules #
44
- # ----------------------- #
45
-
46
-
47
- def make_schedule(scalar_or_schedule):
48
- if isinstance(scalar_or_schedule, LearningRateScheduler):
49
- return scalar_or_schedule
50
- elif isinstance(scalar_or_schedule, (int, float, State)):
51
- return ConstantLR(scalar_or_schedule)
52
- else:
53
- raise TypeError(type(scalar_or_schedule))
54
-
55
-
56
- class LearningRateScheduler(Node):
57
- """
58
- The learning rate scheduler.
59
-
60
- Parameters
61
- ----------
62
- lr: float, State
63
- The learning rate.
64
- last_epoch: int
65
- The index of last epoch.
66
-
67
- """
68
-
69
- def __init__(self, lr: Union[float, State], last_epoch: int = -1):
70
- super().__init__()
71
- if isinstance(lr, State):
72
- lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
73
- else:
74
- lr = jnp.asarray(lr, dtype=environ.dftype())
75
- self._lr = lr
76
- assert last_epoch >= -1, 'last_epoch should be greater than -1.'
77
- self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
78
-
79
- @property
80
- def lr(self):
81
- return self._lr.value if isinstance(self._lr, State) else self._lr
82
-
83
- @lr.setter
84
- def lr(self, value):
85
- if isinstance(value, State):
86
- value = value.value
87
- assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
88
- if isinstance(self._lr, State):
89
- self._lr.value = value
90
- else:
91
- self._lr = value
92
-
93
- def step_epoch(self):
94
- """
95
- Update the epoch count.
96
- """
97
- self.last_epoch.value += 1
98
-
99
- def step_call(self):
100
- """
101
- Update the call count.
102
- """
103
- pass
104
-
105
- def __call__(self, i=None):
106
- raise NotImplementedError
107
-
108
-
109
- class ConstantLR(LearningRateScheduler):
110
- """
111
- Constant learning rate scheduler.
112
- """
113
-
114
- def __call__(self, i=None):
115
- return self.lr
116
-
117
-
118
- class CallBasedLRScheduler(LearningRateScheduler):
119
- """
120
- The learning rate scheduler based on the call count.
121
-
122
- Parameters
123
- ----------
124
- lr: float
125
- The learning rate.
126
- last_epoch: int
127
- The index of last epoch.
128
- last_call: int
129
- The index of last call.
130
-
131
- """
132
-
133
- def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
134
- super().__init__(lr=lr, last_epoch=last_epoch)
135
-
136
- assert last_call >= -1, 'last_call should be greater than -1.'
137
- self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
138
-
139
- def step_call(self):
140
- """
141
- Update the call count.
142
- """
143
- self.last_call.value += 1
144
-
145
-
146
- class StepLR(LearningRateScheduler):
147
- """Decays the learning rate of each parameter group by gamma every
148
- `step_size` epochs.
149
-
150
- Parameters
151
- ----------
152
- lr: float
153
- Initial learning rate.
154
- step_size: int
155
- Period of learning rate decay.
156
- gamma: float
157
- Multiplicative factor of learning rate decay.
158
- Default: 0.1.
159
- last_epoch: int
160
- The index of last epoch. Default: -1.
161
- """
162
-
163
- def __init__(
164
- self,
165
- lr: float,
166
- step_size: int,
167
- gamma: float = 0.1,
168
- last_epoch: int = -1
169
- ):
170
- super().__init__(lr=lr, last_epoch=last_epoch)
171
-
172
- assert step_size >= 1, 'step_size should be greater than or equal to 1.'
173
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
174
- self.step_size = step_size
175
- self.gamma = gamma
176
-
177
- def __call__(self, i=None):
178
- i = (self.last_epoch.value + 1) if i is None else i
179
- return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
180
-
181
-
182
- class MultiStepLR(LearningRateScheduler):
183
- """Decays the learning rate of each parameter group by gamma once the
184
- number of epoch reaches one of the milestones. Notice that such decay can
185
- happen simultaneously with other changes to the learning rate from outside
186
- this scheduler. When last_epoch=-1, sets initial lr as lr.
187
-
188
- Parameters
189
- ----------
190
- lr: float
191
- Initial learning rate.
192
- milestones: sequence of int
193
- List of epoch indices. Must be increasing.
194
- gamma: float
195
- Multiplicative factor of learning rate decay.
196
- Default: 0.1.
197
- last_epoch: int
198
- The index of last epoch. Default: -1.
199
- """
200
-
201
- def __init__(
202
- self,
203
- lr: float,
204
- milestones: Sequence[int],
205
- gamma: float = 0.1,
206
- last_epoch: int = -1
207
- ):
208
- super().__init__(lr=lr, last_epoch=last_epoch)
209
-
210
- assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
211
- assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
212
- 'milestones should be a sequence of increasing integers.'
213
- )
214
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
215
- self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
216
- self.gamma = gamma
217
-
218
- def __call__(self, i=None):
219
- i = (self.last_epoch.value + 1) if i is None else i
220
- conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
221
- p = jnp.argmax(conditions)
222
- return self.lr * self.gamma ** p
223
-
224
-
225
- class CosineAnnealingLR(LearningRateScheduler):
226
- r"""Set the learning rate of each parameter group using a cosine annealing
227
- schedule, where :math:`\eta_{max}` is set to the initial lr and
228
- :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
229
-
230
- .. math::
231
- \begin{aligned}
232
- \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
233
- + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
234
- & T_{cur} \neq (2k+1)T_{max}; \\
235
- \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
236
- \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
237
- & T_{cur} = (2k+1)T_{max}.
238
- \end{aligned}
239
-
240
- When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
241
- is defined recursively, the learning rate can be simultaneously modified
242
- outside this scheduler by other operators. If the learning rate is set
243
- solely by this scheduler, the learning rate at each step becomes:
244
-
245
- .. math::
246
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
247
- \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
248
-
249
- It has been proposed in
250
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
251
- implements the cosine annealing part of SGDR, and not the restarts.
252
-
253
- Parameters
254
- ----------
255
- lr: float
256
- Initial learning rate.
257
- T_max: int
258
- Maximum number of iterations.
259
- eta_min: float
260
- Minimum learning rate. Default: 0.
261
- last_epoch: int
262
- The index of last epoch. Default: -1.
263
-
264
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
265
- https://arxiv.org/abs/1608.03983
266
- """
267
-
268
- def __init__(
269
- self,
270
- lr: float,
271
- T_max: int,
272
- eta_min: float = 0.,
273
- last_epoch: int = -1,
274
- ):
275
- super().__init__(lr=lr, last_epoch=last_epoch)
276
-
277
- assert T_max >= 1, 'T_max should be greater than or equal to 1.'
278
- self._init_epoch = last_epoch
279
- self.T_max = T_max
280
- self.eta_min = eta_min
281
-
282
- def __call__(self, i=None):
283
- i = (self.last_epoch.value + 1) if i is None else i
284
- return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
285
-
286
-
287
- class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
288
- r"""Set the learning rate of each parameter group using a cosine annealing
289
- schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
290
- is the number of epochs since the last restart and :math:`T_{i}` is the number
291
- of epochs between two warm restarts in SGDR:
292
-
293
- .. math::
294
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
295
- \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
296
-
297
- When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
298
- When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
299
-
300
- It has been proposed in
301
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
302
-
303
- Parameters
304
- ----------
305
- lr: float
306
- Initial learning rate.
307
- num_call_per_epoch: int
308
- The number the scheduler to call in each epoch.
309
- This usually means the number of batch in each epoch training.
310
- T_0: int
311
- Number of iterations for the first restart.
312
- T_mult: int
313
- A factor increases :math:`T_{i}` after a restart. Default: 1.
314
- eta_min: float
315
- Minimum learning rate. Default: 0.
316
- last_call: int
317
- The index of last call. Default: -1.
318
-
319
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
320
- https://arxiv.org/abs/1608.03983
321
- """
322
-
323
- def __init__(
324
- self,
325
- lr: float,
326
- num_call_per_epoch: int,
327
- T_0: int,
328
- T_mult: int = 1,
329
- eta_min: float = 0.,
330
- last_epoch: int = -1,
331
- last_call: int = -1
332
- ):
333
- super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
334
- if T_0 <= 0 or not isinstance(T_0, int):
335
- raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
336
- if T_mult < 1 or not isinstance(T_mult, int):
337
- raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
338
-
339
- self.T_mult = T_mult
340
- self.eta_min = eta_min
341
- self.T_0 = T_0
342
- self.num_call_per_epoch = num_call_per_epoch
343
-
344
- def _cond1(self, epoch):
345
- if self.T_mult == 1:
346
- T_cur = epoch % self.T_0
347
- T_i = self.T_0
348
- else:
349
- n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
350
- T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
351
- T_i = self.T_0 * self.T_mult ** n
352
- return T_cur, T_i
353
-
354
- def _cond2(self, epoch):
355
- return epoch, self.T_0
356
-
357
- def __call__(self, i=None):
358
- epoch = self.current_epoch(i)
359
- T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
360
- return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
361
-
362
- def current_epoch(self, i=None):
363
- i = (self.last_call.value + 1) if i is None else i
364
- return jnp.floor(i / self.num_call_per_epoch)
365
-
366
-
367
- class ExponentialLR(LearningRateScheduler):
368
- """Decays the learning rate of each parameter group by gamma every epoch.
369
- When last_epoch=-1, sets initial lr as lr.
370
-
371
- Parameters
372
- ----------
373
- lr: float
374
- Initial learning rate.
375
- gamma: float
376
- Multiplicative factor of learning rate decay.
377
- last_epoch: int
378
- The index of last epoch. Default: -1.
379
- """
380
-
381
- def __init__(self,
382
- lr: float,
383
- gamma: float,
384
- last_epoch: int = -1):
385
- super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
386
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
387
- self.gamma = gamma
388
-
389
- def __call__(self, i: int = None):
390
- i = (self.last_epoch.value + 1) if i is None else i
391
- return self.lr * self.gamma ** i
392
-
393
-
394
- class ExponentialDecayLR(CallBasedLRScheduler):
395
- def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
396
- super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
397
- self.decay_steps = decay_steps
398
- self.decay_rate = decay_rate
399
-
400
- def __call__(self, i=None):
401
- i = (self.last_call.value + 1) if i is None else i
402
- return self.lr * self.decay_rate ** (i / self.decay_steps)
403
-
404
-
405
- class InverseTimeDecayLR(ExponentialDecayLR):
406
- def __init__(self, lr, decay_steps, decay_rate, staircase=False,
407
- last_epoch: int = -1, last_call: int = -1):
408
- super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
409
- self.staircase = staircase
410
-
411
- def __call__(self, i=None):
412
- i = (self.last_call.value + 1) if i is None else i
413
- if self.staircase:
414
- return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
415
- else:
416
- return self.lr / (1 + self.decay_rate * i / self.decay_steps)
417
-
418
-
419
- class PolynomialDecayLR(CallBasedLRScheduler):
420
- def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
421
- super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
422
- self.decay_steps = decay_steps
423
- self.final_lr = final_lr
424
- self.power = power
425
-
426
- def __call__(self, i=None):
427
- i = (self.last_call.value + 1) if i is None else i
428
- i = jnp.minimum(i, self.decay_steps)
429
- step_mult = (1 - i / self.decay_steps) ** self.power
430
- return step_mult * (self.lr - self.final_lr) + self.final_lr
431
-
432
-
433
- class PiecewiseConstantLR(CallBasedLRScheduler):
434
- def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
435
- super().__init__(0., last_epoch=last_epoch, last_call=last_call)
436
-
437
- boundaries = jnp.array(boundaries)
438
- values = jnp.array(values)
439
- if not boundaries.ndim == values.ndim == 1:
440
- raise ValueError("boundaries and values must be sequences")
441
- if not boundaries.shape[0] == values.shape[0] - 1:
442
- raise ValueError("boundaries length must be one shorter than values length")
443
- self.boundaries = boundaries
444
- self.values = values
445
-
446
- def __call__(self, i=None):
447
- i = (self.last_call.value + 1) if i is None else i
448
- return self.values[jnp.sum(i > self.boundaries)]
@@ -1,50 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import unittest
19
-
20
- import jax.numpy as jnp
21
-
22
- import brainstate
23
-
24
-
25
- class TestMultiStepLR(unittest.TestCase):
26
- def test1(self):
27
- lr = brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
28
- for i in range(40):
29
- r = lr(i)
30
- if i < 10:
31
- self.assertEqual(r, 0.1)
32
- elif i < 20:
33
- self.assertTrue(jnp.allclose(r, 0.01))
34
- elif i < 30:
35
- self.assertTrue(jnp.allclose(r, 0.001))
36
- else:
37
- self.assertTrue(jnp.allclose(r, 0.0001))
38
-
39
- def test2(self):
40
- lr = brainstate.compile.jit(brainstate.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
- for i in range(40):
42
- r = lr(i)
43
- if i < 10:
44
- self.assertEqual(r, 0.1)
45
- elif i < 20:
46
- self.assertTrue(jnp.allclose(r, 0.01))
47
- elif i < 30:
48
- self.assertTrue(jnp.allclose(r, 0.001))
49
- else:
50
- self.assertTrue(jnp.allclose(r, 0.0001))
@@ -1,152 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import importlib.util
18
- from typing import Hashable, Dict, Optional
19
-
20
- from brainstate._state import ShortTermState, State, StateDictManager
21
- from brainstate.typing import PyTree
22
- from ._base import Optimizer
23
-
24
- optax_installed = importlib.util.find_spec('optax') is not None
25
-
26
- __all__ = [
27
- 'OptaxOptimizer',
28
- 'LBFGS',
29
- ]
30
-
31
-
32
- class OptaxOptimizer(Optimizer):
33
- """Simple train state for the common case with a single Optax optimizer.
34
-
35
- Example usage::
36
-
37
- >>> import jax
38
- >>> import jax.numpy as jnp
39
- >>> import brainstate as brainstate
40
- >>> import optax
41
- ...
42
- >>> class Model(brainstate.nn.Module):
43
- ... def __init__(self):
44
- ... super().__init__()
45
- ... self.linear1 = brainstate.nn.Linear(2, 3)
46
- ... self.linear2 = brainstate.nn.Linear(3, 4)
47
- ... def __call__(self, x):
48
- ... return self.linear2(self.linear1(x))
49
- ...
50
- >>> x = brainstate.random.randn(1, 2)
51
- >>> y = jnp.ones((1, 4))
52
- ...
53
- >>> model = Model()
54
- >>> tx = optax.adam(1e-3)
55
- >>> optimizer = brainstate.optim.OptaxOptimizer(tx)
56
- >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
57
- ...
58
- >>> loss_fn = lambda: ((model(x) - y) ** 2).mean()
59
- >>> loss_fn()
60
- Array(1.7055722, dtype=float32)
61
- >>> grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
62
- >>> optimizer.update(grads)
63
- >>> loss_fn()
64
- Array(1.6925814, dtype=float32)
65
-
66
- For more exotic usecases (e.g. multiple optimizers) it's probably best to
67
- fork the class and modify it.
68
-
69
- Attributes:
70
- param_states: The parameter states to update.
71
- tx: An Optax gradient transformation.
72
- """
73
-
74
- param_states: StateDictManager
75
- opt_state: Optional[ShortTermState]
76
-
77
- def __init__(
78
- self,
79
- tx: 'optax.GradientTransformation',
80
- ):
81
- """
82
- Instantiate the class and wrap the :class:`FlattedDict` and Optax gradient
83
- transformation. Instantiate the optimizer state to keep track of
84
- :class:`State`.
85
-
86
- Args:
87
- tx: An Optax gradient transformation.
88
- """
89
- super().__init__()
90
-
91
- # tx must be an instance of optax.GradientTransformation
92
- import optax # type: ignore[import-not-found,import-untyped]
93
- if not isinstance(tx, optax.GradientTransformation):
94
- raise TypeError(f"tx must be an instance of optax.GradientTransformation, got {tx}")
95
- self.tx = tx
96
-
97
- # optimizer state
98
- self.opt_state = None
99
-
100
- def register_trainable_weights(self, param_states: Dict[Hashable, State]):
101
- # model
102
- if not isinstance(param_states, dict):
103
- raise TypeError(f"states must be a dict, got {param_states}")
104
- for k, v in param_states.items():
105
- if not isinstance(v, State):
106
- raise TypeError(f"states values must be ParamState, got {v}")
107
- self.param_states.update(param_states)
108
- self.param_states.unique_()
109
-
110
- # wrt
111
- self.opt_state = ShortTermState(self.tx.init({k: v.value for k, v in self.param_states.items()}))
112
- return self
113
-
114
- def update(self, grads: Dict[Hashable, PyTree]):
115
- """Update the model states with the gradients.
116
-
117
- Args:
118
- grads: the gradients derived from ``brainstate.augment.grad``.
119
- """
120
- if self.opt_state is None:
121
- raise ValueError("register_trainable_weights must be called before update.")
122
-
123
- import optax # type: ignore[import-not-found,import-untyped]
124
- grads = {k: grads[k] for k in self.param_states.keys()}
125
- states = {k: v.value for k, v in self.param_states.items()}
126
-
127
- # compute updates
128
- updates, new_opt_state = self.tx.update(grads, self.opt_state.value, states)
129
- new_params = optax.apply_updates(states, updates)
130
-
131
- # update model states and optimizer states
132
- for k, v in self.param_states.items():
133
- v.value = new_params[k]
134
- self.opt_state.value = new_opt_state
135
-
136
-
137
- class LBFGS(OptaxOptimizer):
138
- def __init__(
139
- self,
140
- lr: float,
141
- memory_size: int = 10,
142
- scale_init_precond: bool = True,
143
- ):
144
- import optax # type: ignore[import-not-found,import-untyped]
145
- super().__init__(
146
- optax.lbfgs(
147
- lr,
148
- memory_size=memory_size,
149
- scale_init_precond=scale_init_precond,
150
- linesearch=None,
151
- )
152
- )
@@ -1,53 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax
20
- import optax
21
-
22
- import brainstate
23
-
24
-
25
- class TestOptaxOptimizer(unittest.TestCase):
26
- def test1(self):
27
- class Model(brainstate.nn.Module):
28
- def __init__(self):
29
- super().__init__()
30
- self.linear1 = brainstate.nn.Linear(2, 3)
31
- self.linear2 = brainstate.nn.Linear(3, 4)
32
-
33
- def __call__(self, x):
34
- return self.linear2(self.linear1(x))
35
-
36
- x = brainstate.random.randn(1, 2)
37
- y = jax.numpy.ones((1, 4))
38
-
39
- model = Model()
40
- tx = optax.adam(1e-3)
41
- optimizer = brainstate.optim.OptaxOptimizer(tx)
42
- optimizer.register_trainable_weights(model.states(brainstate.ParamState))
43
-
44
- loss_fn = lambda: ((model(x) - y) ** 2).mean()
45
- prev_loss = loss_fn()
46
-
47
- grads = brainstate.augment.grad(loss_fn, model.states(brainstate.ParamState))()
48
- optimizer.update(grads)
49
-
50
- new_loss = loss_fn()
51
-
52
- print(new_loss, prev_loss)
53
- self.assertLess(new_loss, prev_loss)