brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -14,1136 +14,1061 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  import functools
19
20
  from typing import Union, Dict, Optional, Tuple, Any, TypeVar
20
21
 
21
- import brainunit as bu
22
+ import brainunit as u
22
23
  import jax
23
24
  import jax.numpy as jnp
24
25
 
26
+ from brainstate import environ
27
+ from brainstate._state import State, LongTermState, StateDictManager
28
+ from ._base import Optimizer
25
29
  from ._lr_scheduler import make_schedule, LearningRateScheduler
26
- from .. import environ
27
- from .._module import Module
28
- from .._state import State, LongTermState, StateDictManager, visible_state_dict
29
30
 
30
31
  __all__ = [
31
- 'to_same_dict_tree',
32
-
33
- # new class of brainstate.State for optimizer
34
- 'OptimState',
35
-
36
- # commonly used optimizers
37
- 'Optimizer',
38
- 'SGD',
39
- 'Momentum',
40
- 'MomentumNesterov',
41
- 'Adagrad',
42
- 'Adadelta',
43
- 'RMSProp',
44
- 'Adam',
45
- 'LARS',
46
- 'Adan',
47
- 'AdamW',
32
+ 'to_same_dict_tree',
33
+
34
+ # new class of brainstate.State for optimizer
35
+ 'OptimState',
36
+
37
+ # commonly used optimizers
38
+ 'SGDOptimizer',
39
+ 'SGD',
40
+ 'Momentum',
41
+ 'MomentumNesterov',
42
+ 'Adagrad',
43
+ 'Adadelta',
44
+ 'RMSProp',
45
+ 'Adam',
46
+ 'LARS',
47
+ 'Adan',
48
+ 'AdamW',
48
49
  ]
49
50
 
50
51
  T = TypeVar('T')
51
52
 
52
53
 
53
54
  def cast(value: Any, dtype: Any) -> jax.Array:
54
- if isinstance(value, jax.Array):
55
- return value.astype(dtype)
56
- return jnp.asarray(value, dtype=dtype)
55
+ if isinstance(value, jax.Array):
56
+ return value.astype(dtype)
57
+ return jnp.asarray(value, dtype=dtype)
57
58
 
58
59
 
59
60
  def fcast(value: T, dtype: Any = None) -> jax.Array:
60
- return cast(value, dtype=dtype or environ.dftype())
61
+ return cast(value, dtype=dtype or environ.dftype())
61
62
 
62
63
 
63
64
  def _to_dict_value(old_dict: Dict) -> Dict:
64
- new_dict = dict()
65
- for k, v in old_dict.items():
66
- if isinstance(v, State):
67
- new_dict[k] = v.value
68
- else:
69
- new_dict[k] = v
70
- return new_dict
65
+ new_dict = dict()
66
+ for k, v in old_dict.items():
67
+ if isinstance(v, State):
68
+ new_dict[k] = v.value
69
+ else:
70
+ new_dict[k] = v
71
+ return new_dict
71
72
 
72
73
 
73
74
  def to_same_dict_tree(*dicts: Dict):
74
- """
75
- Convert multiple dictionaries to the same tree structure.
76
-
77
- Parameters
78
- ----------
79
- *dicts: dict
80
- The dictionaries to be converted.
81
-
82
- Returns
83
- -------
84
- dict
85
- The converted dictionary.
86
- """
87
- if len(dicts):
88
- # all keys
89
- all_keys = tuple(set(d.keys()) for d in dicts)
90
- for keys in all_keys[1:]:
91
- if len(all_keys[0].difference(keys)) > 0:
92
- raise ValueError('Dictionary does not match.')
93
-
94
- # flatten to normal python dict
95
- r = [_to_dict_value(d) for d in dicts]
96
-
97
- if len(dicts) == 1:
98
- return r[0]
99
- else:
100
- return tuple(r)
75
+ """
76
+ Convert multiple dictionaries to the same tree structure.
77
+
78
+ Parameters
79
+ ----------
80
+ *dicts: dict
81
+ The dictionaries to be converted.
82
+
83
+ Returns
84
+ -------
85
+ dict
86
+ The converted dictionary.
87
+ """
88
+ if len(dicts):
89
+ # all keys
90
+ all_keys = tuple(set(d.keys()) for d in dicts)
91
+ for keys in all_keys[1:]:
92
+ if len(all_keys[0].difference(keys)) > 0:
93
+ raise ValueError('Dictionary does not match.')
94
+
95
+ # flatten to normal python dict
96
+ r = [_to_dict_value(d) for d in dicts]
97
+
98
+ if len(dicts) == 1:
99
+ return r[0]
100
+ else:
101
+ return tuple(r)
101
102
 
102
103
 
103
104
  def _sgd(prev_weight, gradient, weight_decay, lr=None):
104
- """
105
- The update function for SGD learning.
106
-
107
- Parameters
108
- ----------
109
- prev_weight: jax.Array
110
- The previous weight.
111
- gradient: jax.Array
112
- The gradient.
113
- weight_decay: float
114
- The weight decay.
115
- lr: float
116
- The learning rate.
117
- """
118
- if weight_decay is None:
119
- if lr is None:
120
- return prev_weight - gradient
121
- else:
122
- return prev_weight - lr * gradient
123
- else:
124
- if lr is None:
125
- return (1 - weight_decay) * prev_weight - gradient
105
+ """
106
+ The update function for SGD learning.
107
+
108
+ Parameters
109
+ ----------
110
+ prev_weight: jax.Array
111
+ The previous weight.
112
+ gradient: jax.Array
113
+ The gradient.
114
+ weight_decay: float
115
+ The weight decay.
116
+ lr: float
117
+ The learning rate.
118
+ """
119
+ if weight_decay is None:
120
+ if lr is None:
121
+ return prev_weight - gradient
122
+ else:
123
+ return prev_weight - lr * gradient
126
124
  else:
127
- return (1 - weight_decay) * prev_weight - lr * gradient
125
+ if lr is None:
126
+ return (1 - weight_decay) * prev_weight - gradient
127
+ else:
128
+ return (1 - weight_decay) * prev_weight - lr * gradient
128
129
 
129
130
 
130
131
  class OptimState(LongTermState):
131
- """
132
- The state for optimizer.
133
- """
134
- pass
135
-
132
+ """
133
+ The state for optimizer.
134
+ """
135
+ pass
136
136
 
137
- class Optimizer(Module):
138
- """Base Optimizer Class.
139
137
 
140
- Parameters
141
- ----------
142
- lr: float, LearningRateScheduler
143
- learning rate.
144
- """
145
-
146
- lr: LearningRateScheduler # learning rate
147
- weight_states: StateDictManager # states to train, invisible to ``.states()``
148
-
149
- def __init__(
150
- self,
151
- lr: Union[float, LearningRateScheduler, State],
152
- name: Optional[str] = None
153
- ):
154
- super().__init__(name=name)
155
- self.lr: LearningRateScheduler = make_schedule(lr)
156
- self.weight_states = StateDictManager()
157
-
158
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
159
- raise NotImplementedError
138
+ class SGDOptimizer(Optimizer):
139
+ """
140
+ Base Optimizer Class.
160
141
 
161
- def __repr__(self):
162
- return f"{self.__class__.__name__}(lr={self.lr}{self.extra_repr()})"
142
+ Parameters
143
+ ----------
144
+ lr: float, LearningRateScheduler
145
+ learning rate.
146
+ """
163
147
 
164
- def extra_repr(self) -> str:
165
- return ''
148
+ lr: LearningRateScheduler # learning rate
166
149
 
167
- def update(self, grads: dict):
168
- raise NotImplementedError
150
+ def __init__(
151
+ self, lr: Union[float, LearningRateScheduler, State],
152
+ ):
153
+ super().__init__()
154
+ self.lr: LearningRateScheduler = make_schedule(lr)
169
155
 
170
156
 
171
- class _WeightDecayOptimizer(Optimizer):
172
- def __init__(
173
- self,
174
- lr: Union[float, LearningRateScheduler, State],
175
- weight_decay: Optional[float] = None,
176
- name: Optional[str] = None
177
- ):
178
- super().__init__(lr=lr, name=name)
179
- self.lr: LearningRateScheduler = make_schedule(lr)
180
- assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
181
- self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)
182
-
183
- def extra_repr(self) -> str:
184
- return ''
185
-
186
- def __repr__(self):
187
- return f"{self.__class__.__name__}(lr={self.lr}, weight_decay={self.weight_decay}{self.extra_repr()})"
157
+ class _WeightDecayOptimizer(SGDOptimizer):
158
+ def __init__(
159
+ self,
160
+ lr: Union[float, LearningRateScheduler, State],
161
+ weight_decay: Optional[float] = None,
162
+ ):
163
+ super().__init__(lr=lr)
164
+ self.lr: LearningRateScheduler = make_schedule(lr)
165
+ assert weight_decay is None or 0. <= weight_decay <= 1., 'weight_decay must be in [0, 1].'
166
+ self.weight_decay = (fcast(weight_decay) if weight_decay is not None else None)
188
167
 
189
168
 
190
169
  class SGD(_WeightDecayOptimizer):
191
- r"""
192
- Stochastic gradient descent optimizer.
193
-
194
- SGD performs a parameter update for training examples :math:`x` and label
195
- :math:`y`:
196
-
197
- .. math::
170
+ r"""
171
+ Stochastic gradient descent optimizer.
198
172
 
199
- \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)
173
+ SGD performs a parameter update for training examples :math:`x` and label
174
+ :math:`y`:
200
175
 
176
+ .. math::
201
177
 
202
- Parameters
203
- ----------
204
- lr: float, LearningRateScheduler
205
- learning rate.
206
-
207
- """
208
-
209
- def __init__(
210
- self,
211
- lr: Union[float, LearningRateScheduler, State],
212
- weight_decay: Optional[float] = None,
213
- name: Optional[str] = None
214
- ):
215
- super().__init__(lr=lr, weight_decay=weight_decay, name=name)
216
-
217
- def register_trainable_weights(self, states: Optional[Dict[str, State]] = None):
218
- states = dict() if states is None else states
219
- assert isinstance(states, dict), '"states" must be a dict of brainstate.State.'
220
- for k, v in states.items():
221
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
222
- self.weight_states.add_unique_elem(k, v)
223
-
224
- def update(self, grads: dict):
225
- lr = self.lr()
226
- weight_values, grad_values = to_same_dict_tree(self.weight_states, grads)
227
- updates = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
228
- weight_values,
229
- grad_values)
230
- self.weight_states.assign_values(updates)
231
- self.lr.step_call()
178
+ \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y)
232
179
 
233
180
 
234
- class Momentum(_WeightDecayOptimizer):
235
- r"""
236
- Momentum optimizer.
181
+ Parameters
182
+ ----------
183
+ lr: float, LearningRateScheduler
184
+ learning rate.
237
185
 
238
- Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
239
- and dampens oscillations. It does this by adding a fraction :math:`\gamma`
240
- of the update vector of the past time step to the current update vector:
186
+ """
241
187
 
242
- .. math::
188
+ def __init__(
189
+ self,
190
+ lr: Union[float, LearningRateScheduler, State],
191
+ weight_decay: Optional[float] = None,
192
+ ):
193
+ super().__init__(lr=lr, weight_decay=weight_decay)
243
194
 
244
- \begin{align}
245
- \begin{split}
246
- v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\
247
- \theta &= \theta - v_t
248
- \end{split}
249
- \end{align}
195
+ def register_trainable_weights(self, states: Optional[Dict[str, State]] = None):
196
+ states = dict() if states is None else states
197
+ assert isinstance(states, dict), '"states" must be a dict of brainstate.State.'
198
+ for k, v in states.items():
199
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
200
+ self.param_states.add_unique_value(k, v)
250
201
 
251
- Parameters
252
- ----------
253
- lr: float, LearningRateScheduler
254
- learning rate.
202
+ def update(self, grads: dict):
203
+ lr = self.lr()
204
+ weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
205
+ updates = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
206
+ weight_values,
207
+ grad_values)
208
+ self.param_states.assign_values(updates)
209
+ self.lr.step_call()
255
210
 
256
- References
257
- ----------
258
-
259
- .. [1] Qian, N. (1999). On the momentum term in gradient descent learning
260
- algorithms. Neural Networks : The Official Journal of the International
261
- Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
262
-
263
- """
264
-
265
- def __init__(
266
- self,
267
- lr: Union[float, LearningRateScheduler, State],
268
- momentum: float = 0.9,
269
- weight_decay: Optional[float] = None,
270
- name: Optional[str] = None
271
- ):
272
- super(Momentum, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
273
- self.momentum = fcast(momentum)
274
- self.momentum_states = visible_state_dict()
275
-
276
- def extra_repr(self) -> str:
277
- return f", momentum={self.momentum}"
278
-
279
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
280
- train_states = dict() if train_states is None else train_states
281
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
282
-
283
- for k, v in train_states.items():
284
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
285
- self.weight_states.add_unique_elem(k, v)
286
- self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
287
-
288
- def update(self, grads: dict):
289
- lr = self.lr()
290
- states_values, grad_values, momentum_values = to_same_dict_tree(
291
- self.weight_states, grads, self.momentum_states
292
- )
293
- momentum_values = jax.tree.map(
294
- lambda vv, gg: self.momentum * vv - lr * gg,
295
- momentum_values,
296
- grad_values
297
- )
298
- new_weight_values = jax.tree.map(
299
- functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
300
- states_values,
301
- momentum_values
302
- )
303
- self.momentum_states.assign_values(momentum_values)
304
- self.weight_states.assign_values(new_weight_values)
305
- self.lr.step_call()
306
211
 
212
+ class Momentum(_WeightDecayOptimizer):
213
+ r"""
214
+ Momentum optimizer.
307
215
 
308
- class MomentumNesterov(_WeightDecayOptimizer):
309
- r"""
310
- Nesterov accelerated gradient optimizer [2]_.
216
+ Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
217
+ and dampens oscillations. It does this by adding a fraction :math:`\gamma`
218
+ of the update vector of the past time step to the current update vector:
311
219
 
312
- .. math::
220
+ .. math::
313
221
 
314
222
  \begin{align}
315
223
  \begin{split}
316
- v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\
224
+ v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\
317
225
  \theta &= \theta - v_t
318
226
  \end{split}
319
227
  \end{align}
320
228
 
321
- Parameters
322
- ----------
323
- lr: float, LearningRateScheduler
324
- learning rate.
325
-
326
- References
327
- ----------
328
- .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.
329
-
330
- """
331
-
332
- def __init__(
333
- self,
334
- lr: Union[float, LearningRateScheduler, State],
335
- weight_decay: Optional[float] = None,
336
- momentum: float = 0.9,
337
- name: Optional[str] = None
338
- ):
339
- super(MomentumNesterov, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
340
-
341
- self.momentum = fcast(momentum)
342
- self.momentum_states = visible_state_dict()
343
-
344
- def extra_repr(self) -> str:
345
- return f", momentum={self.momentum}"
346
-
347
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
348
- train_states = dict() if train_states is None else train_states
349
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
350
- for k, v in train_states.items():
351
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
352
- self.weight_states.add_unique_elem(k, v)
353
- self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
354
-
355
- def update(self, grads: dict):
356
- lr = self.lr()
357
- states_values, grad_values, momentum_values = to_same_dict_tree(self.weight_states, grads, self.momentum_states)
358
- momentum_values = jax.tree.map(lambda mv, gv: self.momentum * mv - lr * gv,
359
- momentum_values,
360
- grad_values)
361
- weight_values = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
362
- states_values,
363
- momentum_values)
364
- self.weight_states.assign_values(weight_values)
365
- self.momentum_states.assign_values(momentum_values)
366
- self.lr.step_call()
229
+ Parameters
230
+ ----------
231
+ lr: float, LearningRateScheduler
232
+ learning rate.
233
+
234
+ References
235
+ ----------
236
+
237
+ .. [1] Qian, N. (1999). On the momentum term in gradient descent learning
238
+ algorithms. Neural Networks : The Official Journal of the International
239
+ Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
240
+
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ lr: Union[float, LearningRateScheduler, State],
246
+ momentum: float = 0.9,
247
+ weight_decay: Optional[float] = None,
248
+ ):
249
+ super().__init__(lr=lr, weight_decay=weight_decay)
250
+ self.momentum = fcast(momentum)
251
+ self.momentum_states = StateDictManager()
252
+
253
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
254
+ train_states = dict() if train_states is None else train_states
255
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
256
+
257
+ for k, v in train_states.items():
258
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
259
+ if self.param_states.add_unique_value(k, v):
260
+ self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
261
+
262
+ def update(self, grads: dict):
263
+ lr = self.lr()
264
+ states_values, grad_values, momentum_values = to_same_dict_tree(
265
+ self.param_states, grads, self.momentum_states
266
+ )
267
+ momentum_values = jax.tree.map(
268
+ lambda vv, gg: self.momentum * vv - lr * gg,
269
+ momentum_values,
270
+ grad_values
271
+ )
272
+ new_weight_values = jax.tree.map(
273
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
274
+ states_values,
275
+ momentum_values
276
+ )
277
+ self.momentum_states.assign_values(momentum_values)
278
+ self.param_states.assign_values(new_weight_values)
279
+ self.lr.step_call()
367
280
 
368
281
 
369
- class Adagrad(_WeightDecayOptimizer):
370
- r"""
371
- Optimizer that implements the Adagrad algorithm.
282
+ class MomentumNesterov(_WeightDecayOptimizer):
283
+ r"""
284
+ Nesterov accelerated gradient optimizer [2]_.
372
285
 
373
- Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
374
- adapted relative to how frequently a parameter gets updated during training.
375
- The more updates a parameter receives, the smaller the updates.
286
+ .. math::
376
287
 
377
- .. math::
288
+ \begin{align}
289
+ \begin{split}
290
+ v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\
291
+ \theta &= \theta - v_t
292
+ \end{split}
293
+ \end{align}
294
+
295
+ Parameters
296
+ ----------
297
+ lr: float, LearningRateScheduler
298
+ learning rate.
299
+
300
+ References
301
+ ----------
302
+ .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547.
303
+
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ lr: Union[float, LearningRateScheduler, State],
309
+ weight_decay: Optional[float] = None,
310
+ momentum: float = 0.9,
311
+ ):
312
+ super().__init__(lr=lr, weight_decay=weight_decay)
313
+
314
+ self.momentum = fcast(momentum)
315
+ self.momentum_states = StateDictManager()
316
+
317
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
318
+ train_states = dict() if train_states is None else train_states
319
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
320
+ for k, v in train_states.items():
321
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
322
+ if self.param_states.add_unique_value(k, v):
323
+ self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
324
+
325
+ def update(self, grads: dict):
326
+ lr = self.lr()
327
+ states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
328
+ momentum_values = jax.tree.map(lambda mv, gv: self.momentum * mv - lr * gv,
329
+ momentum_values,
330
+ grad_values)
331
+ weight_values = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
332
+ states_values,
333
+ momentum_values)
334
+ self.param_states.assign_values(weight_values)
335
+ self.momentum_states.assign_values(momentum_values)
336
+ self.lr.step_call()
378
337
 
379
- \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}
380
338
 
381
- where :math:`G(t)` contains the sum of the squares of the past gradients
339
+ class Adagrad(_WeightDecayOptimizer):
340
+ r"""
341
+ Optimizer that implements the Adagrad algorithm.
382
342
 
383
- One of Adagrad's main benefits is that it eliminates the need to manually tune
384
- the learning rate. Most implementations use a default value of 0.01 and leave it at that.
385
- Adagrad's main weakness is its accumulation of the squared gradients in the denominator:
386
- Since every added term is positive, the accumulated sum keeps growing during training.
387
- This in turn causes the learning rate to shrink and eventually become infinitesimally
388
- small, at which point the algorithm is no longer able to acquire additional knowledge.
343
+ Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
344
+ adapted relative to how frequently a parameter gets updated during training.
345
+ The more updates a parameter receives, the smaller the updates.
389
346
 
390
- Parameters
391
- ----------
392
- lr: float, LearningRateScheduler
393
- learning rate.
347
+ .. math::
394
348
 
395
- References
396
- ----------
397
- .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
398
-
399
- """
400
-
401
- def __init__(
402
- self,
403
- lr: Union[float, LearningRateScheduler, State],
404
- weight_decay: Optional[float] = None,
405
- epsilon: float = 1e-6,
406
- name: Optional[str] = None
407
- ):
408
- super().__init__(lr=lr, weight_decay=weight_decay, name=name)
409
- self.epsilon = fcast(epsilon)
410
- self.cache_states = visible_state_dict()
411
-
412
- def extra_repr(self) -> str:
413
- return f", epsilon={self.epsilon}"
414
-
415
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
416
- train_states = dict() if train_states is None else train_states
417
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
418
- for k, v in train_states.items():
419
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
420
- self.weight_states.add_unique_elem(k, v)
421
- self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
422
-
423
- def update(self, grads: dict):
424
- lr = self.lr()
425
- cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.weight_states)
426
- cache_values = jax.tree.map(lambda cv, gv: cv + gv ** 2, cache_values, grad_values)
427
- updates = jax.tree.map(lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon), cache_values, grad_values)
428
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
429
- weight_values,
430
- updates)
431
- self.cache_states.assign_values(cache_values)
432
- self.weight_states.assign_values(weight_values)
433
- self.lr.step_call()
349
+ \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t}
350
+
351
+ where :math:`G(t)` contains the sum of the squares of the past gradients
352
+
353
+ One of Adagrad's main benefits is that it eliminates the need to manually tune
354
+ the learning rate. Most implementations use a default value of 0.01 and leave it at that.
355
+ Adagrad's main weakness is its accumulation of the squared gradients in the denominator:
356
+ Since every added term is positive, the accumulated sum keeps growing during training.
357
+ This in turn causes the learning rate to shrink and eventually become infinitesimally
358
+ small, at which point the algorithm is no longer able to acquire additional knowledge.
359
+
360
+ Parameters
361
+ ----------
362
+ lr: float, LearningRateScheduler
363
+ learning rate.
364
+
365
+ References
366
+ ----------
367
+ .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html
368
+
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ lr: Union[float, LearningRateScheduler, State],
374
+ weight_decay: Optional[float] = None,
375
+ epsilon: float = 1e-6,
376
+ ):
377
+ super().__init__(lr=lr, weight_decay=weight_decay)
378
+ self.epsilon = fcast(epsilon)
379
+ self.cache_states = StateDictManager()
380
+
381
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
382
+ train_states = dict() if train_states is None else train_states
383
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
384
+ for k, v in train_states.items():
385
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
386
+ if self.param_states.add_unique_value(k, v):
387
+ self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
388
+
389
+ def update(self, grads: dict):
390
+ lr = self.lr()
391
+ cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
392
+ cache_values = jax.tree.map(lambda cv, gv: cv + gv ** 2, cache_values, grad_values)
393
+ updates = jax.tree.map(lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon), cache_values, grad_values)
394
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
395
+ weight_values,
396
+ updates)
397
+ self.cache_states.assign_values(cache_values)
398
+ self.param_states.assign_values(weight_values)
399
+ self.lr.step_call()
434
400
 
435
401
 
436
402
  class Adadelta(_WeightDecayOptimizer):
437
- r"""
438
- Optimizer that implements the Adadelta algorithm.
439
-
440
- Adadelta [4]_ optimization is a stochastic gradient descent method that is based
441
- on adaptive learning rate per dimension to address two drawbacks:
442
-
443
- - The continual decay of learning rates throughout training.
444
- - The need for a manually selected global learning rate.
445
-
446
- Adadelta is a more robust extension of Adagrad that adapts learning rates based on
447
- a moving window of gradient updates, instead of accumulating all past gradients.
448
- This way, Adadelta continues learning even when many updates have been done. Compared
449
- to Adagrad, in the original version of Adadelta you don't have to set an initial
450
- learning rate.
403
+ r"""
404
+ Optimizer that implements the Adadelta algorithm.
451
405
 
452
- .. math::
406
+ Adadelta [4]_ optimization is a stochastic gradient descent method that is based
407
+ on adaptive learning rate per dimension to address two drawbacks:
453
408
 
454
- \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\
455
- \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\
456
- \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\
457
- \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t.
409
+ - The continual decay of learning rates throughout training.
410
+ - The need for a manually selected global learning rate.
458
411
 
459
- :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the
460
- moving average slowly and a value close to 0 will decay the moving average fast.
461
-
462
- :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported
463
- to work for multiple datasets (MNIST, speech).
464
-
465
- In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to
466
- keep it at this value. epsilon is important for the very first update (so the
467
- numerator does not become 0).
468
-
469
- Parameters
470
- ----------
471
- lr: float, LearningRateScheduler
412
+ Adadelta is a more robust extension of Adagrad that adapts learning rates based on
413
+ a moving window of gradient updates, instead of accumulating all past gradients.
414
+ This way, Adadelta continues learning even when many updates have been done. Compared
415
+ to Adagrad, in the original version of Adadelta you don't have to set an initial
472
416
  learning rate.
473
417
 
474
- References
475
- ----------
476
- .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
477
-
478
- """
479
-
480
- def __init__(
481
- self,
482
- lr: Union[float, LearningRateScheduler, State] = 0.01,
483
- weight_decay: Optional[float] = None,
484
- epsilon: float = 1e-6,
485
- rho: float = 0.95,
486
- name: Optional[str] = None
487
- ):
488
- super().__init__(lr=lr, weight_decay=weight_decay, name=name)
489
-
490
- self.epsilon = fcast(epsilon)
491
- self.rho = fcast(rho)
492
- self.cache_states = visible_state_dict()
493
- self.delta_states = visible_state_dict()
494
-
495
- def extra_repr(self) -> str:
496
- return f", epsilon={self.epsilon}, rho={self.rho}"
497
-
498
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
499
- train_states = dict() if train_states is None else train_states
500
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
501
- for k, v in train_states.items():
502
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
503
- self.weight_states.add_unique_elem(k, v)
504
- self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
505
- self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
506
-
507
- def update(self, grads: dict):
508
- weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
509
- self.weight_states, grads, self.cache_states, self.delta_states)
510
- cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
511
- updates = jax.tree.map(lambda gv, dv, cv: gv * jnp.sqrt(dv + self.epsilon) / jnp.sqrt(cv + self.epsilon),
512
- grad_values, delta_values, cache_values)
513
- delta_values = jax.tree.map(lambda dv, upd: self.rho * dv + (1 - self.rho) * upd ** 2, delta_values, updates)
514
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
515
- weight_values,
516
- updates)
517
- self.weight_states.assign_values(weight_values)
518
- self.delta_states.assign_values(delta_values)
519
- self.cache_states.assign_values(cache_values)
520
- self.lr.step_call()
521
-
522
-
523
- class RMSProp(_WeightDecayOptimizer):
524
- r"""
525
- Optimizer that implements the RMSprop algorithm.
418
+ .. math::
526
419
 
527
- RMSprop [5]_ and Adadelta have both been developed independently around the same time
528
- stemming from the need to resolve Adagrad's radically diminishing learning rates.
420
+ \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\
421
+ \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\
422
+ \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\
423
+ \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t.
424
+
425
+ :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the
426
+ moving average slowly and a value close to 0 will decay the moving average fast.
427
+
428
+ :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported
429
+ to work for multiple datasets (MNIST, speech).
430
+
431
+ In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to
432
+ keep it at this value. epsilon is important for the very first update (so the
433
+ numerator does not become 0).
434
+
435
+ Parameters
436
+ ----------
437
+ lr: float, LearningRateScheduler
438
+ learning rate.
439
+
440
+ References
441
+ ----------
442
+ .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701
443
+
444
+ """
445
+
446
+ def __init__(
447
+ self,
448
+ lr: Union[float, LearningRateScheduler, State] = 0.01,
449
+ weight_decay: Optional[float] = None,
450
+ epsilon: float = 1e-6,
451
+ rho: float = 0.95,
452
+ ):
453
+ super().__init__(lr=lr, weight_decay=weight_decay)
454
+
455
+ self.epsilon = fcast(epsilon)
456
+ self.rho = fcast(rho)
457
+ self.cache_states = StateDictManager()
458
+ self.delta_states = StateDictManager()
459
+
460
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
461
+ train_states = dict() if train_states is None else train_states
462
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
463
+ for k, v in train_states.items():
464
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
465
+ if self.param_states.add_unique_value(k, v):
466
+ self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
467
+ self.delta_states[k] = OptimState(u.math.tree_zeros_like(v.value))
468
+
469
+ def update(self, grads: dict):
470
+ weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
471
+ self.param_states, grads, self.cache_states, self.delta_states)
472
+ cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
473
+ updates = jax.tree.map(lambda gv, dv, cv: gv * jnp.sqrt(dv + self.epsilon) / jnp.sqrt(cv + self.epsilon),
474
+ grad_values, delta_values, cache_values)
475
+ delta_values = jax.tree.map(lambda dv, upd: self.rho * dv + (1 - self.rho) * upd ** 2, delta_values, updates)
476
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
477
+ weight_values,
478
+ updates)
479
+ self.param_states.assign_values(weight_values)
480
+ self.delta_states.assign_values(delta_values)
481
+ self.cache_states.assign_values(cache_values)
482
+ self.lr.step_call()
529
483
 
530
- The gist of RMSprop is to:
531
484
 
532
- - Maintain a moving (discounted) average of the square of gradients
533
- - Divide the gradient by the root of this average
485
+ class RMSProp(_WeightDecayOptimizer):
486
+ r"""
487
+ Optimizer that implements the RMSprop algorithm.
534
488
 
535
- .. math::
489
+ RMSprop [5]_ and Adadelta have both been developed independently around the same time
490
+ stemming from the need to resolve Adagrad's radically diminishing learning rates.
536
491
 
537
- \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\
538
- p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split}
492
+ The gist of RMSprop is to:
539
493
 
540
- The centered version additionally maintains a moving average of the gradients,
541
- and uses that average to estimate the variance.
494
+ - Maintain a moving (discounted) average of the square of gradients
495
+ - Divide the gradient by the root of this average
542
496
 
543
- Parameters
544
- ----------
545
- lr: float, LearningRateScheduler
546
- learning rate.
497
+ .. math::
547
498
 
548
- References
549
- ----------
550
- .. [5] Tieleman, T. and Hinton, G. (2012):
551
- Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
552
- Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
553
- """
554
-
555
- def __init__(
556
- self,
557
- lr: Union[float, LearningRateScheduler, State],
558
- weight_decay: Optional[float] = None,
559
- epsilon: float = 1e-6,
560
- rho: float = 0.9,
561
- name: Optional[str] = None
562
- ):
563
- super(RMSProp, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
564
-
565
- self.epsilon = fcast(epsilon)
566
- self.rho = fcast(rho)
567
- self.cache_states = visible_state_dict()
568
-
569
- def extra_repr(self) -> str:
570
- return f", epsilon={self.epsilon}, rho={self.rho}"
571
-
572
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
573
- train_states = dict() if train_states is None else train_states
574
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
575
- for k, v in train_states.items():
576
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
577
- self.weight_states.add_unique_elem(k, v)
578
- self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
579
-
580
- def update(self, grads: dict):
581
- lr = self.lr()
582
- weight_values, grad_values, cache_values = to_same_dict_tree(self.weight_states, grads, self.cache_states)
583
- cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
584
- update = jax.tree.map(lambda gv, cv: lr * gv / jnp.sqrt(cv + self.epsilon), grad_values, cache_values)
585
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
586
- weight_values,
587
- update)
588
- self.weight_states.assign_values(weight_values)
589
- self.cache_states.assign_values(cache_values)
590
- self.lr.step_call()
499
+ \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\
500
+ p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split}
501
+
502
+ The centered version additionally maintains a moving average of the gradients,
503
+ and uses that average to estimate the variance.
504
+
505
+ Parameters
506
+ ----------
507
+ lr: float, LearningRateScheduler
508
+ learning rate.
509
+
510
+ References
511
+ ----------
512
+ .. [5] Tieleman, T. and Hinton, G. (2012):
513
+ Neural Networks for Machine Learning, Lecture 6.5 - rmsprop.
514
+ Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20)
515
+ """
516
+
517
+ def __init__(
518
+ self,
519
+ lr: Union[float, LearningRateScheduler, State],
520
+ weight_decay: Optional[float] = None,
521
+ epsilon: float = 1e-6,
522
+ rho: float = 0.9,
523
+ ):
524
+ super().__init__(lr=lr, weight_decay=weight_decay)
525
+
526
+ self.epsilon = fcast(epsilon)
527
+ self.rho = fcast(rho)
528
+ self.cache_states = StateDictManager()
529
+
530
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
531
+ train_states = dict() if train_states is None else train_states
532
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
533
+ for k, v in train_states.items():
534
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
535
+ if self.param_states.add_unique_value(k, v):
536
+ self.cache_states[k] = OptimState(u.math.tree_zeros_like(v.value))
537
+
538
+ def update(self, grads: dict):
539
+ lr = self.lr()
540
+ weight_values, grad_values, cache_values = to_same_dict_tree(self.param_states, grads, self.cache_states)
541
+ cache_values = jax.tree.map(lambda cv, gv: self.rho * cv + (1 - self.rho) * gv ** 2, cache_values, grad_values)
542
+ update = jax.tree.map(lambda gv, cv: lr * gv / jnp.sqrt(cv + self.epsilon), grad_values, cache_values)
543
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
544
+ weight_values,
545
+ update)
546
+ self.param_states.assign_values(weight_values)
547
+ self.cache_states.assign_values(cache_values)
548
+ self.lr.step_call()
591
549
 
592
550
 
593
551
  class Adam(_WeightDecayOptimizer):
594
- """
595
- Optimizer that implements the Adam algorithm.
596
-
597
- Adam [6]_ - a stochastic gradient descent method (SGD) that computes
598
- individual adaptive learning rates for different parameters from estimates of
599
- first- and second-order moments of the gradients.
600
-
601
- Parameters
602
- ----------
603
- lr: float, LearningRateScheduler
604
- learning rate.
605
- beta1: optional, float
606
- A positive scalar value for beta_1, the exponential decay rate
607
- for the first moment estimates (default 0.9).
608
- beta2: optional, float
609
- A positive scalar value for beta_2, the exponential decay rate
610
- for the second moment estimates (default 0.999).
611
- eps: optional, float
612
- A positive scalar value for epsilon, a small constant for
613
- numerical stability (default 1e-8).
614
- name : optional, str
615
- The optimizer name.
616
-
617
- References
618
- ----------
619
- .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
620
- """
621
-
622
- def __init__(
623
- self,
624
- lr: Union[float, State, LearningRateScheduler],
625
- beta1: float = 0.9,
626
- beta2: float = 0.999,
627
- eps: float = 1e-8,
628
- weight_decay: Optional[float] = None,
629
- name: Optional[str] = None
630
- ):
631
- super(Adam, self).__init__(lr=lr,
632
- weight_decay=weight_decay,
633
- name=name)
634
-
635
- self.beta1 = fcast(beta1)
636
- self.beta2 = fcast(beta2)
637
- self.eps = fcast(eps)
638
- self.m1_states = visible_state_dict()
639
- self.m2_states = visible_state_dict()
640
-
641
- def extra_repr(self) -> str:
642
- return f", beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}"
643
-
644
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
645
- train_states = dict() if train_states is None else train_states
646
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
647
-
648
- for k, v in train_states.items():
649
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
650
- self.weight_states.add_unique_elem(k, v)
651
- self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
652
- self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
653
-
654
- def update(self, grads: dict):
655
- lr = self.lr()
656
- lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
657
- lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
658
- weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
659
- self.weight_states, grads, self.m1_states, self.m2_states)
660
- m1_values = jax.tree.map(lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv, m1_values, grad_values)
661
- m2_values = jax.tree.map(lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2, m2_values, grad_values)
662
- update = jax.tree.map(lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps), m1_values, m2_values)
663
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
664
- weight_values,
665
- update)
666
- self.weight_states.assign_values(weight_values)
667
- self.m1_states.assign_values(m1_values)
668
- self.m2_states.assign_values(m2_values)
669
- self.lr.step_call()
552
+ """
553
+ Optimizer that implements the Adam algorithm.
554
+
555
+ Adam [6]_ - a stochastic gradient descent method (SGD) that computes
556
+ individual adaptive learning rates for different parameters from estimates of
557
+ first- and second-order moments of the gradients.
558
+
559
+ Parameters
560
+ ----------
561
+ lr: float, LearningRateScheduler
562
+ learning rate.
563
+ beta1: optional, float
564
+ A positive scalar value for beta_1, the exponential decay rate
565
+ for the first moment estimates (default 0.9).
566
+ beta2: optional, float
567
+ A positive scalar value for beta_2, the exponential decay rate
568
+ for the second moment estimates (default 0.999).
569
+ eps: optional, float
570
+ A positive scalar value for epsilon, a small constant for
571
+ numerical stability (default 1e-8).
572
+
573
+ References
574
+ ----------
575
+ .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
576
+ """
577
+
578
+ def __init__(
579
+ self,
580
+ lr: Union[float, State, LearningRateScheduler],
581
+ beta1: float = 0.9,
582
+ beta2: float = 0.999,
583
+ eps: float = 1e-8,
584
+ weight_decay: Optional[float] = None,
585
+ ):
586
+ super().__init__(lr=lr, weight_decay=weight_decay)
587
+
588
+ self.beta1 = fcast(beta1)
589
+ self.beta2 = fcast(beta2)
590
+ self.eps = fcast(eps)
591
+ self.m1_states = StateDictManager()
592
+ self.m2_states = StateDictManager()
593
+
594
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
595
+ train_states = dict() if train_states is None else train_states
596
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
597
+
598
+ for k, v in train_states.items():
599
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
600
+ if self.param_states.add_unique_value(k, v):
601
+ self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
602
+ self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
603
+
604
+ def update(self, grads: dict):
605
+ lr = self.lr()
606
+ lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
607
+ lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
608
+ weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
609
+ self.param_states, grads, self.m1_states, self.m2_states)
610
+ m1_values = jax.tree.map(lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv, m1_values, grad_values)
611
+ m2_values = jax.tree.map(lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2, m2_values, grad_values)
612
+ update = jax.tree.map(lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps), m1_values, m2_values)
613
+ weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
614
+ weight_values,
615
+ update)
616
+ self.param_states.assign_values(weight_values)
617
+ self.m1_states.assign_values(m1_values)
618
+ self.m2_states.assign_values(m2_values)
619
+ self.lr.step_call()
670
620
 
671
621
 
672
622
  class LARS(_WeightDecayOptimizer):
673
- r"""
674
- Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
675
-
676
- Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
677
- optimization technique. There are two notable differences
678
- between LARS and other adaptive algorithms such as `Adam` or `RMSProp`:
679
- first, LARS uses a separate learning rate for each layer and not for
680
- each weight. And second, the magnitude of the update is controlled
681
- with respect to the weight norm for better control of training speed.
623
+ r"""
624
+ Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
682
625
 
683
- .. math::
626
+ Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
627
+ optimization technique. There are two notable differences
628
+ between LARS and other adaptive algorithms such as `Adam` or `RMSProp`:
629
+ first, LARS uses a separate learning rate for each layer and not for
630
+ each weight. And second, the magnitude of the update is controlled
631
+ with respect to the weight norm for better control of training speed.
684
632
 
685
- m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\
686
- x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)}
633
+ .. math::
687
634
 
688
- Parameters
689
- ----------
690
- lr: float, LearningRateScheduler
691
- learning rate.
692
- momentum: float
693
- coefficient used for the moving average of the gradient.
694
- weight_decay: float
695
- weight decay coefficient.
696
- tc: float
697
- trust coefficient eta ( < 1) for trust ratio computation.
698
- eps: float
699
- epsilon used for trust ratio computation.
700
-
701
- References
702
- ----------
703
- .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag.
704
- """
705
-
706
- def __init__(
707
- self,
708
- lr: Union[float, LearningRateScheduler, State],
709
- momentum: float = 0.9,
710
- weight_decay: float = 1e-4,
711
- tc: float = 1e-3,
712
- eps: float = 1e-5,
713
- name: Optional[str] = None
714
- ):
715
- super(LARS, self).__init__(lr=lr,
716
- weight_decay=weight_decay,
717
- name=name)
718
- assert self.weight_decay is None, 'LARS does not support weight decay.'
719
-
720
- self.momentum = fcast(momentum)
721
- self.tc = fcast(tc)
722
- self.eps = fcast(eps)
723
- self.momentum_states = visible_state_dict()
724
-
725
- def extra_repr(self) -> str:
726
- return f", momentum={self.momentum}, tc={self.tc}, eps={self.eps}"
727
-
728
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
729
- train_states = dict() if train_states is None else train_states
730
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
731
- for k, v in train_states.items():
732
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
733
- self.weight_states.add_unique_elem(k, v)
734
- self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
735
-
736
- def update(self, grads: dict):
737
- lr = self.lr()
738
- weight_values, grad_values, momentum_values = to_same_dict_tree(self.weight_states, grads, self.momentum_states)
739
-
740
- def _lars_update(pv, gv, mv):
741
- p_norm = jnp.linalg.norm(pv)
742
- g_norm = jnp.linalg.norm(gv)
743
- trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
744
- local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
745
- mv = self.momentum * mv + local_lr * (gv + self.weight_decay * pv)
746
- return mv
747
-
748
- momentum_values = jax.tree.map(_lars_update, weight_values, grad_values, momentum_values)
749
- weight_values = jax.tree.map(lambda pv, mv: pv - mv, weight_values, momentum_values)
750
- self.weight_states.assign_values(weight_values)
751
- self.momentum_states.assign_values(momentum_values)
752
- self.lr.step_call()
635
+ m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\
636
+ x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)}
637
+
638
+ Parameters
639
+ ----------
640
+ lr: float, LearningRateScheduler
641
+ learning rate.
642
+ momentum: float
643
+ coefficient used for the moving average of the gradient.
644
+ weight_decay: float
645
+ weight decay coefficient.
646
+ tc: float
647
+ trust coefficient eta ( < 1) for trust ratio computation.
648
+ eps: float
649
+ epsilon used for trust ratio computation.
650
+
651
+ References
652
+ ----------
653
+ .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag.
654
+ """
655
+
656
+ def __init__(
657
+ self,
658
+ lr: Union[float, LearningRateScheduler, State],
659
+ momentum: float = 0.9,
660
+ weight_decay: float = 1e-4,
661
+ tc: float = 1e-3,
662
+ eps: float = 1e-5,
663
+ ):
664
+ super().__init__(lr=lr, weight_decay=weight_decay)
665
+ assert self.weight_decay is None, 'LARS does not support weight decay.'
666
+
667
+ self.momentum = fcast(momentum)
668
+ self.tc = fcast(tc)
669
+ self.eps = fcast(eps)
670
+ self.momentum_states = StateDictManager()
671
+
672
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
673
+ train_states = dict() if train_states is None else train_states
674
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
675
+ for k, v in train_states.items():
676
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
677
+ if self.param_states.add_unique_value(k, v):
678
+ self.momentum_states[k] = OptimState(u.math.tree_zeros_like(v.value))
679
+
680
+ def update(self, grads: dict):
681
+ lr = self.lr()
682
+ weight_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
683
+
684
+ def _lars_update(pv, gv, mv):
685
+ p_norm = jnp.linalg.norm(pv)
686
+ g_norm = jnp.linalg.norm(gv)
687
+ trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
688
+ local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
689
+ mv = self.momentum * mv + local_lr * (gv + self.weight_decay * pv)
690
+ return mv
691
+
692
+ momentum_values = jax.tree.map(_lars_update, weight_values, grad_values, momentum_values)
693
+ weight_values = jax.tree.map(lambda pv, mv: pv - mv, weight_values, momentum_values)
694
+ self.param_states.assign_values(weight_values)
695
+ self.momentum_states.assign_values(momentum_values)
696
+ self.lr.step_call()
753
697
 
754
698
 
755
699
  class Adan(_WeightDecayOptimizer):
756
- r"""
757
- Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
758
-
759
- .. math::
760
-
761
- \begin{equation}
762
- \begin{aligned}
763
- & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\
764
- & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\
765
- & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\
766
- & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\
767
- & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\
768
- \end{aligned}
769
- \end{equation}
770
-
771
- Parameters
772
- ----------
773
- lr: float, LearningRateScheduler
774
- learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3)
775
- betas : tuple
776
- Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01))
777
- eps : float
778
- The term added to the denominator to improve numerical stability. (default: 1e-8)
779
- weight_decay : float
780
- decoupled weight decay (L2 penalty) (default: 0)
781
- no_prox: bool
782
- how to perform the decoupled weight decay (default: False).
783
- It determines the update rule of parameters with weight decay.
784
- By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
700
+ r"""
701
+ Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
785
702
 
786
703
  .. math::
787
- \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right],
788
704
 
789
- But one also can update the parameter like Adamw:
790
-
791
- .. math::
792
- \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k).
793
-
794
- References
795
- ----------
796
- .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan.
797
- “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
798
- Deep Models.” ArXiv abs/2208.06677 (2022): n. pag.
799
- """
800
-
801
- def __init__(
802
- self,
803
- lr: Union[float, LearningRateScheduler, State] = 1e-3,
804
- betas: Tuple[float, float, float] = (0.02, 0.08, 0.01),
805
- eps: float = 1e-8,
806
- weight_decay: float = 0.02,
807
- no_prox: bool = False,
808
- name: Optional[str] = None,
809
- ):
810
- super(Adan, self).__init__(lr=lr, weight_decay=weight_decay, name=name)
811
-
812
- assert len(betas) == 3
813
- if eps < 0.:
814
- raise ValueError("Invalid epsilon value: {}".format(eps))
815
- if not 0.0 <= betas[0] < 1.0:
816
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
817
- if not 0.0 <= betas[1] < 1.0:
818
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
819
- if not 0.0 <= betas[2] < 1.0:
820
- raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
821
-
822
- self.betas = fcast(jnp.asarray(betas))
823
- self.eps = fcast(eps)
824
- self.no_prox = no_prox
825
- self.exp_avg_states = visible_state_dict()
826
- self.exp_avg_sq_states = visible_state_dict()
827
- self.exp_avg_diff_states = visible_state_dict()
828
- self.pre_grad_states = visible_state_dict()
829
-
830
- def extra_repr(self) -> str:
831
- return f", betas={self.betas}, eps={self.eps}, weight_decay={self.weight_decay}, no_prox={self.no_prox}"
832
-
833
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
834
- train_states = dict() if train_states is None else train_states
835
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
836
- for k, v in train_states.items():
837
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
838
- self.weight_states.add_unique_elem(k, v)
839
- self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
840
- self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
841
- self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
842
- self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
843
-
844
- def update(self, grads: dict):
845
- lr = self.lr()
846
- step = self.lr.last_epoch.value + 1
847
- correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
848
- correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
849
- correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
850
- m_values, n_values, v_values, pre_g_values, weight_values, grad_values = to_same_dict_tree(
851
- self.exp_avg_states, self.exp_avg_diff_states, self.exp_avg_sq_states, self.pre_grad_states,
852
- self.weight_states, grads)
853
-
854
- def _adan_update(m, n, v, pre_g, g, p):
855
- m = m * (1 - self.betas[0]) + self.betas[0] * g
856
- gd = g - pre_g
857
- v = v * (1 - self.betas[1]) + self.betas[1] * gd
858
- n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2
859
- weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps)
860
- if self.no_prox:
861
- p = (p * (1 - self.weight_decay * lr) -
862
- weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v))
863
- else:
864
- p = ((p - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) /
865
- (1 + self.weight_decay * lr))
866
- return m, n, v, p
867
-
868
- m_values, n_values, v_values, weight_values = jax.tree.map(
869
- _adan_update, m_values, n_values, v_values, pre_g_values, grad_values, weight_values)
870
- self.exp_avg_states.assign_values(m_values)
871
- self.exp_avg_diff_states.assign_values(n_values)
872
- self.exp_avg_sq_states.assign_values(v_values)
873
- self.weight_states.assign_values(weight_values)
874
- self.lr.step_call()
705
+ \begin{equation}
706
+ \begin{aligned}
707
+ & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\
708
+ & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\
709
+ & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\
710
+ & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\
711
+ & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\
712
+ \end{aligned}
713
+ \end{equation}
714
+
715
+ Parameters
716
+ ----------
717
+ lr: float, LearningRateScheduler
718
+ learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3)
719
+ betas : tuple
720
+ Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01))
721
+ eps : float
722
+ The term added to the denominator to improve numerical stability. (default: 1e-8)
723
+ weight_decay : float
724
+ decoupled weight decay (L2 penalty) (default: 0)
725
+ no_prox: bool
726
+ how to perform the decoupled weight decay (default: False).
727
+ It determines the update rule of parameters with weight decay.
728
+ By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper:
729
+
730
+ .. math::
731
+ \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right],
732
+
733
+ But one also can update the parameter like Adamw:
734
+
735
+ .. math::
736
+ \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k).
737
+
738
+ References
739
+ ----------
740
+ .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan.
741
+ “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing
742
+ Deep Models.” ArXiv abs/2208.06677 (2022): n. pag.
743
+ """
744
+
745
+ def __init__(
746
+ self,
747
+ lr: Union[float, LearningRateScheduler, State] = 1e-3,
748
+ betas: Tuple[float, float, float] = (0.02, 0.08, 0.01),
749
+ eps: float = 1e-8,
750
+ weight_decay: float = 0.02,
751
+ no_prox: bool = False,
752
+ ):
753
+ super().__init__(lr=lr, weight_decay=weight_decay)
754
+
755
+ assert len(betas) == 3
756
+ if eps < 0.:
757
+ raise ValueError("Invalid epsilon value: {}".format(eps))
758
+ if not 0.0 <= betas[0] < 1.0:
759
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
760
+ if not 0.0 <= betas[1] < 1.0:
761
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
762
+ if not 0.0 <= betas[2] < 1.0:
763
+ raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
764
+
765
+ self.betas = fcast(jnp.asarray(betas))
766
+ self.eps = fcast(eps)
767
+ self.no_prox = no_prox
768
+ self.exp_avg_states = StateDictManager()
769
+ self.exp_avg_sq_states = StateDictManager()
770
+ self.exp_avg_diff_states = StateDictManager()
771
+ self.pre_grad_states = StateDictManager()
772
+
773
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
774
+ train_states = dict() if train_states is None else train_states
775
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
776
+ for k, v in train_states.items():
777
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
778
+ if self.param_states.add_unique_value(k, v):
779
+ self.exp_avg_states[k] = OptimState(u.math.tree_zeros_like(v.value))
780
+ self.exp_avg_sq_states[k] = OptimState(u.math.tree_zeros_like(v.value))
781
+ self.exp_avg_diff_states[k] = OptimState(u.math.tree_zeros_like(v.value))
782
+ self.pre_grad_states[k] = OptimState(u.math.tree_zeros_like(v.value))
783
+
784
+ def update(self, grads: dict):
785
+ lr = self.lr()
786
+ step = self.lr.last_epoch.value + 1
787
+ correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1))
788
+ correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1))
789
+ correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1))
790
+ m_values, n_values, v_values, pre_g_values, weight_values, grad_values = to_same_dict_tree(
791
+ self.exp_avg_states, self.exp_avg_diff_states, self.exp_avg_sq_states, self.pre_grad_states,
792
+ self.param_states, grads)
793
+
794
+ def _adan_update(m, n, v, pre_g, g, p):
795
+ m = m * (1 - self.betas[0]) + self.betas[0] * g
796
+ gd = g - pre_g
797
+ v = v * (1 - self.betas[1]) + self.betas[1] * gd
798
+ n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2
799
+ weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps)
800
+ if self.no_prox:
801
+ p = (p * (1 - self.weight_decay * lr) -
802
+ weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v))
803
+ else:
804
+ p = ((p - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) /
805
+ (1 + self.weight_decay * lr))
806
+ return m, n, v, p
807
+
808
+ m_values, n_values, v_values, weight_values = jax.tree.map(
809
+ _adan_update, m_values, n_values, v_values, pre_g_values, grad_values, weight_values)
810
+ self.exp_avg_states.assign_values(m_values)
811
+ self.exp_avg_diff_states.assign_values(n_values)
812
+ self.exp_avg_sq_states.assign_values(v_values)
813
+ self.param_states.assign_values(weight_values)
814
+ self.lr.step_call()
875
815
 
876
816
 
877
817
  class AdamW(_WeightDecayOptimizer):
878
- r"""
879
- Adam with weight decay regularization [1]_.
880
-
881
- AdamW uses weight decay to regularize learning towards small weights, as
882
- this leads to better generalization. In SGD you can also use L2 regularization
883
- to implement this as an additive loss term, however L2 regularization
884
- does not behave as intended for adaptive gradient algorithms such as Adam.
885
-
886
- .. math::
887
-
888
- \begin{aligned}
889
- &\rule{110mm}{0.4pt} \\
890
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
891
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
892
- \: \epsilon \text{ (epsilon)} \\
893
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
894
- \: \textit{maximize} \\
895
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
896
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
897
- &\rule{110mm}{0.4pt} \\
898
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
899
-
900
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
901
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
902
- &\hspace{5mm}\textbf{else} \\
903
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
904
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
905
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
906
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
907
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
908
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
909
- &\hspace{5mm}\textbf{if} \: amsgrad \\
910
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
911
- \widehat{v_t}) \\
912
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
913
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
914
- &\hspace{5mm}\textbf{else} \\
915
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
916
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
917
- &\rule{110mm}{0.4pt} \\[-1.ex]
918
- &\bf{return} \: \theta_t \\[-1.ex]
919
- &\rule{110mm}{0.4pt} \\[-1.ex]
920
- \end{aligned}
921
-
922
-
923
- Parameters
924
- ----------
925
- lr: float, LearningRateScheduler
926
- learning rate.
927
- beta1: optional, float
928
- A positive scalar value for beta_1, the exponential decay rate
929
- for the first moment estimates. Generally close to 1.
930
- beta2: optional, float
931
- A positive scalar value for beta_2, the exponential decay rate
932
- for the second moment estimates. Generally close to 1.
933
- eps: optional, float
934
- A positive scalar value for epsilon, a small constant for
935
- numerical stability.
936
- weight_decay: float
937
- Strength of the weight decay regularization. Note that this
938
- weight decay is multiplied with the learning rate.
939
- amsgrad: bool
940
- whether to use the AMSGrad variant of this algorithm
941
- from the paper `On the Convergence of Adam and Beyond`.
942
- name : optional, str
943
- The optimizer name.
944
-
945
- References
946
- ----------
947
- .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019).
948
-
949
- """
950
-
951
- def __init__(
952
- self,
953
- lr: Union[float, LearningRateScheduler, State],
954
- beta1: float = 0.9,
955
- beta2: float = 0.999,
956
- eps: float = 1e-8,
957
- weight_decay: float = 1e-2,
958
- amsgrad: bool = False,
959
- name: Optional[str] = None,
960
- ):
961
- super(AdamW, self).__init__(lr=lr,
962
- weight_decay=weight_decay,
963
- name=name)
964
-
965
- if eps < 0.:
966
- raise ValueError("Invalid epsilon value: {}".format(eps))
967
- if not 0.0 <= beta1 < 1.0:
968
- raise ValueError("Invalid beta parameter at index 0: {}".format(beta1))
969
- if not 0.0 <= beta2 < 1.0:
970
- raise ValueError("Invalid beta parameter at index 1: {}".format(beta2))
971
- if weight_decay < 0.:
972
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
973
-
974
- self.beta1 = fcast(beta1)
975
- self.beta2 = fcast(beta2)
976
- self.eps = fcast(eps)
977
- self.amsgrad = amsgrad
978
- self.m1_states = visible_state_dict()
979
- self.m2_states = visible_state_dict()
980
- if self.amsgrad:
981
- self.vmax_states = visible_state_dict()
982
-
983
- def extra_repr(self) -> str:
984
- return (f", beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}"
985
- f", weight_decay={self.weight_decay}, amsgrad={self.amsgrad}")
986
-
987
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
988
- train_states = dict() if train_states is None else train_states
989
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
990
- for k, v in train_states.items():
991
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
992
- self.weight_states.add_unique_elem(k, v)
993
- self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
994
- self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
995
- if self.amsgrad:
996
- self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
997
-
998
- def update(self, grads: dict):
999
- lr_old = self.lr()
1000
- step = self.lr.last_epoch.value + 2
1001
- bias_correction1 = 1 - self.beta1 ** step
1002
- bias_correction2 = 1 - self.beta2 ** step
1003
- lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
1004
-
1005
- def _adamw_update(p, m, v, g, vmax=None):
1006
- if self.weight_decay != 0:
1007
- p *= (1 - lr_old * self.weight_decay)
1008
- m = self.beta1 * m + (1 - self.beta1) * g
1009
- v = self.beta2 * v + (1 - self.beta2) * g ** 2
1010
- if self.amsgrad:
1011
- vmax = jnp.maximum(vmax, v)
1012
- denom = jnp.sqrt(vmax) + self.eps
1013
- return p - lr * m / denom, m, v, vmax
1014
- else:
1015
- denom = jnp.sqrt(v.value) + self.eps
1016
- return p - lr * m / denom, m, v
1017
-
1018
- if self.amsgrad:
1019
- weight_values, m1_values, m2_values, vmax_values = to_same_dict_tree(
1020
- self.weight_states, self.m1_states, self.m2_states, self.vmax_states)
1021
- weight_values, m1_values, m2_values, vmax_values = jax.tree.map(
1022
- _adamw_update, weight_values, m1_values, m2_values, grads, vmax_values)
1023
- self.vmax_states.assign_values(vmax_values)
1024
- else:
1025
- weight_values, m1_values, m2_values = to_same_dict_tree(self.weight_states, self.m1_states, self.m2_states)
1026
- weight_values, m1_values, m2_values = jax.tree.map(
1027
- _adamw_update, weight_values, m1_values, m2_values, grads)
1028
- self.weight_states.assign_values(weight_values)
1029
- self.m1_states.assign_values(m1_values)
1030
- self.m2_states.assign_values(m2_values)
1031
- self.lr.step_call()
818
+ r"""
819
+ Adam with weight decay regularization [1]_.
1032
820
 
821
+ AdamW uses weight decay to regularize learning towards small weights, as
822
+ this leads to better generalization. In SGD you can also use L2 regularization
823
+ to implement this as an additive loss term, however L2 regularization
824
+ does not behave as intended for adaptive gradient algorithms such as Adam.
1033
825
 
1034
- class SM3(_WeightDecayOptimizer):
1035
- """
1036
- SM3 algorithm [1]_.
1037
-
1038
- The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
1039
- (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
1040
- to Adam and Adagrad with greatly reduced memory usage for history tensors.
1041
- For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history
1042
- tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor
1043
- of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)`
1044
- memory for storage tensors, while the optimization using SM3 will use
1045
- `O(sum n_i)` memory. Despite storing fewer parameters, this optimization
1046
- algorithm manages to be comparably effective.
1047
-
1048
- This advantage drastically shrinks when `momentum > 0`. The momentum is
1049
- tracked using a tensor of the same shape as the tensor being optimized. With
1050
- momentum, SM3 will use just over half as much memory as Adam, and a bit more
1051
- than Adagrad.
1052
-
1053
- Parameters
1054
- ----------
1055
- lr: float, LearningRateScheduler
1056
- learning rate.
1057
- momentum: float
1058
- coefficient used to scale prior updates
1059
- before adding. This drastically increases memory usage if
1060
- `momentum > 0.0`. (default: 0.0)
1061
- beta: float
1062
- coefficient used for exponential moving averages (default: 0.0)
1063
- eps: float
1064
- Term added to square-root in denominator to
1065
- improve numerical stability (default: 1e-30).
1066
-
1067
- References
1068
- ----------
1069
- .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019).
1070
-
1071
- """
1072
-
1073
- def __init__(
1074
- self,
1075
- lr: Union[float, LearningRateScheduler, State],
1076
- beta: float = 0.,
1077
- momentum: float = 0.,
1078
- eps: float = 1e-30,
1079
- weight_decay: Optional[float] = None,
1080
- name: Optional[str] = None,
1081
- ):
1082
- super(SM3, self).__init__(lr=lr,
1083
- weight_decay=weight_decay,
1084
- name=name)
1085
-
1086
- if not 0.0 <= momentum < 1.0:
1087
- raise ValueError("Invalid momentum: {0}".format(momentum))
1088
- if not 0.0 <= beta < 1.0:
1089
- raise ValueError("Invalid beta: {0}".format(beta))
1090
- if not 0.0 <= eps:
1091
- raise ValueError("Invalid eps: {0}".format(eps))
1092
-
1093
- self.eps = fcast(eps)
1094
- self.beta = fcast(beta)
1095
- self.momentum = fcast(momentum)
1096
- self.memory_states = visible_state_dict()
1097
-
1098
- def extra_repr(self) -> str:
1099
- return f", beta={self.beta}, momentum={self.momentum}, eps={self.eps}"
1100
-
1101
- def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
1102
- train_states = dict() if train_states is None else train_states
1103
- assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
1104
- for k, v in train_states.items():
1105
- assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
1106
- self.weight_states.add_unique_elem(k, v)
1107
- rank, ndim, dtype = v.value.shape, v.value.ndim, v.value.dtype
1108
- for i in range(ndim):
1109
- shape = [1] * ndim
1110
- shape[i] = rank[i]
1111
- self.memory_states[f'{k}_m{i}'] = State(jnp.zeros(shape, dtype=dtype))
1112
- if self.momentum > 0.:
1113
- self.memory_states[f'{k}_mbuffer'] = State(jnp.zeros_like(v.value))
1114
-
1115
- def update(self, grads: dict):
1116
- lr = self.lr()
1117
-
1118
- for k, p in self.weight_states.items():
1119
- g = grads[k]
1120
- ndim = p.ndim
1121
- update = self.memory_states[f'{k}_m0'].value
1122
- for i in range(1, ndim):
1123
- update = jnp.minimum(update, self.memory_states[f'{k}_m{i}'].value)
1124
- if self.beta > 0.:
1125
- update *= self.beta
1126
- update += g * g * (1 - self.beta)
1127
- # Computes max along all dimensions except the given dim.
1128
- # If tensor is a scalar, it returns tensor.
1129
- for i in range(ndim):
1130
- result = update
1131
- for j in range(ndim):
1132
- if i != j:
1133
- result = jnp.maximum(result, axis=j, keepdim=True)
1134
- acc = self.memory_states[f'{k}_m{i}'].value
1135
- if self.beta > 0.:
1136
- acc.value = jnp.maximum(acc, result)
826
+ .. math::
827
+
828
+ \begin{aligned}
829
+ &\rule{110mm}{0.4pt} \\
830
+ &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
831
+ \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
832
+ \: \epsilon \text{ (epsilon)} \\
833
+ &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
834
+ \: \textit{maximize} \\
835
+ &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
836
+ \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
837
+ &\rule{110mm}{0.4pt} \\
838
+ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
839
+
840
+ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
841
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
842
+ &\hspace{5mm}\textbf{else} \\
843
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
844
+ &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
845
+ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
846
+ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
847
+ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
848
+ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
849
+ &\hspace{5mm}\textbf{if} \: amsgrad \\
850
+ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
851
+ \widehat{v_t}) \\
852
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
853
+ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
854
+ &\hspace{5mm}\textbf{else} \\
855
+ &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
856
+ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
857
+ &\rule{110mm}{0.4pt} \\[-1.ex]
858
+ &\bf{return} \: \theta_t \\[-1.ex]
859
+ &\rule{110mm}{0.4pt} \\[-1.ex]
860
+ \end{aligned}
861
+
862
+
863
+ Parameters
864
+ ----------
865
+ lr: float, LearningRateScheduler
866
+ learning rate.
867
+ beta1: optional, float
868
+ A positive scalar value for beta_1, the exponential decay rate
869
+ for the first moment estimates. Generally close to 1.
870
+ beta2: optional, float
871
+ A positive scalar value for beta_2, the exponential decay rate
872
+ for the second moment estimates. Generally close to 1.
873
+ eps: optional, float
874
+ A positive scalar value for epsilon, a small constant for
875
+ numerical stability.
876
+ weight_decay: float
877
+ Strength of the weight decay regularization. Note that this
878
+ weight decay is multiplied with the learning rate.
879
+ amsgrad: bool
880
+ whether to use the AMSGrad variant of this algorithm
881
+ from the paper `On the Convergence of Adam and Beyond`.
882
+
883
+ References
884
+ ----------
885
+ .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019).
886
+
887
+ """
888
+
889
+ def __init__(
890
+ self,
891
+ lr: Union[float, LearningRateScheduler, State],
892
+ beta1: float = 0.9,
893
+ beta2: float = 0.999,
894
+ eps: float = 1e-8,
895
+ weight_decay: float = 1e-2,
896
+ amsgrad: bool = False,
897
+ ):
898
+ super().__init__(lr=lr, weight_decay=weight_decay)
899
+
900
+ if eps < 0.:
901
+ raise ValueError("Invalid epsilon value: {}".format(eps))
902
+ if not 0.0 <= beta1 < 1.0:
903
+ raise ValueError("Invalid beta parameter at index 0: {}".format(beta1))
904
+ if not 0.0 <= beta2 < 1.0:
905
+ raise ValueError("Invalid beta parameter at index 1: {}".format(beta2))
906
+ if weight_decay < 0.:
907
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
908
+
909
+ self.beta1 = fcast(beta1)
910
+ self.beta2 = fcast(beta2)
911
+ self.eps = fcast(eps)
912
+ self.amsgrad = amsgrad
913
+ self.m1_states = StateDictManager()
914
+ self.m2_states = StateDictManager()
915
+ if self.amsgrad:
916
+ self.vmax_states = StateDictManager()
917
+
918
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
919
+ train_states = dict() if train_states is None else train_states
920
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
921
+ for k, v in train_states.items():
922
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
923
+ if self.param_states.add_unique_value(k, v):
924
+ self.m1_states[k] = OptimState(u.math.tree_zeros_like(v.value))
925
+ self.m2_states[k] = OptimState(u.math.tree_zeros_like(v.value))
926
+ if self.amsgrad:
927
+ self.vmax_states[k] = OptimState(u.math.tree_zeros_like(v.value))
928
+
929
+ def update(self, grads: dict):
930
+ lr_old = self.lr()
931
+ step = self.lr.last_epoch.value + 2
932
+ bias_correction1 = 1 - self.beta1 ** step
933
+ bias_correction2 = 1 - self.beta2 ** step
934
+ lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1
935
+
936
+ def _adamw_update(p, m, v, g, vmax=None):
937
+ if self.weight_decay != 0:
938
+ p *= (1 - lr_old * self.weight_decay)
939
+ m = self.beta1 * m + (1 - self.beta1) * g
940
+ v = self.beta2 * v + (1 - self.beta2) * g ** 2
941
+ if self.amsgrad:
942
+ vmax = jnp.maximum(vmax, v)
943
+ denom = jnp.sqrt(vmax) + self.eps
944
+ return p - lr * m / denom, m, v, vmax
945
+ else:
946
+ denom = jnp.sqrt(v.value) + self.eps
947
+ return p - lr * m / denom, m, v
948
+
949
+ if self.amsgrad:
950
+ weight_values, m1_values, m2_values, vmax_values = to_same_dict_tree(
951
+ self.param_states, self.m1_states, self.m2_states, self.vmax_states)
952
+ weight_values, m1_values, m2_values, vmax_values = jax.tree.map(
953
+ _adamw_update, weight_values, m1_values, m2_values, grads, vmax_values)
954
+ self.vmax_states.assign_values(vmax_values)
1137
955
  else:
1138
- # No need to compare - nu_max is bigger because of grad ** 2
1139
- acc.value = result
1140
- update = g / jnp.sqrt(update + self.eps)
1141
- if self.momentum > 0.:
1142
- m_buffer = self.memory_states[f'{k}_mbuffer'].value
1143
- update = update * (1. - self.momentum) + m_buffer * self.momentum
1144
- m_buffer.value = update
1145
- if self.weight_decay is None:
1146
- p.value -= lr * update
1147
- else:
1148
- p.value = (1 - self.weight_decay) * p - lr * update
1149
- self.lr.step_call()
956
+ weight_values, m1_values, m2_values = to_same_dict_tree(self.param_states, self.m1_states, self.m2_states)
957
+ weight_values, m1_values, m2_values = jax.tree.map(
958
+ _adamw_update, weight_values, m1_values, m2_values, grads)
959
+ self.param_states.assign_values(weight_values)
960
+ self.m1_states.assign_values(m1_values)
961
+ self.m2_states.assign_values(m2_values)
962
+ self.lr.step_call()
963
+
964
+
965
+ class SM3(_WeightDecayOptimizer):
966
+ """
967
+ SM3 algorithm [1]_.
968
+
969
+ The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
970
+ (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar
971
+ to Adam and Adagrad with greatly reduced memory usage for history tensors.
972
+ For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history
973
+ tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor
974
+ of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)`
975
+ memory for storage tensors, while the optimization using SM3 will use
976
+ `O(sum n_i)` memory. Despite storing fewer parameters, this optimization
977
+ algorithm manages to be comparably effective.
978
+
979
+ This advantage drastically shrinks when `momentum > 0`. The momentum is
980
+ tracked using a tensor of the same shape as the tensor being optimized. With
981
+ momentum, SM3 will use just over half as much memory as Adam, and a bit more
982
+ than Adagrad.
983
+
984
+ Parameters
985
+ ----------
986
+ lr: float, LearningRateScheduler
987
+ learning rate.
988
+ momentum: float
989
+ coefficient used to scale prior updates
990
+ before adding. This drastically increases memory usage if
991
+ `momentum > 0.0`. (default: 0.0)
992
+ beta: float
993
+ coefficient used for exponential moving averages (default: 0.0)
994
+ eps: float
995
+ Term added to square-root in denominator to
996
+ improve numerical stability (default: 1e-30).
997
+
998
+ References
999
+ ----------
1000
+ .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019).
1001
+
1002
+ """
1003
+
1004
+ def __init__(
1005
+ self,
1006
+ lr: Union[float, LearningRateScheduler, State],
1007
+ beta: float = 0.,
1008
+ momentum: float = 0.,
1009
+ eps: float = 1e-30,
1010
+ weight_decay: Optional[float] = None,
1011
+ ):
1012
+ super().__init__(lr=lr, weight_decay=weight_decay)
1013
+
1014
+ if not 0.0 <= momentum < 1.0:
1015
+ raise ValueError("Invalid momentum: {0}".format(momentum))
1016
+ if not 0.0 <= beta < 1.0:
1017
+ raise ValueError("Invalid beta: {0}".format(beta))
1018
+ if not 0.0 <= eps:
1019
+ raise ValueError("Invalid eps: {0}".format(eps))
1020
+
1021
+ self.eps = fcast(eps)
1022
+ self.beta = fcast(beta)
1023
+ self.momentum = fcast(momentum)
1024
+ self.memory_states = StateDictManager()
1025
+
1026
+ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = None):
1027
+ train_states = dict() if train_states is None else train_states
1028
+ assert isinstance(train_states, dict), '"states" must be a dict of brainstate.State.'
1029
+ for k, v in train_states.items():
1030
+ assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
1031
+ if self.param_states.add_unique_value(k, v):
1032
+ rank, ndim, dtype = v.value.shape, v.value.ndim, v.value.dtype
1033
+ for i in range(ndim):
1034
+ shape = [1] * ndim
1035
+ shape[i] = rank[i]
1036
+ self.memory_states[f'{k}_m{i}'] = State(jnp.zeros(shape, dtype=dtype))
1037
+ if self.momentum > 0.:
1038
+ self.memory_states[f'{k}_mbuffer'] = State(jnp.zeros_like(v.value))
1039
+
1040
+ def update(self, grads: dict):
1041
+ lr = self.lr()
1042
+
1043
+ for k, p in self.param_states.items():
1044
+ g = grads[k]
1045
+ ndim = p.ndim
1046
+ update = self.memory_states[f'{k}_m0'].value
1047
+ for i in range(1, ndim):
1048
+ update = jnp.minimum(update, self.memory_states[f'{k}_m{i}'].value)
1049
+ if self.beta > 0.:
1050
+ update *= self.beta
1051
+ update += g * g * (1 - self.beta)
1052
+ # Computes max along all dimensions except the given dim.
1053
+ # If tensor is a scalar, it returns tensor.
1054
+ for i in range(ndim):
1055
+ result = update
1056
+ for j in range(ndim):
1057
+ if i != j:
1058
+ result = jnp.maximum(result, axis=j, keepdim=True)
1059
+ acc = self.memory_states[f'{k}_m{i}'].value
1060
+ if self.beta > 0.:
1061
+ acc.value = jnp.maximum(acc, result)
1062
+ else:
1063
+ # No need to compare - nu_max is bigger because of grad ** 2
1064
+ acc.value = result
1065
+ update = g / jnp.sqrt(update + self.eps)
1066
+ if self.momentum > 0.:
1067
+ m_buffer = self.memory_states[f'{k}_mbuffer'].value
1068
+ update = update * (1. - self.momentum) + m_buffer * self.momentum
1069
+ m_buffer.value = update
1070
+ if self.weight_decay is None:
1071
+ p.value -= lr * update
1072
+ else:
1073
+ p.value = (1 - self.weight_decay) * p - lr * update
1074
+ self.lr.step_call()