brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from ._lr_scheduler import *
|
18
|
+
from ._lr_scheduler import __all__ as scheduler_all
|
19
|
+
from ._sgd_optimizer import *
|
20
|
+
from ._sgd_optimizer import __all__ as optimizer_all
|
21
|
+
|
22
|
+
__all__ = scheduler_all + optimizer_all
|
23
|
+
|
@@ -0,0 +1,486 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Sequence, Union
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
from .. import environ
|
25
|
+
from .._module import Module
|
26
|
+
from .._state import State, LongTermState
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
'LearningRateScheduler',
|
30
|
+
'ConstantLR',
|
31
|
+
'StepLR',
|
32
|
+
'MultiStepLR',
|
33
|
+
'CosineAnnealingLR',
|
34
|
+
'CosineAnnealingWarmRestarts',
|
35
|
+
'ExponentialLR',
|
36
|
+
'ExponentialDecayLR',
|
37
|
+
'InverseTimeDecayLR',
|
38
|
+
'PolynomialDecayLR',
|
39
|
+
'PiecewiseConstantLR',
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
# learning rate schedules #
|
44
|
+
# ----------------------- #
|
45
|
+
|
46
|
+
|
47
|
+
def make_schedule(scalar_or_schedule):
|
48
|
+
if isinstance(scalar_or_schedule, LearningRateScheduler):
|
49
|
+
return scalar_or_schedule
|
50
|
+
elif isinstance(scalar_or_schedule, (int, float, State)):
|
51
|
+
return ConstantLR(scalar_or_schedule)
|
52
|
+
else:
|
53
|
+
raise TypeError(type(scalar_or_schedule))
|
54
|
+
|
55
|
+
|
56
|
+
class LearningRateScheduler(Module):
|
57
|
+
"""
|
58
|
+
The learning rate scheduler.
|
59
|
+
|
60
|
+
Attributes
|
61
|
+
----------
|
62
|
+
lr: float, State
|
63
|
+
The learning rate.
|
64
|
+
last_epoch: int
|
65
|
+
The index of last epoch.
|
66
|
+
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(self, lr: Union[float, State], last_epoch: int = -1):
|
70
|
+
super().__init__()
|
71
|
+
if isinstance(lr, State):
|
72
|
+
lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
|
73
|
+
else:
|
74
|
+
lr = jnp.asarray(lr, dtype=environ.dftype())
|
75
|
+
self._lr = lr
|
76
|
+
assert last_epoch >= -1, 'last_epoch should be greater than -1.'
|
77
|
+
self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
|
78
|
+
|
79
|
+
@property
|
80
|
+
def lr(self):
|
81
|
+
return self._lr.value if isinstance(self._lr, State) else self._lr
|
82
|
+
|
83
|
+
@lr.setter
|
84
|
+
def lr(self, value):
|
85
|
+
if isinstance(value, State):
|
86
|
+
value = value.value
|
87
|
+
assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
|
88
|
+
if isinstance(self._lr, State):
|
89
|
+
self._lr.value = value
|
90
|
+
else:
|
91
|
+
self._lr = value
|
92
|
+
|
93
|
+
def step_epoch(self):
|
94
|
+
"""
|
95
|
+
Update the epoch count.
|
96
|
+
"""
|
97
|
+
self.last_epoch.value += 1
|
98
|
+
|
99
|
+
def step_call(self):
|
100
|
+
"""
|
101
|
+
Update the call count.
|
102
|
+
"""
|
103
|
+
pass
|
104
|
+
|
105
|
+
def __repr__(self):
|
106
|
+
return f'{self.__class__.__name__}(lr={self.lr.value}, last_epoch={self.last_epoch.value}{self.extra_repr()})'
|
107
|
+
|
108
|
+
def extra_repr(self):
|
109
|
+
return ''
|
110
|
+
|
111
|
+
def __call__(self, i=None):
|
112
|
+
raise NotImplementedError
|
113
|
+
|
114
|
+
|
115
|
+
class ConstantLR(LearningRateScheduler):
|
116
|
+
"""
|
117
|
+
Constant learning rate scheduler.
|
118
|
+
"""
|
119
|
+
|
120
|
+
def __call__(self, i=None):
|
121
|
+
return self.lr
|
122
|
+
|
123
|
+
|
124
|
+
class CallBasedLRScheduler(LearningRateScheduler):
|
125
|
+
"""
|
126
|
+
The learning rate scheduler based on the call count.
|
127
|
+
|
128
|
+
Parameters
|
129
|
+
----------
|
130
|
+
lr: float
|
131
|
+
The learning rate.
|
132
|
+
last_epoch: int
|
133
|
+
The index of last epoch.
|
134
|
+
last_call: int
|
135
|
+
The index of last call.
|
136
|
+
|
137
|
+
"""
|
138
|
+
|
139
|
+
def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
|
140
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
141
|
+
|
142
|
+
assert last_call >= -1, 'last_call should be greater than -1.'
|
143
|
+
self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
|
144
|
+
|
145
|
+
def step_call(self):
|
146
|
+
"""
|
147
|
+
Update the call count.
|
148
|
+
"""
|
149
|
+
self.last_call.value += 1
|
150
|
+
|
151
|
+
def __repr__(self):
|
152
|
+
return (f'{self.__class__.__name__}(lr={self.lr.value}, '
|
153
|
+
f'last_epoch={self.last_epoch.value}, '
|
154
|
+
f'last_call={self.last_call.value}{self.extra_repr()})')
|
155
|
+
|
156
|
+
|
157
|
+
class StepLR(LearningRateScheduler):
|
158
|
+
"""Decays the learning rate of each parameter group by gamma every
|
159
|
+
`step_size` epochs.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
lr: float
|
164
|
+
Initial learning rate.
|
165
|
+
step_size: int
|
166
|
+
Period of learning rate decay.
|
167
|
+
gamma: float
|
168
|
+
Multiplicative factor of learning rate decay.
|
169
|
+
Default: 0.1.
|
170
|
+
last_epoch: int
|
171
|
+
The index of last epoch. Default: -1.
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
lr: float,
|
177
|
+
step_size: int,
|
178
|
+
gamma: float = 0.1,
|
179
|
+
last_epoch: int = -1
|
180
|
+
):
|
181
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
182
|
+
|
183
|
+
assert step_size >= 1, 'step_size should be greater than or equal to 1.'
|
184
|
+
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
185
|
+
self.step_size = step_size
|
186
|
+
self.gamma = gamma
|
187
|
+
|
188
|
+
def __call__(self, i=None):
|
189
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
190
|
+
return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
|
191
|
+
|
192
|
+
def extra_repr(self):
|
193
|
+
return f', gamma={self.gamma}, step_size={self.step_size}'
|
194
|
+
|
195
|
+
|
196
|
+
class MultiStepLR(LearningRateScheduler):
|
197
|
+
"""Decays the learning rate of each parameter group by gamma once the
|
198
|
+
number of epoch reaches one of the milestones. Notice that such decay can
|
199
|
+
happen simultaneously with other changes to the learning rate from outside
|
200
|
+
this scheduler. When last_epoch=-1, sets initial lr as lr.
|
201
|
+
|
202
|
+
Parameters
|
203
|
+
----------
|
204
|
+
lr: float
|
205
|
+
Initial learning rate.
|
206
|
+
milestones: sequence of int
|
207
|
+
List of epoch indices. Must be increasing.
|
208
|
+
gamma: float
|
209
|
+
Multiplicative factor of learning rate decay.
|
210
|
+
Default: 0.1.
|
211
|
+
last_epoch: int
|
212
|
+
The index of last epoch. Default: -1.
|
213
|
+
"""
|
214
|
+
|
215
|
+
def __init__(
|
216
|
+
self,
|
217
|
+
lr: float,
|
218
|
+
milestones: Sequence[int],
|
219
|
+
gamma: float = 0.1,
|
220
|
+
last_epoch: int = -1
|
221
|
+
):
|
222
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
223
|
+
|
224
|
+
assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
|
225
|
+
assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
|
226
|
+
'milestones should be a sequence of increasing integers.'
|
227
|
+
)
|
228
|
+
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
229
|
+
self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
|
230
|
+
self.gamma = gamma
|
231
|
+
|
232
|
+
def __call__(self, i=None):
|
233
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
234
|
+
conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
|
235
|
+
p = jnp.argmax(conditions)
|
236
|
+
return self.lr * self.gamma ** p
|
237
|
+
|
238
|
+
def extra_repr(self):
|
239
|
+
return f', milestones={self.milestones}, gamma={self.gamma}'
|
240
|
+
|
241
|
+
|
242
|
+
class CosineAnnealingLR(LearningRateScheduler):
|
243
|
+
r"""Set the learning rate of each parameter group using a cosine annealing
|
244
|
+
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
245
|
+
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
246
|
+
|
247
|
+
.. math::
|
248
|
+
\begin{aligned}
|
249
|
+
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
250
|
+
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
251
|
+
& T_{cur} \neq (2k+1)T_{max}; \\
|
252
|
+
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
253
|
+
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
254
|
+
& T_{cur} = (2k+1)T_{max}.
|
255
|
+
\end{aligned}
|
256
|
+
|
257
|
+
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
|
258
|
+
is defined recursively, the learning rate can be simultaneously modified
|
259
|
+
outside this scheduler by other operators. If the learning rate is set
|
260
|
+
solely by this scheduler, the learning rate at each step becomes:
|
261
|
+
|
262
|
+
.. math::
|
263
|
+
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
264
|
+
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
265
|
+
|
266
|
+
It has been proposed in
|
267
|
+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
268
|
+
implements the cosine annealing part of SGDR, and not the restarts.
|
269
|
+
|
270
|
+
Parameters
|
271
|
+
----------
|
272
|
+
lr: float
|
273
|
+
Initial learning rate.
|
274
|
+
T_max: int
|
275
|
+
Maximum number of iterations.
|
276
|
+
eta_min: float
|
277
|
+
Minimum learning rate. Default: 0.
|
278
|
+
last_epoch: int
|
279
|
+
The index of last epoch. Default: -1.
|
280
|
+
|
281
|
+
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
282
|
+
https://arxiv.org/abs/1608.03983
|
283
|
+
"""
|
284
|
+
|
285
|
+
def __init__(
|
286
|
+
self,
|
287
|
+
lr: float,
|
288
|
+
T_max: int,
|
289
|
+
eta_min: float = 0.,
|
290
|
+
last_epoch: int = -1,
|
291
|
+
):
|
292
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
293
|
+
|
294
|
+
assert T_max >= 1, 'T_max should be greater than or equal to 1.'
|
295
|
+
self._init_epoch = last_epoch
|
296
|
+
self.T_max = T_max
|
297
|
+
self.eta_min = eta_min
|
298
|
+
|
299
|
+
def __call__(self, i=None):
|
300
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
301
|
+
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
|
302
|
+
|
303
|
+
def extra_repr(self):
|
304
|
+
return f', T_max={self.T_max}, eta_min={self.eta_min}'
|
305
|
+
|
306
|
+
|
307
|
+
class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
|
308
|
+
"""Set the learning rate of each parameter group using a cosine annealing
|
309
|
+
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
310
|
+
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
311
|
+
of epochs between two warm restarts in SGDR:
|
312
|
+
|
313
|
+
.. math::
|
314
|
+
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
315
|
+
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
316
|
+
|
317
|
+
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
318
|
+
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
319
|
+
|
320
|
+
It has been proposed in
|
321
|
+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
322
|
+
|
323
|
+
Parameters
|
324
|
+
----------
|
325
|
+
lr: float
|
326
|
+
Initial learning rate.
|
327
|
+
num_call_per_epoch: int
|
328
|
+
The number the scheduler to call in each epoch.
|
329
|
+
This usually means the number of batch in each epoch training.
|
330
|
+
T_0: int
|
331
|
+
Number of iterations for the first restart.
|
332
|
+
T_mult: int
|
333
|
+
A factor increases :math:`T_{i}` after a restart. Default: 1.
|
334
|
+
eta_min: float
|
335
|
+
Minimum learning rate. Default: 0.
|
336
|
+
last_call: int
|
337
|
+
The index of last call. Default: -1.
|
338
|
+
|
339
|
+
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
340
|
+
https://arxiv.org/abs/1608.03983
|
341
|
+
"""
|
342
|
+
|
343
|
+
def __init__(
|
344
|
+
self,
|
345
|
+
lr: float,
|
346
|
+
num_call_per_epoch: int,
|
347
|
+
T_0: int,
|
348
|
+
T_mult: int = 1,
|
349
|
+
eta_min: float = 0.,
|
350
|
+
last_epoch: int = -1,
|
351
|
+
last_call: int = -1
|
352
|
+
):
|
353
|
+
super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
|
354
|
+
if T_0 <= 0 or not isinstance(T_0, int):
|
355
|
+
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
356
|
+
if T_mult < 1 or not isinstance(T_mult, int):
|
357
|
+
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
358
|
+
|
359
|
+
self.T_mult = T_mult
|
360
|
+
self.eta_min = eta_min
|
361
|
+
self.T_0 = T_0
|
362
|
+
self.num_call_per_epoch = num_call_per_epoch
|
363
|
+
|
364
|
+
def _cond1(self, epoch):
|
365
|
+
if self.T_mult == 1:
|
366
|
+
T_cur = epoch % self.T_0
|
367
|
+
T_i = self.T_0
|
368
|
+
else:
|
369
|
+
n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
|
370
|
+
T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
371
|
+
T_i = self.T_0 * self.T_mult ** n
|
372
|
+
return T_cur, T_i
|
373
|
+
|
374
|
+
def _cond2(self, epoch):
|
375
|
+
return epoch, self.T_0
|
376
|
+
|
377
|
+
def __call__(self, i=None):
|
378
|
+
epoch = self.current_epoch(i)
|
379
|
+
T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
|
380
|
+
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
|
381
|
+
|
382
|
+
def current_epoch(self, i=None):
|
383
|
+
i = (self.last_call.value + 1) if i is None else i
|
384
|
+
return jnp.floor(i / self.num_call_per_epoch)
|
385
|
+
|
386
|
+
def extra_repr(self):
|
387
|
+
return f', T_0={self.T_0}, T_mult={self.T_mult}, eta_min={self.eta_min}'
|
388
|
+
|
389
|
+
|
390
|
+
class ExponentialLR(LearningRateScheduler):
|
391
|
+
"""Decays the learning rate of each parameter group by gamma every epoch.
|
392
|
+
When last_epoch=-1, sets initial lr as lr.
|
393
|
+
|
394
|
+
Parameters
|
395
|
+
----------
|
396
|
+
lr: float
|
397
|
+
Initial learning rate.
|
398
|
+
gamma: float
|
399
|
+
Multiplicative factor of learning rate decay.
|
400
|
+
last_epoch: int
|
401
|
+
The index of last epoch. Default: -1.
|
402
|
+
"""
|
403
|
+
|
404
|
+
def __init__(self,
|
405
|
+
lr: float,
|
406
|
+
gamma: float,
|
407
|
+
last_epoch: int = -1):
|
408
|
+
super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
|
409
|
+
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
410
|
+
self.gamma = gamma
|
411
|
+
|
412
|
+
def __call__(self, i: int = None):
|
413
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
414
|
+
return self.lr * self.gamma ** i
|
415
|
+
|
416
|
+
def extra_repr(self):
|
417
|
+
return f', gamma={self.gamma}'
|
418
|
+
|
419
|
+
|
420
|
+
class ExponentialDecayLR(CallBasedLRScheduler):
|
421
|
+
def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
|
422
|
+
super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
|
423
|
+
self.decay_steps = decay_steps
|
424
|
+
self.decay_rate = decay_rate
|
425
|
+
|
426
|
+
def __call__(self, i=None):
|
427
|
+
i = (self.last_call.value + 1) if i is None else i
|
428
|
+
return self.lr * self.decay_rate ** (i / self.decay_steps)
|
429
|
+
|
430
|
+
def extra_repr(self):
|
431
|
+
return f', decay_steps={self.decay_steps}, decay_rate={self.decay_rate}'
|
432
|
+
|
433
|
+
|
434
|
+
class InverseTimeDecayLR(ExponentialDecayLR):
|
435
|
+
def __init__(self, lr, decay_steps, decay_rate, staircase=False,
|
436
|
+
last_epoch: int = -1, last_call: int = -1):
|
437
|
+
super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
|
438
|
+
self.staircase = staircase
|
439
|
+
|
440
|
+
def __call__(self, i=None):
|
441
|
+
i = (self.last_call.value + 1) if i is None else i
|
442
|
+
if self.staircase:
|
443
|
+
return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
|
444
|
+
else:
|
445
|
+
return self.lr / (1 + self.decay_rate * i / self.decay_steps)
|
446
|
+
|
447
|
+
def extra_repr(self):
|
448
|
+
return f', decay_steps={self.decay_steps}, decay_rate={self.decay_rate}, staircase={self.staircase}'
|
449
|
+
|
450
|
+
|
451
|
+
class PolynomialDecayLR(CallBasedLRScheduler):
|
452
|
+
def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
|
453
|
+
super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
|
454
|
+
self.decay_steps = decay_steps
|
455
|
+
self.final_lr = final_lr
|
456
|
+
self.power = power
|
457
|
+
|
458
|
+
def __call__(self, i=None):
|
459
|
+
i = (self.last_call.value + 1) if i is None else i
|
460
|
+
i = jnp.minimum(i, self.decay_steps)
|
461
|
+
step_mult = (1 - i / self.decay_steps) ** self.power
|
462
|
+
return step_mult * (self.lr - self.final_lr) + self.final_lr
|
463
|
+
|
464
|
+
def extra_repr(self):
|
465
|
+
return f', decay_steps={self.decay_steps}, final_lr={self.final_lr}, power={self.power}'
|
466
|
+
|
467
|
+
|
468
|
+
class PiecewiseConstantLR(CallBasedLRScheduler):
|
469
|
+
def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
|
470
|
+
super().__init__(0., last_epoch=last_epoch, last_call=last_call)
|
471
|
+
|
472
|
+
boundaries = jnp.array(boundaries)
|
473
|
+
values = jnp.array(values)
|
474
|
+
if not boundaries.ndim == values.ndim == 1:
|
475
|
+
raise ValueError("boundaries and values must be sequences")
|
476
|
+
if not boundaries.shape[0] == values.shape[0] - 1:
|
477
|
+
raise ValueError("boundaries length must be one shorter than values length")
|
478
|
+
self.boundaries = boundaries
|
479
|
+
self.values = values
|
480
|
+
|
481
|
+
def __call__(self, i=None):
|
482
|
+
i = (self.last_call.value + 1) if i is None else i
|
483
|
+
return self.values[jnp.sum(i > self.boundaries)]
|
484
|
+
|
485
|
+
def extra_repr(self):
|
486
|
+
return f', boundaries={self.boundaries}, values={self.values}'
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate as bst
|
22
|
+
|
23
|
+
|
24
|
+
class TestMultiStepLR(unittest.TestCase):
|
25
|
+
def test1(self):
|
26
|
+
lr = bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1)
|
27
|
+
for i in range(40):
|
28
|
+
r = lr(i)
|
29
|
+
if i < 10:
|
30
|
+
self.assertEqual(r, 0.1)
|
31
|
+
elif i < 20:
|
32
|
+
self.assertTrue(jnp.allclose(r, 0.01))
|
33
|
+
elif i < 30:
|
34
|
+
self.assertTrue(jnp.allclose(r, 0.001))
|
35
|
+
else:
|
36
|
+
self.assertTrue(jnp.allclose(r, 0.0001))
|