brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1148 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import functools
19
+ from typing import Union, Dict, Optional, Tuple, Any, TypeVar
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ from ._lr_scheduler import make_schedule, LearningRateScheduler
25
+ from .. import environ, math
26
+ from .._module import Module
27
+ from .._state import State, LongTermState, StateDictManager, visible_state_dict
28
+
29
+ __all__ = [
30
+ 'to_same_dict_tree',
31
+
32
+ # new class of brainstate.State for optimizer
33
+ 'OptimState',
34
+
35
+ # commonly used optimizers
36
+ 'Optimizer',
37
+ 'SGD',
38
+ 'Momentum',
39
+ 'MomentumNesterov',
40
+ 'Adagrad',
41
+ 'Adadelta',
42
+ 'RMSProp',
43
+ 'Adam',
44
+ 'LARS',
45
+ 'Adan',
46
+ 'AdamW',
47
+ ]
48
+
49
+ T = TypeVar('T')
50
+
51
+
52
+ def cast(value: Any, dtype: Any) -> jax.Array:
53
+ if isinstance(value, jax.Array):
54
+ return value.astype(dtype)
55
+ return jnp.asarray(value, dtype=dtype)
56
+
57
+
58
+ def fcast(value: T, dtype: Any = None) -> jax.Array:
59
+ return cast(value, dtype=dtype or environ.dftype())
60
+
61
+
62
+ def _to_dict_value(old_dict: Dict) -> Dict:
63
+ new_dict = dict()
64
+ for k, v in old_dict.items():
65
+ if isinstance(v, State):
66
+ new_dict[k] = v.value
67
+ else:
68
+ new_dict[k] = v
69
+ return new_dict
70
+
71
+
72
+ def to_same_dict_tree(*dicts: Dict):
73
+ """
74
+ Convert multiple dictionaries to the same tree structure.
75
+
76
+ Parameters
77
+ ----------
78
+ *dicts: dict
79
+ The dictionaries to be converted.
80
+
81
+ Returns
82
+ -------
83
+ dict
84
+ The converted dictionary.
85
+ """
86
+ if len(dicts):
87
+ # all keys
88
+ all_keys = tuple(set(d.keys()) for d in dicts)
89
+ for keys in all_keys[1:]:
90
+ if len(all_keys[0].difference(keys)) > 0:
91
+ raise ValueError('Dictionary does not match.')
92
+
93
+ # flatten to normal python dict
94
+ r = [_to_dict_value(d) for d in dicts]
95
+
96
+ if len(dicts) == 1:
97
+ return r[0]
98
+ else:
99
+ return tuple(r)
100
+
101
+
102
+ def _sgd(prev_weight, gradient, weight_decay, lr=None):
103
+ """
104
+ The update function for SGD learning.
105
+
106
+ Parameters
107
+ ----------
108
+ prev_weight: jax.Array
109
+ The previous weight.
110
+ gradient: jax.Array
111
+ The gradient.
112
+ weight_decay: float
113
+ The weight decay.
114
+ lr: float
115
+ The learning rate.
116
+ """
117
+ if weight_decay is None:
118
+ if lr is None:
119
+ return prev_weight - gradient
120
+ else:
121
+ return prev_weight - lr * gradient
122
+ else:
123
+ if lr is None:
124
+ return (1 - weight_decay) * prev_weight - gradient
125
+ else:
126
+ return (1 - weight_decay) * prev_weight - lr * gradient
127
+
128
+
129
+ class OptimState(LongTermState):
130
+ """
131
+ The state for optimizer.
132
+ """
133
+ pass
134
+
135
+
136
+ class Optimizer(Module):
137
+ """Base Optimizer Class.
138
+
139
+ Parameters
140
+ ----------
141
+ lr: float, LearningRateScheduler
142
+ learning rate.
143
+ """
144
+
145
+ lr: LearningRateScheduler # learning rate
146
+ weight_states: StateDictManager # states to train, invisible to ``.states()``
147
+
148
+ def __init__(
149
+ self,
150
+ lr: Union[float, LearningRateScheduler, State],
151
+ name: Optional[str] = None
152
+ ):
153
+ super().__init__(name=name)
154
+ self.lr: LearningRateScheduler = make_schedule(lr)
155
+ self.weight_states = StateDictManager()
156
+
157
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
158
+ raise NotImplementedError
159
+
160
+ def __repr__(self):
161
+ return f"{self.__class__.__name__}(lr={self.lr}{self.extra_repr()})"
162
+
163
+ def extra_repr(self) -> str:
164
+ return ''
165
+
166
+ def update(self, grads: dict):
167
+ raise NotImplementedError
168
+
169
+
170
+ class _WeightDecayOptimizer(Optimizer):
171
+ def __init__(
172
+ self,
173
+ lr: Union[float, LearningRateScheduler, State],
174
+ weight_decay: Optional[float] = None,
175
+ name: Optional[str] = None
176
+ ):
177
+ super().__init__(lr=lr, name=name)
178
+ self.lr: LearningRateScheduler = make_schedule(lr)
179
+ assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
180
+ self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)
181
+
182
+ def extra_repr(self) -> str:
183
+ return ''
184
+
185
+ def __repr__(self):
186
+ return f"{self.__class__.__name__}(lr={self.lr}, weight_decay={self.weight_decay}{self.extra_repr()})"
187
+
188
+
189
+ class SGD(_WeightDecayOptimizer):
190
+ r"""
191
+ Stochastic gradient descent optimizer.
192
+
193
+ SGD performs a parameter update for training examples :math:`x` and label
194
+ :math:`y`:
195
+
196
+ .. math::
197
+
198
+ \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)
199
+
200
+
201
+ Parameters
202
+ ----------
203
+ lr: float, LearningRateScheduler
204
+ learning rate.
205
+
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ lr: Union[float, LearningRateScheduler, State],
211
+ weight_decay: Optional[float] = None,
212
+ name: Optional[str] = None
213
+ ):
214
+ super().__init__(lr=lr, weight_decay=weight_decay, name=name)
215
+
216
+ def register_trainable_weights(self, states: Optional[Dict[str, State]] = None):
217
+ states = dict() if states is None else states
218
+ assert isinstance(states, dict), '"states" must be a dict of brainstate.State.'
219
+ for k, v in states.items():
220
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
221
+ self.weight_states.add_unique_elem(k, v)
222
+
223
+ def update(self, grads: dict):
224
+ lr = self.lr()
225
+ weight_values, grad_values = to_same_dict_tree(self.weight_states, grads)
226
+ updates = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
227
+ weight_values,
228
+ grad_values)
229
+ self.weight_states.assign_values(updates)
230
+ self.lr.step_call()
231
+
232
+
233
+ class Momentum(_WeightDecayOptimizer):
234
+ r"""
235
+ Momentum optimizer.
236
+
237
+ Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
238
+ and dampens oscillations. It does this by adding a fraction :math:`\gamma`
239
+ of the update vector of the past time step to the current update vector:
240
+
241
+ .. math::
242
+
243
+ \begin{align}
244
+ \begin{split}
245
+ v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\
246
+ \theta &= \theta - v_t
247
+ \end{split}
248
+ \end{align}
249
+
250
+ Parameters
251
+ ----------
252
+ lr: float, LearningRateScheduler
253
+ learning rate.
254
+
255
+ References
256
+ ----------
257
+
258
+ .. [1] Qian, N. (1999). On the momentum term in gradient descent learning
259
+ algorithms. Neural Networks : The Official Journal of the International
260
+ Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
261
+
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ lr: Union[float, LearningRateScheduler, State],
267
+ momentum: float = 0.9,
268
+ weight_decay: Optional[float] = None,
269
+ name: Optional[str] = None
270
+ ):
271
+ super(Momentum, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
272
+ self.momentum = fcast(momentum)
273
+ self.momentum_states = visible_state_dict()
274
+
275
+ def extra_repr(self) -> str:
276
+ return f", momentum={self.momentum}"
277
+
278
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
279
+ train_states = dict() if train_states is None else train_states
280
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
281
+
282
+ for k, v in train_states.items():
283
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
284
+ self.weight_states.add_unique_elem(k, v)
285
+ self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
286
+
287
+ def update(self, grads: dict):
288
+ lr = self.lr()
289
+ states_values, grad_values, momentum_values = to_same_dict_tree(
290
+ self.weight_states, grads, self.momentum_states
291
+ )
292
+ momentum_values = jax.tree.map(
293
+ lambda vv, gg: self.momentum * vv - lr * gg,
294
+ momentum_values,
295
+ grad_values
296
+ )
297
+ new_weight_values = jax.tree.map(
298
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
299
+ states_values,
300
+ momentum_values
301
+ )
302
+ self.momentum_states.assign_values(momentum_values)
303
+ self.weight_states.assign_values(new_weight_values)
304
+ self.lr.step_call()
305
+
306
+
307
+ class MomentumNesterov(_WeightDecayOptimizer):
308
+ r"""
309
+ Nesterov accelerated gradient optimizer [2]_.
310
+
311
+ .. math::
312
+
313
+ \begin{align}
314
+ \begin{split}
315
+ v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\
316
+ \theta &= \theta - v_t
317
+ \end{split}
318
+ \end{align}
319
+
320
+ Parameters
321
+ ----------
322
+ lr: float, LearningRateScheduler
323
+ learning rate.
324
+
325
+ References
326
+ ----------
327
+ .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.
328
+
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ lr: Union[float, LearningRateScheduler, State],
334
+ weight_decay: Optional[float] = None,
335
+ momentum: float = 0.9,
336
+ name: Optional[str] = None
337
+ ):
338
+ super(MomentumNesterov, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
339
+
340
+ self.momentum = fcast(momentum)
341
+ self.momentum_states = visible_state_dict()
342
+
343
+ def extra_repr(self) -> str:
344
+ return f", momentum={self.momentum}"
345
+
346
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
347
+ train_states = dict() if train_states is None else train_states
348
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
349
+ for k, v in train_states.items():
350
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
351
+ self.weight_states.add_unique_elem(k, v)
352
+ self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
353
+
354
+ def update(self, grads: dict):
355
+ lr = self.lr()
356
+ states_values, grad_values, momentum_values = to_same_dict_tree(self.weight_states, grads, self.momentum_states)
357
+ momentum_values = jax.tree.map(lambda mv, gv: self.momentum * mv - lr * gv,
358
+ momentum_values,
359
+ grad_values)
360
+ weight_values = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
361
+ states_values,
362
+ momentum_values)
363
+ self.weight_states.assign_values(weight_values)
364
+ self.momentum_states.assign_values(momentum_values)
365
+ self.lr.step_call()
366
+
367
+
368
+ class Adagrad(_WeightDecayOptimizer):
369
+ r"""
370
+ Optimizer that implements the Adagrad algorithm.
371
+
372
+ Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
373
+ adapted relative to how frequently a parameter gets updated during training.
374
+ The more updates a parameter receives, the smaller the updates.
375
+
376
+ .. math::
377
+
378
+ \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}
379
+
380
+ where :math:`G(t)` contains the sum of the squares of the past gradients
381
+
382
+ One of Adagrad's main benefits is that it eliminates the need to manually tune
383
+ the learning rate. Most implementations use a default value of 0.01 and leave it at that.
384
+ Adagrad's main weakness is its accumulation of the squared gradients in the denominator:
385
+ Since every added term is positive, the accumulated sum keeps growing during training.
386
+ This in turn causes the learning rate to shrink and eventually become infinitesimally
387
+ small, at which point the algorithm is no longer able to acquire additional knowledge.
388
+
389
+ Parameters
390
+ ----------
391
+ lr: float, LearningRateScheduler
392
+ learning rate.
393
+
394
+ References
395
+ ----------
396
+ .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
397
+
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ lr: Union[float, LearningRateScheduler, State],
403
+ weight_decay: Optional[float] = None,
404
+ epsilon: float = 1e-6,
405
+ name: Optional[str] = None
406
+ ):
407
+ super().__init__(lr=lr, weight_decay=weight_decay, name=name)
408
+ self.epsilon = fcast(epsilon)
409
+ self.cache_states = visible_state_dict()
410
+
411
+ def extra_repr(self) -> str:
412
+ return f", epsilon={self.epsilon}"
413
+
414
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
415
+ train_states = dict() if train_states is None else train_states
416
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
417
+ for k, v in train_states.items():
418
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
419
+ self.weight_states.add_unique_elem(k, v)
420
+ self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
421
+
422
+ def update(self, grads: dict):
423
+ lr = self.lr()
424
+ cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.weight_states)
425
+ cache_values = jax.tree.map(lambda cv, gv: cv + gv ** 2, cache_values, grad_values)
426
+ updates = jax.tree.map(lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon), cache_values, grad_values)
427
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
428
+ weight_values,
429
+ updates)
430
+ self.cache_states.assign_values(cache_values)
431
+ self.weight_states.assign_values(weight_values)
432
+ self.lr.step_call()
433
+
434
+
435
+ class Adadelta(_WeightDecayOptimizer):
436
+ r"""
437
+ Optimizer that implements the Adadelta algorithm.
438
+
439
+ Adadelta [4]_ optimization is a stochastic gradient descent method that is based
440
+ on adaptive learning rate per dimension to address two drawbacks:
441
+
442
+ - The continual decay of learning rates throughout training.
443
+ - The need for a manually selected global learning rate.
444
+
445
+ Adadelta is a more robust extension of Adagrad that adapts learning rates based on
446
+ a moving window of gradient updates, instead of accumulating all past gradients.
447
+ This way, Adadelta continues learning even when many updates have been done. Compared
448
+ to Adagrad, in the original version of Adadelta you don't have to set an initial
449
+ learning rate.
450
+
451
+ .. math::
452
+
453
+ \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\
454
+ \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\
455
+ \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\
456
+ \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t.
457
+
458
+ :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the
459
+ moving average slowly and a value close to 0 will decay the moving average fast.
460
+
461
+ :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported
462
+ to work for multiple datasets (MNIST, speech).
463
+
464
+ In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to
465
+ keep it at this value. epsilon is important for the very first update (so the
466
+ numerator does not become 0).
467
+
468
+ Parameters
469
+ ----------
470
+ lr: float, LearningRateScheduler
471
+ learning rate.
472
+
473
+ References
474
+ ----------
475
+ .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
476
+
477
+ """
478
+
479
+ def __init__(
480
+ self,
481
+ lr: Union[float, LearningRateScheduler, State] = 0.01,
482
+ weight_decay: Optional[float] = None,
483
+ epsilon: float = 1e-6,
484
+ rho: float = 0.95,
485
+ name: Optional[str] = None
486
+ ):
487
+ super().__init__(lr=lr, weight_decay=weight_decay, name=name)
488
+
489
+ self.epsilon = fcast(epsilon)
490
+ self.rho = fcast(rho)
491
+ self.cache_states = visible_state_dict()
492
+ self.delta_states = visible_state_dict()
493
+
494
+ def extra_repr(self) -> str:
495
+ return f", epsilon={self.epsilon}, rho={self.rho}"
496
+
497
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
498
+ train_states = dict() if train_states is None else train_states
499
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
500
+ for k, v in train_states.items():
501
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
502
+ self.weight_states.add_unique_elem(k, v)
503
+ self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
504
+ self.delta_states[k] = OptimState(math.tree_zeros_like(v.value))
505
+
506
+ def update(self, grads: dict):
507
+ weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
508
+ self.weight_states, grads, self.cache_states, self.delta_states)
509
+ cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
510
+ updates = jax.tree.map(lambda gv, dv, cv: gv * jnp.sqrt(dv + self.epsilon) / jnp.sqrt(cv + self.epsilon),
511
+ grad_values, delta_values, cache_values)
512
+ delta_values = jax.tree.map(lambda dv, upd: self.rho * dv + (1 - self.rho) * upd ** 2, delta_values, updates)
513
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
514
+ weight_values,
515
+ updates)
516
+ self.weight_states.assign_values(weight_values)
517
+ self.delta_states.assign_values(delta_values)
518
+ self.cache_states.assign_values(cache_values)
519
+ self.lr.step_call()
520
+
521
+
522
+ class RMSProp(_WeightDecayOptimizer):
523
+ r"""
524
+ Optimizer that implements the RMSprop algorithm.
525
+
526
+ RMSprop [5]_ and Adadelta have both been developed independently around the same time
527
+ stemming from the need to resolve Adagrad's radically diminishing learning rates.
528
+
529
+ The gist of RMSprop is to:
530
+
531
+ - Maintain a moving (discounted) average of the square of gradients
532
+ - Divide the gradient by the root of this average
533
+
534
+ .. math::
535
+
536
+ \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\
537
+ p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split}
538
+
539
+ The centered version additionally maintains a moving average of the gradients,
540
+ and uses that average to estimate the variance.
541
+
542
+ Parameters
543
+ ----------
544
+ lr: float, LearningRateScheduler
545
+ learning rate.
546
+
547
+ References
548
+ ----------
549
+ .. [5] Tieleman, T. and Hinton, G. (2012):
550
+ Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
551
+ Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
552
+ """
553
+
554
+ def __init__(
555
+ self,
556
+ lr: Union[float, LearningRateScheduler, State],
557
+ weight_decay: Optional[float] = None,
558
+ epsilon: float = 1e-6,
559
+ rho: float = 0.9,
560
+ name: Optional[str] = None
561
+ ):
562
+ super(RMSProp, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
563
+
564
+ self.epsilon = fcast(epsilon)
565
+ self.rho = fcast(rho)
566
+ self.cache_states = visible_state_dict()
567
+
568
+ def extra_repr(self) -> str:
569
+ return f", epsilon={self.epsilon}, rho={self.rho}"
570
+
571
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
572
+ train_states = dict() if train_states is None else train_states
573
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
574
+ for k, v in train_states.items():
575
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
576
+ self.weight_states.add_unique_elem(k, v)
577
+ self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
578
+
579
+ def update(self, grads: dict):
580
+ lr = self.lr()
581
+ weight_values, grad_values, cache_values = to_same_dict_tree(self.weight_states, grads, self.cache_states)
582
+ cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
583
+ update = jax.tree.map(lambda gv, cv: lr * gv / jnp.sqrt(cv + self.epsilon), grad_values, cache_values)
584
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
585
+ weight_values,
586
+ update)
587
+ self.weight_states.assign_values(weight_values)
588
+ self.cache_states.assign_values(cache_values)
589
+ self.lr.step_call()
590
+
591
+
592
+ class Adam(_WeightDecayOptimizer):
593
+ """
594
+ Optimizer that implements the Adam algorithm.
595
+
596
+ Adam [6]_ - a stochastic gradient descent method (SGD) that computes
597
+ individual adaptive learning rates for different parameters from estimates of
598
+ first- and second-order moments of the gradients.
599
+
600
+ Parameters
601
+ ----------
602
+ lr: float, LearningRateScheduler
603
+ learning rate.
604
+ beta1: optional, float
605
+ A positive scalar value for beta_1, the exponential decay rate
606
+ for the first moment estimates (default 0.9).
607
+ beta2: optional, float
608
+ A positive scalar value for beta_2, the exponential decay rate
609
+ for the second moment estimates (default 0.999).
610
+ eps: optional, float
611
+ A positive scalar value for epsilon, a small constant for
612
+ numerical stability (default 1e-8).
613
+ name : optional, str
614
+ The optimizer name.
615
+
616
+ References
617
+ ----------
618
+ .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
619
+ """
620
+
621
+ def __init__(
622
+ self,
623
+ lr: Union[float, State, LearningRateScheduler],
624
+ beta1: float = 0.9,
625
+ beta2: float = 0.999,
626
+ eps: float = 1e-8,
627
+ weight_decay: Optional[float] = None,
628
+ name: Optional[str] = None
629
+ ):
630
+ super(Adam, self).__init__(lr=lr,
631
+ weight_decay=weight_decay,
632
+ name=name)
633
+
634
+ self.beta1 = fcast(beta1)
635
+ self.beta2 = fcast(beta2)
636
+ self.eps = fcast(eps)
637
+ self.m1_states = visible_state_dict()
638
+ self.m2_states = visible_state_dict()
639
+
640
+ def extra_repr(self) -> str:
641
+ return f", beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}"
642
+
643
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
644
+ train_states = dict() if train_states is None else train_states
645
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
646
+
647
+ for k, v in train_states.items():
648
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
649
+ self.weight_states.add_unique_elem(k, v)
650
+ self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
651
+ self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
652
+
653
+ def update(self, grads: dict):
654
+ lr = self.lr()
655
+ lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
656
+ lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
657
+ weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
658
+ self.weight_states, grads, self.m1_states, self.m2_states)
659
+ m1_values = jax.tree.map(lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv, m1_values, grad_values)
660
+ m2_values = jax.tree.map(lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2, m2_values, grad_values)
661
+ update = jax.tree.map(lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps), m1_values, m2_values)
662
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
663
+ weight_values,
664
+ update)
665
+ self.weight_states.assign_values(weight_values)
666
+ self.m1_states.assign_values(m1_values)
667
+ self.m2_states.assign_values(m2_values)
668
+ self.lr.step_call()
669
+
670
+
671
+ class LARS(_WeightDecayOptimizer):
672
+ r"""
673
+ Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
674
+
675
+ Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
676
+ optimization technique. There are two notable differences
677
+ between LARS and other adaptive algorithms such as `Adam` or `RMSProp`:
678
+ first, LARS uses a separate learning rate for each layer and not for
679
+ each weight. And second, the magnitude of the update is controlled
680
+ with respect to the weight norm for better control of training speed.
681
+
682
+ .. math::
683
+
684
+ m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\
685
+ x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)}
686
+
687
+ Parameters
688
+ ----------
689
+ lr: float, LearningRateScheduler
690
+ learning rate.
691
+ momentum: float
692
+ coefficient used for the moving average of the gradient.
693
+ weight_decay: float
694
+ weight decay coefficient.
695
+ tc: float
696
+ trust coefficient eta ( < 1) for trust ratio computation.
697
+ eps: float
698
+ epsilon used for trust ratio computation.
699
+
700
+ References
701
+ ----------
702
+ .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag.
703
+ """
704
+
705
+ def __init__(
706
+ self,
707
+ lr: Union[float, LearningRateScheduler, State],
708
+ momentum: float = 0.9,
709
+ weight_decay: float = 1e-4,
710
+ tc: float = 1e-3,
711
+ eps: float = 1e-5,
712
+ name: Optional[str] = None
713
+ ):
714
+ super(LARS, self).__init__(lr=lr,
715
+ weight_decay=weight_decay,
716
+ name=name)
717
+ assert self.weight_decay is None, 'LARS does not support weight decay.'
718
+
719
+ self.momentum = fcast(momentum)
720
+ self.tc = fcast(tc)
721
+ self.eps = fcast(eps)
722
+ self.momentum_states = visible_state_dict()
723
+
724
+ def extra_repr(self) -> str:
725
+ return f", momentum={self.momentum}, tc={self.tc}, eps={self.eps}"
726
+
727
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
728
+ train_states = dict() if train_states is None else train_states
729
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
730
+ for k, v in train_states.items():
731
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
732
+ self.weight_states.add_unique_elem(k, v)
733
+ self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
734
+
735
+ def update(self, grads: dict):
736
+ lr = self.lr()
737
+ weight_values, grad_values, momentum_values = to_same_dict_tree(self.weight_states, grads, self.momentum_states)
738
+
739
+ def _lars_update(pv, gv, mv):
740
+ p_norm = jnp.linalg.norm(pv)
741
+ g_norm = jnp.linalg.norm(gv)
742
+ trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
743
+ local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
744
+ mv = self.momentum * mv + local_lr * (gv + self.weight_decay * pv)
745
+ return mv
746
+
747
+ momentum_values = jax.tree.map(_lars_update, weight_values, grad_values, momentum_values)
748
+ weight_values = jax.tree.map(lambda pv, mv: pv - mv, weight_values, momentum_values)
749
+ self.weight_states.assign_values(weight_values)
750
+ self.momentum_states.assign_values(momentum_values)
751
+ self.lr.step_call()
752
+
753
+
754
+ class Adan(_WeightDecayOptimizer):
755
+ r"""
756
+ Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
757
+
758
+ .. math::
759
+
760
+ \begin{equation}
761
+ \begin{aligned}
762
+ & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\
763
+ & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\
764
+ & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\
765
+ & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\
766
+ & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\
767
+ \end{aligned}
768
+ \end{equation}
769
+
770
+ Parameters
771
+ ----------
772
+ lr: float, LearningRateScheduler
773
+ learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3)
774
+ betas : tuple
775
+ Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01))
776
+ eps : float
777
+ The term added to the denominator to improve numerical stability. (default: 1e-8)
778
+ weight_decay : float
779
+ decoupled weight decay (L2 penalty) (default: 0)
780
+ no_prox: bool
781
+ how to perform the decoupled weight decay (default: False).
782
+ It determines the update rule of parameters with weight decay.
783
+ By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
784
+
785
+ .. math::
786
+ \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right],
787
+
788
+ But one also can update the parameter like Adamw:
789
+
790
+ .. math::
791
+ \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k).
792
+
793
+ References
794
+ ----------
795
+ .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan.
796
+ “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
797
+ Deep Models.” ArXiv abs/2208.06677 (2022): n. pag.
798
+ """
799
+
800
+ def __init__(
801
+ self,
802
+ lr: Union[float, LearningRateScheduler, State] = 1e-3,
803
+ betas: Tuple[float, float, float] = (0.02, 0.08, 0.01),
804
+ eps: float = 1e-8,
805
+ weight_decay: float = 0.02,
806
+ no_prox: bool = False,
807
+ name: Optional[str] = None,
808
+ ):
809
+ super(Adan, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
810
+
811
+ assert len(betas) == 3
812
+ if eps < 0.:
813
+ raise ValueError("Invalid epsilon value: {}".format(eps))
814
+ if not 0.0 <= betas[0] < 1.0:
815
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
816
+ if not 0.0 <= betas[1] < 1.0:
817
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
818
+ if not 0.0 <= betas[2] < 1.0:
819
+ raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
820
+
821
+ self.betas = fcast(jnp.asarray(betas))
822
+ self.eps = fcast(eps)
823
+ self.no_prox = no_prox
824
+ self.exp_avg_states = visible_state_dict()
825
+ self.exp_avg_sq_states = visible_state_dict()
826
+ self.exp_avg_diff_states = visible_state_dict()
827
+ self.pre_grad_states = visible_state_dict()
828
+
829
+ def extra_repr(self) -> str:
830
+ return f", betas={self.betas}, eps={self.eps}, weight_decay={self.weight_decay}, no_prox={self.no_prox}"
831
+
832
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
833
+ train_states = dict() if train_states is None else train_states
834
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
835
+ for k, v in train_states.items():
836
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
837
+ self.weight_states.add_unique_elem(k, v)
838
+ self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value))
839
+ self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value))
840
+ self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value))
841
+ self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value))
842
+
843
+ def update(self, grads: dict):
844
+ lr = self.lr()
845
+ step = self.lr.last_epoch.value + 1
846
+ correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
847
+ correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
848
+ correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
849
+ m_values, n_values, v_values, pre_g_values, weight_values, grad_values = to_same_dict_tree(
850
+ self.exp_avg_states, self.exp_avg_diff_states, self.exp_avg_sq_states, self.pre_grad_states,
851
+ self.weight_states, grads)
852
+
853
+ def _adan_update(m, n, v, pre_g, g, p):
854
+ m = m * (1 - self.betas[0]) + self.betas[0] * g
855
+ gd = g - pre_g
856
+ v = v * (1 - self.betas[1]) + self.betas[1] * gd
857
+ n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2
858
+ weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps)
859
+ if self.no_prox:
860
+ p = (p * (1 - self.weight_decay * lr) -
861
+ weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v))
862
+ else:
863
+ p = ((p - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) /
864
+ (1 + self.weight_decay * lr))
865
+ return m, n, v, p
866
+
867
+ m_values, n_values, v_values, weight_values = jax.tree.map(
868
+ _adan_update, m_values, n_values, v_values, pre_g_values, grad_values, weight_values)
869
+ self.exp_avg_states.assign_values(m_values)
870
+ self.exp_avg_diff_states.assign_values(n_values)
871
+ self.exp_avg_sq_states.assign_values(v_values)
872
+ self.weight_states.assign_values(weight_values)
873
+ self.lr.step_call()
874
+
875
+
876
+ class AdamW(_WeightDecayOptimizer):
877
+ r"""
878
+ Adam with weight decay regularization [1]_.
879
+
880
+ AdamW uses weight decay to regularize learning towards small weights, as
881
+ this leads to better generalization. In SGD you can also use L2 regularization
882
+ to implement this as an additive loss term, however L2 regularization
883
+ does not behave as intended for adaptive gradient algorithms such as Adam.
884
+
885
+ .. math::
886
+
887
+ \begin{aligned}
888
+ &\rule{110mm}{0.4pt} \\
889
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
890
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
891
+ \: \epsilon \text{ (epsilon)} \\
892
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
893
+ \: \textit{maximize} \\
894
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
895
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
896
+ &\rule{110mm}{0.4pt} \\
897
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
898
+
899
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
900
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
901
+ &\hspace{5mm}\textbf{else} \\
902
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
903
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
904
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
905
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
906
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
907
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
908
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
909
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
910
+ \widehat{v_t}) \\
911
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
912
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
913
+ &\hspace{5mm}\textbf{else} \\
914
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
915
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
916
+ &\rule{110mm}{0.4pt} \\[-1.ex]
917
+ &\bf{return} \: \theta_t \\[-1.ex]
918
+ &\rule{110mm}{0.4pt} \\[-1.ex]
919
+ \end{aligned}
920
+
921
+
922
+ Parameters
923
+ ----------
924
+ lr: float, LearningRateScheduler
925
+ learning rate.
926
+ beta1: optional, float
927
+ A positive scalar value for beta_1, the exponential decay rate
928
+ for the first moment estimates. Generally close to 1.
929
+ beta2: optional, float
930
+ A positive scalar value for beta_2, the exponential decay rate
931
+ for the second moment estimates. Generally close to 1.
932
+ eps: optional, float
933
+ A positive scalar value for epsilon, a small constant for
934
+ numerical stability.
935
+ weight_decay: float
936
+ Strength of the weight decay regularization. Note that this
937
+ weight decay is multiplied with the learning rate.
938
+ amsgrad: bool
939
+ whether to use the AMSGrad variant of this algorithm
940
+ from the paper `On the Convergence of Adam and Beyond`.
941
+ name : optional, str
942
+ The optimizer name.
943
+
944
+ References
945
+ ----------
946
+ .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019).
947
+
948
+ """
949
+
950
+ def __init__(
951
+ self,
952
+ lr: Union[float, LearningRateScheduler, State],
953
+ beta1: float = 0.9,
954
+ beta2: float = 0.999,
955
+ eps: float = 1e-8,
956
+ weight_decay: float = 1e-2,
957
+ amsgrad: bool = False,
958
+ name: Optional[str] = None,
959
+ ):
960
+ super(AdamW, self).__init__(lr=lr,
961
+ weight_decay=weight_decay,
962
+ name=name)
963
+
964
+ if eps < 0.:
965
+ raise ValueError("Invalid epsilon value: {}".format(eps))
966
+ if not 0.0 <= beta1 < 1.0:
967
+ raise ValueError("Invalid beta parameter at index 0: {}".format(beta1))
968
+ if not 0.0 <= beta2 < 1.0:
969
+ raise ValueError("Invalid beta parameter at index 1: {}".format(beta2))
970
+ if weight_decay < 0.:
971
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
972
+
973
+ self.beta1 = fcast(beta1)
974
+ self.beta2 = fcast(beta2)
975
+ self.eps = fcast(eps)
976
+ self.amsgrad = amsgrad
977
+ self.m1_states = visible_state_dict()
978
+ self.m2_states = visible_state_dict()
979
+ if self.amsgrad:
980
+ self.vmax_states = visible_state_dict()
981
+
982
+ def extra_repr(self) -> str:
983
+ return (f", beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}"
984
+ f", weight_decay={self.weight_decay}, amsgrad={self.amsgrad}")
985
+
986
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
987
+ train_states = dict() if train_states is None else train_states
988
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
989
+ for k, v in train_states.items():
990
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
991
+ self.weight_states.add_unique_elem(k, v)
992
+ self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
993
+ self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
994
+ if self.amsgrad:
995
+ self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value))
996
+
997
+ def update(self, grads: dict):
998
+ lr_old = self.lr()
999
+ step = self.lr.last_epoch.value + 2
1000
+ bias_correction1 = 1 - self.beta1 ** step
1001
+ bias_correction2 = 1 - self.beta2 ** step
1002
+ lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
1003
+
1004
+ def _adamw_update(p, m, v, g, vmax=None):
1005
+ if self.weight_decay != 0:
1006
+ p *= (1 - lr_old * self.weight_decay)
1007
+ m = self.beta1 * m + (1 - self.beta1) * g
1008
+ v = self.beta2 * v + (1 - self.beta2) * g ** 2
1009
+ if self.amsgrad:
1010
+ vmax = jnp.maximum(vmax, v)
1011
+ denom = jnp.sqrt(vmax) + self.eps
1012
+ return p - lr * m / denom, m, v, vmax
1013
+ else:
1014
+ denom = jnp.sqrt(v.value) + self.eps
1015
+ return p - lr * m / denom, m, v
1016
+
1017
+ if self.amsgrad:
1018
+ weight_values, m1_values, m2_values, vmax_values = to_same_dict_tree(
1019
+ self.weight_states, self.m1_states, self.m2_states, self.vmax_states)
1020
+ weight_values, m1_values, m2_values, vmax_values = jax.tree.map(
1021
+ _adamw_update, weight_values, m1_values, m2_values, grads, vmax_values)
1022
+ self.vmax_states.assign_values(vmax_values)
1023
+ else:
1024
+ weight_values, m1_values, m2_values = to_same_dict_tree(self.weight_states, self.m1_states, self.m2_states)
1025
+ weight_values, m1_values, m2_values = jax.tree.map(
1026
+ _adamw_update, weight_values, m1_values, m2_values, grads)
1027
+ self.weight_states.assign_values(weight_values)
1028
+ self.m1_states.assign_values(m1_values)
1029
+ self.m2_states.assign_values(m2_values)
1030
+ self.lr.step_call()
1031
+
1032
+
1033
+ class SM3(_WeightDecayOptimizer):
1034
+ """
1035
+ SM3 algorithm [1]_.
1036
+
1037
+ The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
1038
+ (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
1039
+ to Adam and Adagrad with greatly reduced memory usage for history tensors.
1040
+ For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history
1041
+ tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor
1042
+ of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)`
1043
+ memory for storage tensors, while the optimization using SM3 will use
1044
+ `O(sum n_i)` memory. Despite storing fewer parameters, this optimization
1045
+ algorithm manages to be comparably effective.
1046
+
1047
+ This advantage drastically shrinks when `momentum > 0`. The momentum is
1048
+ tracked using a tensor of the same shape as the tensor being optimized. With
1049
+ momentum, SM3 will use just over half as much memory as Adam, and a bit more
1050
+ than Adagrad.
1051
+
1052
+ Parameters
1053
+ ----------
1054
+ lr: float, LearningRateScheduler
1055
+ learning rate.
1056
+ momentum: float
1057
+ coefficient used to scale prior updates
1058
+ before adding. This drastically increases memory usage if
1059
+ `momentum > 0.0`. (default: 0.0)
1060
+ beta: float
1061
+ coefficient used for exponential moving averages (default: 0.0)
1062
+ eps: float
1063
+ Term added to square-root in denominator to
1064
+ improve numerical stability (default: 1e-30).
1065
+
1066
+ References
1067
+ ----------
1068
+ .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019).
1069
+
1070
+ """
1071
+
1072
+ def __init__(
1073
+ self,
1074
+ lr: Union[float, LearningRateScheduler, State],
1075
+ beta: float = 0.,
1076
+ momentum: float = 0.,
1077
+ eps: float = 1e-30,
1078
+ weight_decay: Optional[float] = None,
1079
+ name: Optional[str] = None,
1080
+ ):
1081
+ super(SM3, self).__init__(lr=lr,
1082
+ weight_decay=weight_decay,
1083
+ name=name)
1084
+
1085
+ if not 0.0 <= momentum < 1.0:
1086
+ raise ValueError("Invalid momentum: {0}".format(momentum))
1087
+ if not 0.0 <= beta < 1.0:
1088
+ raise ValueError("Invalid beta: {0}".format(beta))
1089
+ if not 0.0 <= eps:
1090
+ raise ValueError("Invalid eps: {0}".format(eps))
1091
+
1092
+ self.eps = fcast(eps)
1093
+ self.beta = fcast(beta)
1094
+ self.momentum = fcast(momentum)
1095
+ self.memory_states = visible_state_dict()
1096
+
1097
+ def extra_repr(self) -> str:
1098
+ return f", beta={self.beta}, momentum={self.momentum}, eps={self.eps}"
1099
+
1100
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
1101
+ train_states = dict() if train_states is None else train_states
1102
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
1103
+ for k, v in train_states.items():
1104
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
1105
+ self.weight_states.add_unique_elem(k, v)
1106
+ rank, ndim, dtype = v.value.shape, v.value.ndim, v.value.dtype
1107
+ for i in range(ndim):
1108
+ shape = [1] * ndim
1109
+ shape[i] = rank[i]
1110
+ self.memory_states[f'{k}_m{i}'] = State(jnp.zeros(shape, dtype=dtype))
1111
+ if self.momentum > 0.:
1112
+ self.memory_states[f'{k}_mbuffer'] = State(jnp.zeros_like(v.value))
1113
+
1114
+ def update(self, grads: dict):
1115
+ lr = self.lr()
1116
+
1117
+ for k, p in self.weight_states.items():
1118
+ g = grads[k]
1119
+ ndim = p.ndim
1120
+ update = self.memory_states[f'{k}_m0'].value
1121
+ for i in range(1, ndim):
1122
+ update = jnp.minimum(update, self.memory_states[f'{k}_m{i}'].value)
1123
+ if self.beta > 0.:
1124
+ update *= self.beta
1125
+ update += g * g * (1 - self.beta)
1126
+ # Computes max along all dimensions except the given dim.
1127
+ # If tensor is a scalar, it returns tensor.
1128
+ for i in range(ndim):
1129
+ result = update
1130
+ for j in range(ndim):
1131
+ if i != j:
1132
+ result = jnp.maximum(result, axis=j, keepdim=True)
1133
+ acc = self.memory_states[f'{k}_m{i}'].value
1134
+ if self.beta > 0.:
1135
+ acc.value = jnp.maximum(acc, result)
1136
+ else:
1137
+ # No need to compare - nu_max is bigger because of grad ** 2
1138
+ acc.value = result
1139
+ update = g / jnp.sqrt(update + self.eps)
1140
+ if self.momentum > 0.:
1141
+ m_buffer = self.memory_states[f'{k}_mbuffer'].value
1142
+ update = update * (1. - self.momentum) + m_buffer * self.momentum
1143
+ m_buffer.value = update
1144
+ if self.weight_decay is None:
1145
+ p.value -= lr * update
1146
+ else:
1147
+ p.value = (1 - self.weight_decay) * p - lr * update
1148
+ self.lr.step_call()