brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from ._lr_scheduler import *
18
+ from ._lr_scheduler import __all__ as scheduler_all
19
+ from ._sgd_optimizer import *
20
+ from ._sgd_optimizer import __all__ as optimizer_all
21
+
22
+ __all__ = scheduler_all + optimizer_all
23
+
@@ -0,0 +1,486 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Sequence, Union
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from .. import environ
25
+ from .._module import Module
26
+ from .._state import State, LongTermState
27
+
28
+ __all__ = [
29
+ 'LearningRateScheduler',
30
+ 'ConstantLR',
31
+ 'StepLR',
32
+ 'MultiStepLR',
33
+ 'CosineAnnealingLR',
34
+ 'CosineAnnealingWarmRestarts',
35
+ 'ExponentialLR',
36
+ 'ExponentialDecayLR',
37
+ 'InverseTimeDecayLR',
38
+ 'PolynomialDecayLR',
39
+ 'PiecewiseConstantLR',
40
+ ]
41
+
42
+
43
+ # learning rate schedules #
44
+ # ----------------------- #
45
+
46
+
47
+ def make_schedule(scalar_or_schedule):
48
+ if isinstance(scalar_or_schedule, LearningRateScheduler):
49
+ return scalar_or_schedule
50
+ elif isinstance(scalar_or_schedule, (int, float, State)):
51
+ return ConstantLR(scalar_or_schedule)
52
+ else:
53
+ raise TypeError(type(scalar_or_schedule))
54
+
55
+
56
+ class LearningRateScheduler(Module):
57
+ """
58
+ The learning rate scheduler.
59
+
60
+ Attributes
61
+ ----------
62
+ lr: float, State
63
+ The learning rate.
64
+ last_epoch: int
65
+ The index of last epoch.
66
+
67
+ """
68
+
69
+ def __init__(self, lr: Union[float, State], last_epoch: int = -1):
70
+ super().__init__()
71
+ if isinstance(lr, State):
72
+ lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
73
+ else:
74
+ lr = jnp.asarray(lr, dtype=environ.dftype())
75
+ self._lr = lr
76
+ assert last_epoch >= -1, 'last_epoch should be greater than -1.'
77
+ self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
78
+
79
+ @property
80
+ def lr(self):
81
+ return self._lr.value if isinstance(self._lr, State) else self._lr
82
+
83
+ @lr.setter
84
+ def lr(self, value):
85
+ if isinstance(value, State):
86
+ value = value.value
87
+ assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
88
+ if isinstance(self._lr, State):
89
+ self._lr.value = value
90
+ else:
91
+ self._lr = value
92
+
93
+ def step_epoch(self):
94
+ """
95
+ Update the epoch count.
96
+ """
97
+ self.last_epoch.value += 1
98
+
99
+ def step_call(self):
100
+ """
101
+ Update the call count.
102
+ """
103
+ pass
104
+
105
+ def __repr__(self):
106
+ return f'{self.__class__.__name__}(lr={self.lr.value}, last_epoch={self.last_epoch.value}{self.extra_repr()})'
107
+
108
+ def extra_repr(self):
109
+ return ''
110
+
111
+ def __call__(self, i=None):
112
+ raise NotImplementedError
113
+
114
+
115
+ class ConstantLR(LearningRateScheduler):
116
+ """
117
+ Constant learning rate scheduler.
118
+ """
119
+
120
+ def __call__(self, i=None):
121
+ return self.lr
122
+
123
+
124
+ class CallBasedLRScheduler(LearningRateScheduler):
125
+ """
126
+ The learning rate scheduler based on the call count.
127
+
128
+ Parameters
129
+ ----------
130
+ lr: float
131
+ The learning rate.
132
+ last_epoch: int
133
+ The index of last epoch.
134
+ last_call: int
135
+ The index of last call.
136
+
137
+ """
138
+
139
+ def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
140
+ super().__init__(lr=lr, last_epoch=last_epoch)
141
+
142
+ assert last_call >= -1, 'last_call should be greater than -1.'
143
+ self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
144
+
145
+ def step_call(self):
146
+ """
147
+ Update the call count.
148
+ """
149
+ self.last_call.value += 1
150
+
151
+ def __repr__(self):
152
+ return (f'{self.__class__.__name__}(lr={self.lr.value}, '
153
+ f'last_epoch={self.last_epoch.value}, '
154
+ f'last_call={self.last_call.value}{self.extra_repr()})')
155
+
156
+
157
+ class StepLR(LearningRateScheduler):
158
+ """Decays the learning rate of each parameter group by gamma every
159
+ `step_size` epochs.
160
+
161
+ Parameters
162
+ ----------
163
+ lr: float
164
+ Initial learning rate.
165
+ step_size: int
166
+ Period of learning rate decay.
167
+ gamma: float
168
+ Multiplicative factor of learning rate decay.
169
+ Default: 0.1.
170
+ last_epoch: int
171
+ The index of last epoch. Default: -1.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ lr: float,
177
+ step_size: int,
178
+ gamma: float = 0.1,
179
+ last_epoch: int = -1
180
+ ):
181
+ super().__init__(lr=lr, last_epoch=last_epoch)
182
+
183
+ assert step_size >= 1, 'step_size should be greater than or equal to 1.'
184
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
185
+ self.step_size = step_size
186
+ self.gamma = gamma
187
+
188
+ def __call__(self, i=None):
189
+ i = (self.last_epoch.value + 1) if i is None else i
190
+ return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
191
+
192
+ def extra_repr(self):
193
+ return f', gamma={self.gamma}, step_size={self.step_size}'
194
+
195
+
196
+ class MultiStepLR(LearningRateScheduler):
197
+ """Decays the learning rate of each parameter group by gamma once the
198
+ number of epoch reaches one of the milestones. Notice that such decay can
199
+ happen simultaneously with other changes to the learning rate from outside
200
+ this scheduler. When last_epoch=-1, sets initial lr as lr.
201
+
202
+ Parameters
203
+ ----------
204
+ lr: float
205
+ Initial learning rate.
206
+ milestones: sequence of int
207
+ List of epoch indices. Must be increasing.
208
+ gamma: float
209
+ Multiplicative factor of learning rate decay.
210
+ Default: 0.1.
211
+ last_epoch: int
212
+ The index of last epoch. Default: -1.
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ lr: float,
218
+ milestones: Sequence[int],
219
+ gamma: float = 0.1,
220
+ last_epoch: int = -1
221
+ ):
222
+ super().__init__(lr=lr, last_epoch=last_epoch)
223
+
224
+ assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
225
+ assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
226
+ 'milestones should be a sequence of increasing integers.'
227
+ )
228
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
229
+ self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
230
+ self.gamma = gamma
231
+
232
+ def __call__(self, i=None):
233
+ i = (self.last_epoch.value + 1) if i is None else i
234
+ conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
235
+ p = jnp.argmax(conditions)
236
+ return self.lr * self.gamma ** p
237
+
238
+ def extra_repr(self):
239
+ return f', milestones={self.milestones}, gamma={self.gamma}'
240
+
241
+
242
+ class CosineAnnealingLR(LearningRateScheduler):
243
+ r"""Set the learning rate of each parameter group using a cosine annealing
244
+ schedule, where :math:`\eta_{max}` is set to the initial lr and
245
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
246
+
247
+ .. math::
248
+ \begin{aligned}
249
+ \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
250
+ + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
251
+ & T_{cur} \neq (2k+1)T_{max}; \\
252
+ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
253
+ \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
254
+ & T_{cur} = (2k+1)T_{max}.
255
+ \end{aligned}
256
+
257
+ When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
258
+ is defined recursively, the learning rate can be simultaneously modified
259
+ outside this scheduler by other operators. If the learning rate is set
260
+ solely by this scheduler, the learning rate at each step becomes:
261
+
262
+ .. math::
263
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
264
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
265
+
266
+ It has been proposed in
267
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
268
+ implements the cosine annealing part of SGDR, and not the restarts.
269
+
270
+ Parameters
271
+ ----------
272
+ lr: float
273
+ Initial learning rate.
274
+ T_max: int
275
+ Maximum number of iterations.
276
+ eta_min: float
277
+ Minimum learning rate. Default: 0.
278
+ last_epoch: int
279
+ The index of last epoch. Default: -1.
280
+
281
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
282
+ https://arxiv.org/abs/1608.03983
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ lr: float,
288
+ T_max: int,
289
+ eta_min: float = 0.,
290
+ last_epoch: int = -1,
291
+ ):
292
+ super().__init__(lr=lr, last_epoch=last_epoch)
293
+
294
+ assert T_max >= 1, 'T_max should be greater than or equal to 1.'
295
+ self._init_epoch = last_epoch
296
+ self.T_max = T_max
297
+ self.eta_min = eta_min
298
+
299
+ def __call__(self, i=None):
300
+ i = (self.last_epoch.value + 1) if i is None else i
301
+ return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
302
+
303
+ def extra_repr(self):
304
+ return f', T_max={self.T_max}, eta_min={self.eta_min}'
305
+
306
+
307
+ class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
308
+ """Set the learning rate of each parameter group using a cosine annealing
309
+ schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
310
+ is the number of epochs since the last restart and :math:`T_{i}` is the number
311
+ of epochs between two warm restarts in SGDR:
312
+
313
+ .. math::
314
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
315
+ \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
316
+
317
+ When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
318
+ When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
319
+
320
+ It has been proposed in
321
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
322
+
323
+ Parameters
324
+ ----------
325
+ lr: float
326
+ Initial learning rate.
327
+ num_call_per_epoch: int
328
+ The number the scheduler to call in each epoch.
329
+ This usually means the number of batch in each epoch training.
330
+ T_0: int
331
+ Number of iterations for the first restart.
332
+ T_mult: int
333
+ A factor increases :math:`T_{i}` after a restart. Default: 1.
334
+ eta_min: float
335
+ Minimum learning rate. Default: 0.
336
+ last_call: int
337
+ The index of last call. Default: -1.
338
+
339
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
340
+ https://arxiv.org/abs/1608.03983
341
+ """
342
+
343
+ def __init__(
344
+ self,
345
+ lr: float,
346
+ num_call_per_epoch: int,
347
+ T_0: int,
348
+ T_mult: int = 1,
349
+ eta_min: float = 0.,
350
+ last_epoch: int = -1,
351
+ last_call: int = -1
352
+ ):
353
+ super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
354
+ if T_0 <= 0 or not isinstance(T_0, int):
355
+ raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
356
+ if T_mult < 1 or not isinstance(T_mult, int):
357
+ raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
358
+
359
+ self.T_mult = T_mult
360
+ self.eta_min = eta_min
361
+ self.T_0 = T_0
362
+ self.num_call_per_epoch = num_call_per_epoch
363
+
364
+ def _cond1(self, epoch):
365
+ if self.T_mult == 1:
366
+ T_cur = epoch % self.T_0
367
+ T_i = self.T_0
368
+ else:
369
+ n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
370
+ T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
371
+ T_i = self.T_0 * self.T_mult ** n
372
+ return T_cur, T_i
373
+
374
+ def _cond2(self, epoch):
375
+ return epoch, self.T_0
376
+
377
+ def __call__(self, i=None):
378
+ epoch = self.current_epoch(i)
379
+ T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
380
+ return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
381
+
382
+ def current_epoch(self, i=None):
383
+ i = (self.last_call.value + 1) if i is None else i
384
+ return jnp.floor(i / self.num_call_per_epoch)
385
+
386
+ def extra_repr(self):
387
+ return f', T_0={self.T_0}, T_mult={self.T_mult}, eta_min={self.eta_min}'
388
+
389
+
390
+ class ExponentialLR(LearningRateScheduler):
391
+ """Decays the learning rate of each parameter group by gamma every epoch.
392
+ When last_epoch=-1, sets initial lr as lr.
393
+
394
+ Parameters
395
+ ----------
396
+ lr: float
397
+ Initial learning rate.
398
+ gamma: float
399
+ Multiplicative factor of learning rate decay.
400
+ last_epoch: int
401
+ The index of last epoch. Default: -1.
402
+ """
403
+
404
+ def __init__(self,
405
+ lr: float,
406
+ gamma: float,
407
+ last_epoch: int = -1):
408
+ super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
409
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
410
+ self.gamma = gamma
411
+
412
+ def __call__(self, i: int = None):
413
+ i = (self.last_epoch.value + 1) if i is None else i
414
+ return self.lr * self.gamma ** i
415
+
416
+ def extra_repr(self):
417
+ return f', gamma={self.gamma}'
418
+
419
+
420
+ class ExponentialDecayLR(CallBasedLRScheduler):
421
+ def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
422
+ super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
423
+ self.decay_steps = decay_steps
424
+ self.decay_rate = decay_rate
425
+
426
+ def __call__(self, i=None):
427
+ i = (self.last_call.value + 1) if i is None else i
428
+ return self.lr * self.decay_rate ** (i / self.decay_steps)
429
+
430
+ def extra_repr(self):
431
+ return f', decay_steps={self.decay_steps}, decay_rate={self.decay_rate}'
432
+
433
+
434
+ class InverseTimeDecayLR(ExponentialDecayLR):
435
+ def __init__(self, lr, decay_steps, decay_rate, staircase=False,
436
+ last_epoch: int = -1, last_call: int = -1):
437
+ super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
438
+ self.staircase = staircase
439
+
440
+ def __call__(self, i=None):
441
+ i = (self.last_call.value + 1) if i is None else i
442
+ if self.staircase:
443
+ return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
444
+ else:
445
+ return self.lr / (1 + self.decay_rate * i / self.decay_steps)
446
+
447
+ def extra_repr(self):
448
+ return f', decay_steps={self.decay_steps}, decay_rate={self.decay_rate}, staircase={self.staircase}'
449
+
450
+
451
+ class PolynomialDecayLR(CallBasedLRScheduler):
452
+ def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
453
+ super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
454
+ self.decay_steps = decay_steps
455
+ self.final_lr = final_lr
456
+ self.power = power
457
+
458
+ def __call__(self, i=None):
459
+ i = (self.last_call.value + 1) if i is None else i
460
+ i = jnp.minimum(i, self.decay_steps)
461
+ step_mult = (1 - i / self.decay_steps) ** self.power
462
+ return step_mult * (self.lr - self.final_lr) + self.final_lr
463
+
464
+ def extra_repr(self):
465
+ return f', decay_steps={self.decay_steps}, final_lr={self.final_lr}, power={self.power}'
466
+
467
+
468
+ class PiecewiseConstantLR(CallBasedLRScheduler):
469
+ def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
470
+ super().__init__(0., last_epoch=last_epoch, last_call=last_call)
471
+
472
+ boundaries = jnp.array(boundaries)
473
+ values = jnp.array(values)
474
+ if not boundaries.ndim == values.ndim == 1:
475
+ raise ValueError("boundaries and values must be sequences")
476
+ if not boundaries.shape[0] == values.shape[0] - 1:
477
+ raise ValueError("boundaries length must be one shorter than values length")
478
+ self.boundaries = boundaries
479
+ self.values = values
480
+
481
+ def __call__(self, i=None):
482
+ i = (self.last_call.value + 1) if i is None else i
483
+ return self.values[jnp.sum(i > self.boundaries)]
484
+
485
+ def extra_repr(self):
486
+ return f', boundaries={self.boundaries}, values={self.values}'
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate as bst
22
+
23
+
24
+ class TestMultiStepLR(unittest.TestCase):
25
+ def test1(self):
26
+ lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
27
+ for i in range(40):
28
+ r = lr(i)
29
+ if i < 10:
30
+ self.assertEqual(r, 0.1)
31
+ elif i < 20:
32
+ self.assertTrue(jnp.allclose(r, 0.01))
33
+ elif i < 30:
34
+ self.assertTrue(jnp.allclose(r, 0.001))
35
+ else:
36
+ self.assertTrue(jnp.allclose(r, 0.0001))