brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
+
from __future__ import annotations
|
17
18
|
|
18
19
|
from typing import Sequence, Union
|
19
20
|
|
@@ -21,22 +22,22 @@ import jax
|
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import numpy as np
|
23
24
|
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
25
|
+
from brainstate import environ
|
26
|
+
from brainstate._state import State, LongTermState
|
27
|
+
from brainstate.graph import Node
|
27
28
|
|
28
29
|
__all__ = [
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
30
|
+
'LearningRateScheduler',
|
31
|
+
'ConstantLR',
|
32
|
+
'StepLR',
|
33
|
+
'MultiStepLR',
|
34
|
+
'CosineAnnealingLR',
|
35
|
+
'CosineAnnealingWarmRestarts',
|
36
|
+
'ExponentialLR',
|
37
|
+
'ExponentialDecayLR',
|
38
|
+
'InverseTimeDecayLR',
|
39
|
+
'PolynomialDecayLR',
|
40
|
+
'PiecewiseConstantLR',
|
40
41
|
]
|
41
42
|
|
42
43
|
|
@@ -45,442 +46,404 @@ __all__ = [
|
|
45
46
|
|
46
47
|
|
47
48
|
def make_schedule(scalar_or_schedule):
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
else:
|
53
|
-
raise TypeError(type(scalar_or_schedule))
|
54
|
-
|
55
|
-
|
56
|
-
class LearningRateScheduler(Module):
|
57
|
-
"""
|
58
|
-
The learning rate scheduler.
|
59
|
-
|
60
|
-
Attributes
|
61
|
-
----------
|
62
|
-
lr: float, State
|
63
|
-
The learning rate.
|
64
|
-
last_epoch: int
|
65
|
-
The index of last epoch.
|
66
|
-
|
67
|
-
"""
|
68
|
-
|
69
|
-
def __init__(self, lr: Union[float, State], last_epoch: int = -1):
|
70
|
-
super().__init__()
|
71
|
-
if isinstance(lr, State):
|
72
|
-
lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
|
49
|
+
if isinstance(scalar_or_schedule, LearningRateScheduler):
|
50
|
+
return scalar_or_schedule
|
51
|
+
elif isinstance(scalar_or_schedule, (int, float, State)):
|
52
|
+
return ConstantLR(scalar_or_schedule)
|
73
53
|
else:
|
74
|
-
|
75
|
-
self._lr = lr
|
76
|
-
assert last_epoch >= -1, 'last_epoch should be greater than -1.'
|
77
|
-
self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
|
78
|
-
|
79
|
-
@property
|
80
|
-
def lr(self):
|
81
|
-
return self._lr.value if isinstance(self._lr, State) else self._lr
|
82
|
-
|
83
|
-
@lr.setter
|
84
|
-
def lr(self, value):
|
85
|
-
if isinstance(value, State):
|
86
|
-
value = value.value
|
87
|
-
assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
|
88
|
-
if isinstance(self._lr, State):
|
89
|
-
self._lr.value = value
|
90
|
-
else:
|
91
|
-
self._lr = value
|
54
|
+
raise TypeError(type(scalar_or_schedule))
|
92
55
|
|
93
|
-
def step_epoch(self):
|
94
|
-
"""
|
95
|
-
Update the epoch count.
|
96
|
-
"""
|
97
|
-
self.last_epoch.value += 1
|
98
56
|
|
99
|
-
|
100
|
-
"""
|
101
|
-
Update the call count.
|
57
|
+
class LearningRateScheduler(Node):
|
102
58
|
"""
|
103
|
-
|
59
|
+
The learning rate scheduler.
|
104
60
|
|
105
|
-
|
106
|
-
|
61
|
+
Parameters
|
62
|
+
----------
|
63
|
+
lr: float, State
|
64
|
+
The learning rate.
|
65
|
+
last_epoch: int
|
66
|
+
The index of last epoch.
|
107
67
|
|
108
|
-
|
109
|
-
return ''
|
68
|
+
"""
|
110
69
|
|
111
|
-
|
112
|
-
|
70
|
+
def __init__(self, lr: Union[float, State], last_epoch: int = -1):
|
71
|
+
super().__init__()
|
72
|
+
if isinstance(lr, State):
|
73
|
+
lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
|
74
|
+
else:
|
75
|
+
lr = jnp.asarray(lr, dtype=environ.dftype())
|
76
|
+
self._lr = lr
|
77
|
+
assert last_epoch >= -1, 'last_epoch should be greater than -1.'
|
78
|
+
self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
|
79
|
+
|
80
|
+
@property
|
81
|
+
def lr(self):
|
82
|
+
return self._lr.value if isinstance(self._lr, State) else self._lr
|
83
|
+
|
84
|
+
@lr.setter
|
85
|
+
def lr(self, value):
|
86
|
+
if isinstance(value, State):
|
87
|
+
value = value.value
|
88
|
+
assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
|
89
|
+
if isinstance(self._lr, State):
|
90
|
+
self._lr.value = value
|
91
|
+
else:
|
92
|
+
self._lr = value
|
93
|
+
|
94
|
+
def step_epoch(self):
|
95
|
+
"""
|
96
|
+
Update the epoch count.
|
97
|
+
"""
|
98
|
+
self.last_epoch.value += 1
|
99
|
+
|
100
|
+
def step_call(self):
|
101
|
+
"""
|
102
|
+
Update the call count.
|
103
|
+
"""
|
104
|
+
pass
|
105
|
+
|
106
|
+
def __call__(self, i=None):
|
107
|
+
raise NotImplementedError
|
113
108
|
|
114
109
|
|
115
110
|
class ConstantLR(LearningRateScheduler):
|
116
|
-
|
117
|
-
|
118
|
-
|
111
|
+
"""
|
112
|
+
Constant learning rate scheduler.
|
113
|
+
"""
|
119
114
|
|
120
|
-
|
121
|
-
|
115
|
+
def __call__(self, i=None):
|
116
|
+
return self.lr
|
122
117
|
|
123
118
|
|
124
119
|
class CallBasedLRScheduler(LearningRateScheduler):
|
125
|
-
|
126
|
-
|
120
|
+
"""
|
121
|
+
The learning rate scheduler based on the call count.
|
122
|
+
|
123
|
+
Parameters
|
124
|
+
----------
|
125
|
+
lr: float
|
126
|
+
The learning rate.
|
127
|
+
last_epoch: int
|
128
|
+
The index of last epoch.
|
129
|
+
last_call: int
|
130
|
+
The index of last call.
|
127
131
|
|
128
|
-
|
129
|
-
----------
|
130
|
-
lr: float
|
131
|
-
The learning rate.
|
132
|
-
last_epoch: int
|
133
|
-
The index of last epoch.
|
134
|
-
last_call: int
|
135
|
-
The index of last call.
|
132
|
+
"""
|
136
133
|
|
137
|
-
|
134
|
+
def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
|
135
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
138
136
|
|
139
|
-
|
140
|
-
|
137
|
+
assert last_call >= -1, 'last_call should be greater than -1.'
|
138
|
+
self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
|
141
139
|
|
142
|
-
|
143
|
-
|
140
|
+
def step_call(self):
|
141
|
+
"""
|
142
|
+
Update the call count.
|
143
|
+
"""
|
144
|
+
self.last_call.value += 1
|
144
145
|
|
145
|
-
|
146
|
-
|
147
|
-
|
146
|
+
|
147
|
+
class StepLR(LearningRateScheduler):
|
148
|
+
"""Decays the learning rate of each parameter group by gamma every
|
149
|
+
`step_size` epochs.
|
150
|
+
|
151
|
+
Parameters
|
152
|
+
----------
|
153
|
+
lr: float
|
154
|
+
Initial learning rate.
|
155
|
+
step_size: int
|
156
|
+
Period of learning rate decay.
|
157
|
+
gamma: float
|
158
|
+
Multiplicative factor of learning rate decay.
|
159
|
+
Default: 0.1.
|
160
|
+
last_epoch: int
|
161
|
+
The index of last epoch. Default: -1.
|
148
162
|
"""
|
149
|
-
self.last_call.value += 1
|
150
163
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
164
|
+
def __init__(
|
165
|
+
self,
|
166
|
+
lr: float,
|
167
|
+
step_size: int,
|
168
|
+
gamma: float = 0.1,
|
169
|
+
last_epoch: int = -1
|
170
|
+
):
|
171
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
155
172
|
|
173
|
+
assert step_size >= 1, 'step_size should be greater than or equal to 1.'
|
174
|
+
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
175
|
+
self.step_size = step_size
|
176
|
+
self.gamma = gamma
|
156
177
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
Parameters
|
162
|
-
----------
|
163
|
-
lr: float
|
164
|
-
Initial learning rate.
|
165
|
-
step_size: int
|
166
|
-
Period of learning rate decay.
|
167
|
-
gamma: float
|
168
|
-
Multiplicative factor of learning rate decay.
|
169
|
-
Default: 0.1.
|
170
|
-
last_epoch: int
|
171
|
-
The index of last epoch. Default: -1.
|
172
|
-
"""
|
173
|
-
|
174
|
-
def __init__(
|
175
|
-
self,
|
176
|
-
lr: float,
|
177
|
-
step_size: int,
|
178
|
-
gamma: float = 0.1,
|
179
|
-
last_epoch: int = -1
|
180
|
-
):
|
181
|
-
super().__init__(lr=lr, last_epoch=last_epoch)
|
182
|
-
|
183
|
-
assert step_size >= 1, 'step_size should be greater than or equal to 1.'
|
184
|
-
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
185
|
-
self.step_size = step_size
|
186
|
-
self.gamma = gamma
|
187
|
-
|
188
|
-
def __call__(self, i=None):
|
189
|
-
i = (self.last_epoch.value + 1) if i is None else i
|
190
|
-
return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
|
191
|
-
|
192
|
-
def extra_repr(self):
|
193
|
-
return f', gamma={self.gamma}, step_size={self.step_size}'
|
178
|
+
def __call__(self, i=None):
|
179
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
180
|
+
return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
|
194
181
|
|
195
182
|
|
196
183
|
class MultiStepLR(LearningRateScheduler):
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
def extra_repr(self):
|
239
|
-
return f', milestones={self.milestones}, gamma={self.gamma}'
|
184
|
+
"""Decays the learning rate of each parameter group by gamma once the
|
185
|
+
number of epoch reaches one of the milestones. Notice that such decay can
|
186
|
+
happen simultaneously with other changes to the learning rate from outside
|
187
|
+
this scheduler. When last_epoch=-1, sets initial lr as lr.
|
188
|
+
|
189
|
+
Parameters
|
190
|
+
----------
|
191
|
+
lr: float
|
192
|
+
Initial learning rate.
|
193
|
+
milestones: sequence of int
|
194
|
+
List of epoch indices. Must be increasing.
|
195
|
+
gamma: float
|
196
|
+
Multiplicative factor of learning rate decay.
|
197
|
+
Default: 0.1.
|
198
|
+
last_epoch: int
|
199
|
+
The index of last epoch. Default: -1.
|
200
|
+
"""
|
201
|
+
|
202
|
+
def __init__(
|
203
|
+
self,
|
204
|
+
lr: float,
|
205
|
+
milestones: Sequence[int],
|
206
|
+
gamma: float = 0.1,
|
207
|
+
last_epoch: int = -1
|
208
|
+
):
|
209
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
210
|
+
|
211
|
+
assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
|
212
|
+
assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
|
213
|
+
'milestones should be a sequence of increasing integers.'
|
214
|
+
)
|
215
|
+
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
216
|
+
self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
|
217
|
+
self.gamma = gamma
|
218
|
+
|
219
|
+
def __call__(self, i=None):
|
220
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
221
|
+
conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
|
222
|
+
p = jnp.argmax(conditions)
|
223
|
+
return self.lr * self.gamma ** p
|
240
224
|
|
241
225
|
|
242
226
|
class CosineAnnealingLR(LearningRateScheduler):
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
def __init__(
|
286
|
-
self,
|
287
|
-
lr: float,
|
288
|
-
T_max: int,
|
289
|
-
eta_min: float = 0.,
|
290
|
-
last_epoch: int = -1,
|
291
|
-
):
|
292
|
-
super().__init__(lr=lr, last_epoch=last_epoch)
|
293
|
-
|
294
|
-
assert T_max >= 1, 'T_max should be greater than or equal to 1.'
|
295
|
-
self._init_epoch = last_epoch
|
296
|
-
self.T_max = T_max
|
297
|
-
self.eta_min = eta_min
|
298
|
-
|
299
|
-
def __call__(self, i=None):
|
300
|
-
i = (self.last_epoch.value + 1) if i is None else i
|
301
|
-
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
|
302
|
-
|
303
|
-
def extra_repr(self):
|
304
|
-
return f', T_max={self.T_max}, eta_min={self.eta_min}'
|
227
|
+
r"""Set the learning rate of each parameter group using a cosine annealing
|
228
|
+
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
229
|
+
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
230
|
+
|
231
|
+
.. math::
|
232
|
+
\begin{aligned}
|
233
|
+
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
234
|
+
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
235
|
+
& T_{cur} \neq (2k+1)T_{max}; \\
|
236
|
+
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
237
|
+
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
238
|
+
& T_{cur} = (2k+1)T_{max}.
|
239
|
+
\end{aligned}
|
240
|
+
|
241
|
+
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
|
242
|
+
is defined recursively, the learning rate can be simultaneously modified
|
243
|
+
outside this scheduler by other operators. If the learning rate is set
|
244
|
+
solely by this scheduler, the learning rate at each step becomes:
|
245
|
+
|
246
|
+
.. math::
|
247
|
+
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
248
|
+
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
249
|
+
|
250
|
+
It has been proposed in
|
251
|
+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
252
|
+
implements the cosine annealing part of SGDR, and not the restarts.
|
253
|
+
|
254
|
+
Parameters
|
255
|
+
----------
|
256
|
+
lr: float
|
257
|
+
Initial learning rate.
|
258
|
+
T_max: int
|
259
|
+
Maximum number of iterations.
|
260
|
+
eta_min: float
|
261
|
+
Minimum learning rate. Default: 0.
|
262
|
+
last_epoch: int
|
263
|
+
The index of last epoch. Default: -1.
|
264
|
+
|
265
|
+
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
266
|
+
https://arxiv.org/abs/1608.03983
|
267
|
+
"""
|
305
268
|
|
269
|
+
def __init__(
|
270
|
+
self,
|
271
|
+
lr: float,
|
272
|
+
T_max: int,
|
273
|
+
eta_min: float = 0.,
|
274
|
+
last_epoch: int = -1,
|
275
|
+
):
|
276
|
+
super().__init__(lr=lr, last_epoch=last_epoch)
|
306
277
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
of epochs between two warm restarts in SGDR:
|
312
|
-
|
313
|
-
.. math::
|
314
|
-
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
315
|
-
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
316
|
-
|
317
|
-
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
318
|
-
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
319
|
-
|
320
|
-
It has been proposed in
|
321
|
-
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
322
|
-
|
323
|
-
Parameters
|
324
|
-
----------
|
325
|
-
lr: float
|
326
|
-
Initial learning rate.
|
327
|
-
num_call_per_epoch: int
|
328
|
-
The number the scheduler to call in each epoch.
|
329
|
-
This usually means the number of batch in each epoch training.
|
330
|
-
T_0: int
|
331
|
-
Number of iterations for the first restart.
|
332
|
-
T_mult: int
|
333
|
-
A factor increases :math:`T_{i}` after a restart. Default: 1.
|
334
|
-
eta_min: float
|
335
|
-
Minimum learning rate. Default: 0.
|
336
|
-
last_call: int
|
337
|
-
The index of last call. Default: -1.
|
338
|
-
|
339
|
-
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
340
|
-
https://arxiv.org/abs/1608.03983
|
341
|
-
"""
|
342
|
-
|
343
|
-
def __init__(
|
344
|
-
self,
|
345
|
-
lr: float,
|
346
|
-
num_call_per_epoch: int,
|
347
|
-
T_0: int,
|
348
|
-
T_mult: int = 1,
|
349
|
-
eta_min: float = 0.,
|
350
|
-
last_epoch: int = -1,
|
351
|
-
last_call: int = -1
|
352
|
-
):
|
353
|
-
super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
|
354
|
-
if T_0 <= 0 or not isinstance(T_0, int):
|
355
|
-
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
356
|
-
if T_mult < 1 or not isinstance(T_mult, int):
|
357
|
-
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
358
|
-
|
359
|
-
self.T_mult = T_mult
|
360
|
-
self.eta_min = eta_min
|
361
|
-
self.T_0 = T_0
|
362
|
-
self.num_call_per_epoch = num_call_per_epoch
|
363
|
-
|
364
|
-
def _cond1(self, epoch):
|
365
|
-
if self.T_mult == 1:
|
366
|
-
T_cur = epoch % self.T_0
|
367
|
-
T_i = self.T_0
|
368
|
-
else:
|
369
|
-
n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
|
370
|
-
T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
371
|
-
T_i = self.T_0 * self.T_mult ** n
|
372
|
-
return T_cur, T_i
|
278
|
+
assert T_max >= 1, 'T_max should be greater than or equal to 1.'
|
279
|
+
self._init_epoch = last_epoch
|
280
|
+
self.T_max = T_max
|
281
|
+
self.eta_min = eta_min
|
373
282
|
|
374
|
-
|
375
|
-
|
283
|
+
def __call__(self, i=None):
|
284
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
285
|
+
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
|
376
286
|
|
377
|
-
def __call__(self, i=None):
|
378
|
-
epoch = self.current_epoch(i)
|
379
|
-
T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
|
380
|
-
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
|
381
287
|
|
382
|
-
|
383
|
-
|
384
|
-
|
288
|
+
class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
|
289
|
+
"""Set the learning rate of each parameter group using a cosine annealing
|
290
|
+
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
291
|
+
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
292
|
+
of epochs between two warm restarts in SGDR:
|
293
|
+
|
294
|
+
.. math::
|
295
|
+
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
296
|
+
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
297
|
+
|
298
|
+
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
299
|
+
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
300
|
+
|
301
|
+
It has been proposed in
|
302
|
+
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
303
|
+
|
304
|
+
Parameters
|
305
|
+
----------
|
306
|
+
lr: float
|
307
|
+
Initial learning rate.
|
308
|
+
num_call_per_epoch: int
|
309
|
+
The number the scheduler to call in each epoch.
|
310
|
+
This usually means the number of batch in each epoch training.
|
311
|
+
T_0: int
|
312
|
+
Number of iterations for the first restart.
|
313
|
+
T_mult: int
|
314
|
+
A factor increases :math:`T_{i}` after a restart. Default: 1.
|
315
|
+
eta_min: float
|
316
|
+
Minimum learning rate. Default: 0.
|
317
|
+
last_call: int
|
318
|
+
The index of last call. Default: -1.
|
319
|
+
|
320
|
+
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
321
|
+
https://arxiv.org/abs/1608.03983
|
322
|
+
"""
|
385
323
|
|
386
|
-
|
387
|
-
|
324
|
+
def __init__(
|
325
|
+
self,
|
326
|
+
lr: float,
|
327
|
+
num_call_per_epoch: int,
|
328
|
+
T_0: int,
|
329
|
+
T_mult: int = 1,
|
330
|
+
eta_min: float = 0.,
|
331
|
+
last_epoch: int = -1,
|
332
|
+
last_call: int = -1
|
333
|
+
):
|
334
|
+
super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
|
335
|
+
if T_0 <= 0 or not isinstance(T_0, int):
|
336
|
+
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
337
|
+
if T_mult < 1 or not isinstance(T_mult, int):
|
338
|
+
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
339
|
+
|
340
|
+
self.T_mult = T_mult
|
341
|
+
self.eta_min = eta_min
|
342
|
+
self.T_0 = T_0
|
343
|
+
self.num_call_per_epoch = num_call_per_epoch
|
344
|
+
|
345
|
+
def _cond1(self, epoch):
|
346
|
+
if self.T_mult == 1:
|
347
|
+
T_cur = epoch % self.T_0
|
348
|
+
T_i = self.T_0
|
349
|
+
else:
|
350
|
+
n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
|
351
|
+
T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
|
352
|
+
T_i = self.T_0 * self.T_mult ** n
|
353
|
+
return T_cur, T_i
|
354
|
+
|
355
|
+
def _cond2(self, epoch):
|
356
|
+
return epoch, self.T_0
|
357
|
+
|
358
|
+
def __call__(self, i=None):
|
359
|
+
epoch = self.current_epoch(i)
|
360
|
+
T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
|
361
|
+
return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
|
362
|
+
|
363
|
+
def current_epoch(self, i=None):
|
364
|
+
i = (self.last_call.value + 1) if i is None else i
|
365
|
+
return jnp.floor(i / self.num_call_per_epoch)
|
388
366
|
|
389
367
|
|
390
368
|
class ExponentialLR(LearningRateScheduler):
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
def __init__(self,
|
405
|
-
lr: float,
|
406
|
-
gamma: float,
|
407
|
-
last_epoch: int = -1):
|
408
|
-
super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
|
409
|
-
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
410
|
-
self.gamma = gamma
|
411
|
-
|
412
|
-
def __call__(self, i: int = None):
|
413
|
-
i = (self.last_epoch.value + 1) if i is None else i
|
414
|
-
return self.lr * self.gamma ** i
|
415
|
-
|
416
|
-
def extra_repr(self):
|
417
|
-
return f', gamma={self.gamma}'
|
369
|
+
"""Decays the learning rate of each parameter group by gamma every epoch.
|
370
|
+
When last_epoch=-1, sets initial lr as lr.
|
371
|
+
|
372
|
+
Parameters
|
373
|
+
----------
|
374
|
+
lr: float
|
375
|
+
Initial learning rate.
|
376
|
+
gamma: float
|
377
|
+
Multiplicative factor of learning rate decay.
|
378
|
+
last_epoch: int
|
379
|
+
The index of last epoch. Default: -1.
|
380
|
+
"""
|
418
381
|
|
382
|
+
def __init__(self,
|
383
|
+
lr: float,
|
384
|
+
gamma: float,
|
385
|
+
last_epoch: int = -1):
|
386
|
+
super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
|
387
|
+
assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
|
388
|
+
self.gamma = gamma
|
389
|
+
|
390
|
+
def __call__(self, i: int = None):
|
391
|
+
i = (self.last_epoch.value + 1) if i is None else i
|
392
|
+
return self.lr * self.gamma ** i
|
419
393
|
|
420
|
-
class ExponentialDecayLR(CallBasedLRScheduler):
|
421
|
-
def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
|
422
|
-
super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
|
423
|
-
self.decay_steps = decay_steps
|
424
|
-
self.decay_rate = decay_rate
|
425
394
|
|
426
|
-
|
427
|
-
|
428
|
-
|
395
|
+
class ExponentialDecayLR(CallBasedLRScheduler):
|
396
|
+
def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
|
397
|
+
super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
|
398
|
+
self.decay_steps = decay_steps
|
399
|
+
self.decay_rate = decay_rate
|
429
400
|
|
430
|
-
|
431
|
-
|
401
|
+
def __call__(self, i=None):
|
402
|
+
i = (self.last_call.value + 1) if i is None else i
|
403
|
+
return self.lr * self.decay_rate ** (i / self.decay_steps)
|
432
404
|
|
433
405
|
|
434
406
|
class InverseTimeDecayLR(ExponentialDecayLR):
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
def __call__(self, i=None):
|
441
|
-
i = (self.last_call.value + 1) if i is None else i
|
442
|
-
if self.staircase:
|
443
|
-
return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
|
444
|
-
else:
|
445
|
-
return self.lr / (1 + self.decay_rate * i / self.decay_steps)
|
407
|
+
def __init__(self, lr, decay_steps, decay_rate, staircase=False,
|
408
|
+
last_epoch: int = -1, last_call: int = -1):
|
409
|
+
super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
|
410
|
+
self.staircase = staircase
|
446
411
|
|
447
|
-
|
448
|
-
|
412
|
+
def __call__(self, i=None):
|
413
|
+
i = (self.last_call.value + 1) if i is None else i
|
414
|
+
if self.staircase:
|
415
|
+
return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
|
416
|
+
else:
|
417
|
+
return self.lr / (1 + self.decay_rate * i / self.decay_steps)
|
449
418
|
|
450
419
|
|
451
420
|
class PolynomialDecayLR(CallBasedLRScheduler):
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
def __call__(self, i=None):
|
459
|
-
i = (self.last_call.value + 1) if i is None else i
|
460
|
-
i = jnp.minimum(i, self.decay_steps)
|
461
|
-
step_mult = (1 - i / self.decay_steps) ** self.power
|
462
|
-
return step_mult * (self.lr - self.final_lr) + self.final_lr
|
421
|
+
def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
|
422
|
+
super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
|
423
|
+
self.decay_steps = decay_steps
|
424
|
+
self.final_lr = final_lr
|
425
|
+
self.power = power
|
463
426
|
|
464
|
-
|
465
|
-
|
427
|
+
def __call__(self, i=None):
|
428
|
+
i = (self.last_call.value + 1) if i is None else i
|
429
|
+
i = jnp.minimum(i, self.decay_steps)
|
430
|
+
step_mult = (1 - i / self.decay_steps) ** self.power
|
431
|
+
return step_mult * (self.lr - self.final_lr) + self.final_lr
|
466
432
|
|
467
433
|
|
468
434
|
class PiecewiseConstantLR(CallBasedLRScheduler):
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
def extra_repr(self):
|
486
|
-
return f', boundaries={self.boundaries}, values={self.values}'
|
435
|
+
def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
|
436
|
+
super().__init__(0., last_epoch=last_epoch, last_call=last_call)
|
437
|
+
|
438
|
+
boundaries = jnp.array(boundaries)
|
439
|
+
values = jnp.array(values)
|
440
|
+
if not boundaries.ndim == values.ndim == 1:
|
441
|
+
raise ValueError("boundaries and values must be sequences")
|
442
|
+
if not boundaries.shape[0] == values.shape[0] - 1:
|
443
|
+
raise ValueError("boundaries length must be one shorter than values length")
|
444
|
+
self.boundaries = boundaries
|
445
|
+
self.values = values
|
446
|
+
|
447
|
+
def __call__(self, i=None):
|
448
|
+
i = (self.last_call.value + 1) if i is None else i
|
449
|
+
return self.values[jnp.sum(i > self.boundaries)]
|