brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1104 +1,1104 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import functools
19
- from typing import Union, Dict, Optional, Tuple, Any, TypeVar
20
-
21
- import brainunit as u
22
- import jax
23
- import jax.numpy as jnp
24
-
25
- from brainstate import environ
26
- from brainstate._state import State, LongTermState, StateDictManager
27
- from ._base import Optimizer
28
- from ._lr_scheduler import make_schedule, LearningRateScheduler
29
-
30
- __all__ = [
31
- 'to_same_dict_tree',
32
-
33
- # new class of brainstate.State for optimizer
34
- 'OptimState',
35
-
36
- # commonly used optimizers
37
- 'SGDOptimizer',
38
- 'SGD',
39
- 'Momentum',
40
- 'MomentumNesterov',
41
- 'Adagrad',
42
- 'Adadelta',
43
- 'RMSProp',
44
- 'Adam',
45
- 'LARS',
46
- 'Adan',
47
- 'AdamW',
48
- ]
49
-
50
- T = TypeVar('T')
51
-
52
-
53
- def cast(value: Any, dtype: Any) -> jax.Array:
54
- if isinstance(value, jax.Array):
55
- return value.astype(dtype)
56
- return jnp.asarray(value, dtype=dtype)
57
-
58
-
59
- def fcast(value: T, dtype: Any = None) -> jax.Array:
60
- return cast(value, dtype=dtype or environ.dftype())
61
-
62
-
63
- def _to_dict_value(old_dict: Dict) -> Dict:
64
- new_dict = dict()
65
- for k, v in old_dict.items():
66
- if isinstance(v, State):
67
- new_dict[k] = v.value
68
- else:
69
- new_dict[k] = v
70
- return new_dict
71
-
72
-
73
- def to_same_dict_tree(*dicts: Dict):
74
- """
75
- Convert multiple dictionaries to the same tree structure.
76
-
77
- Parameters
78
- ----------
79
- *dicts: dict
80
- The dictionaries to be converted.
81
-
82
- Returns
83
- -------
84
- dict
85
- The converted dictionary.
86
- """
87
- if len(dicts):
88
- # all keys
89
- all_keys = tuple(set(d.keys()) for d in dicts)
90
- for keys in all_keys[1:]:
91
- if len(all_keys[0].difference(keys)) > 0:
92
- raise ValueError('Dictionary does not match.')
93
-
94
- # flatten to normal python dict
95
- r = [_to_dict_value(d) for d in dicts]
96
-
97
- if len(dicts) == 1:
98
- return r[0]
99
- else:
100
- return tuple(r)
101
-
102
-
103
- def _sgd(prev_weight, gradient, weight_decay, lr=None):
104
- """
105
- The update function for SGD learning.
106
-
107
- Parameters
108
- ----------
109
- prev_weight: jax.Array
110
- The previous weight.
111
- gradient: jax.Array
112
- The gradient.
113
- weight_decay: float
114
- The weight decay.
115
- lr: float
116
- The learning rate.
117
- """
118
- if weight_decay is None:
119
- if lr is None:
120
- return prev_weight - gradient
121
- else:
122
- return prev_weight - lr * gradient
123
- else:
124
- if lr is None:
125
- return (1 - weight_decay) * prev_weight - gradient
126
- else:
127
- return (1 - weight_decay) * prev_weight - lr * gradient
128
-
129
-
130
- class OptimState(LongTermState):
131
- """
132
- The state for optimizer.
133
- """
134
- pass
135
-
136
-
137
- class SGDOptimizer(Optimizer):
138
- """
139
- Base Optimizer Class.
140
-
141
- Parameters
142
- ----------
143
- lr: float, LearningRateScheduler
144
- learning rate.
145
- """
146
-
147
- lr: LearningRateScheduler # learning rate
148
-
149
- def __init__(
150
- self, lr: Union[float, LearningRateScheduler, State],
151
- ):
152
- super().__init__()
153
- self.lr: LearningRateScheduler = make_schedule(lr)
154
-
155
-
156
- class _WeightDecayOptimizer(SGDOptimizer):
157
- def __init__(
158
- self,
159
- lr: Union[float, LearningRateScheduler, State],
160
- weight_decay: Optional[float] = None,
161
- ):
162
- super().__init__(lr=lr)
163
- self.lr: LearningRateScheduler = make_schedule(lr)
164
- assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
165
- self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)
166
-
167
-
168
- class SGD(_WeightDecayOptimizer):
169
- r"""
170
- Stochastic gradient descent optimizer.
171
-
172
- SGD performs a parameter update for training examples :math:`x` and label
173
- :math:`y`:
174
-
175
- .. math::
176
-
177
- \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)
178
-
179
-
180
- Parameters
181
- ----------
182
- lr: float, LearningRateScheduler
183
- learning rate.
184
-
185
- """
186
-
187
- def __init__(
188
- self,
189
- lr: Union[float, LearningRateScheduler, State],
190
- weight_decay: Optional[float] = None,
191
- ):
192
- super().__init__(lr=lr, weight_decay=weight_decay)
193
-
194
- def register_trainable_weights(self, states: Optional[Dict[str, State]] = None):
195
- states = dict() if states is None else states
196
- assert isinstance(states, dict), '"states" must be a dict of brainstate.State.'
197
- for k, v in states.items():
198
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
199
- self.param_states.add_unique_value(k, v)
200
-
201
- def update(self, grads: dict):
202
- lr = self.lr()
203
- weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
204
- updates = jax.tree.map(
205
- functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
206
- weight_values,
207
- grad_values
208
- )
209
- self.param_states.assign_values(updates)
210
- self.lr.step_call()
211
-
212
-
213
- class Momentum(_WeightDecayOptimizer):
214
- r"""
215
- Momentum optimizer.
216
-
217
- Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
218
- and dampens oscillations. It does this by adding a fraction :math:`\gamma`
219
- of the update vector of the past time step to the current update vector:
220
-
221
- .. math::
222
-
223
- \begin{align}
224
- \begin{split}
225
- v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\
226
- \theta &= \theta - v_t
227
- \end{split}
228
- \end{align}
229
-
230
- Parameters
231
- ----------
232
- lr: float, LearningRateScheduler
233
- learning rate.
234
-
235
- References
236
- ----------
237
-
238
- .. [1] Qian, N. (1999). On the momentum term in gradient descent learning
239
- algorithms. Neural Networks : The Official Journal of the International
240
- Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
241
-
242
- """
243
-
244
- def __init__(
245
- self,
246
- lr: Union[float, LearningRateScheduler, State],
247
- momentum: float = 0.9,
248
- weight_decay: Optional[float] = None,
249
- ):
250
- super().__init__(lr=lr, weight_decay=weight_decay)
251
- self.momentum = fcast(momentum)
252
- self.momentum_states = StateDictManager()
253
-
254
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
255
- train_states = dict() if train_states is None else train_states
256
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
257
-
258
- for k, v in train_states.items():
259
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
260
- if self.param_states.add_unique_value(k, v):
261
- self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
262
-
263
- def update(self, grads: dict):
264
- lr = self.lr()
265
- states_values, grad_values, momentum_values = to_same_dict_tree(
266
- self.param_states, grads, self.momentum_states
267
- )
268
- momentum_values = jax.tree.map(
269
- lambda vv, gg: self.momentum * vv - lr * gg,
270
- momentum_values,
271
- grad_values
272
- )
273
- new_weight_values = jax.tree.map(
274
- functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
275
- states_values,
276
- momentum_values
277
- )
278
- self.momentum_states.assign_values(momentum_values)
279
- self.param_states.assign_values(new_weight_values)
280
- self.lr.step_call()
281
-
282
-
283
- class MomentumNesterov(_WeightDecayOptimizer):
284
- r"""
285
- Nesterov accelerated gradient optimizer [2]_.
286
-
287
- .. math::
288
-
289
- \begin{align}
290
- \begin{split}
291
- v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\
292
- \theta &= \theta - v_t
293
- \end{split}
294
- \end{align}
295
-
296
- Parameters
297
- ----------
298
- lr: float, LearningRateScheduler
299
- learning rate.
300
-
301
- References
302
- ----------
303
- .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.
304
-
305
- """
306
-
307
- def __init__(
308
- self,
309
- lr: Union[float, LearningRateScheduler, State],
310
- weight_decay: Optional[float] = None,
311
- momentum: float = 0.9,
312
- ):
313
- super().__init__(lr=lr, weight_decay=weight_decay)
314
-
315
- self.momentum = fcast(momentum)
316
- self.momentum_states = StateDictManager()
317
-
318
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
319
- train_states = dict() if train_states is None else train_states
320
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
321
- for k, v in train_states.items():
322
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
323
- if self.param_states.add_unique_value(k, v):
324
- self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
325
-
326
- def update(self, grads: dict):
327
- lr = self.lr()
328
- states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
329
- momentum_values = jax.tree.map(
330
- lambda mv, gv: self.momentum * mv - lr * gv,
331
- momentum_values,
332
- grad_values
333
- )
334
- weight_values = jax.tree.map(
335
- functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
336
- states_values,
337
- momentum_values
338
- )
339
- self.param_states.assign_values(weight_values)
340
- self.momentum_states.assign_values(momentum_values)
341
- self.lr.step_call()
342
-
343
-
344
- class Adagrad(_WeightDecayOptimizer):
345
- r"""
346
- Optimizer that implements the Adagrad algorithm.
347
-
348
- Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
349
- adapted relative to how frequently a parameter gets updated during training.
350
- The more updates a parameter receives, the smaller the updates.
351
-
352
- .. math::
353
-
354
- \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}
355
-
356
- where :math:`G(t)` contains the sum of the squares of the past gradients
357
-
358
- One of Adagrad's main benefits is that it eliminates the need to manually tune
359
- the learning rate. Most implementations use a default value of 0.01 and leave it at that.
360
- Adagrad's main weakness is its accumulation of the squared gradients in the denominator:
361
- Since every added term is positive, the accumulated sum keeps growing during training.
362
- This in turn causes the learning rate to shrink and eventually become infinitesimally
363
- small, at which point the algorithm is no longer able to acquire additional knowledge.
364
-
365
- Parameters
366
- ----------
367
- lr: float, LearningRateScheduler
368
- learning rate.
369
-
370
- References
371
- ----------
372
- .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
373
-
374
- """
375
-
376
- def __init__(
377
- self,
378
- lr: Union[float, LearningRateScheduler, State],
379
- weight_decay: Optional[float] = None,
380
- epsilon: float = 1e-6,
381
- ):
382
- super().__init__(lr=lr, weight_decay=weight_decay)
383
- self.epsilon = fcast(epsilon)
384
- self.cache_states = StateDictManager()
385
-
386
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
387
- train_states = dict() if train_states is None else train_states
388
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
389
- for k, v in train_states.items():
390
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
391
- if self.param_states.add_unique_value(k, v):
392
- self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
393
-
394
- def update(self, grads: dict):
395
- lr = self.lr()
396
- cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
397
- cache_values = jax.tree.map(
398
- lambda cv, gv: cv + gv ** 2,
399
- cache_values,
400
- grad_values
401
- )
402
- updates = jax.tree.map(
403
- lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon),
404
- cache_values,
405
- grad_values
406
- )
407
- weight_values = jax.tree.map(
408
- functools.partial(_sgd, weight_decay=self.weight_decay),
409
- weight_values,
410
- updates
411
- )
412
- self.cache_states.assign_values(cache_values)
413
- self.param_states.assign_values(weight_values)
414
- self.lr.step_call()
415
-
416
-
417
- class Adadelta(_WeightDecayOptimizer):
418
- r"""
419
- Optimizer that implements the Adadelta algorithm.
420
-
421
- Adadelta [4]_ optimization is a stochastic gradient descent method that is based
422
- on adaptive learning rate per dimension to address two drawbacks:
423
-
424
- - The continual decay of learning rates throughout training.
425
- - The need for a manually selected global learning rate.
426
-
427
- Adadelta is a more robust extension of Adagrad that adapts learning rates based on
428
- a moving window of gradient updates, instead of accumulating all past gradients.
429
- This way, Adadelta continues learning even when many updates have been done. Compared
430
- to Adagrad, in the original version of Adadelta you don't have to set an initial
431
- learning rate.
432
-
433
- .. math::
434
-
435
- \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\
436
- \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\
437
- \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\
438
- \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t.
439
-
440
- :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the
441
- moving average slowly and a value close to 0 will decay the moving average fast.
442
-
443
- :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported
444
- to work for multiple datasets (MNIST, speech).
445
-
446
- In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to
447
- keep it at this value. epsilon is important for the very first update (so the
448
- numerator does not become 0).
449
-
450
- Parameters
451
- ----------
452
- lr: float, LearningRateScheduler
453
- learning rate.
454
-
455
- References
456
- ----------
457
- .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
458
-
459
- """
460
-
461
- def __init__(
462
- self,
463
- lr: Union[float, LearningRateScheduler, State] = 0.01,
464
- weight_decay: Optional[float] = None,
465
- epsilon: float = 1e-6,
466
- rho: float = 0.95,
467
- ):
468
- super().__init__(lr=lr, weight_decay=weight_decay)
469
-
470
- self.epsilon = fcast(epsilon)
471
- self.rho = fcast(rho)
472
- self.cache_states = StateDictManager()
473
- self.delta_states = StateDictManager()
474
-
475
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
476
- train_states = dict() if train_states is None else train_states
477
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
478
- for k, v in train_states.items():
479
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
480
- if self.param_states.add_unique_value(k, v):
481
- self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
482
- self.delta_states[k] = OptimState(u.math.tree_zeros_like(v.value))
483
-
484
- def update(self, grads: dict):
485
- weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
486
- self.param_states, grads, self.cache_states, self.delta_states)
487
- cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
488
- updates = jax.tree.map(lambda gv, dv, cv: gv * jnp.sqrt(dv + self.epsilon) / jnp.sqrt(cv + self.epsilon),
489
- grad_values, delta_values, cache_values)
490
- delta_values = jax.tree.map(lambda dv, upd: self.rho * dv + (1 - self.rho) * upd ** 2, delta_values, updates)
491
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
492
- weight_values,
493
- updates)
494
- self.param_states.assign_values(weight_values)
495
- self.delta_states.assign_values(delta_values)
496
- self.cache_states.assign_values(cache_values)
497
- self.lr.step_call()
498
-
499
-
500
- class RMSProp(_WeightDecayOptimizer):
501
- r"""
502
- Optimizer that implements the RMSprop algorithm.
503
-
504
- RMSprop [5]_ and Adadelta have both been developed independently around the same time
505
- stemming from the need to resolve Adagrad's radically diminishing learning rates.
506
-
507
- The gist of RMSprop is to:
508
-
509
- - Maintain a moving (discounted) average of the square of gradients
510
- - Divide the gradient by the root of this average
511
-
512
- .. math::
513
-
514
- \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\
515
- p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split}
516
-
517
- The centered version additionally maintains a moving average of the gradients,
518
- and uses that average to estimate the variance.
519
-
520
- Parameters
521
- ----------
522
- lr: float, LearningRateScheduler
523
- learning rate.
524
-
525
- References
526
- ----------
527
- .. [5] Tieleman, T. and Hinton, G. (2012):
528
- Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
529
- Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
530
- """
531
-
532
- def __init__(
533
- self,
534
- lr: Union[float, LearningRateScheduler, State],
535
- weight_decay: Optional[float] = None,
536
- epsilon: float = 1e-6,
537
- rho: float = 0.9,
538
- ):
539
- super().__init__(lr=lr, weight_decay=weight_decay)
540
-
541
- self.epsilon = fcast(epsilon)
542
- self.rho = fcast(rho)
543
- self.cache_states = StateDictManager()
544
-
545
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
546
- train_states = dict() if train_states is None else train_states
547
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
548
- for k, v in train_states.items():
549
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
550
- if self.param_states.add_unique_value(k, v):
551
- self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
552
-
553
- def update(self, grads: dict):
554
- lr = self.lr()
555
- weight_values, grad_values, cache_values = to_same_dict_tree(self.param_states, grads, self.cache_states)
556
- cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
557
- update = jax.tree.map(lambda gv, cv: lr * gv / jnp.sqrt(cv + self.epsilon), grad_values, cache_values)
558
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
559
- weight_values,
560
- update)
561
- self.param_states.assign_values(weight_values)
562
- self.cache_states.assign_values(cache_values)
563
- self.lr.step_call()
564
-
565
-
566
- class Adam(_WeightDecayOptimizer):
567
- """
568
- Optimizer that implements the Adam algorithm.
569
-
570
- Adam [6]_ - a stochastic gradient descent method (SGD) that computes
571
- individual adaptive learning rates for different parameters from estimates of
572
- first- and second-order moments of the gradients.
573
-
574
- Parameters
575
- ----------
576
- lr: float, LearningRateScheduler
577
- learning rate.
578
- beta1: optional, float
579
- A positive scalar value for beta_1, the exponential decay rate
580
- for the first moment estimates (default 0.9).
581
- beta2: optional, float
582
- A positive scalar value for beta_2, the exponential decay rate
583
- for the second moment estimates (default 0.999).
584
- eps: optional, float
585
- A positive scalar value for epsilon, a small constant for
586
- numerical stability (default 1e-8).
587
-
588
- References
589
- ----------
590
- .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
591
- """
592
-
593
- def __init__(
594
- self,
595
- lr: Union[float, State, LearningRateScheduler],
596
- beta1: float = 0.9,
597
- beta2: float = 0.999,
598
- eps: float = 1e-8,
599
- weight_decay: Optional[float] = None,
600
- ):
601
- super().__init__(lr=lr, weight_decay=weight_decay)
602
-
603
- self.beta1 = fcast(beta1)
604
- self.beta2 = fcast(beta2)
605
- self.eps = fcast(eps)
606
- self.m1_states = StateDictManager()
607
- self.m2_states = StateDictManager()
608
-
609
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
610
- train_states = dict() if train_states is None else train_states
611
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
612
-
613
- for k, v in train_states.items():
614
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
615
- if self.param_states.add_unique_value(k, v):
616
- self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
617
- self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
618
-
619
- def update(self, grads: dict):
620
- lr = self.lr()
621
- lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
622
- lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
623
- weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
624
- self.param_states, grads, self.m1_states, self.m2_states
625
- )
626
- m1_values = jax.tree.map(
627
- lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv,
628
- m1_values,
629
- grad_values
630
- )
631
- m2_values = jax.tree.map(
632
- lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2,
633
- m2_values,
634
- grad_values
635
- )
636
- update = jax.tree.map(
637
- lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps),
638
- m1_values,
639
- m2_values
640
- )
641
- weight_values = jax.tree.map(
642
- functools.partial(_sgd, weight_decay=self.weight_decay),
643
- weight_values,
644
- update
645
- )
646
- self.param_states.assign_values(weight_values)
647
- self.m1_states.assign_values(m1_values)
648
- self.m2_states.assign_values(m2_values)
649
- self.lr.step_call()
650
-
651
-
652
- class LARS(_WeightDecayOptimizer):
653
- r"""
654
- Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
655
-
656
- Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
657
- optimization technique. There are two notable differences
658
- between LARS and other adaptive algorithms such as `Adam` or `RMSProp`:
659
- first, LARS uses a separate learning rate for each layer and not for
660
- each weight. And second, the magnitude of the update is controlled
661
- with respect to the weight norm for better control of training speed.
662
-
663
- .. math::
664
-
665
- m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\
666
- x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)}
667
-
668
- Parameters
669
- ----------
670
- lr: float, LearningRateScheduler
671
- learning rate.
672
- momentum: float
673
- coefficient used for the moving average of the gradient.
674
- weight_decay: float
675
- weight decay coefficient.
676
- tc: float
677
- trust coefficient eta ( < 1) for trust ratio computation.
678
- eps: float
679
- epsilon used for trust ratio computation.
680
-
681
- References
682
- ----------
683
- .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag.
684
- """
685
-
686
- def __init__(
687
- self,
688
- lr: Union[float, LearningRateScheduler, State],
689
- momentum: float = 0.9,
690
- weight_decay: float = 1e-4,
691
- tc: float = 1e-3,
692
- eps: float = 1e-5,
693
- ):
694
- super().__init__(lr=lr, weight_decay=weight_decay)
695
- assert self.weight_decay is None, 'LARS does not support weight decay.'
696
-
697
- self.momentum = fcast(momentum)
698
- self.tc = fcast(tc)
699
- self.eps = fcast(eps)
700
- self.momentum_states = StateDictManager()
701
-
702
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
703
- train_states = dict() if train_states is None else train_states
704
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
705
- for k, v in train_states.items():
706
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
707
- if self.param_states.add_unique_value(k, v):
708
- self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
709
-
710
- def update(self, grads: dict):
711
- lr = self.lr()
712
- weight_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
713
-
714
- def _lars_update(pv, gv, mv):
715
- p_norm = jnp.linalg.norm(pv)
716
- g_norm = jnp.linalg.norm(gv)
717
- trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
718
- local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
719
- mv = self.momentum * mv + local_lr * (gv + self.weight_decay * pv)
720
- return mv
721
-
722
- momentum_values = jax.tree.map(_lars_update, weight_values, grad_values, momentum_values)
723
- weight_values = jax.tree.map(lambda pv, mv: pv - mv, weight_values, momentum_values)
724
- self.param_states.assign_values(weight_values)
725
- self.momentum_states.assign_values(momentum_values)
726
- self.lr.step_call()
727
-
728
-
729
- class Adan(_WeightDecayOptimizer):
730
- r"""
731
- Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
732
-
733
- .. math::
734
-
735
- \begin{equation}
736
- \begin{aligned}
737
- & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\
738
- & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\
739
- & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\
740
- & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\
741
- & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\
742
- \end{aligned}
743
- \end{equation}
744
-
745
- Parameters
746
- ----------
747
- lr: float, LearningRateScheduler
748
- learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3)
749
- betas : tuple
750
- Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01))
751
- eps : float
752
- The term added to the denominator to improve numerical stability. (default: 1e-8)
753
- weight_decay : float
754
- decoupled weight decay (L2 penalty) (default: 0)
755
- no_prox: bool
756
- how to perform the decoupled weight decay (default: False).
757
- It determines the update rule of parameters with weight decay.
758
- By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
759
-
760
- .. math::
761
- \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right],
762
-
763
- But one also can update the parameter like Adamw:
764
-
765
- .. math::
766
- \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k).
767
-
768
- References
769
- ----------
770
- .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan.
771
- “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
772
- Deep Models.” ArXiv abs/2208.06677 (2022): n. pag.
773
- """
774
-
775
- def __init__(
776
- self,
777
- lr: Union[float, LearningRateScheduler, State] = 1e-3,
778
- betas: Tuple[float, float, float] = (0.02, 0.08, 0.01),
779
- eps: float = 1e-8,
780
- weight_decay: float = 0.02,
781
- no_prox: bool = False,
782
- ):
783
- super().__init__(lr=lr, weight_decay=weight_decay)
784
-
785
- assert len(betas) == 3
786
- if eps < 0.:
787
- raise ValueError("Invalid epsilon value: {}".format(eps))
788
- if not 0.0 <= betas[0] < 1.0:
789
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
790
- if not 0.0 <= betas[1] < 1.0:
791
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
792
- if not 0.0 <= betas[2] < 1.0:
793
- raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
794
-
795
- self.betas = fcast(jnp.asarray(betas))
796
- self.eps = fcast(eps)
797
- self.no_prox = no_prox
798
- self.exp_avg_states = StateDictManager()
799
- self.exp_avg_sq_states = StateDictManager()
800
- self.exp_avg_diff_states = StateDictManager()
801
- self.pre_grad_states = StateDictManager()
802
-
803
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
804
- train_states = dict() if train_states is None else train_states
805
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
806
- for k, v in train_states.items():
807
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
808
- if self.param_states.add_unique_value(k, v):
809
- self.exp_avg_states[k] = OptimState(u.math.tree_zeros_like(v.value))
810
- self.exp_avg_sq_states[k] = OptimState(u.math.tree_zeros_like(v.value))
811
- self.exp_avg_diff_states[k] = OptimState(u.math.tree_zeros_like(v.value))
812
- self.pre_grad_states[k] = OptimState(u.math.tree_zeros_like(v.value))
813
-
814
- def update(self, grads: dict):
815
- lr = self.lr()
816
- step = self.lr.last_epoch.value + 1
817
- correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
818
- correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
819
- correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
820
- m_values, n_values, v_values, pre_g_values, weight_values, grad_values = to_same_dict_tree(
821
- self.exp_avg_states, self.exp_avg_diff_states, self.exp_avg_sq_states, self.pre_grad_states,
822
- self.param_states, grads)
823
-
824
- def _adan_update(m, n, v, pre_g, g, p):
825
- m = m * (1 - self.betas[0]) + self.betas[0] * g
826
- gd = g - pre_g
827
- v = v * (1 - self.betas[1]) + self.betas[1] * gd
828
- n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2
829
- weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps)
830
- if self.no_prox:
831
- p = (p * (1 - self.weight_decay * lr) -
832
- weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v))
833
- else:
834
- p = ((p - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) /
835
- (1 + self.weight_decay * lr))
836
- return m, n, v, p
837
-
838
- m_values, n_values, v_values, weight_values = jax.tree.map(
839
- _adan_update, m_values, n_values, v_values, pre_g_values, grad_values, weight_values)
840
- self.exp_avg_states.assign_values(m_values)
841
- self.exp_avg_diff_states.assign_values(n_values)
842
- self.exp_avg_sq_states.assign_values(v_values)
843
- self.param_states.assign_values(weight_values)
844
- self.lr.step_call()
845
-
846
-
847
- class AdamW(_WeightDecayOptimizer):
848
- r"""
849
- Adam with weight decay regularization [1]_.
850
-
851
- AdamW uses weight decay to regularize learning towards small weights, as
852
- this leads to better generalization. In SGD you can also use L2 regularization
853
- to implement this as an additive loss term, however L2 regularization
854
- does not behave as intended for adaptive gradient algorithms such as Adam.
855
-
856
- .. math::
857
-
858
- \begin{aligned}
859
- &\rule{110mm}{0.4pt} \\
860
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
861
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
862
- \: \epsilon \text{ (epsilon)} \\
863
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
864
- \: \textit{maximize} \\
865
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
866
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
867
- &\rule{110mm}{0.4pt} \\
868
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
869
-
870
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
871
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
872
- &\hspace{5mm}\textbf{else} \\
873
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
874
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
875
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
876
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
877
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
878
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
879
- &\hspace{5mm}\textbf{if} \: amsgrad \\
880
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
881
- \widehat{v_t}) \\
882
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
883
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
884
- &\hspace{5mm}\textbf{else} \\
885
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
886
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
887
- &\rule{110mm}{0.4pt} \\[-1.ex]
888
- &\bf{return} \: \theta_t \\[-1.ex]
889
- &\rule{110mm}{0.4pt} \\[-1.ex]
890
- \end{aligned}
891
-
892
-
893
- Parameters
894
- ----------
895
- lr: float, LearningRateScheduler
896
- learning rate.
897
- beta1: optional, float
898
- A positive scalar value for beta_1, the exponential decay rate
899
- for the first moment estimates. Generally close to 1.
900
- beta2: optional, float
901
- A positive scalar value for beta_2, the exponential decay rate
902
- for the second moment estimates. Generally close to 1.
903
- eps: optional, float
904
- A positive scalar value for epsilon, a small constant for
905
- numerical stability.
906
- weight_decay: float
907
- Strength of the weight decay regularization. Note that this
908
- weight decay is multiplied with the learning rate.
909
- amsgrad: bool
910
- whether to use the AMSGrad variant of this algorithm
911
- from the paper `On the Convergence of Adam and Beyond`.
912
-
913
- References
914
- ----------
915
- .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019).
916
-
917
- """
918
-
919
- def __init__(
920
- self,
921
- lr: Union[float, LearningRateScheduler, State],
922
- beta1: float = 0.9,
923
- beta2: float = 0.999,
924
- eps: float = 1e-8,
925
- weight_decay: float = 1e-2,
926
- amsgrad: bool = False,
927
- ):
928
- super().__init__(lr=lr, weight_decay=weight_decay)
929
-
930
- if eps < 0.:
931
- raise ValueError("Invalid epsilon value: {}".format(eps))
932
- if not 0.0 <= beta1 < 1.0:
933
- raise ValueError("Invalid beta parameter at index 0: {}".format(beta1))
934
- if not 0.0 <= beta2 < 1.0:
935
- raise ValueError("Invalid beta parameter at index 1: {}".format(beta2))
936
- if weight_decay < 0.:
937
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
938
-
939
- self.beta1 = fcast(beta1)
940
- self.beta2 = fcast(beta2)
941
- self.eps = fcast(eps)
942
- self.amsgrad = amsgrad
943
- self.m1_states = StateDictManager()
944
- self.m2_states = StateDictManager()
945
- if self.amsgrad:
946
- self.vmax_states = StateDictManager()
947
-
948
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
949
- train_states = dict() if train_states is None else train_states
950
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
951
- for k, v in train_states.items():
952
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
953
- if self.param_states.add_unique_value(k, v):
954
- self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
955
- self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
956
- if self.amsgrad:
957
- self.vmax_states[k] = OptimState(u.math.tree_zeros_like(v.value))
958
-
959
- def update(self, grads: dict):
960
- lr_old = self.lr()
961
- step = self.lr.last_epoch.value + 2
962
- bias_correction1 = 1 - self.beta1 ** step
963
- bias_correction2 = 1 - self.beta2 ** step
964
- lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
965
-
966
- def _adamw_update(p, m, v, g, vmax=None):
967
- if self.weight_decay != 0:
968
- p *= (1 - lr_old * self.weight_decay)
969
- m = self.beta1 * m + (1 - self.beta1) * g
970
- v = self.beta2 * v + (1 - self.beta2) * g ** 2
971
- if self.amsgrad:
972
- vmax = jnp.maximum(vmax, v)
973
- denom = jnp.sqrt(vmax) + self.eps
974
- return p - lr * m / denom, m, v, vmax
975
- else:
976
- denom = jnp.sqrt(v.value) + self.eps
977
- return p - lr * m / denom, m, v
978
-
979
- if self.amsgrad:
980
- weight_values, m1_values, m2_values, vmax_values = to_same_dict_tree(
981
- self.param_states, self.m1_states, self.m2_states, self.vmax_states)
982
- weight_values, m1_values, m2_values, vmax_values = jax.tree.map(
983
- _adamw_update, weight_values, m1_values, m2_values, grads, vmax_values)
984
- self.vmax_states.assign_values(vmax_values)
985
- else:
986
- weight_values, m1_values, m2_values = to_same_dict_tree(self.param_states, self.m1_states, self.m2_states)
987
- weight_values, m1_values, m2_values = jax.tree.map(
988
- _adamw_update, weight_values, m1_values, m2_values, grads)
989
- self.param_states.assign_values(weight_values)
990
- self.m1_states.assign_values(m1_values)
991
- self.m2_states.assign_values(m2_values)
992
- self.lr.step_call()
993
-
994
-
995
- class SM3(_WeightDecayOptimizer):
996
- """
997
- SM3 algorithm [1]_.
998
-
999
- The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
1000
- (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
1001
- to Adam and Adagrad with greatly reduced memory usage for history tensors.
1002
- For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history
1003
- tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor
1004
- of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)`
1005
- memory for storage tensors, while the optimization using SM3 will use
1006
- `O(sum n_i)` memory. Despite storing fewer parameters, this optimization
1007
- algorithm manages to be comparably effective.
1008
-
1009
- This advantage drastically shrinks when `momentum > 0`. The momentum is
1010
- tracked using a tensor of the same shape as the tensor being optimized. With
1011
- momentum, SM3 will use just over half as much memory as Adam, and a bit more
1012
- than Adagrad.
1013
-
1014
- Parameters
1015
- ----------
1016
- lr: float, LearningRateScheduler
1017
- learning rate.
1018
- momentum: float
1019
- coefficient used to scale prior updates
1020
- before adding. This drastically increases memory usage if
1021
- `momentum > 0.0`. (default: 0.0)
1022
- beta: float
1023
- coefficient used for exponential moving averages (default: 0.0)
1024
- eps: float
1025
- Term added to square-root in denominator to
1026
- improve numerical stability (default: 1e-30).
1027
-
1028
- References
1029
- ----------
1030
- .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019).
1031
-
1032
- """
1033
-
1034
- def __init__(
1035
- self,
1036
- lr: Union[float, LearningRateScheduler, State],
1037
- beta: float = 0.,
1038
- momentum: float = 0.,
1039
- eps: float = 1e-30,
1040
- weight_decay: Optional[float] = None,
1041
- ):
1042
- super().__init__(lr=lr, weight_decay=weight_decay)
1043
-
1044
- if not 0.0 <= momentum < 1.0:
1045
- raise ValueError("Invalid momentum: {0}".format(momentum))
1046
- if not 0.0 <= beta < 1.0:
1047
- raise ValueError("Invalid beta: {0}".format(beta))
1048
- if not 0.0 <= eps:
1049
- raise ValueError("Invalid eps: {0}".format(eps))
1050
-
1051
- self.eps = fcast(eps)
1052
- self.beta = fcast(beta)
1053
- self.momentum = fcast(momentum)
1054
- self.memory_states = StateDictManager()
1055
-
1056
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
1057
- train_states = dict() if train_states is None else train_states
1058
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
1059
- for k, v in train_states.items():
1060
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
1061
- if self.param_states.add_unique_value(k, v):
1062
- rank, ndim, dtype = v.value.shape, v.value.ndim, v.value.dtype
1063
- for i in range(ndim):
1064
- shape = [1] * ndim
1065
- shape[i] = rank[i]
1066
- self.memory_states[f'{k}_m{i}'] = State(jnp.zeros(shape, dtype=dtype))
1067
- if self.momentum > 0.:
1068
- self.memory_states[f'{k}_mbuffer'] = State(jnp.zeros_like(v.value))
1069
-
1070
- def update(self, grads: dict):
1071
- lr = self.lr()
1072
-
1073
- for k, p in self.param_states.items():
1074
- g = grads[k]
1075
- ndim = p.ndim
1076
- update = self.memory_states[f'{k}_m0'].value
1077
- for i in range(1, ndim):
1078
- update = jnp.minimum(update, self.memory_states[f'{k}_m{i}'].value)
1079
- if self.beta > 0.:
1080
- update *= self.beta
1081
- update += g * g * (1 - self.beta)
1082
- # Computes max along all dimensions except the given dim.
1083
- # If tensor is a scalar, it returns tensor.
1084
- for i in range(ndim):
1085
- result = update
1086
- for j in range(ndim):
1087
- if i != j:
1088
- result = jnp.maximum(result, axis=j, keepdim=True)
1089
- acc = self.memory_states[f'{k}_m{i}'].value
1090
- if self.beta > 0.:
1091
- acc.value = jnp.maximum(acc, result)
1092
- else:
1093
- # No need to compare - nu_max is bigger because of grad ** 2
1094
- acc.value = result
1095
- update = g / jnp.sqrt(update + self.eps)
1096
- if self.momentum > 0.:
1097
- m_buffer = self.memory_states[f'{k}_mbuffer'].value
1098
- update = update * (1. - self.momentum) + m_buffer * self.momentum
1099
- m_buffer.value = update
1100
- if self.weight_decay is None:
1101
- p.value -= lr * update
1102
- else:
1103
- p.value = (1 - self.weight_decay) * p - lr * update
1104
- self.lr.step_call()
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import functools
19
+ from typing import Union, Dict, Optional, Tuple, Any, TypeVar
20
+
21
+ import brainunit as u
22
+ import jax
23
+ import jax.numpy as jnp
24
+
25
+ from brainstate import environ
26
+ from brainstate._state import State, LongTermState, StateDictManager
27
+ from ._base import Optimizer
28
+ from ._lr_scheduler import make_schedule, LearningRateScheduler
29
+
30
+ __all__ = [
31
+ 'to_same_dict_tree',
32
+
33
+ # new class of brainstate.State for optimizer
34
+ 'OptimState',
35
+
36
+ # commonly used optimizers
37
+ 'SGDOptimizer',
38
+ 'SGD',
39
+ 'Momentum',
40
+ 'MomentumNesterov',
41
+ 'Adagrad',
42
+ 'Adadelta',
43
+ 'RMSProp',
44
+ 'Adam',
45
+ 'LARS',
46
+ 'Adan',
47
+ 'AdamW',
48
+ ]
49
+
50
+ T = TypeVar('T')
51
+
52
+
53
+ def cast(value: Any, dtype: Any) -> jax.Array:
54
+ if isinstance(value, jax.Array):
55
+ return value.astype(dtype)
56
+ return jnp.asarray(value, dtype=dtype)
57
+
58
+
59
+ def fcast(value: T, dtype: Any = None) -> jax.Array:
60
+ return cast(value, dtype=dtype or environ.dftype())
61
+
62
+
63
+ def _to_dict_value(old_dict: Dict) -> Dict:
64
+ new_dict = dict()
65
+ for k, v in old_dict.items():
66
+ if isinstance(v, State):
67
+ new_dict[k] = v.value
68
+ else:
69
+ new_dict[k] = v
70
+ return new_dict
71
+
72
+
73
+ def to_same_dict_tree(*dicts: Dict):
74
+ """
75
+ Convert multiple dictionaries to the same tree structure.
76
+
77
+ Parameters
78
+ ----------
79
+ *dicts: dict
80
+ The dictionaries to be converted.
81
+
82
+ Returns
83
+ -------
84
+ dict
85
+ The converted dictionary.
86
+ """
87
+ if len(dicts):
88
+ # all keys
89
+ all_keys = tuple(set(d.keys()) for d in dicts)
90
+ for keys in all_keys[1:]:
91
+ if len(all_keys[0].difference(keys)) > 0:
92
+ raise ValueError('Dictionary does not match.')
93
+
94
+ # flatten to normal python dict
95
+ r = [_to_dict_value(d) for d in dicts]
96
+
97
+ if len(dicts) == 1:
98
+ return r[0]
99
+ else:
100
+ return tuple(r)
101
+
102
+
103
+ def _sgd(prev_weight, gradient, weight_decay, lr=None):
104
+ """
105
+ The update function for SGD learning.
106
+
107
+ Parameters
108
+ ----------
109
+ prev_weight: jax.Array
110
+ The previous weight.
111
+ gradient: jax.Array
112
+ The gradient.
113
+ weight_decay: float
114
+ The weight decay.
115
+ lr: float
116
+ The learning rate.
117
+ """
118
+ if weight_decay is None:
119
+ if lr is None:
120
+ return prev_weight - gradient
121
+ else:
122
+ return prev_weight - lr * gradient
123
+ else:
124
+ if lr is None:
125
+ return (1 - weight_decay) * prev_weight - gradient
126
+ else:
127
+ return (1 - weight_decay) * prev_weight - lr * gradient
128
+
129
+
130
+ class OptimState(LongTermState):
131
+ """
132
+ The state for optimizer.
133
+ """
134
+ pass
135
+
136
+
137
+ class SGDOptimizer(Optimizer):
138
+ """
139
+ Base Optimizer Class.
140
+
141
+ Parameters
142
+ ----------
143
+ lr: float, LearningRateScheduler
144
+ learning rate.
145
+ """
146
+
147
+ lr: LearningRateScheduler # learning rate
148
+
149
+ def __init__(
150
+ self, lr: Union[float, LearningRateScheduler, State],
151
+ ):
152
+ super().__init__()
153
+ self.lr: LearningRateScheduler = make_schedule(lr)
154
+
155
+
156
+ class _WeightDecayOptimizer(SGDOptimizer):
157
+ def __init__(
158
+ self,
159
+ lr: Union[float, LearningRateScheduler, State],
160
+ weight_decay: Optional[float] = None,
161
+ ):
162
+ super().__init__(lr=lr)
163
+ self.lr: LearningRateScheduler = make_schedule(lr)
164
+ assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
165
+ self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)
166
+
167
+
168
+ class SGD(_WeightDecayOptimizer):
169
+ r"""
170
+ Stochastic gradient descent optimizer.
171
+
172
+ SGD performs a parameter update for training examples :math:`x` and label
173
+ :math:`y`:
174
+
175
+ .. math::
176
+
177
+ \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)
178
+
179
+
180
+ Parameters
181
+ ----------
182
+ lr: float, LearningRateScheduler
183
+ learning rate.
184
+
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ lr: Union[float, LearningRateScheduler, State],
190
+ weight_decay: Optional[float] = None,
191
+ ):
192
+ super().__init__(lr=lr, weight_decay=weight_decay)
193
+
194
+ def register_trainable_weights(self, states: Optional[Dict[str, State]] = None):
195
+ states = dict() if states is None else states
196
+ assert isinstance(states, dict), '"states" must be a dict of brainstate.State.'
197
+ for k, v in states.items():
198
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
199
+ self.param_states.add_unique_value(k, v)
200
+
201
+ def update(self, grads: dict):
202
+ lr = self.lr()
203
+ weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
204
+ updates = jax.tree.map(
205
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
206
+ weight_values,
207
+ grad_values
208
+ )
209
+ self.param_states.assign_values(updates)
210
+ self.lr.step_call()
211
+
212
+
213
+ class Momentum(_WeightDecayOptimizer):
214
+ r"""
215
+ Momentum optimizer.
216
+
217
+ Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
218
+ and dampens oscillations. It does this by adding a fraction :math:`\gamma`
219
+ of the update vector of the past time step to the current update vector:
220
+
221
+ .. math::
222
+
223
+ \begin{align}
224
+ \begin{split}
225
+ v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\
226
+ \theta &= \theta - v_t
227
+ \end{split}
228
+ \end{align}
229
+
230
+ Parameters
231
+ ----------
232
+ lr: float, LearningRateScheduler
233
+ learning rate.
234
+
235
+ References
236
+ ----------
237
+
238
+ .. [1] Qian, N. (1999). On the momentum term in gradient descent learning
239
+ algorithms. Neural Networks : The Official Journal of the International
240
+ Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
241
+
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ lr: Union[float, LearningRateScheduler, State],
247
+ momentum: float = 0.9,
248
+ weight_decay: Optional[float] = None,
249
+ ):
250
+ super().__init__(lr=lr, weight_decay=weight_decay)
251
+ self.momentum = fcast(momentum)
252
+ self.momentum_states = StateDictManager()
253
+
254
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
255
+ train_states = dict() if train_states is None else train_states
256
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
257
+
258
+ for k, v in train_states.items():
259
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
260
+ if self.param_states.add_unique_value(k, v):
261
+ self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
262
+
263
+ def update(self, grads: dict):
264
+ lr = self.lr()
265
+ states_values, grad_values, momentum_values = to_same_dict_tree(
266
+ self.param_states, grads, self.momentum_states
267
+ )
268
+ momentum_values = jax.tree.map(
269
+ lambda vv, gg: self.momentum * vv - lr * gg,
270
+ momentum_values,
271
+ grad_values
272
+ )
273
+ new_weight_values = jax.tree.map(
274
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
275
+ states_values,
276
+ momentum_values
277
+ )
278
+ self.momentum_states.assign_values(momentum_values)
279
+ self.param_states.assign_values(new_weight_values)
280
+ self.lr.step_call()
281
+
282
+
283
+ class MomentumNesterov(_WeightDecayOptimizer):
284
+ r"""
285
+ Nesterov accelerated gradient optimizer [2]_.
286
+
287
+ .. math::
288
+
289
+ \begin{align}
290
+ \begin{split}
291
+ v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\
292
+ \theta &= \theta - v_t
293
+ \end{split}
294
+ \end{align}
295
+
296
+ Parameters
297
+ ----------
298
+ lr: float, LearningRateScheduler
299
+ learning rate.
300
+
301
+ References
302
+ ----------
303
+ .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.
304
+
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ lr: Union[float, LearningRateScheduler, State],
310
+ weight_decay: Optional[float] = None,
311
+ momentum: float = 0.9,
312
+ ):
313
+ super().__init__(lr=lr, weight_decay=weight_decay)
314
+
315
+ self.momentum = fcast(momentum)
316
+ self.momentum_states = StateDictManager()
317
+
318
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
319
+ train_states = dict() if train_states is None else train_states
320
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
321
+ for k, v in train_states.items():
322
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
323
+ if self.param_states.add_unique_value(k, v):
324
+ self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
325
+
326
+ def update(self, grads: dict):
327
+ lr = self.lr()
328
+ states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
329
+ momentum_values = jax.tree.map(
330
+ lambda mv, gv: self.momentum * mv - lr * gv,
331
+ momentum_values,
332
+ grad_values
333
+ )
334
+ weight_values = jax.tree.map(
335
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
336
+ states_values,
337
+ momentum_values
338
+ )
339
+ self.param_states.assign_values(weight_values)
340
+ self.momentum_states.assign_values(momentum_values)
341
+ self.lr.step_call()
342
+
343
+
344
+ class Adagrad(_WeightDecayOptimizer):
345
+ r"""
346
+ Optimizer that implements the Adagrad algorithm.
347
+
348
+ Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
349
+ adapted relative to how frequently a parameter gets updated during training.
350
+ The more updates a parameter receives, the smaller the updates.
351
+
352
+ .. math::
353
+
354
+ \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}
355
+
356
+ where :math:`G(t)` contains the sum of the squares of the past gradients
357
+
358
+ One of Adagrad's main benefits is that it eliminates the need to manually tune
359
+ the learning rate. Most implementations use a default value of 0.01 and leave it at that.
360
+ Adagrad's main weakness is its accumulation of the squared gradients in the denominator:
361
+ Since every added term is positive, the accumulated sum keeps growing during training.
362
+ This in turn causes the learning rate to shrink and eventually become infinitesimally
363
+ small, at which point the algorithm is no longer able to acquire additional knowledge.
364
+
365
+ Parameters
366
+ ----------
367
+ lr: float, LearningRateScheduler
368
+ learning rate.
369
+
370
+ References
371
+ ----------
372
+ .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
373
+
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ lr: Union[float, LearningRateScheduler, State],
379
+ weight_decay: Optional[float] = None,
380
+ epsilon: float = 1e-6,
381
+ ):
382
+ super().__init__(lr=lr, weight_decay=weight_decay)
383
+ self.epsilon = fcast(epsilon)
384
+ self.cache_states = StateDictManager()
385
+
386
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
387
+ train_states = dict() if train_states is None else train_states
388
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
389
+ for k, v in train_states.items():
390
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
391
+ if self.param_states.add_unique_value(k, v):
392
+ self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
393
+
394
+ def update(self, grads: dict):
395
+ lr = self.lr()
396
+ cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
397
+ cache_values = jax.tree.map(
398
+ lambda cv, gv: cv + gv ** 2,
399
+ cache_values,
400
+ grad_values
401
+ )
402
+ updates = jax.tree.map(
403
+ lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon),
404
+ cache_values,
405
+ grad_values
406
+ )
407
+ weight_values = jax.tree.map(
408
+ functools.partial(_sgd, weight_decay=self.weight_decay),
409
+ weight_values,
410
+ updates
411
+ )
412
+ self.cache_states.assign_values(cache_values)
413
+ self.param_states.assign_values(weight_values)
414
+ self.lr.step_call()
415
+
416
+
417
+ class Adadelta(_WeightDecayOptimizer):
418
+ r"""
419
+ Optimizer that implements the Adadelta algorithm.
420
+
421
+ Adadelta [4]_ optimization is a stochastic gradient descent method that is based
422
+ on adaptive learning rate per dimension to address two drawbacks:
423
+
424
+ - The continual decay of learning rates throughout training.
425
+ - The need for a manually selected global learning rate.
426
+
427
+ Adadelta is a more robust extension of Adagrad that adapts learning rates based on
428
+ a moving window of gradient updates, instead of accumulating all past gradients.
429
+ This way, Adadelta continues learning even when many updates have been done. Compared
430
+ to Adagrad, in the original version of Adadelta you don't have to set an initial
431
+ learning rate.
432
+
433
+ .. math::
434
+
435
+ \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\
436
+ \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\
437
+ \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\
438
+ \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t.
439
+
440
+ :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the
441
+ moving average slowly and a value close to 0 will decay the moving average fast.
442
+
443
+ :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported
444
+ to work for multiple datasets (MNIST, speech).
445
+
446
+ In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to
447
+ keep it at this value. epsilon is important for the very first update (so the
448
+ numerator does not become 0).
449
+
450
+ Parameters
451
+ ----------
452
+ lr: float, LearningRateScheduler
453
+ learning rate.
454
+
455
+ References
456
+ ----------
457
+ .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
458
+
459
+ """
460
+
461
+ def __init__(
462
+ self,
463
+ lr: Union[float, LearningRateScheduler, State] = 0.01,
464
+ weight_decay: Optional[float] = None,
465
+ epsilon: float = 1e-6,
466
+ rho: float = 0.95,
467
+ ):
468
+ super().__init__(lr=lr, weight_decay=weight_decay)
469
+
470
+ self.epsilon = fcast(epsilon)
471
+ self.rho = fcast(rho)
472
+ self.cache_states = StateDictManager()
473
+ self.delta_states = StateDictManager()
474
+
475
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
476
+ train_states = dict() if train_states is None else train_states
477
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
478
+ for k, v in train_states.items():
479
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
480
+ if self.param_states.add_unique_value(k, v):
481
+ self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
482
+ self.delta_states[k] = OptimState(u.math.tree_zeros_like(v.value))
483
+
484
+ def update(self, grads: dict):
485
+ weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
486
+ self.param_states, grads, self.cache_states, self.delta_states)
487
+ cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
488
+ updates = jax.tree.map(lambda gv, dv, cv: gv * jnp.sqrt(dv + self.epsilon) / jnp.sqrt(cv + self.epsilon),
489
+ grad_values, delta_values, cache_values)
490
+ delta_values = jax.tree.map(lambda dv, upd: self.rho * dv + (1 - self.rho) * upd ** 2, delta_values, updates)
491
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
492
+ weight_values,
493
+ updates)
494
+ self.param_states.assign_values(weight_values)
495
+ self.delta_states.assign_values(delta_values)
496
+ self.cache_states.assign_values(cache_values)
497
+ self.lr.step_call()
498
+
499
+
500
+ class RMSProp(_WeightDecayOptimizer):
501
+ r"""
502
+ Optimizer that implements the RMSprop algorithm.
503
+
504
+ RMSprop [5]_ and Adadelta have both been developed independently around the same time
505
+ stemming from the need to resolve Adagrad's radically diminishing learning rates.
506
+
507
+ The gist of RMSprop is to:
508
+
509
+ - Maintain a moving (discounted) average of the square of gradients
510
+ - Divide the gradient by the root of this average
511
+
512
+ .. math::
513
+
514
+ \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\
515
+ p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split}
516
+
517
+ The centered version additionally maintains a moving average of the gradients,
518
+ and uses that average to estimate the variance.
519
+
520
+ Parameters
521
+ ----------
522
+ lr: float, LearningRateScheduler
523
+ learning rate.
524
+
525
+ References
526
+ ----------
527
+ .. [5] Tieleman, T. and Hinton, G. (2012):
528
+ Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
529
+ Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
530
+ """
531
+
532
+ def __init__(
533
+ self,
534
+ lr: Union[float, LearningRateScheduler, State],
535
+ weight_decay: Optional[float] = None,
536
+ epsilon: float = 1e-6,
537
+ rho: float = 0.9,
538
+ ):
539
+ super().__init__(lr=lr, weight_decay=weight_decay)
540
+
541
+ self.epsilon = fcast(epsilon)
542
+ self.rho = fcast(rho)
543
+ self.cache_states = StateDictManager()
544
+
545
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
546
+ train_states = dict() if train_states is None else train_states
547
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
548
+ for k, v in train_states.items():
549
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
550
+ if self.param_states.add_unique_value(k, v):
551
+ self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
552
+
553
+ def update(self, grads: dict):
554
+ lr = self.lr()
555
+ weight_values, grad_values, cache_values = to_same_dict_tree(self.param_states, grads, self.cache_states)
556
+ cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
557
+ update = jax.tree.map(lambda gv, cv: lr * gv / jnp.sqrt(cv + self.epsilon), grad_values, cache_values)
558
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
559
+ weight_values,
560
+ update)
561
+ self.param_states.assign_values(weight_values)
562
+ self.cache_states.assign_values(cache_values)
563
+ self.lr.step_call()
564
+
565
+
566
+ class Adam(_WeightDecayOptimizer):
567
+ """
568
+ Optimizer that implements the Adam algorithm.
569
+
570
+ Adam [6]_ - a stochastic gradient descent method (SGD) that computes
571
+ individual adaptive learning rates for different parameters from estimates of
572
+ first- and second-order moments of the gradients.
573
+
574
+ Parameters
575
+ ----------
576
+ lr: float, LearningRateScheduler
577
+ learning rate.
578
+ beta1: optional, float
579
+ A positive scalar value for beta_1, the exponential decay rate
580
+ for the first moment estimates (default 0.9).
581
+ beta2: optional, float
582
+ A positive scalar value for beta_2, the exponential decay rate
583
+ for the second moment estimates (default 0.999).
584
+ eps: optional, float
585
+ A positive scalar value for epsilon, a small constant for
586
+ numerical stability (default 1e-8).
587
+
588
+ References
589
+ ----------
590
+ .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ lr: Union[float, State, LearningRateScheduler],
596
+ beta1: float = 0.9,
597
+ beta2: float = 0.999,
598
+ eps: float = 1e-8,
599
+ weight_decay: Optional[float] = None,
600
+ ):
601
+ super().__init__(lr=lr, weight_decay=weight_decay)
602
+
603
+ self.beta1 = fcast(beta1)
604
+ self.beta2 = fcast(beta2)
605
+ self.eps = fcast(eps)
606
+ self.m1_states = StateDictManager()
607
+ self.m2_states = StateDictManager()
608
+
609
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
610
+ train_states = dict() if train_states is None else train_states
611
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
612
+
613
+ for k, v in train_states.items():
614
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
615
+ if self.param_states.add_unique_value(k, v):
616
+ self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
617
+ self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
618
+
619
+ def update(self, grads: dict):
620
+ lr = self.lr()
621
+ lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
622
+ lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
623
+ weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
624
+ self.param_states, grads, self.m1_states, self.m2_states
625
+ )
626
+ m1_values = jax.tree.map(
627
+ lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv,
628
+ m1_values,
629
+ grad_values
630
+ )
631
+ m2_values = jax.tree.map(
632
+ lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2,
633
+ m2_values,
634
+ grad_values
635
+ )
636
+ update = jax.tree.map(
637
+ lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps),
638
+ m1_values,
639
+ m2_values
640
+ )
641
+ weight_values = jax.tree.map(
642
+ functools.partial(_sgd, weight_decay=self.weight_decay),
643
+ weight_values,
644
+ update
645
+ )
646
+ self.param_states.assign_values(weight_values)
647
+ self.m1_states.assign_values(m1_values)
648
+ self.m2_states.assign_values(m2_values)
649
+ self.lr.step_call()
650
+
651
+
652
+ class LARS(_WeightDecayOptimizer):
653
+ r"""
654
+ Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
655
+
656
+ Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
657
+ optimization technique. There are two notable differences
658
+ between LARS and other adaptive algorithms such as `Adam` or `RMSProp`:
659
+ first, LARS uses a separate learning rate for each layer and not for
660
+ each weight. And second, the magnitude of the update is controlled
661
+ with respect to the weight norm for better control of training speed.
662
+
663
+ .. math::
664
+
665
+ m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\
666
+ x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)}
667
+
668
+ Parameters
669
+ ----------
670
+ lr: float, LearningRateScheduler
671
+ learning rate.
672
+ momentum: float
673
+ coefficient used for the moving average of the gradient.
674
+ weight_decay: float
675
+ weight decay coefficient.
676
+ tc: float
677
+ trust coefficient eta ( < 1) for trust ratio computation.
678
+ eps: float
679
+ epsilon used for trust ratio computation.
680
+
681
+ References
682
+ ----------
683
+ .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag.
684
+ """
685
+
686
+ def __init__(
687
+ self,
688
+ lr: Union[float, LearningRateScheduler, State],
689
+ momentum: float = 0.9,
690
+ weight_decay: float = 1e-4,
691
+ tc: float = 1e-3,
692
+ eps: float = 1e-5,
693
+ ):
694
+ super().__init__(lr=lr, weight_decay=weight_decay)
695
+ assert self.weight_decay is None, 'LARS does not support weight decay.'
696
+
697
+ self.momentum = fcast(momentum)
698
+ self.tc = fcast(tc)
699
+ self.eps = fcast(eps)
700
+ self.momentum_states = StateDictManager()
701
+
702
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
703
+ train_states = dict() if train_states is None else train_states
704
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
705
+ for k, v in train_states.items():
706
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
707
+ if self.param_states.add_unique_value(k, v):
708
+ self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
709
+
710
+ def update(self, grads: dict):
711
+ lr = self.lr()
712
+ weight_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
713
+
714
+ def _lars_update(pv, gv, mv):
715
+ p_norm = jnp.linalg.norm(pv)
716
+ g_norm = jnp.linalg.norm(gv)
717
+ trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
718
+ local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
719
+ mv = self.momentum * mv + local_lr * (gv + self.weight_decay * pv)
720
+ return mv
721
+
722
+ momentum_values = jax.tree.map(_lars_update, weight_values, grad_values, momentum_values)
723
+ weight_values = jax.tree.map(lambda pv, mv: pv - mv, weight_values, momentum_values)
724
+ self.param_states.assign_values(weight_values)
725
+ self.momentum_states.assign_values(momentum_values)
726
+ self.lr.step_call()
727
+
728
+
729
+ class Adan(_WeightDecayOptimizer):
730
+ r"""
731
+ Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
732
+
733
+ .. math::
734
+
735
+ \begin{equation}
736
+ \begin{aligned}
737
+ & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\
738
+ & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\
739
+ & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\
740
+ & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\
741
+ & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\
742
+ \end{aligned}
743
+ \end{equation}
744
+
745
+ Parameters
746
+ ----------
747
+ lr: float, LearningRateScheduler
748
+ learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3)
749
+ betas : tuple
750
+ Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01))
751
+ eps : float
752
+ The term added to the denominator to improve numerical stability. (default: 1e-8)
753
+ weight_decay : float
754
+ decoupled weight decay (L2 penalty) (default: 0)
755
+ no_prox: bool
756
+ how to perform the decoupled weight decay (default: False).
757
+ It determines the update rule of parameters with weight decay.
758
+ By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
759
+
760
+ .. math::
761
+ \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right],
762
+
763
+ But one also can update the parameter like Adamw:
764
+
765
+ .. math::
766
+ \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k).
767
+
768
+ References
769
+ ----------
770
+ .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan.
771
+ “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
772
+ Deep Models.” ArXiv abs/2208.06677 (2022): n. pag.
773
+ """
774
+
775
+ def __init__(
776
+ self,
777
+ lr: Union[float, LearningRateScheduler, State] = 1e-3,
778
+ betas: Tuple[float, float, float] = (0.02, 0.08, 0.01),
779
+ eps: float = 1e-8,
780
+ weight_decay: float = 0.02,
781
+ no_prox: bool = False,
782
+ ):
783
+ super().__init__(lr=lr, weight_decay=weight_decay)
784
+
785
+ assert len(betas) == 3
786
+ if eps < 0.:
787
+ raise ValueError("Invalid epsilon value: {}".format(eps))
788
+ if not 0.0 <= betas[0] < 1.0:
789
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
790
+ if not 0.0 <= betas[1] < 1.0:
791
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
792
+ if not 0.0 <= betas[2] < 1.0:
793
+ raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
794
+
795
+ self.betas = fcast(jnp.asarray(betas))
796
+ self.eps = fcast(eps)
797
+ self.no_prox = no_prox
798
+ self.exp_avg_states = StateDictManager()
799
+ self.exp_avg_sq_states = StateDictManager()
800
+ self.exp_avg_diff_states = StateDictManager()
801
+ self.pre_grad_states = StateDictManager()
802
+
803
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
804
+ train_states = dict() if train_states is None else train_states
805
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
806
+ for k, v in train_states.items():
807
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
808
+ if self.param_states.add_unique_value(k, v):
809
+ self.exp_avg_states[k] = OptimState(u.math.tree_zeros_like(v.value))
810
+ self.exp_avg_sq_states[k] = OptimState(u.math.tree_zeros_like(v.value))
811
+ self.exp_avg_diff_states[k] = OptimState(u.math.tree_zeros_like(v.value))
812
+ self.pre_grad_states[k] = OptimState(u.math.tree_zeros_like(v.value))
813
+
814
+ def update(self, grads: dict):
815
+ lr = self.lr()
816
+ step = self.lr.last_epoch.value + 1
817
+ correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
818
+ correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
819
+ correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
820
+ m_values, n_values, v_values, pre_g_values, weight_values, grad_values = to_same_dict_tree(
821
+ self.exp_avg_states, self.exp_avg_diff_states, self.exp_avg_sq_states, self.pre_grad_states,
822
+ self.param_states, grads)
823
+
824
+ def _adan_update(m, n, v, pre_g, g, p):
825
+ m = m * (1 - self.betas[0]) + self.betas[0] * g
826
+ gd = g - pre_g
827
+ v = v * (1 - self.betas[1]) + self.betas[1] * gd
828
+ n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2
829
+ weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps)
830
+ if self.no_prox:
831
+ p = (p * (1 - self.weight_decay * lr) -
832
+ weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v))
833
+ else:
834
+ p = ((p - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) /
835
+ (1 + self.weight_decay * lr))
836
+ return m, n, v, p
837
+
838
+ m_values, n_values, v_values, weight_values = jax.tree.map(
839
+ _adan_update, m_values, n_values, v_values, pre_g_values, grad_values, weight_values)
840
+ self.exp_avg_states.assign_values(m_values)
841
+ self.exp_avg_diff_states.assign_values(n_values)
842
+ self.exp_avg_sq_states.assign_values(v_values)
843
+ self.param_states.assign_values(weight_values)
844
+ self.lr.step_call()
845
+
846
+
847
+ class AdamW(_WeightDecayOptimizer):
848
+ r"""
849
+ Adam with weight decay regularization [1]_.
850
+
851
+ AdamW uses weight decay to regularize learning towards small weights, as
852
+ this leads to better generalization. In SGD you can also use L2 regularization
853
+ to implement this as an additive loss term, however L2 regularization
854
+ does not behave as intended for adaptive gradient algorithms such as Adam.
855
+
856
+ .. math::
857
+
858
+ \begin{aligned}
859
+ &\rule{110mm}{0.4pt} \\
860
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
861
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
862
+ \: \epsilon \text{ (epsilon)} \\
863
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
864
+ \: \textit{maximize} \\
865
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
866
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
867
+ &\rule{110mm}{0.4pt} \\
868
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
869
+
870
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
871
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
872
+ &\hspace{5mm}\textbf{else} \\
873
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
874
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
875
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
876
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
877
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
878
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
879
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
880
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
881
+ \widehat{v_t}) \\
882
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
883
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
884
+ &\hspace{5mm}\textbf{else} \\
885
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
886
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
887
+ &\rule{110mm}{0.4pt} \\[-1.ex]
888
+ &\bf{return} \: \theta_t \\[-1.ex]
889
+ &\rule{110mm}{0.4pt} \\[-1.ex]
890
+ \end{aligned}
891
+
892
+
893
+ Parameters
894
+ ----------
895
+ lr: float, LearningRateScheduler
896
+ learning rate.
897
+ beta1: optional, float
898
+ A positive scalar value for beta_1, the exponential decay rate
899
+ for the first moment estimates. Generally close to 1.
900
+ beta2: optional, float
901
+ A positive scalar value for beta_2, the exponential decay rate
902
+ for the second moment estimates. Generally close to 1.
903
+ eps: optional, float
904
+ A positive scalar value for epsilon, a small constant for
905
+ numerical stability.
906
+ weight_decay: float
907
+ Strength of the weight decay regularization. Note that this
908
+ weight decay is multiplied with the learning rate.
909
+ amsgrad: bool
910
+ whether to use the AMSGrad variant of this algorithm
911
+ from the paper `On the Convergence of Adam and Beyond`.
912
+
913
+ References
914
+ ----------
915
+ .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019).
916
+
917
+ """
918
+
919
+ def __init__(
920
+ self,
921
+ lr: Union[float, LearningRateScheduler, State],
922
+ beta1: float = 0.9,
923
+ beta2: float = 0.999,
924
+ eps: float = 1e-8,
925
+ weight_decay: float = 1e-2,
926
+ amsgrad: bool = False,
927
+ ):
928
+ super().__init__(lr=lr, weight_decay=weight_decay)
929
+
930
+ if eps < 0.:
931
+ raise ValueError("Invalid epsilon value: {}".format(eps))
932
+ if not 0.0 <= beta1 < 1.0:
933
+ raise ValueError("Invalid beta parameter at index 0: {}".format(beta1))
934
+ if not 0.0 <= beta2 < 1.0:
935
+ raise ValueError("Invalid beta parameter at index 1: {}".format(beta2))
936
+ if weight_decay < 0.:
937
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
938
+
939
+ self.beta1 = fcast(beta1)
940
+ self.beta2 = fcast(beta2)
941
+ self.eps = fcast(eps)
942
+ self.amsgrad = amsgrad
943
+ self.m1_states = StateDictManager()
944
+ self.m2_states = StateDictManager()
945
+ if self.amsgrad:
946
+ self.vmax_states = StateDictManager()
947
+
948
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
949
+ train_states = dict() if train_states is None else train_states
950
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
951
+ for k, v in train_states.items():
952
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
953
+ if self.param_states.add_unique_value(k, v):
954
+ self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
955
+ self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
956
+ if self.amsgrad:
957
+ self.vmax_states[k] = OptimState(u.math.tree_zeros_like(v.value))
958
+
959
+ def update(self, grads: dict):
960
+ lr_old = self.lr()
961
+ step = self.lr.last_epoch.value + 2
962
+ bias_correction1 = 1 - self.beta1 ** step
963
+ bias_correction2 = 1 - self.beta2 ** step
964
+ lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
965
+
966
+ def _adamw_update(p, m, v, g, vmax=None):
967
+ if self.weight_decay != 0:
968
+ p *= (1 - lr_old * self.weight_decay)
969
+ m = self.beta1 * m + (1 - self.beta1) * g
970
+ v = self.beta2 * v + (1 - self.beta2) * g ** 2
971
+ if self.amsgrad:
972
+ vmax = jnp.maximum(vmax, v)
973
+ denom = jnp.sqrt(vmax) + self.eps
974
+ return p - lr * m / denom, m, v, vmax
975
+ else:
976
+ denom = jnp.sqrt(v.value) + self.eps
977
+ return p - lr * m / denom, m, v
978
+
979
+ if self.amsgrad:
980
+ weight_values, m1_values, m2_values, vmax_values = to_same_dict_tree(
981
+ self.param_states, self.m1_states, self.m2_states, self.vmax_states)
982
+ weight_values, m1_values, m2_values, vmax_values = jax.tree.map(
983
+ _adamw_update, weight_values, m1_values, m2_values, grads, vmax_values)
984
+ self.vmax_states.assign_values(vmax_values)
985
+ else:
986
+ weight_values, m1_values, m2_values = to_same_dict_tree(self.param_states, self.m1_states, self.m2_states)
987
+ weight_values, m1_values, m2_values = jax.tree.map(
988
+ _adamw_update, weight_values, m1_values, m2_values, grads)
989
+ self.param_states.assign_values(weight_values)
990
+ self.m1_states.assign_values(m1_values)
991
+ self.m2_states.assign_values(m2_values)
992
+ self.lr.step_call()
993
+
994
+
995
+ class SM3(_WeightDecayOptimizer):
996
+ """
997
+ SM3 algorithm [1]_.
998
+
999
+ The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
1000
+ (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
1001
+ to Adam and Adagrad with greatly reduced memory usage for history tensors.
1002
+ For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history
1003
+ tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor
1004
+ of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)`
1005
+ memory for storage tensors, while the optimization using SM3 will use
1006
+ `O(sum n_i)` memory. Despite storing fewer parameters, this optimization
1007
+ algorithm manages to be comparably effective.
1008
+
1009
+ This advantage drastically shrinks when `momentum > 0`. The momentum is
1010
+ tracked using a tensor of the same shape as the tensor being optimized. With
1011
+ momentum, SM3 will use just over half as much memory as Adam, and a bit more
1012
+ than Adagrad.
1013
+
1014
+ Parameters
1015
+ ----------
1016
+ lr: float, LearningRateScheduler
1017
+ learning rate.
1018
+ momentum: float
1019
+ coefficient used to scale prior updates
1020
+ before adding. This drastically increases memory usage if
1021
+ `momentum > 0.0`. (default: 0.0)
1022
+ beta: float
1023
+ coefficient used for exponential moving averages (default: 0.0)
1024
+ eps: float
1025
+ Term added to square-root in denominator to
1026
+ improve numerical stability (default: 1e-30).
1027
+
1028
+ References
1029
+ ----------
1030
+ .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019).
1031
+
1032
+ """
1033
+
1034
+ def __init__(
1035
+ self,
1036
+ lr: Union[float, LearningRateScheduler, State],
1037
+ beta: float = 0.,
1038
+ momentum: float = 0.,
1039
+ eps: float = 1e-30,
1040
+ weight_decay: Optional[float] = None,
1041
+ ):
1042
+ super().__init__(lr=lr, weight_decay=weight_decay)
1043
+
1044
+ if not 0.0 <= momentum < 1.0:
1045
+ raise ValueError("Invalid momentum: {0}".format(momentum))
1046
+ if not 0.0 <= beta < 1.0:
1047
+ raise ValueError("Invalid beta: {0}".format(beta))
1048
+ if not 0.0 <= eps:
1049
+ raise ValueError("Invalid eps: {0}".format(eps))
1050
+
1051
+ self.eps = fcast(eps)
1052
+ self.beta = fcast(beta)
1053
+ self.momentum = fcast(momentum)
1054
+ self.memory_states = StateDictManager()
1055
+
1056
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
1057
+ train_states = dict() if train_states is None else train_states
1058
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
1059
+ for k, v in train_states.items():
1060
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
1061
+ if self.param_states.add_unique_value(k, v):
1062
+ rank, ndim, dtype = v.value.shape, v.value.ndim, v.value.dtype
1063
+ for i in range(ndim):
1064
+ shape = [1] * ndim
1065
+ shape[i] = rank[i]
1066
+ self.memory_states[f'{k}_m{i}'] = State(jnp.zeros(shape, dtype=dtype))
1067
+ if self.momentum > 0.:
1068
+ self.memory_states[f'{k}_mbuffer'] = State(jnp.zeros_like(v.value))
1069
+
1070
+ def update(self, grads: dict):
1071
+ lr = self.lr()
1072
+
1073
+ for k, p in self.param_states.items():
1074
+ g = grads[k]
1075
+ ndim = p.ndim
1076
+ update = self.memory_states[f'{k}_m0'].value
1077
+ for i in range(1, ndim):
1078
+ update = jnp.minimum(update, self.memory_states[f'{k}_m{i}'].value)
1079
+ if self.beta > 0.:
1080
+ update *= self.beta
1081
+ update += g * g * (1 - self.beta)
1082
+ # Computes max along all dimensions except the given dim.
1083
+ # If tensor is a scalar, it returns tensor.
1084
+ for i in range(ndim):
1085
+ result = update
1086
+ for j in range(ndim):
1087
+ if i != j:
1088
+ result = jnp.maximum(result, axis=j, keepdim=True)
1089
+ acc = self.memory_states[f'{k}_m{i}'].value
1090
+ if self.beta > 0.:
1091
+ acc.value = jnp.maximum(acc, result)
1092
+ else:
1093
+ # No need to compare - nu_max is bigger because of grad ** 2
1094
+ acc.value = result
1095
+ update = g / jnp.sqrt(update + self.eps)
1096
+ if self.momentum > 0.:
1097
+ m_buffer = self.memory_states[f'{k}_mbuffer'].value
1098
+ update = update * (1. - self.momentum) + m_buffer * self.momentum
1099
+ m_buffer.value = update
1100
+ if self.weight_decay is None:
1101
+ p.value -= lr * update
1102
+ else:
1103
+ p.value = (1 - self.weight_decay) * p - lr * update
1104
+ self.lr.step_call()