brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,1104 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import functools
|
19
|
-
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
20
|
-
|
21
|
-
import brainunit as u
|
22
|
-
import jax
|
23
|
-
import jax.numpy as jnp
|
24
|
-
|
25
|
-
from brainstate import environ
|
26
|
-
from brainstate._state import State, LongTermState, StateDictManager
|
27
|
-
from ._base import Optimizer
|
28
|
-
from ._lr_scheduler import make_schedule, LearningRateScheduler
|
29
|
-
|
30
|
-
__all__ = [
|
31
|
-
'to_same_dict_tree',
|
32
|
-
|
33
|
-
# new class of brainstate.State for optimizer
|
34
|
-
'OptimState',
|
35
|
-
|
36
|
-
# commonly used optimizers
|
37
|
-
'SGDOptimizer',
|
38
|
-
'SGD',
|
39
|
-
'Momentum',
|
40
|
-
'MomentumNesterov',
|
41
|
-
'Adagrad',
|
42
|
-
'Adadelta',
|
43
|
-
'RMSProp',
|
44
|
-
'Adam',
|
45
|
-
'LARS',
|
46
|
-
'Adan',
|
47
|
-
'AdamW',
|
48
|
-
]
|
49
|
-
|
50
|
-
T = TypeVar('T')
|
51
|
-
|
52
|
-
|
53
|
-
def cast(value: Any, dtype: Any) -> jax.Array:
|
54
|
-
if isinstance(value, jax.Array):
|
55
|
-
return value.astype(dtype)
|
56
|
-
return jnp.asarray(value, dtype=dtype)
|
57
|
-
|
58
|
-
|
59
|
-
def fcast(value: T, dtype: Any = None) -> jax.Array:
|
60
|
-
return cast(value, dtype=dtype or environ.dftype())
|
61
|
-
|
62
|
-
|
63
|
-
def _to_dict_value(old_dict: Dict) -> Dict:
|
64
|
-
new_dict = dict()
|
65
|
-
for k, v in old_dict.items():
|
66
|
-
if isinstance(v, State):
|
67
|
-
new_dict[k] = v.value
|
68
|
-
else:
|
69
|
-
new_dict[k] = v
|
70
|
-
return new_dict
|
71
|
-
|
72
|
-
|
73
|
-
def to_same_dict_tree(*dicts: Dict):
|
74
|
-
"""
|
75
|
-
Convert multiple dictionaries to the same tree structure.
|
76
|
-
|
77
|
-
Parameters
|
78
|
-
----------
|
79
|
-
*dicts: dict
|
80
|
-
The dictionaries to be converted.
|
81
|
-
|
82
|
-
Returns
|
83
|
-
-------
|
84
|
-
dict
|
85
|
-
The converted dictionary.
|
86
|
-
"""
|
87
|
-
if len(dicts):
|
88
|
-
# all keys
|
89
|
-
all_keys = tuple(set(d.keys()) for d in dicts)
|
90
|
-
for keys in all_keys[1:]:
|
91
|
-
if len(all_keys[0].difference(keys)) > 0:
|
92
|
-
raise ValueError('Dictionary does not match.')
|
93
|
-
|
94
|
-
# flatten to normal python dict
|
95
|
-
r = [_to_dict_value(d) for d in dicts]
|
96
|
-
|
97
|
-
if len(dicts) == 1:
|
98
|
-
return r[0]
|
99
|
-
else:
|
100
|
-
return tuple(r)
|
101
|
-
|
102
|
-
|
103
|
-
def _sgd(prev_weight, gradient, weight_decay, lr=None):
|
104
|
-
"""
|
105
|
-
The update function for SGD learning.
|
106
|
-
|
107
|
-
Parameters
|
108
|
-
----------
|
109
|
-
prev_weight: jax.Array
|
110
|
-
The previous weight.
|
111
|
-
gradient: jax.Array
|
112
|
-
The gradient.
|
113
|
-
weight_decay: float
|
114
|
-
The weight decay.
|
115
|
-
lr: float
|
116
|
-
The learning rate.
|
117
|
-
"""
|
118
|
-
if weight_decay is None:
|
119
|
-
if lr is None:
|
120
|
-
return prev_weight - gradient
|
121
|
-
else:
|
122
|
-
return prev_weight - lr * gradient
|
123
|
-
else:
|
124
|
-
if lr is None:
|
125
|
-
return (1 - weight_decay) * prev_weight - gradient
|
126
|
-
else:
|
127
|
-
return (1 - weight_decay) * prev_weight - lr * gradient
|
128
|
-
|
129
|
-
|
130
|
-
class OptimState(LongTermState):
|
131
|
-
"""
|
132
|
-
The state for optimizer.
|
133
|
-
"""
|
134
|
-
pass
|
135
|
-
|
136
|
-
|
137
|
-
class SGDOptimizer(Optimizer):
|
138
|
-
"""
|
139
|
-
Base Optimizer Class.
|
140
|
-
|
141
|
-
Parameters
|
142
|
-
----------
|
143
|
-
lr: float, LearningRateScheduler
|
144
|
-
learning rate.
|
145
|
-
"""
|
146
|
-
|
147
|
-
lr: LearningRateScheduler # learning rate
|
148
|
-
|
149
|
-
def __init__(
|
150
|
-
self, lr: Union[float, LearningRateScheduler, State],
|
151
|
-
):
|
152
|
-
super().__init__()
|
153
|
-
self.lr: LearningRateScheduler = make_schedule(lr)
|
154
|
-
|
155
|
-
|
156
|
-
class _WeightDecayOptimizer(SGDOptimizer):
|
157
|
-
def __init__(
|
158
|
-
self,
|
159
|
-
lr: Union[float, LearningRateScheduler, State],
|
160
|
-
weight_decay: Optional[float] = None,
|
161
|
-
):
|
162
|
-
super().__init__(lr=lr)
|
163
|
-
self.lr: LearningRateScheduler = make_schedule(lr)
|
164
|
-
assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
|
165
|
-
self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)
|
166
|
-
|
167
|
-
|
168
|
-
class SGD(_WeightDecayOptimizer):
|
169
|
-
r"""
|
170
|
-
Stochastic gradient descent optimizer.
|
171
|
-
|
172
|
-
SGD performs a parameter update for training examples :math:`x` and label
|
173
|
-
:math:`y`:
|
174
|
-
|
175
|
-
.. math::
|
176
|
-
|
177
|
-
\theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)
|
178
|
-
|
179
|
-
|
180
|
-
Parameters
|
181
|
-
----------
|
182
|
-
lr: float, LearningRateScheduler
|
183
|
-
learning rate.
|
184
|
-
|
185
|
-
"""
|
186
|
-
|
187
|
-
def __init__(
|
188
|
-
self,
|
189
|
-
lr: Union[float, LearningRateScheduler, State],
|
190
|
-
weight_decay: Optional[float] = None,
|
191
|
-
):
|
192
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
193
|
-
|
194
|
-
def register_trainable_weights(self, states: Optional[Dict[str, State]] = None):
|
195
|
-
states = dict() if states is None else states
|
196
|
-
assert isinstance(states, dict), '"states" must be a dict of brainstate.State.'
|
197
|
-
for k, v in states.items():
|
198
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
199
|
-
self.param_states.add_unique_value(k, v)
|
200
|
-
|
201
|
-
def update(self, grads: dict):
|
202
|
-
lr = self.lr()
|
203
|
-
weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
|
204
|
-
updates = jax.tree.map(
|
205
|
-
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
206
|
-
weight_values,
|
207
|
-
grad_values
|
208
|
-
)
|
209
|
-
self.param_states.assign_values(updates)
|
210
|
-
self.lr.step_call()
|
211
|
-
|
212
|
-
|
213
|
-
class Momentum(_WeightDecayOptimizer):
|
214
|
-
r"""
|
215
|
-
Momentum optimizer.
|
216
|
-
|
217
|
-
Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
|
218
|
-
and dampens oscillations. It does this by adding a fraction :math:`\gamma`
|
219
|
-
of the update vector of the past time step to the current update vector:
|
220
|
-
|
221
|
-
.. math::
|
222
|
-
|
223
|
-
\begin{align}
|
224
|
-
\begin{split}
|
225
|
-
v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\
|
226
|
-
\theta &= \theta - v_t
|
227
|
-
\end{split}
|
228
|
-
\end{align}
|
229
|
-
|
230
|
-
Parameters
|
231
|
-
----------
|
232
|
-
lr: float, LearningRateScheduler
|
233
|
-
learning rate.
|
234
|
-
|
235
|
-
References
|
236
|
-
----------
|
237
|
-
|
238
|
-
.. [1] Qian, N. (1999). On the momentum term in gradient descent learning
|
239
|
-
algorithms. Neural Networks : The Official Journal of the International
|
240
|
-
Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
|
241
|
-
|
242
|
-
"""
|
243
|
-
|
244
|
-
def __init__(
|
245
|
-
self,
|
246
|
-
lr: Union[float, LearningRateScheduler, State],
|
247
|
-
momentum: float = 0.9,
|
248
|
-
weight_decay: Optional[float] = None,
|
249
|
-
):
|
250
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
251
|
-
self.momentum = fcast(momentum)
|
252
|
-
self.momentum_states = StateDictManager()
|
253
|
-
|
254
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
255
|
-
train_states = dict() if train_states is None else train_states
|
256
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
257
|
-
|
258
|
-
for k, v in train_states.items():
|
259
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
260
|
-
if self.param_states.add_unique_value(k, v):
|
261
|
-
self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
262
|
-
|
263
|
-
def update(self, grads: dict):
|
264
|
-
lr = self.lr()
|
265
|
-
states_values, grad_values, momentum_values = to_same_dict_tree(
|
266
|
-
self.param_states, grads, self.momentum_states
|
267
|
-
)
|
268
|
-
momentum_values = jax.tree.map(
|
269
|
-
lambda vv, gg: self.momentum * vv - lr * gg,
|
270
|
-
momentum_values,
|
271
|
-
grad_values
|
272
|
-
)
|
273
|
-
new_weight_values = jax.tree.map(
|
274
|
-
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
275
|
-
states_values,
|
276
|
-
momentum_values
|
277
|
-
)
|
278
|
-
self.momentum_states.assign_values(momentum_values)
|
279
|
-
self.param_states.assign_values(new_weight_values)
|
280
|
-
self.lr.step_call()
|
281
|
-
|
282
|
-
|
283
|
-
class MomentumNesterov(_WeightDecayOptimizer):
|
284
|
-
r"""
|
285
|
-
Nesterov accelerated gradient optimizer [2]_.
|
286
|
-
|
287
|
-
.. math::
|
288
|
-
|
289
|
-
\begin{align}
|
290
|
-
\begin{split}
|
291
|
-
v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\
|
292
|
-
\theta &= \theta - v_t
|
293
|
-
\end{split}
|
294
|
-
\end{align}
|
295
|
-
|
296
|
-
Parameters
|
297
|
-
----------
|
298
|
-
lr: float, LearningRateScheduler
|
299
|
-
learning rate.
|
300
|
-
|
301
|
-
References
|
302
|
-
----------
|
303
|
-
.. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.
|
304
|
-
|
305
|
-
"""
|
306
|
-
|
307
|
-
def __init__(
|
308
|
-
self,
|
309
|
-
lr: Union[float, LearningRateScheduler, State],
|
310
|
-
weight_decay: Optional[float] = None,
|
311
|
-
momentum: float = 0.9,
|
312
|
-
):
|
313
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
314
|
-
|
315
|
-
self.momentum = fcast(momentum)
|
316
|
-
self.momentum_states = StateDictManager()
|
317
|
-
|
318
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
319
|
-
train_states = dict() if train_states is None else train_states
|
320
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
321
|
-
for k, v in train_states.items():
|
322
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
323
|
-
if self.param_states.add_unique_value(k, v):
|
324
|
-
self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
325
|
-
|
326
|
-
def update(self, grads: dict):
|
327
|
-
lr = self.lr()
|
328
|
-
states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
|
329
|
-
momentum_values = jax.tree.map(
|
330
|
-
lambda mv, gv: self.momentum * mv - lr * gv,
|
331
|
-
momentum_values,
|
332
|
-
grad_values
|
333
|
-
)
|
334
|
-
weight_values = jax.tree.map(
|
335
|
-
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
336
|
-
states_values,
|
337
|
-
momentum_values
|
338
|
-
)
|
339
|
-
self.param_states.assign_values(weight_values)
|
340
|
-
self.momentum_states.assign_values(momentum_values)
|
341
|
-
self.lr.step_call()
|
342
|
-
|
343
|
-
|
344
|
-
class Adagrad(_WeightDecayOptimizer):
|
345
|
-
r"""
|
346
|
-
Optimizer that implements the Adagrad algorithm.
|
347
|
-
|
348
|
-
Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
|
349
|
-
adapted relative to how frequently a parameter gets updated during training.
|
350
|
-
The more updates a parameter receives, the smaller the updates.
|
351
|
-
|
352
|
-
.. math::
|
353
|
-
|
354
|
-
\theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}
|
355
|
-
|
356
|
-
where :math:`G(t)` contains the sum of the squares of the past gradients
|
357
|
-
|
358
|
-
One of Adagrad's main benefits is that it eliminates the need to manually tune
|
359
|
-
the learning rate. Most implementations use a default value of 0.01 and leave it at that.
|
360
|
-
Adagrad's main weakness is its accumulation of the squared gradients in the denominator:
|
361
|
-
Since every added term is positive, the accumulated sum keeps growing during training.
|
362
|
-
This in turn causes the learning rate to shrink and eventually become infinitesimally
|
363
|
-
small, at which point the algorithm is no longer able to acquire additional knowledge.
|
364
|
-
|
365
|
-
Parameters
|
366
|
-
----------
|
367
|
-
lr: float, LearningRateScheduler
|
368
|
-
learning rate.
|
369
|
-
|
370
|
-
References
|
371
|
-
----------
|
372
|
-
.. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
|
373
|
-
|
374
|
-
"""
|
375
|
-
|
376
|
-
def __init__(
|
377
|
-
self,
|
378
|
-
lr: Union[float, LearningRateScheduler, State],
|
379
|
-
weight_decay: Optional[float] = None,
|
380
|
-
epsilon: float = 1e-6,
|
381
|
-
):
|
382
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
383
|
-
self.epsilon = fcast(epsilon)
|
384
|
-
self.cache_states = StateDictManager()
|
385
|
-
|
386
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
387
|
-
train_states = dict() if train_states is None else train_states
|
388
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
389
|
-
for k, v in train_states.items():
|
390
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
391
|
-
if self.param_states.add_unique_value(k, v):
|
392
|
-
self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
393
|
-
|
394
|
-
def update(self, grads: dict):
|
395
|
-
lr = self.lr()
|
396
|
-
cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
|
397
|
-
cache_values = jax.tree.map(
|
398
|
-
lambda cv, gv: cv + gv ** 2,
|
399
|
-
cache_values,
|
400
|
-
grad_values
|
401
|
-
)
|
402
|
-
updates = jax.tree.map(
|
403
|
-
lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon),
|
404
|
-
cache_values,
|
405
|
-
grad_values
|
406
|
-
)
|
407
|
-
weight_values = jax.tree.map(
|
408
|
-
functools.partial(_sgd, weight_decay=self.weight_decay),
|
409
|
-
weight_values,
|
410
|
-
updates
|
411
|
-
)
|
412
|
-
self.cache_states.assign_values(cache_values)
|
413
|
-
self.param_states.assign_values(weight_values)
|
414
|
-
self.lr.step_call()
|
415
|
-
|
416
|
-
|
417
|
-
class Adadelta(_WeightDecayOptimizer):
|
418
|
-
r"""
|
419
|
-
Optimizer that implements the Adadelta algorithm.
|
420
|
-
|
421
|
-
Adadelta [4]_ optimization is a stochastic gradient descent method that is based
|
422
|
-
on adaptive learning rate per dimension to address two drawbacks:
|
423
|
-
|
424
|
-
- The continual decay of learning rates throughout training.
|
425
|
-
- The need for a manually selected global learning rate.
|
426
|
-
|
427
|
-
Adadelta is a more robust extension of Adagrad that adapts learning rates based on
|
428
|
-
a moving window of gradient updates, instead of accumulating all past gradients.
|
429
|
-
This way, Adadelta continues learning even when many updates have been done. Compared
|
430
|
-
to Adagrad, in the original version of Adadelta you don't have to set an initial
|
431
|
-
learning rate.
|
432
|
-
|
433
|
-
.. math::
|
434
|
-
|
435
|
-
\boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\
|
436
|
-
\boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\
|
437
|
-
\boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\
|
438
|
-
\Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t.
|
439
|
-
|
440
|
-
:math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the
|
441
|
-
moving average slowly and a value close to 0 will decay the moving average fast.
|
442
|
-
|
443
|
-
:math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported
|
444
|
-
to work for multiple datasets (MNIST, speech).
|
445
|
-
|
446
|
-
In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to
|
447
|
-
keep it at this value. epsilon is important for the very first update (so the
|
448
|
-
numerator does not become 0).
|
449
|
-
|
450
|
-
Parameters
|
451
|
-
----------
|
452
|
-
lr: float, LearningRateScheduler
|
453
|
-
learning rate.
|
454
|
-
|
455
|
-
References
|
456
|
-
----------
|
457
|
-
.. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
|
458
|
-
|
459
|
-
"""
|
460
|
-
|
461
|
-
def __init__(
|
462
|
-
self,
|
463
|
-
lr: Union[float, LearningRateScheduler, State] = 0.01,
|
464
|
-
weight_decay: Optional[float] = None,
|
465
|
-
epsilon: float = 1e-6,
|
466
|
-
rho: float = 0.95,
|
467
|
-
):
|
468
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
469
|
-
|
470
|
-
self.epsilon = fcast(epsilon)
|
471
|
-
self.rho = fcast(rho)
|
472
|
-
self.cache_states = StateDictManager()
|
473
|
-
self.delta_states = StateDictManager()
|
474
|
-
|
475
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
476
|
-
train_states = dict() if train_states is None else train_states
|
477
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
478
|
-
for k, v in train_states.items():
|
479
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
480
|
-
if self.param_states.add_unique_value(k, v):
|
481
|
-
self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
482
|
-
self.delta_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
483
|
-
|
484
|
-
def update(self, grads: dict):
|
485
|
-
weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
|
486
|
-
self.param_states, grads, self.cache_states, self.delta_states)
|
487
|
-
cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
|
488
|
-
updates = jax.tree.map(lambda gv, dv, cv: gv * jnp.sqrt(dv + self.epsilon) / jnp.sqrt(cv + self.epsilon),
|
489
|
-
grad_values, delta_values, cache_values)
|
490
|
-
delta_values = jax.tree.map(lambda dv, upd: self.rho * dv + (1 - self.rho) * upd ** 2, delta_values, updates)
|
491
|
-
weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
|
492
|
-
weight_values,
|
493
|
-
updates)
|
494
|
-
self.param_states.assign_values(weight_values)
|
495
|
-
self.delta_states.assign_values(delta_values)
|
496
|
-
self.cache_states.assign_values(cache_values)
|
497
|
-
self.lr.step_call()
|
498
|
-
|
499
|
-
|
500
|
-
class RMSProp(_WeightDecayOptimizer):
|
501
|
-
r"""
|
502
|
-
Optimizer that implements the RMSprop algorithm.
|
503
|
-
|
504
|
-
RMSprop [5]_ and Adadelta have both been developed independently around the same time
|
505
|
-
stemming from the need to resolve Adagrad's radically diminishing learning rates.
|
506
|
-
|
507
|
-
The gist of RMSprop is to:
|
508
|
-
|
509
|
-
- Maintain a moving (discounted) average of the square of gradients
|
510
|
-
- Divide the gradient by the root of this average
|
511
|
-
|
512
|
-
.. math::
|
513
|
-
|
514
|
-
\begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\
|
515
|
-
p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split}
|
516
|
-
|
517
|
-
The centered version additionally maintains a moving average of the gradients,
|
518
|
-
and uses that average to estimate the variance.
|
519
|
-
|
520
|
-
Parameters
|
521
|
-
----------
|
522
|
-
lr: float, LearningRateScheduler
|
523
|
-
learning rate.
|
524
|
-
|
525
|
-
References
|
526
|
-
----------
|
527
|
-
.. [5] Tieleman, T. and Hinton, G. (2012):
|
528
|
-
Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
|
529
|
-
Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
|
530
|
-
"""
|
531
|
-
|
532
|
-
def __init__(
|
533
|
-
self,
|
534
|
-
lr: Union[float, LearningRateScheduler, State],
|
535
|
-
weight_decay: Optional[float] = None,
|
536
|
-
epsilon: float = 1e-6,
|
537
|
-
rho: float = 0.9,
|
538
|
-
):
|
539
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
540
|
-
|
541
|
-
self.epsilon = fcast(epsilon)
|
542
|
-
self.rho = fcast(rho)
|
543
|
-
self.cache_states = StateDictManager()
|
544
|
-
|
545
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
546
|
-
train_states = dict() if train_states is None else train_states
|
547
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
548
|
-
for k, v in train_states.items():
|
549
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
550
|
-
if self.param_states.add_unique_value(k, v):
|
551
|
-
self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
552
|
-
|
553
|
-
def update(self, grads: dict):
|
554
|
-
lr = self.lr()
|
555
|
-
weight_values, grad_values, cache_values = to_same_dict_tree(self.param_states, grads, self.cache_states)
|
556
|
-
cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
|
557
|
-
update = jax.tree.map(lambda gv, cv: lr * gv / jnp.sqrt(cv + self.epsilon), grad_values, cache_values)
|
558
|
-
weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
|
559
|
-
weight_values,
|
560
|
-
update)
|
561
|
-
self.param_states.assign_values(weight_values)
|
562
|
-
self.cache_states.assign_values(cache_values)
|
563
|
-
self.lr.step_call()
|
564
|
-
|
565
|
-
|
566
|
-
class Adam(_WeightDecayOptimizer):
|
567
|
-
"""
|
568
|
-
Optimizer that implements the Adam algorithm.
|
569
|
-
|
570
|
-
Adam [6]_ - a stochastic gradient descent method (SGD) that computes
|
571
|
-
individual adaptive learning rates for different parameters from estimates of
|
572
|
-
first- and second-order moments of the gradients.
|
573
|
-
|
574
|
-
Parameters
|
575
|
-
----------
|
576
|
-
lr: float, LearningRateScheduler
|
577
|
-
learning rate.
|
578
|
-
beta1: optional, float
|
579
|
-
A positive scalar value for beta_1, the exponential decay rate
|
580
|
-
for the first moment estimates (default 0.9).
|
581
|
-
beta2: optional, float
|
582
|
-
A positive scalar value for beta_2, the exponential decay rate
|
583
|
-
for the second moment estimates (default 0.999).
|
584
|
-
eps: optional, float
|
585
|
-
A positive scalar value for epsilon, a small constant for
|
586
|
-
numerical stability (default 1e-8).
|
587
|
-
|
588
|
-
References
|
589
|
-
----------
|
590
|
-
.. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
|
591
|
-
"""
|
592
|
-
|
593
|
-
def __init__(
|
594
|
-
self,
|
595
|
-
lr: Union[float, State, LearningRateScheduler],
|
596
|
-
beta1: float = 0.9,
|
597
|
-
beta2: float = 0.999,
|
598
|
-
eps: float = 1e-8,
|
599
|
-
weight_decay: Optional[float] = None,
|
600
|
-
):
|
601
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
602
|
-
|
603
|
-
self.beta1 = fcast(beta1)
|
604
|
-
self.beta2 = fcast(beta2)
|
605
|
-
self.eps = fcast(eps)
|
606
|
-
self.m1_states = StateDictManager()
|
607
|
-
self.m2_states = StateDictManager()
|
608
|
-
|
609
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
610
|
-
train_states = dict() if train_states is None else train_states
|
611
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
612
|
-
|
613
|
-
for k, v in train_states.items():
|
614
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
615
|
-
if self.param_states.add_unique_value(k, v):
|
616
|
-
self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
617
|
-
self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
618
|
-
|
619
|
-
def update(self, grads: dict):
|
620
|
-
lr = self.lr()
|
621
|
-
lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
|
622
|
-
lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
|
623
|
-
weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
|
624
|
-
self.param_states, grads, self.m1_states, self.m2_states
|
625
|
-
)
|
626
|
-
m1_values = jax.tree.map(
|
627
|
-
lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv,
|
628
|
-
m1_values,
|
629
|
-
grad_values
|
630
|
-
)
|
631
|
-
m2_values = jax.tree.map(
|
632
|
-
lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2,
|
633
|
-
m2_values,
|
634
|
-
grad_values
|
635
|
-
)
|
636
|
-
update = jax.tree.map(
|
637
|
-
lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps),
|
638
|
-
m1_values,
|
639
|
-
m2_values
|
640
|
-
)
|
641
|
-
weight_values = jax.tree.map(
|
642
|
-
functools.partial(_sgd, weight_decay=self.weight_decay),
|
643
|
-
weight_values,
|
644
|
-
update
|
645
|
-
)
|
646
|
-
self.param_states.assign_values(weight_values)
|
647
|
-
self.m1_states.assign_values(m1_values)
|
648
|
-
self.m2_states.assign_values(m2_values)
|
649
|
-
self.lr.step_call()
|
650
|
-
|
651
|
-
|
652
|
-
class LARS(_WeightDecayOptimizer):
|
653
|
-
r"""
|
654
|
-
Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
|
655
|
-
|
656
|
-
Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
|
657
|
-
optimization technique. There are two notable differences
|
658
|
-
between LARS and other adaptive algorithms such as `Adam` or `RMSProp`:
|
659
|
-
first, LARS uses a separate learning rate for each layer and not for
|
660
|
-
each weight. And second, the magnitude of the update is controlled
|
661
|
-
with respect to the weight norm for better control of training speed.
|
662
|
-
|
663
|
-
.. math::
|
664
|
-
|
665
|
-
m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\
|
666
|
-
x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)}
|
667
|
-
|
668
|
-
Parameters
|
669
|
-
----------
|
670
|
-
lr: float, LearningRateScheduler
|
671
|
-
learning rate.
|
672
|
-
momentum: float
|
673
|
-
coefficient used for the moving average of the gradient.
|
674
|
-
weight_decay: float
|
675
|
-
weight decay coefficient.
|
676
|
-
tc: float
|
677
|
-
trust coefficient eta ( < 1) for trust ratio computation.
|
678
|
-
eps: float
|
679
|
-
epsilon used for trust ratio computation.
|
680
|
-
|
681
|
-
References
|
682
|
-
----------
|
683
|
-
.. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag.
|
684
|
-
"""
|
685
|
-
|
686
|
-
def __init__(
|
687
|
-
self,
|
688
|
-
lr: Union[float, LearningRateScheduler, State],
|
689
|
-
momentum: float = 0.9,
|
690
|
-
weight_decay: float = 1e-4,
|
691
|
-
tc: float = 1e-3,
|
692
|
-
eps: float = 1e-5,
|
693
|
-
):
|
694
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
695
|
-
assert self.weight_decay is None, 'LARS does not support weight decay.'
|
696
|
-
|
697
|
-
self.momentum = fcast(momentum)
|
698
|
-
self.tc = fcast(tc)
|
699
|
-
self.eps = fcast(eps)
|
700
|
-
self.momentum_states = StateDictManager()
|
701
|
-
|
702
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
703
|
-
train_states = dict() if train_states is None else train_states
|
704
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
705
|
-
for k, v in train_states.items():
|
706
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
707
|
-
if self.param_states.add_unique_value(k, v):
|
708
|
-
self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
709
|
-
|
710
|
-
def update(self, grads: dict):
|
711
|
-
lr = self.lr()
|
712
|
-
weight_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
|
713
|
-
|
714
|
-
def _lars_update(pv, gv, mv):
|
715
|
-
p_norm = jnp.linalg.norm(pv)
|
716
|
-
g_norm = jnp.linalg.norm(gv)
|
717
|
-
trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
|
718
|
-
local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
|
719
|
-
mv = self.momentum * mv + local_lr * (gv + self.weight_decay * pv)
|
720
|
-
return mv
|
721
|
-
|
722
|
-
momentum_values = jax.tree.map(_lars_update, weight_values, grad_values, momentum_values)
|
723
|
-
weight_values = jax.tree.map(lambda pv, mv: pv - mv, weight_values, momentum_values)
|
724
|
-
self.param_states.assign_values(weight_values)
|
725
|
-
self.momentum_states.assign_values(momentum_values)
|
726
|
-
self.lr.step_call()
|
727
|
-
|
728
|
-
|
729
|
-
class Adan(_WeightDecayOptimizer):
|
730
|
-
r"""
|
731
|
-
Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
|
732
|
-
|
733
|
-
.. math::
|
734
|
-
|
735
|
-
\begin{equation}
|
736
|
-
\begin{aligned}
|
737
|
-
& \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\
|
738
|
-
& \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\
|
739
|
-
& \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\
|
740
|
-
& \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\
|
741
|
-
& \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\
|
742
|
-
\end{aligned}
|
743
|
-
\end{equation}
|
744
|
-
|
745
|
-
Parameters
|
746
|
-
----------
|
747
|
-
lr: float, LearningRateScheduler
|
748
|
-
learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3)
|
749
|
-
betas : tuple
|
750
|
-
Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01))
|
751
|
-
eps : float
|
752
|
-
The term added to the denominator to improve numerical stability. (default: 1e-8)
|
753
|
-
weight_decay : float
|
754
|
-
decoupled weight decay (L2 penalty) (default: 0)
|
755
|
-
no_prox: bool
|
756
|
-
how to perform the decoupled weight decay (default: False).
|
757
|
-
It determines the update rule of parameters with weight decay.
|
758
|
-
By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
|
759
|
-
|
760
|
-
.. math::
|
761
|
-
\boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right],
|
762
|
-
|
763
|
-
But one also can update the parameter like Adamw:
|
764
|
-
|
765
|
-
.. math::
|
766
|
-
\boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k).
|
767
|
-
|
768
|
-
References
|
769
|
-
----------
|
770
|
-
.. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan.
|
771
|
-
“Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
|
772
|
-
Deep Models.” ArXiv abs/2208.06677 (2022): n. pag.
|
773
|
-
"""
|
774
|
-
|
775
|
-
def __init__(
|
776
|
-
self,
|
777
|
-
lr: Union[float, LearningRateScheduler, State] = 1e-3,
|
778
|
-
betas: Tuple[float, float, float] = (0.02, 0.08, 0.01),
|
779
|
-
eps: float = 1e-8,
|
780
|
-
weight_decay: float = 0.02,
|
781
|
-
no_prox: bool = False,
|
782
|
-
):
|
783
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
784
|
-
|
785
|
-
assert len(betas) == 3
|
786
|
-
if eps < 0.:
|
787
|
-
raise ValueError("Invalid epsilon value: {}".format(eps))
|
788
|
-
if not 0.0 <= betas[0] < 1.0:
|
789
|
-
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
790
|
-
if not 0.0 <= betas[1] < 1.0:
|
791
|
-
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
792
|
-
if not 0.0 <= betas[2] < 1.0:
|
793
|
-
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
|
794
|
-
|
795
|
-
self.betas = fcast(jnp.asarray(betas))
|
796
|
-
self.eps = fcast(eps)
|
797
|
-
self.no_prox = no_prox
|
798
|
-
self.exp_avg_states = StateDictManager()
|
799
|
-
self.exp_avg_sq_states = StateDictManager()
|
800
|
-
self.exp_avg_diff_states = StateDictManager()
|
801
|
-
self.pre_grad_states = StateDictManager()
|
802
|
-
|
803
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
804
|
-
train_states = dict() if train_states is None else train_states
|
805
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
806
|
-
for k, v in train_states.items():
|
807
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
808
|
-
if self.param_states.add_unique_value(k, v):
|
809
|
-
self.exp_avg_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
810
|
-
self.exp_avg_sq_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
811
|
-
self.exp_avg_diff_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
812
|
-
self.pre_grad_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
813
|
-
|
814
|
-
def update(self, grads: dict):
|
815
|
-
lr = self.lr()
|
816
|
-
step = self.lr.last_epoch.value + 1
|
817
|
-
correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
|
818
|
-
correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
|
819
|
-
correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
|
820
|
-
m_values, n_values, v_values, pre_g_values, weight_values, grad_values = to_same_dict_tree(
|
821
|
-
self.exp_avg_states, self.exp_avg_diff_states, self.exp_avg_sq_states, self.pre_grad_states,
|
822
|
-
self.param_states, grads)
|
823
|
-
|
824
|
-
def _adan_update(m, n, v, pre_g, g, p):
|
825
|
-
m = m * (1 - self.betas[0]) + self.betas[0] * g
|
826
|
-
gd = g - pre_g
|
827
|
-
v = v * (1 - self.betas[1]) + self.betas[1] * gd
|
828
|
-
n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2
|
829
|
-
weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps)
|
830
|
-
if self.no_prox:
|
831
|
-
p = (p * (1 - self.weight_decay * lr) -
|
832
|
-
weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v))
|
833
|
-
else:
|
834
|
-
p = ((p - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) /
|
835
|
-
(1 + self.weight_decay * lr))
|
836
|
-
return m, n, v, p
|
837
|
-
|
838
|
-
m_values, n_values, v_values, weight_values = jax.tree.map(
|
839
|
-
_adan_update, m_values, n_values, v_values, pre_g_values, grad_values, weight_values)
|
840
|
-
self.exp_avg_states.assign_values(m_values)
|
841
|
-
self.exp_avg_diff_states.assign_values(n_values)
|
842
|
-
self.exp_avg_sq_states.assign_values(v_values)
|
843
|
-
self.param_states.assign_values(weight_values)
|
844
|
-
self.lr.step_call()
|
845
|
-
|
846
|
-
|
847
|
-
class AdamW(_WeightDecayOptimizer):
|
848
|
-
r"""
|
849
|
-
Adam with weight decay regularization [1]_.
|
850
|
-
|
851
|
-
AdamW uses weight decay to regularize learning towards small weights, as
|
852
|
-
this leads to better generalization. In SGD you can also use L2 regularization
|
853
|
-
to implement this as an additive loss term, however L2 regularization
|
854
|
-
does not behave as intended for adaptive gradient algorithms such as Adam.
|
855
|
-
|
856
|
-
.. math::
|
857
|
-
|
858
|
-
\begin{aligned}
|
859
|
-
&\rule{110mm}{0.4pt} \\
|
860
|
-
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
|
861
|
-
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
|
862
|
-
\: \epsilon \text{ (epsilon)} \\
|
863
|
-
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
|
864
|
-
\: \textit{maximize} \\
|
865
|
-
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
|
866
|
-
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
|
867
|
-
&\rule{110mm}{0.4pt} \\
|
868
|
-
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
869
|
-
|
870
|
-
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
871
|
-
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
872
|
-
&\hspace{5mm}\textbf{else} \\
|
873
|
-
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
874
|
-
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
|
875
|
-
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
876
|
-
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
877
|
-
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
878
|
-
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
|
879
|
-
&\hspace{5mm}\textbf{if} \: amsgrad \\
|
880
|
-
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
|
881
|
-
\widehat{v_t}) \\
|
882
|
-
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
883
|
-
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
|
884
|
-
&\hspace{5mm}\textbf{else} \\
|
885
|
-
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
|
886
|
-
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
|
887
|
-
&\rule{110mm}{0.4pt} \\[-1.ex]
|
888
|
-
&\bf{return} \: \theta_t \\[-1.ex]
|
889
|
-
&\rule{110mm}{0.4pt} \\[-1.ex]
|
890
|
-
\end{aligned}
|
891
|
-
|
892
|
-
|
893
|
-
Parameters
|
894
|
-
----------
|
895
|
-
lr: float, LearningRateScheduler
|
896
|
-
learning rate.
|
897
|
-
beta1: optional, float
|
898
|
-
A positive scalar value for beta_1, the exponential decay rate
|
899
|
-
for the first moment estimates. Generally close to 1.
|
900
|
-
beta2: optional, float
|
901
|
-
A positive scalar value for beta_2, the exponential decay rate
|
902
|
-
for the second moment estimates. Generally close to 1.
|
903
|
-
eps: optional, float
|
904
|
-
A positive scalar value for epsilon, a small constant for
|
905
|
-
numerical stability.
|
906
|
-
weight_decay: float
|
907
|
-
Strength of the weight decay regularization. Note that this
|
908
|
-
weight decay is multiplied with the learning rate.
|
909
|
-
amsgrad: bool
|
910
|
-
whether to use the AMSGrad variant of this algorithm
|
911
|
-
from the paper `On the Convergence of Adam and Beyond`.
|
912
|
-
|
913
|
-
References
|
914
|
-
----------
|
915
|
-
.. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019).
|
916
|
-
|
917
|
-
"""
|
918
|
-
|
919
|
-
def __init__(
|
920
|
-
self,
|
921
|
-
lr: Union[float, LearningRateScheduler, State],
|
922
|
-
beta1: float = 0.9,
|
923
|
-
beta2: float = 0.999,
|
924
|
-
eps: float = 1e-8,
|
925
|
-
weight_decay: float = 1e-2,
|
926
|
-
amsgrad: bool = False,
|
927
|
-
):
|
928
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
929
|
-
|
930
|
-
if eps < 0.:
|
931
|
-
raise ValueError("Invalid epsilon value: {}".format(eps))
|
932
|
-
if not 0.0 <= beta1 < 1.0:
|
933
|
-
raise ValueError("Invalid beta parameter at index 0: {}".format(beta1))
|
934
|
-
if not 0.0 <= beta2 < 1.0:
|
935
|
-
raise ValueError("Invalid beta parameter at index 1: {}".format(beta2))
|
936
|
-
if weight_decay < 0.:
|
937
|
-
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
938
|
-
|
939
|
-
self.beta1 = fcast(beta1)
|
940
|
-
self.beta2 = fcast(beta2)
|
941
|
-
self.eps = fcast(eps)
|
942
|
-
self.amsgrad = amsgrad
|
943
|
-
self.m1_states = StateDictManager()
|
944
|
-
self.m2_states = StateDictManager()
|
945
|
-
if self.amsgrad:
|
946
|
-
self.vmax_states = StateDictManager()
|
947
|
-
|
948
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
949
|
-
train_states = dict() if train_states is None else train_states
|
950
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
951
|
-
for k, v in train_states.items():
|
952
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
953
|
-
if self.param_states.add_unique_value(k, v):
|
954
|
-
self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
955
|
-
self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
956
|
-
if self.amsgrad:
|
957
|
-
self.vmax_states[k] = OptimState(u.math.tree_zeros_like(v.value))
|
958
|
-
|
959
|
-
def update(self, grads: dict):
|
960
|
-
lr_old = self.lr()
|
961
|
-
step = self.lr.last_epoch.value + 2
|
962
|
-
bias_correction1 = 1 - self.beta1 ** step
|
963
|
-
bias_correction2 = 1 - self.beta2 ** step
|
964
|
-
lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
|
965
|
-
|
966
|
-
def _adamw_update(p, m, v, g, vmax=None):
|
967
|
-
if self.weight_decay != 0:
|
968
|
-
p *= (1 - lr_old * self.weight_decay)
|
969
|
-
m = self.beta1 * m + (1 - self.beta1) * g
|
970
|
-
v = self.beta2 * v + (1 - self.beta2) * g ** 2
|
971
|
-
if self.amsgrad:
|
972
|
-
vmax = jnp.maximum(vmax, v)
|
973
|
-
denom = jnp.sqrt(vmax) + self.eps
|
974
|
-
return p - lr * m / denom, m, v, vmax
|
975
|
-
else:
|
976
|
-
denom = jnp.sqrt(v.value) + self.eps
|
977
|
-
return p - lr * m / denom, m, v
|
978
|
-
|
979
|
-
if self.amsgrad:
|
980
|
-
weight_values, m1_values, m2_values, vmax_values = to_same_dict_tree(
|
981
|
-
self.param_states, self.m1_states, self.m2_states, self.vmax_states)
|
982
|
-
weight_values, m1_values, m2_values, vmax_values = jax.tree.map(
|
983
|
-
_adamw_update, weight_values, m1_values, m2_values, grads, vmax_values)
|
984
|
-
self.vmax_states.assign_values(vmax_values)
|
985
|
-
else:
|
986
|
-
weight_values, m1_values, m2_values = to_same_dict_tree(self.param_states, self.m1_states, self.m2_states)
|
987
|
-
weight_values, m1_values, m2_values = jax.tree.map(
|
988
|
-
_adamw_update, weight_values, m1_values, m2_values, grads)
|
989
|
-
self.param_states.assign_values(weight_values)
|
990
|
-
self.m1_states.assign_values(m1_values)
|
991
|
-
self.m2_states.assign_values(m2_values)
|
992
|
-
self.lr.step_call()
|
993
|
-
|
994
|
-
|
995
|
-
class SM3(_WeightDecayOptimizer):
|
996
|
-
"""
|
997
|
-
SM3 algorithm [1]_.
|
998
|
-
|
999
|
-
The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
|
1000
|
-
(SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
|
1001
|
-
to Adam and Adagrad with greatly reduced memory usage for history tensors.
|
1002
|
-
For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history
|
1003
|
-
tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor
|
1004
|
-
of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)`
|
1005
|
-
memory for storage tensors, while the optimization using SM3 will use
|
1006
|
-
`O(sum n_i)` memory. Despite storing fewer parameters, this optimization
|
1007
|
-
algorithm manages to be comparably effective.
|
1008
|
-
|
1009
|
-
This advantage drastically shrinks when `momentum > 0`. The momentum is
|
1010
|
-
tracked using a tensor of the same shape as the tensor being optimized. With
|
1011
|
-
momentum, SM3 will use just over half as much memory as Adam, and a bit more
|
1012
|
-
than Adagrad.
|
1013
|
-
|
1014
|
-
Parameters
|
1015
|
-
----------
|
1016
|
-
lr: float, LearningRateScheduler
|
1017
|
-
learning rate.
|
1018
|
-
momentum: float
|
1019
|
-
coefficient used to scale prior updates
|
1020
|
-
before adding. This drastically increases memory usage if
|
1021
|
-
`momentum > 0.0`. (default: 0.0)
|
1022
|
-
beta: float
|
1023
|
-
coefficient used for exponential moving averages (default: 0.0)
|
1024
|
-
eps: float
|
1025
|
-
Term added to square-root in denominator to
|
1026
|
-
improve numerical stability (default: 1e-30).
|
1027
|
-
|
1028
|
-
References
|
1029
|
-
----------
|
1030
|
-
.. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019).
|
1031
|
-
|
1032
|
-
"""
|
1033
|
-
|
1034
|
-
def __init__(
|
1035
|
-
self,
|
1036
|
-
lr: Union[float, LearningRateScheduler, State],
|
1037
|
-
beta: float = 0.,
|
1038
|
-
momentum: float = 0.,
|
1039
|
-
eps: float = 1e-30,
|
1040
|
-
weight_decay: Optional[float] = None,
|
1041
|
-
):
|
1042
|
-
super().__init__(lr=lr, weight_decay=weight_decay)
|
1043
|
-
|
1044
|
-
if not 0.0 <= momentum < 1.0:
|
1045
|
-
raise ValueError("Invalid momentum: {0}".format(momentum))
|
1046
|
-
if not 0.0 <= beta < 1.0:
|
1047
|
-
raise ValueError("Invalid beta: {0}".format(beta))
|
1048
|
-
if not 0.0 <= eps:
|
1049
|
-
raise ValueError("Invalid eps: {0}".format(eps))
|
1050
|
-
|
1051
|
-
self.eps = fcast(eps)
|
1052
|
-
self.beta = fcast(beta)
|
1053
|
-
self.momentum = fcast(momentum)
|
1054
|
-
self.memory_states = StateDictManager()
|
1055
|
-
|
1056
|
-
def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
|
1057
|
-
train_states = dict() if train_states is None else train_states
|
1058
|
-
assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
|
1059
|
-
for k, v in train_states.items():
|
1060
|
-
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
|
1061
|
-
if self.param_states.add_unique_value(k, v):
|
1062
|
-
rank, ndim, dtype = v.value.shape, v.value.ndim, v.value.dtype
|
1063
|
-
for i in range(ndim):
|
1064
|
-
shape = [1] * ndim
|
1065
|
-
shape[i] = rank[i]
|
1066
|
-
self.memory_states[f'{k}_m{i}'] = State(jnp.zeros(shape, dtype=dtype))
|
1067
|
-
if self.momentum > 0.:
|
1068
|
-
self.memory_states[f'{k}_mbuffer'] = State(jnp.zeros_like(v.value))
|
1069
|
-
|
1070
|
-
def update(self, grads: dict):
|
1071
|
-
lr = self.lr()
|
1072
|
-
|
1073
|
-
for k, p in self.param_states.items():
|
1074
|
-
g = grads[k]
|
1075
|
-
ndim = p.ndim
|
1076
|
-
update = self.memory_states[f'{k}_m0'].value
|
1077
|
-
for i in range(1, ndim):
|
1078
|
-
update = jnp.minimum(update, self.memory_states[f'{k}_m{i}'].value)
|
1079
|
-
if self.beta > 0.:
|
1080
|
-
update *= self.beta
|
1081
|
-
update += g * g * (1 - self.beta)
|
1082
|
-
# Computes max along all dimensions except the given dim.
|
1083
|
-
# If tensor is a scalar, it returns tensor.
|
1084
|
-
for i in range(ndim):
|
1085
|
-
result = update
|
1086
|
-
for j in range(ndim):
|
1087
|
-
if i != j:
|
1088
|
-
result = jnp.maximum(result, axis=j, keepdim=True)
|
1089
|
-
acc = self.memory_states[f'{k}_m{i}'].value
|
1090
|
-
if self.beta > 0.:
|
1091
|
-
acc.value = jnp.maximum(acc, result)
|
1092
|
-
else:
|
1093
|
-
# No need to compare - nu_max is bigger because of grad ** 2
|
1094
|
-
acc.value = result
|
1095
|
-
update = g / jnp.sqrt(update + self.eps)
|
1096
|
-
if self.momentum > 0.:
|
1097
|
-
m_buffer = self.memory_states[f'{k}_mbuffer'].value
|
1098
|
-
update = update * (1. - self.momentum) + m_buffer * self.momentum
|
1099
|
-
m_buffer.value = update
|
1100
|
-
if self.weight_decay is None:
|
1101
|
-
p.value -= lr * update
|
1102
|
-
else:
|
1103
|
-
p.value = (1 - self.weight_decay) * p - lr * update
|
1104
|
-
self.lr.step_call()
|