brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,448 +1,448 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Sequence, Union
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
-
24
- from brainstate import environ
25
- from brainstate._state import State, LongTermState
26
- from brainstate.graph import Node
27
-
28
- __all__ = [
29
- 'LearningRateScheduler',
30
- 'ConstantLR',
31
- 'StepLR',
32
- 'MultiStepLR',
33
- 'CosineAnnealingLR',
34
- 'CosineAnnealingWarmRestarts',
35
- 'ExponentialLR',
36
- 'ExponentialDecayLR',
37
- 'InverseTimeDecayLR',
38
- 'PolynomialDecayLR',
39
- 'PiecewiseConstantLR',
40
- ]
41
-
42
-
43
- # learning rate schedules #
44
- # ----------------------- #
45
-
46
-
47
- def make_schedule(scalar_or_schedule):
48
- if isinstance(scalar_or_schedule, LearningRateScheduler):
49
- return scalar_or_schedule
50
- elif isinstance(scalar_or_schedule, (int, float, State)):
51
- return ConstantLR(scalar_or_schedule)
52
- else:
53
- raise TypeError(type(scalar_or_schedule))
54
-
55
-
56
- class LearningRateScheduler(Node):
57
- """
58
- The learning rate scheduler.
59
-
60
- Parameters
61
- ----------
62
- lr: float, State
63
- The learning rate.
64
- last_epoch: int
65
- The index of last epoch.
66
-
67
- """
68
-
69
- def __init__(self, lr: Union[float, State], last_epoch: int = -1):
70
- super().__init__()
71
- if isinstance(lr, State):
72
- lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
73
- else:
74
- lr = jnp.asarray(lr, dtype=environ.dftype())
75
- self._lr = lr
76
- assert last_epoch >= -1, 'last_epoch should be greater than -1.'
77
- self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
78
-
79
- @property
80
- def lr(self):
81
- return self._lr.value if isinstance(self._lr, State) else self._lr
82
-
83
- @lr.setter
84
- def lr(self, value):
85
- if isinstance(value, State):
86
- value = value.value
87
- assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
88
- if isinstance(self._lr, State):
89
- self._lr.value = value
90
- else:
91
- self._lr = value
92
-
93
- def step_epoch(self):
94
- """
95
- Update the epoch count.
96
- """
97
- self.last_epoch.value += 1
98
-
99
- def step_call(self):
100
- """
101
- Update the call count.
102
- """
103
- pass
104
-
105
- def __call__(self, i=None):
106
- raise NotImplementedError
107
-
108
-
109
- class ConstantLR(LearningRateScheduler):
110
- """
111
- Constant learning rate scheduler.
112
- """
113
-
114
- def __call__(self, i=None):
115
- return self.lr
116
-
117
-
118
- class CallBasedLRScheduler(LearningRateScheduler):
119
- """
120
- The learning rate scheduler based on the call count.
121
-
122
- Parameters
123
- ----------
124
- lr: float
125
- The learning rate.
126
- last_epoch: int
127
- The index of last epoch.
128
- last_call: int
129
- The index of last call.
130
-
131
- """
132
-
133
- def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
134
- super().__init__(lr=lr, last_epoch=last_epoch)
135
-
136
- assert last_call >= -1, 'last_call should be greater than -1.'
137
- self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
138
-
139
- def step_call(self):
140
- """
141
- Update the call count.
142
- """
143
- self.last_call.value += 1
144
-
145
-
146
- class StepLR(LearningRateScheduler):
147
- """Decays the learning rate of each parameter group by gamma every
148
- `step_size` epochs.
149
-
150
- Parameters
151
- ----------
152
- lr: float
153
- Initial learning rate.
154
- step_size: int
155
- Period of learning rate decay.
156
- gamma: float
157
- Multiplicative factor of learning rate decay.
158
- Default: 0.1.
159
- last_epoch: int
160
- The index of last epoch. Default: -1.
161
- """
162
-
163
- def __init__(
164
- self,
165
- lr: float,
166
- step_size: int,
167
- gamma: float = 0.1,
168
- last_epoch: int = -1
169
- ):
170
- super().__init__(lr=lr, last_epoch=last_epoch)
171
-
172
- assert step_size >= 1, 'step_size should be greater than or equal to 1.'
173
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
174
- self.step_size = step_size
175
- self.gamma = gamma
176
-
177
- def __call__(self, i=None):
178
- i = (self.last_epoch.value + 1) if i is None else i
179
- return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
180
-
181
-
182
- class MultiStepLR(LearningRateScheduler):
183
- """Decays the learning rate of each parameter group by gamma once the
184
- number of epoch reaches one of the milestones. Notice that such decay can
185
- happen simultaneously with other changes to the learning rate from outside
186
- this scheduler. When last_epoch=-1, sets initial lr as lr.
187
-
188
- Parameters
189
- ----------
190
- lr: float
191
- Initial learning rate.
192
- milestones: sequence of int
193
- List of epoch indices. Must be increasing.
194
- gamma: float
195
- Multiplicative factor of learning rate decay.
196
- Default: 0.1.
197
- last_epoch: int
198
- The index of last epoch. Default: -1.
199
- """
200
-
201
- def __init__(
202
- self,
203
- lr: float,
204
- milestones: Sequence[int],
205
- gamma: float = 0.1,
206
- last_epoch: int = -1
207
- ):
208
- super().__init__(lr=lr, last_epoch=last_epoch)
209
-
210
- assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
211
- assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
212
- 'milestones should be a sequence of increasing integers.'
213
- )
214
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
215
- self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
216
- self.gamma = gamma
217
-
218
- def __call__(self, i=None):
219
- i = (self.last_epoch.value + 1) if i is None else i
220
- conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
221
- p = jnp.argmax(conditions)
222
- return self.lr * self.gamma ** p
223
-
224
-
225
- class CosineAnnealingLR(LearningRateScheduler):
226
- r"""Set the learning rate of each parameter group using a cosine annealing
227
- schedule, where :math:`\eta_{max}` is set to the initial lr and
228
- :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
229
-
230
- .. math::
231
- \begin{aligned}
232
- \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
233
- + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
234
- & T_{cur} \neq (2k+1)T_{max}; \\
235
- \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
236
- \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
237
- & T_{cur} = (2k+1)T_{max}.
238
- \end{aligned}
239
-
240
- When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
241
- is defined recursively, the learning rate can be simultaneously modified
242
- outside this scheduler by other operators. If the learning rate is set
243
- solely by this scheduler, the learning rate at each step becomes:
244
-
245
- .. math::
246
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
247
- \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
248
-
249
- It has been proposed in
250
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
251
- implements the cosine annealing part of SGDR, and not the restarts.
252
-
253
- Parameters
254
- ----------
255
- lr: float
256
- Initial learning rate.
257
- T_max: int
258
- Maximum number of iterations.
259
- eta_min: float
260
- Minimum learning rate. Default: 0.
261
- last_epoch: int
262
- The index of last epoch. Default: -1.
263
-
264
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
265
- https://arxiv.org/abs/1608.03983
266
- """
267
-
268
- def __init__(
269
- self,
270
- lr: float,
271
- T_max: int,
272
- eta_min: float = 0.,
273
- last_epoch: int = -1,
274
- ):
275
- super().__init__(lr=lr, last_epoch=last_epoch)
276
-
277
- assert T_max >= 1, 'T_max should be greater than or equal to 1.'
278
- self._init_epoch = last_epoch
279
- self.T_max = T_max
280
- self.eta_min = eta_min
281
-
282
- def __call__(self, i=None):
283
- i = (self.last_epoch.value + 1) if i is None else i
284
- return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
285
-
286
-
287
- class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
288
- r"""Set the learning rate of each parameter group using a cosine annealing
289
- schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
290
- is the number of epochs since the last restart and :math:`T_{i}` is the number
291
- of epochs between two warm restarts in SGDR:
292
-
293
- .. math::
294
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
295
- \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
296
-
297
- When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
298
- When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
299
-
300
- It has been proposed in
301
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
302
-
303
- Parameters
304
- ----------
305
- lr: float
306
- Initial learning rate.
307
- num_call_per_epoch: int
308
- The number the scheduler to call in each epoch.
309
- This usually means the number of batch in each epoch training.
310
- T_0: int
311
- Number of iterations for the first restart.
312
- T_mult: int
313
- A factor increases :math:`T_{i}` after a restart. Default: 1.
314
- eta_min: float
315
- Minimum learning rate. Default: 0.
316
- last_call: int
317
- The index of last call. Default: -1.
318
-
319
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
320
- https://arxiv.org/abs/1608.03983
321
- """
322
-
323
- def __init__(
324
- self,
325
- lr: float,
326
- num_call_per_epoch: int,
327
- T_0: int,
328
- T_mult: int = 1,
329
- eta_min: float = 0.,
330
- last_epoch: int = -1,
331
- last_call: int = -1
332
- ):
333
- super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
334
- if T_0 <= 0 or not isinstance(T_0, int):
335
- raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
336
- if T_mult < 1 or not isinstance(T_mult, int):
337
- raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
338
-
339
- self.T_mult = T_mult
340
- self.eta_min = eta_min
341
- self.T_0 = T_0
342
- self.num_call_per_epoch = num_call_per_epoch
343
-
344
- def _cond1(self, epoch):
345
- if self.T_mult == 1:
346
- T_cur = epoch % self.T_0
347
- T_i = self.T_0
348
- else:
349
- n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
350
- T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
351
- T_i = self.T_0 * self.T_mult ** n
352
- return T_cur, T_i
353
-
354
- def _cond2(self, epoch):
355
- return epoch, self.T_0
356
-
357
- def __call__(self, i=None):
358
- epoch = self.current_epoch(i)
359
- T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
360
- return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
361
-
362
- def current_epoch(self, i=None):
363
- i = (self.last_call.value + 1) if i is None else i
364
- return jnp.floor(i / self.num_call_per_epoch)
365
-
366
-
367
- class ExponentialLR(LearningRateScheduler):
368
- """Decays the learning rate of each parameter group by gamma every epoch.
369
- When last_epoch=-1, sets initial lr as lr.
370
-
371
- Parameters
372
- ----------
373
- lr: float
374
- Initial learning rate.
375
- gamma: float
376
- Multiplicative factor of learning rate decay.
377
- last_epoch: int
378
- The index of last epoch. Default: -1.
379
- """
380
-
381
- def __init__(self,
382
- lr: float,
383
- gamma: float,
384
- last_epoch: int = -1):
385
- super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
386
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
387
- self.gamma = gamma
388
-
389
- def __call__(self, i: int = None):
390
- i = (self.last_epoch.value + 1) if i is None else i
391
- return self.lr * self.gamma ** i
392
-
393
-
394
- class ExponentialDecayLR(CallBasedLRScheduler):
395
- def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
396
- super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
397
- self.decay_steps = decay_steps
398
- self.decay_rate = decay_rate
399
-
400
- def __call__(self, i=None):
401
- i = (self.last_call.value + 1) if i is None else i
402
- return self.lr * self.decay_rate ** (i / self.decay_steps)
403
-
404
-
405
- class InverseTimeDecayLR(ExponentialDecayLR):
406
- def __init__(self, lr, decay_steps, decay_rate, staircase=False,
407
- last_epoch: int = -1, last_call: int = -1):
408
- super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
409
- self.staircase = staircase
410
-
411
- def __call__(self, i=None):
412
- i = (self.last_call.value + 1) if i is None else i
413
- if self.staircase:
414
- return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
415
- else:
416
- return self.lr / (1 + self.decay_rate * i / self.decay_steps)
417
-
418
-
419
- class PolynomialDecayLR(CallBasedLRScheduler):
420
- def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
421
- super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
422
- self.decay_steps = decay_steps
423
- self.final_lr = final_lr
424
- self.power = power
425
-
426
- def __call__(self, i=None):
427
- i = (self.last_call.value + 1) if i is None else i
428
- i = jnp.minimum(i, self.decay_steps)
429
- step_mult = (1 - i / self.decay_steps) ** self.power
430
- return step_mult * (self.lr - self.final_lr) + self.final_lr
431
-
432
-
433
- class PiecewiseConstantLR(CallBasedLRScheduler):
434
- def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
435
- super().__init__(0., last_epoch=last_epoch, last_call=last_call)
436
-
437
- boundaries = jnp.array(boundaries)
438
- values = jnp.array(values)
439
- if not boundaries.ndim == values.ndim == 1:
440
- raise ValueError("boundaries and values must be sequences")
441
- if not boundaries.shape[0] == values.shape[0] - 1:
442
- raise ValueError("boundaries length must be one shorter than values length")
443
- self.boundaries = boundaries
444
- self.values = values
445
-
446
- def __call__(self, i=None):
447
- i = (self.last_call.value + 1) if i is None else i
448
- return self.values[jnp.sum(i > self.boundaries)]
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Sequence, Union
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ from brainstate import environ
25
+ from brainstate._state import State, LongTermState
26
+ from brainstate.graph import Node
27
+
28
+ __all__ = [
29
+ 'LearningRateScheduler',
30
+ 'ConstantLR',
31
+ 'StepLR',
32
+ 'MultiStepLR',
33
+ 'CosineAnnealingLR',
34
+ 'CosineAnnealingWarmRestarts',
35
+ 'ExponentialLR',
36
+ 'ExponentialDecayLR',
37
+ 'InverseTimeDecayLR',
38
+ 'PolynomialDecayLR',
39
+ 'PiecewiseConstantLR',
40
+ ]
41
+
42
+
43
+ # learning rate schedules #
44
+ # ----------------------- #
45
+
46
+
47
+ def make_schedule(scalar_or_schedule):
48
+ if isinstance(scalar_or_schedule, LearningRateScheduler):
49
+ return scalar_or_schedule
50
+ elif isinstance(scalar_or_schedule, (int, float, State)):
51
+ return ConstantLR(scalar_or_schedule)
52
+ else:
53
+ raise TypeError(type(scalar_or_schedule))
54
+
55
+
56
+ class LearningRateScheduler(Node):
57
+ """
58
+ The learning rate scheduler.
59
+
60
+ Parameters
61
+ ----------
62
+ lr: float, State
63
+ The learning rate.
64
+ last_epoch: int
65
+ The index of last epoch.
66
+
67
+ """
68
+
69
+ def __init__(self, lr: Union[float, State], last_epoch: int = -1):
70
+ super().__init__()
71
+ if isinstance(lr, State):
72
+ lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
73
+ else:
74
+ lr = jnp.asarray(lr, dtype=environ.dftype())
75
+ self._lr = lr
76
+ assert last_epoch >= -1, 'last_epoch should be greater than -1.'
77
+ self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
78
+
79
+ @property
80
+ def lr(self):
81
+ return self._lr.value if isinstance(self._lr, State) else self._lr
82
+
83
+ @lr.setter
84
+ def lr(self, value):
85
+ if isinstance(value, State):
86
+ value = value.value
87
+ assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
88
+ if isinstance(self._lr, State):
89
+ self._lr.value = value
90
+ else:
91
+ self._lr = value
92
+
93
+ def step_epoch(self):
94
+ """
95
+ Update the epoch count.
96
+ """
97
+ self.last_epoch.value += 1
98
+
99
+ def step_call(self):
100
+ """
101
+ Update the call count.
102
+ """
103
+ pass
104
+
105
+ def __call__(self, i=None):
106
+ raise NotImplementedError
107
+
108
+
109
+ class ConstantLR(LearningRateScheduler):
110
+ """
111
+ Constant learning rate scheduler.
112
+ """
113
+
114
+ def __call__(self, i=None):
115
+ return self.lr
116
+
117
+
118
+ class CallBasedLRScheduler(LearningRateScheduler):
119
+ """
120
+ The learning rate scheduler based on the call count.
121
+
122
+ Parameters
123
+ ----------
124
+ lr: float
125
+ The learning rate.
126
+ last_epoch: int
127
+ The index of last epoch.
128
+ last_call: int
129
+ The index of last call.
130
+
131
+ """
132
+
133
+ def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
134
+ super().__init__(lr=lr, last_epoch=last_epoch)
135
+
136
+ assert last_call >= -1, 'last_call should be greater than -1.'
137
+ self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
138
+
139
+ def step_call(self):
140
+ """
141
+ Update the call count.
142
+ """
143
+ self.last_call.value += 1
144
+
145
+
146
+ class StepLR(LearningRateScheduler):
147
+ """Decays the learning rate of each parameter group by gamma every
148
+ `step_size` epochs.
149
+
150
+ Parameters
151
+ ----------
152
+ lr: float
153
+ Initial learning rate.
154
+ step_size: int
155
+ Period of learning rate decay.
156
+ gamma: float
157
+ Multiplicative factor of learning rate decay.
158
+ Default: 0.1.
159
+ last_epoch: int
160
+ The index of last epoch. Default: -1.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ lr: float,
166
+ step_size: int,
167
+ gamma: float = 0.1,
168
+ last_epoch: int = -1
169
+ ):
170
+ super().__init__(lr=lr, last_epoch=last_epoch)
171
+
172
+ assert step_size >= 1, 'step_size should be greater than or equal to 1.'
173
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
174
+ self.step_size = step_size
175
+ self.gamma = gamma
176
+
177
+ def __call__(self, i=None):
178
+ i = (self.last_epoch.value + 1) if i is None else i
179
+ return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
180
+
181
+
182
+ class MultiStepLR(LearningRateScheduler):
183
+ """Decays the learning rate of each parameter group by gamma once the
184
+ number of epoch reaches one of the milestones. Notice that such decay can
185
+ happen simultaneously with other changes to the learning rate from outside
186
+ this scheduler. When last_epoch=-1, sets initial lr as lr.
187
+
188
+ Parameters
189
+ ----------
190
+ lr: float
191
+ Initial learning rate.
192
+ milestones: sequence of int
193
+ List of epoch indices. Must be increasing.
194
+ gamma: float
195
+ Multiplicative factor of learning rate decay.
196
+ Default: 0.1.
197
+ last_epoch: int
198
+ The index of last epoch. Default: -1.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ lr: float,
204
+ milestones: Sequence[int],
205
+ gamma: float = 0.1,
206
+ last_epoch: int = -1
207
+ ):
208
+ super().__init__(lr=lr, last_epoch=last_epoch)
209
+
210
+ assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
211
+ assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
212
+ 'milestones should be a sequence of increasing integers.'
213
+ )
214
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
215
+ self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
216
+ self.gamma = gamma
217
+
218
+ def __call__(self, i=None):
219
+ i = (self.last_epoch.value + 1) if i is None else i
220
+ conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
221
+ p = jnp.argmax(conditions)
222
+ return self.lr * self.gamma ** p
223
+
224
+
225
+ class CosineAnnealingLR(LearningRateScheduler):
226
+ r"""Set the learning rate of each parameter group using a cosine annealing
227
+ schedule, where :math:`\eta_{max}` is set to the initial lr and
228
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
229
+
230
+ .. math::
231
+ \begin{aligned}
232
+ \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
233
+ + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
234
+ & T_{cur} \neq (2k+1)T_{max}; \\
235
+ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
236
+ \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
237
+ & T_{cur} = (2k+1)T_{max}.
238
+ \end{aligned}
239
+
240
+ When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
241
+ is defined recursively, the learning rate can be simultaneously modified
242
+ outside this scheduler by other operators. If the learning rate is set
243
+ solely by this scheduler, the learning rate at each step becomes:
244
+
245
+ .. math::
246
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
247
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
248
+
249
+ It has been proposed in
250
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
251
+ implements the cosine annealing part of SGDR, and not the restarts.
252
+
253
+ Parameters
254
+ ----------
255
+ lr: float
256
+ Initial learning rate.
257
+ T_max: int
258
+ Maximum number of iterations.
259
+ eta_min: float
260
+ Minimum learning rate. Default: 0.
261
+ last_epoch: int
262
+ The index of last epoch. Default: -1.
263
+
264
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
265
+ https://arxiv.org/abs/1608.03983
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ lr: float,
271
+ T_max: int,
272
+ eta_min: float = 0.,
273
+ last_epoch: int = -1,
274
+ ):
275
+ super().__init__(lr=lr, last_epoch=last_epoch)
276
+
277
+ assert T_max >= 1, 'T_max should be greater than or equal to 1.'
278
+ self._init_epoch = last_epoch
279
+ self.T_max = T_max
280
+ self.eta_min = eta_min
281
+
282
+ def __call__(self, i=None):
283
+ i = (self.last_epoch.value + 1) if i is None else i
284
+ return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
285
+
286
+
287
+ class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
288
+ r"""Set the learning rate of each parameter group using a cosine annealing
289
+ schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
290
+ is the number of epochs since the last restart and :math:`T_{i}` is the number
291
+ of epochs between two warm restarts in SGDR:
292
+
293
+ .. math::
294
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
295
+ \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
296
+
297
+ When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
298
+ When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
299
+
300
+ It has been proposed in
301
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
302
+
303
+ Parameters
304
+ ----------
305
+ lr: float
306
+ Initial learning rate.
307
+ num_call_per_epoch: int
308
+ The number the scheduler to call in each epoch.
309
+ This usually means the number of batch in each epoch training.
310
+ T_0: int
311
+ Number of iterations for the first restart.
312
+ T_mult: int
313
+ A factor increases :math:`T_{i}` after a restart. Default: 1.
314
+ eta_min: float
315
+ Minimum learning rate. Default: 0.
316
+ last_call: int
317
+ The index of last call. Default: -1.
318
+
319
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
320
+ https://arxiv.org/abs/1608.03983
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ lr: float,
326
+ num_call_per_epoch: int,
327
+ T_0: int,
328
+ T_mult: int = 1,
329
+ eta_min: float = 0.,
330
+ last_epoch: int = -1,
331
+ last_call: int = -1
332
+ ):
333
+ super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
334
+ if T_0 <= 0 or not isinstance(T_0, int):
335
+ raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
336
+ if T_mult < 1 or not isinstance(T_mult, int):
337
+ raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
338
+
339
+ self.T_mult = T_mult
340
+ self.eta_min = eta_min
341
+ self.T_0 = T_0
342
+ self.num_call_per_epoch = num_call_per_epoch
343
+
344
+ def _cond1(self, epoch):
345
+ if self.T_mult == 1:
346
+ T_cur = epoch % self.T_0
347
+ T_i = self.T_0
348
+ else:
349
+ n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
350
+ T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
351
+ T_i = self.T_0 * self.T_mult ** n
352
+ return T_cur, T_i
353
+
354
+ def _cond2(self, epoch):
355
+ return epoch, self.T_0
356
+
357
+ def __call__(self, i=None):
358
+ epoch = self.current_epoch(i)
359
+ T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
360
+ return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
361
+
362
+ def current_epoch(self, i=None):
363
+ i = (self.last_call.value + 1) if i is None else i
364
+ return jnp.floor(i / self.num_call_per_epoch)
365
+
366
+
367
+ class ExponentialLR(LearningRateScheduler):
368
+ """Decays the learning rate of each parameter group by gamma every epoch.
369
+ When last_epoch=-1, sets initial lr as lr.
370
+
371
+ Parameters
372
+ ----------
373
+ lr: float
374
+ Initial learning rate.
375
+ gamma: float
376
+ Multiplicative factor of learning rate decay.
377
+ last_epoch: int
378
+ The index of last epoch. Default: -1.
379
+ """
380
+
381
+ def __init__(self,
382
+ lr: float,
383
+ gamma: float,
384
+ last_epoch: int = -1):
385
+ super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
386
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
387
+ self.gamma = gamma
388
+
389
+ def __call__(self, i: int = None):
390
+ i = (self.last_epoch.value + 1) if i is None else i
391
+ return self.lr * self.gamma ** i
392
+
393
+
394
+ class ExponentialDecayLR(CallBasedLRScheduler):
395
+ def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
396
+ super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
397
+ self.decay_steps = decay_steps
398
+ self.decay_rate = decay_rate
399
+
400
+ def __call__(self, i=None):
401
+ i = (self.last_call.value + 1) if i is None else i
402
+ return self.lr * self.decay_rate ** (i / self.decay_steps)
403
+
404
+
405
+ class InverseTimeDecayLR(ExponentialDecayLR):
406
+ def __init__(self, lr, decay_steps, decay_rate, staircase=False,
407
+ last_epoch: int = -1, last_call: int = -1):
408
+ super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
409
+ self.staircase = staircase
410
+
411
+ def __call__(self, i=None):
412
+ i = (self.last_call.value + 1) if i is None else i
413
+ if self.staircase:
414
+ return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
415
+ else:
416
+ return self.lr / (1 + self.decay_rate * i / self.decay_steps)
417
+
418
+
419
+ class PolynomialDecayLR(CallBasedLRScheduler):
420
+ def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
421
+ super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
422
+ self.decay_steps = decay_steps
423
+ self.final_lr = final_lr
424
+ self.power = power
425
+
426
+ def __call__(self, i=None):
427
+ i = (self.last_call.value + 1) if i is None else i
428
+ i = jnp.minimum(i, self.decay_steps)
429
+ step_mult = (1 - i / self.decay_steps) ** self.power
430
+ return step_mult * (self.lr - self.final_lr) + self.final_lr
431
+
432
+
433
+ class PiecewiseConstantLR(CallBasedLRScheduler):
434
+ def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
435
+ super().__init__(0., last_epoch=last_epoch, last_call=last_call)
436
+
437
+ boundaries = jnp.array(boundaries)
438
+ values = jnp.array(values)
439
+ if not boundaries.ndim == values.ndim == 1:
440
+ raise ValueError("boundaries and values must be sequences")
441
+ if not boundaries.shape[0] == values.shape[0] - 1:
442
+ raise ValueError("boundaries length must be one shorter than values length")
443
+ self.boundaries = boundaries
444
+ self.values = values
445
+
446
+ def __call__(self, i=None):
447
+ i = (self.last_call.value + 1) if i is None else i
448
+ return self.values[jnp.sum(i > self.boundaries)]