heavyball 1.4.0__py3-none-any.whl → 1.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/chainable.py CHANGED
@@ -477,7 +477,7 @@ class ChainOpt(utils.StatefulOptimizer):
477
477
  else:
478
478
  chain(self.state_, group, g, p, *self.fns)
479
479
 
480
- group['lr'] = None
480
+ group['lr'] = group['prev_lr']
481
481
  group['step'] = None
482
482
 
483
483
 
heavyball/utils.py CHANGED
@@ -195,10 +195,8 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
195
195
 
196
196
  @decorator_knowngood
197
197
  def _compilable_exp_avg_(state, grad, beta):
198
- for s, g in zip(state, grad):
199
- lerped = s.lerp(g, 1 - beta)
200
- copy_stochastic_(s, lerped)
201
- copy_stochastic_(g, lerped)
198
+ lerped = _lerp32(state, grad, beta)
199
+ copy_stochastic_list_(grad, lerped)
202
200
 
203
201
 
204
202
  def scale_by_exp_avg_(state, grad, beta):
@@ -746,6 +744,7 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
746
744
  def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
747
745
  ea32 = list(map(promote, state))
748
746
  grad = list(map(promote, grad))
747
+ beta = promote(beta)
749
748
 
750
749
  ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
751
750
  copy_stochastic_list_(state, ea32)
@@ -1032,26 +1031,19 @@ def psgd_balance_Q(Q_in):
1032
1031
 
1033
1032
 
1034
1033
  def psgd_calc_A_and_conjB(exprA, G, Q):
1035
- V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
1036
1034
  eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1037
1035
  eps *= G.norm() / G.numel()
1038
- G = G + V * eps
1036
+ G = G + torch.randn_like(G) * eps
1039
1037
  md = min_dtype(Q + [G])
1040
1038
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
1041
1039
  order = G.dim()
1042
- p = list(range(order))
1043
- conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
1040
+ conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1044
1041
  Q = [promote(q) for q in Q]
1045
1042
  for i, q in enumerate(Q):
1046
1043
  if q.dim() <= 1:
1047
1044
  conjB /= q
1048
1045
  else:
1049
- unsqueeze = conjB.dim() <= 1
1050
- if unsqueeze:
1051
- conjB = conjB.unsqueeze(0)
1052
- conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False)
1053
- if unsqueeze:
1054
- conjB = conjB.squeeze(0)
1046
+ conjB = torch.linalg.solve_triangular(q, conjB.reshape(-1, q.size(0)), upper=True, left=False).reshape_as(conjB)
1055
1047
  if i < order - 1:
1056
1048
  conjB = torch.transpose(conjB, i, order - 1)
1057
1049
  return A, conjB
@@ -1092,7 +1084,7 @@ def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line):
1092
1084
  else:
1093
1085
  torch.triu(term1, out=term1)
1094
1086
  term1 /= torch.where(norm > 0, psgd_lb(term2, norm), norm).clamp_(tiny_bf16)
1095
- term1 = torch.mm(term1, q)
1087
+ term1 = torch.mm(term1, q.to(term1.dtype))
1096
1088
  if store_triu_as_line:
1097
1089
  term1 = triu_to_line([term1])[0][1]
1098
1090
  o = o[1]
@@ -0,0 +1,934 @@
1
+ Metadata-Version: 2.1
2
+ Name: heavyball
3
+ Version: 1.4.3
4
+ Summary: Efficient optimizers
5
+ Home-page: https://github.com/clashluke/heavyball
6
+ Author: Lucas Nestler
7
+ Author-email: github.heavyball@nestler.sh
8
+ License: BSD
9
+ Classifier: Development Status :: 5 - Production/Stable
10
+ Classifier: License :: OSI Approved :: BSD License
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3.7
13
+ Classifier: Programming Language :: Python :: 3.8
14
+ Classifier: Programming Language :: Python :: 3.9
15
+ Classifier: Topic :: Software Development :: Libraries
16
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
17
+ Classifier: Intended Audience :: Developers
18
+ Requires-Python: >=3.7
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: opt-einsum
22
+ Requires-Dist: torch
23
+ Requires-Dist: numpy
24
+
25
+ # `heavyball`: Efficient Optimizers
26
+
27
+ * [Public API](#Public-API)
28
+ - [Foreach Optimizers](#Foreach-Optimizers)
29
+ - [`heavyball.utils`](#heavyball.utils)
30
+ - [Example Usage](#Example-Usage)
31
+
32
+ * [`heavyball.chainable`](##heavyball.chainable)
33
+ - [Core Concept](#Core-Concept)
34
+ - [`FunctionTransform` and Guards](#FunctionTransform-and-Guards)
35
+ - [Chaining Transformations](#Chaining-Transformations)
36
+ - [Building Optimizers](#Building-Optimizers)
37
+ - [Creating New Transformations](#Creating-New-Transformations)
38
+
39
+ * [Optimizer Recommendations](#Optimizer-Recommendations)
40
+ - [Choosing the Right Optimizer](#Choosing-the-Right-Optimizer)
41
+
42
+ ---
43
+
44
+ The `heavyball` library provides a collection of efficient optimizers designed for deep learning. It leverages
45
+ techniques like preconditioning, momentum, and adaptive learning rates to accelerate training and improve convergence.
46
+ The library's core strength lies in its `chainable` API, which allows for flexible composition of optimizers, enabling
47
+ users to build custom optimization strategies.
48
+
49
+ ## Public API
50
+
51
+ The `heavyball` library exposes the following optimizers through its main namespace:
52
+
53
+ ### Foreach Optimizers
54
+
55
+ These optimizers are designed to be efficient by operating on batches of parameters simultaneously using `foreach`
56
+ operations whenever possible.
57
+
58
+ #### `ForeachAdamW`
59
+
60
+ ```python
61
+ class ForeachAdamW(C.BaseOpt):
62
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
63
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
64
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
65
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
66
+ # ...
67
+ ```
68
+
69
+ A foreach implementation of the AdamW optimizer. It incorporates weight decay into the update rule and uses adaptive
70
+ learning rates based on the first and second moments of the gradients.
71
+
72
+ **Key Parameters:**
73
+
74
+ * **`lr`**: Learning rate.
75
+ * **`betas`**: Coefficients used for computing running averages of the gradient and its square.
76
+ * **`eps`**: A small constant for numerical stability.
77
+ * **`weight_decay`**: Weight decay coefficient.
78
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
79
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
80
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
81
+ * **`mars`**: Enables/disables Mars correction.
82
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
83
+ direction to the gradients.
84
+ * **`mars_gamma`**: Mars correction coefficient.
85
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
86
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
87
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
88
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
89
+
90
+ #### `ForeachRMSprop`
91
+
92
+ ```python
93
+ class ForeachRMSprop(C.BaseOpt):
94
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
95
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
96
+ caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
97
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
98
+ # ...
99
+ ```
100
+
101
+ A foreach implementation of a debiased RMSprop optimizer (Note: this is different from `torch.optim.RMSprop`). It uses
102
+ adaptive learning rates based on the second moment of the gradients.
103
+
104
+ **Key Parameters:**
105
+
106
+ * **`lr`**: Learning rate.
107
+ * **`betas`**: Coefficients used for computing running averages of the squared gradient.
108
+ * **`eps`**: A small constant for numerical stability.
109
+ * **`weight_decay`**: Weight decay coefficient.
110
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
111
+ * **`r`**: Schedule-Free coefficient that controls dependence of the learning rate on step count.
112
+ * **`weight_lr_power`**: Schedule-Free coefficient that controls the sensitivity of `r` to the learning rate.
113
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
114
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
115
+ * **`mars`**: Enables/disables Mars correction.
116
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
117
+ direction to the gradients.
118
+ * **`mars_gamma`**: Mars correction coefficient.
119
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
120
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
121
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
122
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
123
+
124
+ #### `ForeachSFAdamW`
125
+
126
+ ```python
127
+ class ForeachSFAdamW(C.ScheduleFree):
128
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
129
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
130
+ caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
131
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
132
+ # ...
133
+ ```
134
+
135
+ A foreach implementation of the Schedule-Free AdamW optimizer. It combines the benefits of AdamW with the Schedule-Free
136
+ approach, which dynamically adjusts the learning rate based on the current state of optimization.
137
+
138
+ **Key Parameters:**
139
+
140
+ * **`lr`**: Base learning rate. The effective learning rate at each step depends on `lr`, `r`, and `weight_lr_power`.
141
+ * **`betas`**: Coefficients used for computing running averages of the gradient and its square.
142
+ * **`eps`**: A small constant for numerical stability.
143
+ * **`weight_decay`**: Weight decay coefficient.
144
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
145
+ * **`r`**: Schedule-Free coefficient that controls dependence of the learning rate on step count.
146
+ * **`weight_lr_power`**: Schedule-Free coefficient that controls the sensitivity of `r` to the learning rate.
147
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
148
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
149
+ * **`mars`**: Enables/disables Mars correction.
150
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
151
+ direction to the gradients.
152
+ * **`mars_gamma`**: Mars correction coefficient.
153
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
154
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
155
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
156
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
157
+
158
+ #### `PaLMForeachSFAdamW`
159
+
160
+ ```python
161
+ class PaLMForeachSFAdamW(ForeachSFAdamW):
162
+ palm: bool = True
163
+ ```
164
+
165
+ A specialized version of `ForeachSFAdamW` with PaLM's beta2 schedule enabled by default.
166
+
167
+ #### `ForeachADOPT`
168
+
169
+ ```python
170
+ class ForeachADOPT(C.BaseOpt):
171
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
172
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
173
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
174
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
175
+ # ...
176
+ ```
177
+
178
+ A foreach implementation of the ADOPT optimizer, which uses a debiased estimate of the second moment of the gradients.
179
+
180
+ **Key Parameters:**
181
+
182
+ * **`lr`**: Learning rate.
183
+ * **`betas`**: Coefficients used for computing running averages of the gradient and its square.
184
+ * **`eps`**: A small constant for numerical stability.
185
+ * **`weight_decay`**: Weight decay coefficient.
186
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
187
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
188
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
189
+ * **`mars`**: Enables/disables Mars correction.
190
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
191
+ direction to the gradients.
192
+ * **`mars_gamma`**: Mars correction coefficient.
193
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
194
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
195
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
196
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
197
+
198
+ #### `ForeachMuon`
199
+
200
+ ```python
201
+ class ForeachMuon(C.BaseOpt):
202
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
203
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
204
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
205
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
206
+ nesterov: bool = True):
207
+ # ...
208
+ ```
209
+
210
+ A foreach implementation of the Muon optimizer, incorporating orthogonal updates via the `orthogonalize_update`
211
+ transformation.
212
+
213
+ **Key Parameters:**
214
+
215
+ * **`lr`**: Learning rate.
216
+ * **`betas`**: Coefficients used for computing running averages of the gradient and its square.
217
+ * **`eps`**: A small constant for numerical stability.
218
+ * **`weight_decay`**: Weight decay coefficient.
219
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
220
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
221
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
222
+ * **`mars`**: Enables/disables Mars correction.
223
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
224
+ direction to the gradients.
225
+ * **`mars_gamma`**: Mars correction coefficient.
226
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
227
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
228
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
229
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
230
+ * **`nesterov`**: Enables/disables Nesterov momentum.
231
+
232
+ #### `ForeachLaProp`
233
+
234
+ ```python
235
+ class ForeachLaProp(C.BaseOpt):
236
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
237
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
238
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
239
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
240
+ # ...
241
+ ```
242
+
243
+ A foreach implementation of the LaProp optimizer.
244
+
245
+ **Key Parameters:**
246
+
247
+ * **`lr`**: Learning rate.
248
+ * **`betas`**: Coefficients used for computing running averages of the gradient and its square.
249
+ * **`eps`**: A small constant for numerical stability.
250
+ * **`weight_decay`**: Weight decay coefficient.
251
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
252
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
253
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
254
+ * **`mars`**: Enables/disables Mars correction.
255
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
256
+ direction to the gradients.
257
+ * **`mars_gamma`**: Mars correction coefficient.
258
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
259
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
260
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
261
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
262
+
263
+ #### `MuonLaProp`
264
+
265
+ ```python
266
+ class MuonLaProp(C.BaseOpt):
267
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
268
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
269
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
270
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
271
+ # ...
272
+ ```
273
+
274
+ A variant of LaProp that incorporates orthogonal updates via the `orthogonalize_update` transformation.
275
+
276
+ **Key Parameters:**
277
+
278
+ * **`lr`**: Learning rate.
279
+ * **`betas`**: Coefficients used for computing running averages of the gradient and its square.
280
+ * **`eps`**: A small constant for numerical stability.
281
+ * **`weight_decay`**: Weight decay coefficient.
282
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
283
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
284
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
285
+ * **`mars`**: Enables/disables Mars correction.
286
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
287
+ direction to the gradients.
288
+ * **`mars_gamma`**: Mars correction coefficient.
289
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
290
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
291
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
292
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
293
+
294
+ #### `ForeachSOAP`
295
+
296
+ ```python
297
+ class ForeachSOAP(C.BaseOpt):
298
+ use_precond_schedule: bool = False
299
+
300
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
301
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
302
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
303
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
304
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
305
+ mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
306
+ beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
307
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
308
+ # ...
309
+ ```
310
+
311
+ A foreach implementation of the SOAP (Second-Order Adaptive Preconditioner) optimizer. It uses a preconditioner based on
312
+ the second-order statistics of the gradients to accelerate convergence.
313
+
314
+ **Key Parameters:**
315
+
316
+ * **`lr`**: Learning rate.
317
+ * **`betas`**: Coefficients used for computing running averages of the gradient.
318
+ * **`shampoo_beta`**: Coefficient used for computing running average of the preconditioner.
319
+ * **`eps`**: A small constant for numerical stability.
320
+ * **`weight_decay`**: Weight decay coefficient.
321
+ * **`precondition_frequency`**: Frequency of preconditioner updates. If using `use_precond_schedule`, this parameter is
322
+ ignored.
323
+ * **`max_precond_dim`**: Maximum dimension of the preconditioner.
324
+ * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
325
+ * **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
326
+ * **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
327
+ * **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
328
+ * **`correct_bias`**: Enables/disables bias correction for the running averages.
329
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
330
+ * **`split`**: Whether to split large dimensions when forming the preconditioner.
331
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
332
+ * **`mars`**: Enables/disables Mars correction.
333
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
334
+ direction to the gradients.
335
+ * **`mars_gamma`**: Mars correction coefficient.
336
+ * **`palm`**: Enables/disables PaLM's beta2 schedule.
337
+ * **`precond_scheduler`**: A tuple `(power, log_base)` specifying the preconditioner update schedule, where the update
338
+ probability is `1 / (step ** power * log_base)`. This parameter is only used if `use_precond_schedule` is `True`.
339
+ * **`beta2_scale`**: if we're using the PaLM schedule, `beta2 = step ** -beta2_scale`
340
+ * **`use_precond_schedule`**: Whether to use a dynamic preconditioner update schedule instead of a fixed frequency.
341
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
342
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
343
+
344
+ #### `PaLMForeachSOAP`
345
+
346
+ ```python
347
+ class PaLMForeachSOAP(ForeachSOAP):
348
+ use_precond_schedule: bool = False
349
+ palm: bool = True
350
+ ```
351
+
352
+ A specialized version of `ForeachSOAP` with PaLM's beta2 schedule enabled by default.
353
+
354
+ #### `PrecondScheduleForeachSOAP`
355
+
356
+ ```python
357
+ class PrecondScheduleForeachSOAP(ForeachSOAP):
358
+ use_precond_schedule: bool = True
359
+ ```
360
+
361
+ A specialized version of `ForeachSOAP` that uses a dynamic preconditioner update schedule.
362
+
363
+ #### `PrecondSchedulePaLMForeachSOAP`
364
+
365
+ ```python
366
+ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
367
+ use_precond_schedule: bool = True
368
+ palm: bool = True
369
+ ```
370
+
371
+ A specialized version of `ForeachSOAP` with both PaLM-specific modifications and a dynamic preconditioner update
372
+ schedule enabled by default.
373
+
374
+ #### `ForeachPSGDKron`
375
+
376
+ ```python
377
+ class ForeachPSGDKron(C.BaseOpt):
378
+ delayed: bool = False
379
+ cached: bool = False
380
+ exp_avg_input: bool = True
381
+
382
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
383
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
384
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
385
+ split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
386
+ stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
387
+ caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
388
+ cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
389
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
390
+ # expert parameters
391
+ precond_init_scale=1.0, precond_lr=0.1):
392
+ # ...
393
+ ```
394
+
395
+ A foreach implementation of the PSGD (Preconditioned Stochastic Gradient Descent) optimizer with Kronecker-factored
396
+ preconditioners.
397
+
398
+ **Key Parameters:**
399
+
400
+ * **`lr`**: Learning rate.
401
+ * **`beta`**: Coefficient used for computing running average of the gradient.
402
+ * **`weight_decay`**: Weight decay coefficient.
403
+ * **`preconditioner_update_probability`**: Probability of updating the preconditioner at each step. If `None`, a default
404
+ schedule is used.
405
+ * **`max_size_triangular`**: Maximum size of triangular matrices used in the preconditioner.
406
+ * **`min_ndim_triangular`**: Minimum number of dimensions for a tensor to be considered for triangular preconditioner.
407
+ * **`memory_save_mode`**: Memory saving mode for the preconditioner. Can be `None`, `"one_diag"`, or `"all_diag"`.
408
+ * **`momentum_into_precond_update`**: Whether to use momentum in the preconditioner update.
409
+ * **`warmup_steps`**: Number of steps for linear learning rate warmup.
410
+ * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
411
+ * **`split`**: Whether to split large dimensions when forming the preconditioner.
412
+ * **`store_triu_as_line`**: Whether to store the upper triangular part of the preconditioner as a 1D vector.
413
+ * **`foreach`**: Enables/disables the use of `foreach` operations.
414
+ * **`q_dtype`**: The floating-point type to be used for the preconditioner. `"float32"` or `"bfloat16"`.
415
+ * **`stochastic_schedule`**: Whether to use a stochastic schedule for updating the preconditioner.
416
+ * **`storage_dtype`**: The floating-point type to be used for internal state. `"float32"` or `"bfloat16"`.
417
+ * **`mars`**: Enables/disables Mars correction.
418
+ * **`caution`**: Enables/disables the use of a cautious update rule, avoiding updates that point in the opposite
419
+ direction to the gradients.
420
+ * **`mars_gamma`**: Mars correction coefficient.
421
+ * **`delayed`**: Enables/disables delayed preconditioner updates.
422
+ * **`cached`**: Enables/disables caching of preconditioner-related computations.
423
+ * **`exp_avg_input`**: Whether to apply `exp_avg` to the input before calculating the preconditioner.
424
+ * **`gradient_clipping`**: Gradient clipping function or method. See `heavyball.utils` for available options.
425
+ * **`update_clipping`**: Update clipping function or method. See `heavyball.utils` for available options.
426
+ * **`precond_init_scale`**: Initial scale of the preconditioner.
427
+ * **`precond_lr`**: Learning rate for preconditioner updates.
428
+
429
+ #### `ForeachPurePSGD`
430
+
431
+ ```python
432
+ class ForeachPurePSGD(ForeachPSGDKron):
433
+ exp_avg_input: bool = False
434
+ ```
435
+
436
+ A specialized version of `ForeachPSGDKron` that does not apply `exp_avg` to the input before calculating the
437
+ preconditioner.
438
+
439
+ #### `ForeachCachedDelayedPSGDKron`
440
+
441
+ ```python
442
+ class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
443
+ delayed: bool = True
444
+ cached: bool = True
445
+ ```
446
+
447
+ A specialized version of `ForeachPSGDKron` with both delayed preconditioner updates and caching enabled by default.
448
+
449
+ #### `ForeachCachedPSGDKron`
450
+
451
+ ```python
452
+ class ForeachCachedPSGDKron(ForeachPSGDKron):
453
+ cached: bool = True
454
+ ```
455
+
456
+ A specialized version of `ForeachPSGDKron` with caching enabled by default.
457
+
458
+ #### `ForeachDelayedPSGD`
459
+
460
+ ```python
461
+ class ForeachDelayedPSGD(ForeachPSGDKron):
462
+ delayed: bool = True
463
+ ```
464
+
465
+ A specialized version of `ForeachPSGDKron` with delayed preconditioner updates enabled by default.
466
+
467
+ ## `heavyball.utils`
468
+
469
+ The `heavyball.utils` module provides several important functions and settings that users may find useful:
470
+
471
+ ### Settings
472
+
473
+ * **`compile_mode`**: (defaults to `"max-autotune-no-cudagraphs"`) Controls the compilation mode used by
474
+ `torch.compile`. Setting this to `"default"` or `"max-autotune-no-cudagraphs"` improves performance at the cost of
475
+ increasd compile time. Setting it to `None` disables compilation.
476
+ * **`dynamic`**: (defaults to `False`) Enables/disables dynamic shapes during compilation. Enabling this reduces
477
+ compilation time but may lead to slower execution.
478
+ * **`zeroth_power_mode`**: (defaults to `"qr"`) Controls the method used for computing the zeroth power of a matrix (
479
+ orthogonalization) in certain preconditioners. Options include:
480
+ * `"qr"`: Uses QR decomposition.
481
+ * `"svd"`: Uses singular value decomposition.
482
+ * `"newtonschulz"`: Uses Newton-Schulz iteration.
483
+
484
+ ### Gradient/Update Clipping
485
+
486
+ The following functions are used for gradient and update clipping. They can be passed to the `gradient_clipping` or
487
+ `update_clipping` arguments of the optimizers:
488
+
489
+ * **`l2_clip_`**: Clips the gradient/update by its L2 norm.
490
+ * **`rmsnorm_clip_`**: Clips the gradient/update by its RMS norm.
491
+ * **`trust_region_clip_`**: Clips the gradient/update using a trust region method.
492
+ * **`mu_law_compress`**: Compresses the gradient/update using the µ-law algorithm.
493
+ * **`a_law_compress`**: Compresses the gradient/update using the A-law algorithm.
494
+ * **`identity`**: Does not modify the gradient/update (no clipping).
495
+
496
+ ### Other Utilities
497
+
498
+ * **`set_torch`**: Sets recommended PyTorch settings for performance, including enabling cuDNN benchmark mode, disabling
499
+ deterministic algorithms, setting the precision of float32 matrix multiplications, and enabling opt-einsum with the "
500
+ auto-hq" strategy.
501
+ * **`clean`**: Clears the CUDA cache.
502
+ * **`hook_optimizer_into_model`**: Hooks an optimizer into a model's `post_accumulate_grad_hook`.
503
+ * **`fused_hook`**: Hooks an optimizer into a model's `post_accumulate_grad_hook`, fusing multiple parameter updates
504
+ into a single step.
505
+ * **`disable_caution_scaling`**: Disables the scaling factor applied when `caution` is enabled in optimizers.
506
+
507
+ ## Example Usage
508
+
509
+ ```python
510
+ import torch
511
+ from torch import nn
512
+ import heavyball
513
+
514
+ # Define a simple model
515
+ model = nn.Linear(10, 2)
516
+
517
+ # Create an optimizer
518
+ optimizer = heavyball.ForeachAdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
519
+ # alternative:
520
+ optimizer = heavyball.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
521
+
522
+ # Generate some dummy data
523
+ input = torch.randn(1, 10)
524
+ target = torch.randn(1, 2)
525
+
526
+ # Training loop
527
+ for _ in range(100):
528
+ # Forward pass
529
+ output = model(input)
530
+ loss = (output - target).sum()
531
+
532
+ # Backward pass
533
+ loss.backward()
534
+
535
+ # Optimizer step
536
+ optimizer.step()
537
+
538
+ # optional: zero gradients; optimizer.step() already does this, which is different from torch.optim
539
+ optimizer.zero_grad()
540
+ ```
541
+
542
+ This example demonstrates how to create an `AdamW` optimizer and use it to train a simple linear model. You can easily
543
+ replace `AdamW` with any other optimizer from the `heavyball` library and customize its behavior using the various
544
+ available parameters and settings.
545
+
546
+ By using `heavyball`'s optimizers and understanding the options in `heavyball.utils`, users can achieve better
547
+ performance, control over training, and easier experimentation with advanced optimization techniques.
548
+
549
+
550
+ ---
551
+
552
+ # `heavyball.chainable`: A Composable Optimizer API
553
+
554
+ The `heavyball.chainable` module provides a powerful and flexible way to build optimizers through function composition,
555
+ similar to Optax. It allows you to chain together a sequence of transformations to create custom optimization algorithms
556
+ tailored to your specific needs. This modular approach makes it easy to experiment with different optimization
557
+ strategies and build complex optimizers from simple, reusable components.
558
+
559
+ ## Core Concept
560
+
561
+ At the heart of `heavyball.chainable` lies the concept of gradient transformations. A gradient transformation is simply
562
+ a function that takes a state dictionary, a group dictionary, an update tensor, a gradient tensor, and a parameter
563
+ tensor as input, and returns a new (or modified) update tensor. These transformations can be chained together to form an
564
+ optimization algorithm.
565
+
566
+ The state dictionary stores any persistent state needed by the transformation, such as momentum buffers or
567
+ preconditioners. The group dictionary contains hyperparameters specific to a group of parameters. The update tensor is
568
+ the current update being processed, the gradient tensor is the gradient of the loss with respect to the parameter, and
569
+ the parameter tensor is the parameter itself.
570
+
571
+ ### Function Signature
572
+
573
+ A typical gradient transformation function has the following signature:
574
+
575
+ ```python
576
+
577
+ def my_transformation(state: dict, group: dict, update: List[torch.Tensor], grad: List[torch.Tensor],
578
+ param: List[torch.Tensor]) -> torch.Tensor:
579
+ # ... transformation logic ...
580
+ return update
581
+ ```
582
+
583
+ or
584
+
585
+ ```python
586
+ @C.no_state_no_foreach
587
+ def my_transformation(group: dict, update: torch.Tensor, grad: torch.Tensor, param: torch.Tensor, *args,
588
+ **kwargs) -> torch.Tensor:
589
+ # ... transformation logic ...
590
+ return update
591
+ ```
592
+
593
+ Note that the second version has no state and processes updates one by one, while the first version processes updates
594
+ in parallel.
595
+
596
+ These functions modify the `update` in place or return a new tensor.
597
+
598
+ ### Example: Scaling by the learning rate
599
+
600
+ ```python
601
+ from heavyball import chainable as C
602
+
603
+
604
+ @C.no_state_no_foreach
605
+ def scale_by_learning_rate(group: dict, update: torch.Tensor, grad: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
606
+ return update * group["lr"]
607
+ ```
608
+
609
+ ## `FunctionTransform` and Guards
610
+
611
+ To make it easier to create gradient transformations, `heavyball.chainable` provides the `FunctionTransform` class and a
612
+ set of "guard" decorators.
613
+
614
+ ### `FunctionTransform`
615
+
616
+ `FunctionTransform` is a base class for gradient transformations that provides a common interface and helper methods. It
617
+ takes a function `fn` as input and stores it along with its name.
618
+
619
+ ```python
620
+ class FunctionTransform:
621
+ def __init__(self, fn):
622
+ self.fn = fn
623
+ self.fn_name = self.get_fn().__name__
624
+
625
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
626
+ raise NotImplementedError
627
+
628
+ def get_fn(self):
629
+ if hasattr(self.fn, 'get_fn'):
630
+ return self.fn.get_fn()
631
+ return self.fn
632
+
633
+ def val_name(self, name):
634
+ return f"{self.fn_name}_{name}"
635
+ ```
636
+
637
+ ### Guards
638
+
639
+ Guards are decorators that help manage the state dictionary and ensure that transformations are applied correctly. They
640
+ handle common tasks like initializing state variables and preventing redundant computations.
641
+
642
+ #### `zero_guard`
643
+
644
+ The `zero_guard` decorator ensures that a specific variable in the state dictionary is initialized to zero if it doesn't
645
+ exist.
646
+
647
+ ```python
648
+ @C.zero_guard("momentum")
649
+ def my_transformation(state, group, update, grad, param, momentum):
650
+ # ... momentum will be initialized to zero if it doesn't exist in state ...
651
+ return update
652
+ ```
653
+
654
+ #### `copy_guard`
655
+
656
+ The `copy_guard` decorator creates a copy of a specified input (update, grad, or param) and stores it in the state
657
+ dictionary.
658
+
659
+ ```python
660
+ @C.copy_guard(0, "update_copy") # 0 refers to the 'update' argument
661
+ def my_transformation(state, group, update, grad, param, update_copy):
662
+ # ... update_copy will be a copy of the update tensor ...
663
+ return update
664
+ ```
665
+
666
+ #### `general_guard`
667
+
668
+ The `general_guard` decorator provides a more flexible way to manage state. It allows you to specify a custom
669
+ initialization function that is called if a specific variable is not found in the state.
670
+
671
+ ```python
672
+ def init_preconditioner(state, group, update, grad, param, **kwargs):
673
+
674
+
675
+ # ... initialize preconditioner ...
676
+
677
+ @C.general_guard("precond", init_fn=init_preconditioner)
678
+ def my_transformation(state, group, update, grad, param, precond):
679
+ # ... precond will be initialized using init_preconditioner if it doesn't exist ...
680
+ return update
681
+ ```
682
+
683
+ #### `no_state`
684
+
685
+ The `no_state` decorator indicates that a transformation does not use or modify any state.
686
+
687
+ #### `no_state_no_foreach`
688
+
689
+ The `no_state_no_foreach` decorator indicates that a transformation does not use or modify any state and also does not
690
+ support `foreach` implementations.
691
+
692
+ ## Chaining Transformations
693
+
694
+ The power of `heavyball.chainable` comes from its ability to chain transformations together. This is achieved through
695
+ the `chain` function.
696
+
697
+ ```python
698
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
699
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
700
+ skip_update = False
701
+ for fn in fns:
702
+ try:
703
+ update = fn(state, group, update, grad, param)
704
+ except SkipUpdate:
705
+ skip_update = True
706
+ continue
707
+ if update is None:
708
+ break
709
+ if not skip_update and update is not None:
710
+ utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
711
+ ```
712
+
713
+ The `chain` function takes a state dictionary, a group dictionary, a gradient tensor, a parameter tensor, and a sequence
714
+ of gradient transformations as input. It applies each transformation in order, passing the output of one transformation
715
+ as the input to the next.
716
+
717
+ ## Building Optimizers
718
+
719
+ The `ChainOpt` class provides a convenient way to build optimizers from chained transformations.
720
+
721
+ ```python
722
+ class ChainOpt(utils.StatefulOptimizer):
723
+ # ...
724
+ def __init__(self, params, defaults, foreach: bool, *fns):
725
+ # ...
726
+ self.fns = tuple(fns)
727
+
728
+ def _step(self, group):
729
+ # ...
730
+ if not group['foreach'] or len(p) == 1:
731
+ for param, grad in zip(p, g):
732
+ chain(self.state_, group, [grad], [param], *self.fns)
733
+ else:
734
+ chain(self.state_, group, g, p, *self.fns)
735
+ # ...
736
+ ```
737
+
738
+ ### BaseOpt
739
+
740
+ The `BaseOpt` class extends `ChainOpt` and provides additional features like gradient clipping, update clipping, and
741
+ optional PaLM beta2 schedule.
742
+
743
+ ```python
744
+ class BaseOpt(ChainOpt):
745
+ # ...
746
+ def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
747
+ palm: bool = use_default, *fns, compile_step: bool = use_default, promote: bool = use_default):
748
+ # ...
749
+ ```
750
+
751
+ ### `ScheduleFree`
752
+
753
+ The `ScheduleFree` class provides a convenient interface for using the `update_by_schedule_free` transformation.
754
+
755
+ ### Predefined Transformations
756
+
757
+ `heavyball.chainable` provides a number of predefined gradient transformations, including:
758
+
759
+ * `exp_avg`: Calculates the exponential moving average of the gradients.
760
+ * `scale_by_exp_avg_sq`: Scales the updates by the inverse square root of the exponential moving average of squared
761
+ gradients.
762
+ * `scale_by_adam`: Scales the updates using the Adam algorithm.
763
+ * `update_by_adam`: Updates the parameters using the Adam algorithm.
764
+ * `scale_by_laprop`: Scales the updates using the LaProp algorithm.
765
+ * `update_by_laprop`: Updates the parameters using the LaProp algorithm.
766
+ * `update_by_schedule_free`: Updates the parameters using the Schedule-Free algorithm.
767
+ * `update_by_adopt`: Updates the parameters using the ADOPT algorithm.
768
+ * `scale_by_adopt`: Scales the updates using the ADOPT algorithm.
769
+ * `orthogonalize_update`: Orthogonalizes the update tensor.
770
+ * `nesterov_momentum`: Applies Nesterov momentum to the updates.
771
+ * `heavyball_momentum`: Applies heavy-ball momentum to the updates.
772
+ * `scale_by_soap`: Scales the updates using the SOAP preconditioner.
773
+ * `scale_by_psgd`: Scales the updates using the PSGD preconditioner.
774
+ * `scale_by_delayed_psgd`: Scales the updates using the delayed PSGD preconditioner.
775
+ * `update_by_psgd`: Updates the parameters using the PSGD preconditioner.
776
+ * `update_by_delayed_psgd`: Updates the parameters using the delayed PSGD preconditioner.
777
+ * `palm_beta2`: Modifies the beta2 parameter for PaLM optimizers.
778
+
779
+ ## Creating New Transformations
780
+
781
+ You can easily create new gradient transformations by following the function signature and using the provided guards and
782
+ `FunctionTransform` class.
783
+
784
+ ### Example: Clipping gradients by norm
785
+
786
+ ```python
787
+ from heavyball import chainable as C
788
+ from heavyball import utils
789
+
790
+
791
+ @C.no_state
792
+ def clip_by_global_norm(group: dict, update: torch.Tensor, grad: torch.Tensor, param: torch.Tensor,
793
+ max_norm: float) -> torch.Tensor:
794
+ """Clips the gradient by its global norm."""
795
+ total_norm = torch.norm(torch.stack([torch.norm(g) for g in grad]))
796
+ clip_coef = max_norm / (total_norm + 1e-6)
797
+ if clip_coef < 1:
798
+ return [u * clip_coef for u in update]
799
+ return update
800
+ ```
801
+
802
+ ### Example: L2-Normalization of updates
803
+
804
+ ```python
805
+ from heavyball import chainable as C
806
+ from heavyball import utils
807
+
808
+
809
+ @C.no_state_no_foreach
810
+ def l2_normalize_updates(group: dict, update: torch.Tensor, grad: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
811
+ """L2-normalizes the updates."""
812
+ norm = update.norm()
813
+ if norm > 0:
814
+ return update / norm
815
+ return update
816
+ ```
817
+
818
+ ---
819
+
820
+ ## Optimizer Recommendations
821
+
822
+ This hierarchy ranks optimizers from most recommended (top) to least recommended (bottom) for general deep learning
823
+ tasks. However, the best choice always depends on your specific model, dataset, and computational resources.
824
+
825
+ **1. Preconditioned Optimizers (SOAP and PSGD):**
826
+
827
+ - **Recommendation:** **Start here.** These are generally the most powerful and efficient optimizers in `heavyball`.
828
+ - **`ForeachSOAP`** (and its variants: `PaLMForeachSOAP`, `PrecondScheduleForeachSOAP`,
829
+ `PrecondSchedulePaLMForeachSOAP`):
830
+ - **Strengths:**
831
+ - **Adaptive Preconditioning:** SOAP dynamically adapts to the curvature of the loss landscape using
832
+ second-order information, leading to faster convergence, especially in ill-conditioned problems.
833
+ - **Robustness:** Less sensitive to hyperparameter choices compared to Adam.
834
+ - **Strong Empirical Performance:** Often outperforms other optimizers across various tasks and architectures.
835
+ - **Weaknesses:**
836
+ - **Computational Cost:** Higher per-step cost due to preconditioner computation and updates.
837
+ - **Memory Usage:** Can use more memory than simpler optimizers, particularly for large models.
838
+ - **`precondition_frequency` or `precond_scheduler`:** Needs to be tuned, though the default schedule usually
839
+ works well.
840
+ - **When to use:**
841
+ - **Complex models and datasets:** Where optimization is challenging.
842
+ - **When training stability is crucial.**
843
+ - **When you can't retune hyperparameters.**
844
+ - **Variants:**
845
+ - `PaLMForeachSOAP`: Enables PaLM's beta2 schedule by default.
846
+ - `PrecondScheduleForeachSOAP`: Uses a dynamic schedule for preconditioner updates.
847
+ - `PrecondSchedulePaLMForeachSOAP`: Combines the PaLM schedule with a dynamic preconditioner schedule.
848
+
849
+ - **`ForeachPSGDKron`** (and its variants: `ForeachPurePSGD`, `ForeachCachedDelayedPSGDKron`, `ForeachCachedPSGDKron`,
850
+ `ForeachDelayedPSGD`):
851
+ - **Strengths:**
852
+ - **Preconditioning:** Uses Kronecker-factored approximations to capture curvature information, providing many
853
+ of the benefits of second-order methods at a lower cost than full curvature methods.
854
+ - **Efficiency:** Relatively efficient in terms of computation.
855
+ - **Tunability:** Offers many options for customization.
856
+ - **Convergence:** Tends to converge faster than SOAP.
857
+ - **Weaknesses:**
858
+ - **No baseline:** SOAP can copy Adam's hyperparameters - PSGD requires more tuning.
859
+ - **Complexity:** Has many hyperparameters to tune.
860
+ - **When to use:**
861
+ - **Large models:** Where memory is a constraint.
862
+ - **When `ForeachSOAP` is too computationally expensive.**
863
+ - **When you want potentially the best performance regardless of computational cost.**
864
+ - **Variants:**
865
+ - `ForeachPurePSGD`: Disables exponential averaging of the input when calculating the preconditioner.
866
+ - `ForeachCachedDelayedPSGDKron`: Caches preconditioner-related computations and uses delayed preconditioner
867
+ updates.
868
+ - `ForeachCachedPSGDKron`: Caches preconditioner-related computations.
869
+ - `ForeachDelayedPSGD`: Uses delayed preconditioner updates.
870
+
871
+ **2. Muon:**
872
+
873
+ - **`ForeachMuon`** (and `MuonLaProp`):
874
+ - **Strengths:**
875
+ - **Momentum with Orthogonal Updates:** Combines momentum with orthogonalized updates, which can
876
+ improve stability and exploration.
877
+ - **Good Generalization:** Often leads to better generalization performance compared to Adam.
878
+ - **Weaknesses:**
879
+ - **Performance:** Generally outperformed by SOAP and PSGD.
880
+ - **Computational Cost:** Higher overheads than SOAP and PSGD.
881
+ - **When to use:**
882
+ - **When generalization is a primary concern.**
883
+ - **When you want an optimizer less prone to finding sharp minima.**
884
+
885
+ **3. Adam-Based Optimizers:**
886
+
887
+ - **`ForeachLaProp`**:
888
+ - **Strengths:**
889
+ - **Backward Compatibility:** Can use Adam's hyperparameters, but allows a larger range of betas.
890
+ - **Stability:** More stable than Adam.
891
+ - **Weaknesses:**
892
+ - **Performance:** Generally outperformed by SOAP, PSGD, and Muon.
893
+ - **When to use:**
894
+ - **When you want less risk or better losses than Adam, but can't run advanced methods.**
895
+
896
+ - **`ForeachAdamW`** (and `ForeachSFAdamW`, `PaLMForeachSFAdamW`):
897
+ - **Strengths:**
898
+ - **Widely Used:** A popular and well-established optimizer.
899
+ - **Weaknesses:**
900
+ - **Performance:** Often outperformed by preconditioned optimizers (SOAP, PSGD) and Muon.
901
+ - **Sensitivity to Hyperparameters:** Can be sensitive to the choice of learning rate and beta parameters.
902
+ - **When to use:**
903
+ - **As a strong baseline.**
904
+ - **When you are familiar with Adam and want a robust starting point.**
905
+ - **When computational cost is a major concern (compared to second-order methods).**
906
+ - **Variants:**
907
+ - `ForeachSFAdamW`: A Schedule-Free version of AdamW that dynamically adjusts the learning rate.
908
+ - `PaLMForeachSFAdamW`: A PaLM version of Schedule-Free AdamW.
909
+
910
+ ## Choosing the Right Optimizer
911
+
912
+ 1. **Start with Preconditioning:** Begin with either `ForeachSOAP` or `ForeachPSGDKron`. If computational resources are
913
+ a major constraint, lean towards `ForeachPSGDKron`. If performance is paramount, try `ForeachSOAP` first.
914
+
915
+ 2. **Consider Muon:** If preconditioned optimizers are not feasible or if you want to explore alternatives that
916
+ incorporate momentum and orthogonal updates, try `ForeachMuon`.
917
+
918
+ 3. **Use LaProp or Adam as Baselines:** `ForeachLaProp` can serve as a simple adaptive baseline. `ForeachAdamW` is a
919
+ strong and widely used baseline that you should always compare against.
920
+
921
+ 4. **Experiment and Tune:** The best optimizer ultimately depends on your specific problem. It's crucial to experiment
922
+ with different optimizers and carefully tune their hyperparameters (especially the learning rate).
923
+
924
+ ## Important Notes
925
+
926
+ * **Learning Rate:** The learning rate is the most important hyperparameter. You'll likely need to adjust it when
927
+ switching between optimizers.
928
+ * **Warmup:** Consider using a learning rate warmup, especially for more complex optimizers like SOAP and PSGD.
929
+ * **Weight Decay:** Weight decay can improve generalization for many optimizers, especially AdamW.
930
+ * **`foreach`:** Use `foreach` versions of the optimizers when possible for better performance.
931
+ * **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
932
+ `compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
933
+ configurations.
934
+
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=-5ovRa7yD7V41_cgaBJtO5fBrnBemAILl4YKjQmeuns,24183
3
+ heavyball/utils.py,sha256=x0rSU8lko7ACdI9GuTLC0wP6HwIZxwB8f8tukBOR0xA,48129
4
+ heavyball-1.4.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.4.3.dist-info/METADATA,sha256=RM_pOme3dsQL-drKcKD6FJ0qE3SSh4JdPM-kC9vpbeU,43584
6
+ heavyball-1.4.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.4.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.4.3.dist-info/RECORD,,
@@ -1,136 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: heavyball
3
- Version: 1.4.0
4
- Summary: Efficient optimizers
5
- Home-page: https://github.com/clashluke/heavyball
6
- Author: Lucas Nestler
7
- Author-email: github.heavyball@nestler.sh
8
- License: BSD
9
- Classifier: Development Status :: 5 - Production/Stable
10
- Classifier: License :: OSI Approved :: BSD License
11
- Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Topic :: Software Development :: Libraries
16
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
17
- Classifier: Intended Audience :: Developers
18
- Requires-Python: >=3.7
19
- Description-Content-Type: text/markdown
20
- License-File: LICENSE
21
- Requires-Dist: opt-einsum
22
- Requires-Dist: torch
23
- Requires-Dist: numpy
24
-
25
- # HeavyBall
26
-
27
- > [!IMPORTANT]
28
- > It's recommended to use `heavyball.utils.set_torch()` for faster training and less memory usage.
29
-
30
- A simple package of efficient optimizers
31
-
32
- The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
33
- largely static alternative to `torch.optim` with more and better optimizers.
34
-
35
- Currently (2024-12-07, 1.0.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
36
- recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
37
-
38
- ## Features
39
-
40
- * **Optax-like API**: `C = heavyball.chainable; grokfast = C.ChainOpt(p, lr, C.exp_avg, C.scale_by_adam)`
41
- * **Stochastic Rounding**: [FP32 convergence with BF16 parameters](https://github.com/pytorch/pytorch/issues/120376)
42
- * **Inplace EMA**: Same math, but less memory, less compute and higher stability
43
- * **Foreach**: Fast multi-tensor application (turn it off to save memory via `foreach=False`)
44
- * **PaLM Beta2**: Fast initial
45
- convergence, [stable late convergence](https://x.com/_clashluke/status/1820810798693818761)
46
- * **ScheduleFree**: No learning rate schedule, but better convergence
47
- * [**Preconditioner Schedule**](https://github.com/lixilinx/psgd_torch/): Improved loss-per-step in early convergence,
48
- better step-per-second in late convergence (explained below)
49
- * **Memory-efficient storage** PSGD supports `store_triu_as_line` (default: `True`) and `q_dtype` to trade off memory
50
- usage for memory
51
- bandwidth; Other optimizers have `storage_dtype`, supporting lower-precision EMAs at no(?) performance drop via
52
- stochastic rounding
53
-
54
- ## Getting started
55
-
56
- ```bash
57
- pip install heavyball
58
- ```
59
-
60
- ```python
61
- import torch
62
- import heavyball
63
-
64
- # Create a model
65
- model = torch.nn.Linear(16, 1)
66
-
67
- # Create an optimizer
68
- optimizer = heavyball.PrecondSchedulePaLMSOAP(model.parameters(), lr=1e-3)
69
-
70
- x = torch.randn(128, 16)
71
- y = torch.randn(128, 1)
72
-
73
- for _ in range(1000):
74
- optimizer.zero_grad()
75
- loss = torch.nn.functional.mse_loss(model(x), y)
76
- loss.backward()
77
- optimizer.step()
78
- ```
79
-
80
- ## Optimizers
81
-
82
- | Name | Description | Advantages / Disadvantages |
83
- |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
84
- | **AdamW** | More efficient (speed, memory) [AdamW](https://arxiv.org/abs/1711.05101) | + Faster than AdamW<br>+ Possibly more (numerically) stable
85
- | **LaProp** | More efficient (speed, memory) [LaProp](https://arxiv.org/abs/2002.04839) | + Same cost as AdamW<br>+ Marginally better converence (better proofs)<br>+ Higher hyperparameter stability<br>- Not a guaranteed win (can be neutral)<br>- No "Slingshot" |
86
- | **ADOPT** | More efficient (speed, memory) [ADOPT](https://arxiv.org/abs/2411.02853) | + Same cost as AdamW<br>+ Rigorous mathematical convergence proofs, even for challenging models (GANs)<br>- Empirically underperforms LaProp<br>- no bf16 |
87
- | **SFAdamW** | More efficient (speed, memory) [ScheduleFree AdamW](https://arxiv.org/abs/2405.15682) | + Same cost as AdamW, but better eval perf<br>+ Full control over hyperparameters |
88
- | **PaLMSFAdamW** | ForeachSFAdamW with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Same cost as AdamW, but better eval perf<br>+ Less control, but faster early and more stable late convergence<br>+ ScheduleFree<br>- slow early convergence |
89
- | **SOAP** | More efficient (speed, memory) [SOAP](https://arxiv.org/abs/2409.11321) | + Faster convergence (loss-at-step)<br>+ Full control over hyperparameters<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
90
- | **PaLMSOAP** | ForeachSOAP with [PaLM's beta2 schedule](https://arxiv.org/abs/2204.02311) | + Faster convergence (loss-at-step)<br>+ Less control, but faster early and more stable late convergence<br>- more memory usage<br>- more hyperparameters<br>- higher overhead than AdamW (can be ammortized; better loss-at-second) |
91
- | **SFPaLMSOAP** | ScheduleFree PaLMForeachSOAP | + Fast convergence (loss-at-step)<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized) |
92
- | **PrecondScheduleSFPaLMSOAP** | SFPaLMForeachSOAP with [preconditioner schedule](https://github.com/lixilinx/psgd_torch/), matching the error of PrecondEvery=2 with the cost of PrecondEvery=512 | + Better initial convergence than SFPaLMForeachSOAP<br>+ Significantly faster (sec/it) later<br>+ less memory usage than PaLMForeachSOAP (more tham AdamW)<br>- slower initial convergence than PaLMForeachSOAP (but allows higher LRs)<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of step |
93
- | **PrecondSchedulePaLMSOAP** | PrecondScheduleSFPaLMForeachSOAP without schedule-free | + Best initial convergence<br>+ Significantly faster (sec/it) later<br>+ high stability<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
94
- | **PrecondScheduleSOAP** | PrecondScheduleSFPaLMForeachSOAP without PaLM's beta2 schedule | + Better initial convergence<br>+ Significantly faster (sec/it) later<br>- more memory usage than PrecondScheduleSFPaLMForeachSOAP<br>- higher overhead than AdamW (can be ammortized), goes to 0 with increasing number of steps |
95
-
96
- ## Precond Schedule
97
-
98
- The default preconditioner schedule (`f`) would yield the following update intervals:
99
-
100
- | Steps | Interval, `f` | Total (schedule) | Total (constant, every 2) | Total (constant, every 16) |
101
- |-----------|---------------|------------------|---------------------------|----------------------------|
102
- | 10 | 1.00005 | 10 | 5 (0.5x) | 0 (0.0x) |
103
- | 100 | 1.026 | 99 | 50 (0.5x) | 6 (0.1x) |
104
- | 1,000 | 2.0 | 738 | 500 (0.7x) | 62 (0.1x) |
105
- | 10,000 | 14.3 | 2,168 | 5,000 (2.3x) | 625 (0.3x) |
106
- | 100,000 | 100.2 | 4,049 | 50,000 (12.3x) | 6,250 (1.5x) |
107
- | 1,000,000 | 513 | 7,245 | 500,000 (69.0x) | 62,500 (8.6x) |
108
-
109
- ## Memory
110
-
111
- Second order optimizers make it difficult to estimate memory usage, as it depends on shapes and hyperparameters. To
112
- estimate your memory usage, you may use `test/test_memory.py` which attempts to ensure there are no regressions.\
113
- Furthermore, you can find real-world memory usage of a 300M parameters video diffusion model below:
114
- ![img.png](assets/memory.png)
115
-
116
- ## PSGD
117
-
118
- HeavyBall offers various configurations of PSGD:
119
-
120
- * "PSGDKron" is the baseline, equivalent to [kron_torch](https://github.com/evanatyourservice/kron_torch/), but with
121
- lower compute and memory
122
- overhead.
123
- * "PurePSGD" has no momentum, further reducing memory and compute
124
- * "DelayedPSGD" implements SOAP/ADOPT-style off-by-one momentum, which has worse initial convergence but higher
125
- stability
126
- ![img.png](assets/delayed_psgd.png)
127
-
128
- ## Utils
129
-
130
- To access `heavyball.utils`, you need to explicitly `import heavyball.utils`.\
131
- It has several handy functions:
132
-
133
- * `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
134
- * `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
135
- * `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
136
- the eigenvectors.
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
- heavyball/chainable.py,sha256=OYlCVe06SjpxUM8tBBJUIOrmU3uMYwVYwPzkaQMwN98,24171
3
- heavyball/utils.py,sha256=djwaSLZOB8B-xD2jJxZfXTJpJrcWp-mWTmKxC2F5Sh0,48330
4
- heavyball-1.4.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.4.0.dist-info/METADATA,sha256=Pbb_2JAZevdXVpjkJajuBNrTvrLnpDYBL5NLx1SZxHg,12022
6
- heavyball-1.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.4.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.4.0.dist-info/RECORD,,