heavyball 1.5.2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
104
104
  Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
105
105
  https://arxiv.org/abs/2409.11321
106
106
  https://github.com/nikhilvyas/SOAP
107
-
108
- ScheduleFree:
109
- The Road Less Scheduled
110
- Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
111
- https://arxiv.org/abs/2405.15682
112
- https://github.com/facebookresearch/schedule_free
113
107
  """
114
108
  use_precond_schedule: bool = False
115
109
 
116
110
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
117
111
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
118
112
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
119
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 0,
120
- split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
121
- mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
122
- beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
123
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
114
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
115
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
116
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
117
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
124
118
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
125
119
 
126
120
  defaults = locals()
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
137
131
  C.scale_by_soap)
138
132
 
139
133
 
134
+ class ForeachSignLaProp(C.BaseOpt):
135
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
136
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
137
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
138
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
139
+ defaults = locals()
140
+ defaults.pop("self")
141
+ params = defaults.pop("params")
142
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
143
+
144
+
145
+ class ForeachSOLP(C.BaseOpt):
146
+ """
147
+ ForeachSOLP
148
+
149
+ Sources:
150
+ Baseline SOAP:
151
+ SOAP: Improving and Stabilizing Shampoo using Adam
152
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
153
+ https://arxiv.org/abs/2409.11321
154
+ https://github.com/nikhilvyas/SOAP
155
+ """
156
+ use_precond_schedule: bool = False
157
+
158
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
159
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
160
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
161
+ correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
162
+ mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
163
+ precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
164
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
165
+ storage_dtype: str = 'float32', stochastic_schedule: bool = False):
166
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
167
+
168
+ defaults = locals()
169
+ defaults.pop("self")
170
+ params = defaults.pop("params")
171
+
172
+ if use_precond_schedule:
173
+ del defaults['precondition_frequency']
174
+ self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
175
+ else:
176
+ del defaults['precond_scheduler']
177
+ self.precond_schedule = 1 / defaults.pop("precondition_frequency")
178
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
179
+ functools.partial(C.scale_by_soap, inner='laprop'))
180
+
181
+
140
182
  class PaLMForeachSOAP(ForeachSOAP):
141
183
  use_precond_schedule: bool = False
142
184
  palm: bool = True
@@ -163,6 +205,18 @@ class OrthoLaProp(C.BaseOpt):
163
205
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
206
 
165
207
 
208
+ class LaPropOrtho(C.BaseOpt):
209
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
210
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
211
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
212
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
213
+ defaults = locals()
214
+ defaults.pop("self")
215
+ params = defaults.pop("params")
216
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
217
+ C.orthogonalize_grad_to_param)
218
+
219
+
166
220
  class ForeachPSGDKron(C.BaseOpt):
167
221
  """
168
222
  Originally from Evan Walters and Omead Pooladzandi, 2024
@@ -178,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
178
232
  max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
179
233
  momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
180
234
  split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
181
- stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
235
+ stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
182
236
  caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
183
237
  cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
184
- gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
238
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
185
239
  # expert parameters
186
240
  precond_init_scale=1.0, precond_lr=0.1):
187
241
  defaults = locals()
@@ -238,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
238
292
  CachedPSGDKron = ForeachCachedPSGDKron
239
293
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
240
294
  Muon = ForeachMuon
295
+ SignLaProp = ForeachSignLaProp
241
296
 
242
297
  __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
243
298
  "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
244
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
245
- "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
246
- "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
247
- "ForeachRMSprop", "ForeachMuon", 'ForeachCachedNewtonPSGD']
299
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
300
+ "ForeachAdamW", "ForeachSFAdamW",
301
+ "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
302
+ "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
303
+ 'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
heavyball/chainable.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import random
3
- from typing import Optional, Union, Literal
3
+ from typing import Optional, Union, Literal, List
4
4
 
5
5
  import torch
6
6
 
@@ -127,7 +127,7 @@ def zero_guard(*names):
127
127
 
128
128
 
129
129
  def copy_guard(index, *names):
130
- return functools.partial(CopyGuard, index=index, names=names, )
130
+ return functools.partial(CopyGuard, index=index, names=names)
131
131
 
132
132
 
133
133
  def general_guard(*names, init_fn, skip_first: bool = True):
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
152
152
  return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
153
153
 
154
154
 
155
+ @zero_guard('exp_avg')
156
+ @no_state
157
+ def weight_decay_to_ema(group, update, grad, param, exp_avg):
158
+ utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
159
+ group['weight_decay_to_ema'] * group['lr'])
160
+ return update
161
+
162
+
163
+ @zero_guard('exp_avg')
164
+ @no_state
165
+ def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
166
+ utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
167
+ group['weight_decay_to_ema'] * group['lr'])
168
+ return update
169
+
170
+
155
171
  @zero_guard("exp_avg_sq")
156
172
  @no_state
157
173
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
@@ -206,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
206
222
  @no_state
207
223
  def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
208
224
  if group['step'] == 1:
209
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
225
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
210
226
  raise SkipUpdate
211
227
 
212
228
  if group['step'] == 2:
213
229
  update = utils.promote(update)
214
230
  easq = utils.promote(exp_avg_sq)
215
231
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
216
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
232
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
233
+ group['eps'])
217
234
  raise SkipUpdate
218
235
 
219
236
  utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
@@ -225,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
225
242
  @no_state
226
243
  def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
227
244
  if group['step'] == 1:
228
- utils.exp_avg_sq_(exp_avg_sq, update, 0, 1)
245
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
229
246
  raise SkipUpdate
230
247
 
231
248
  if group['step'] == 2:
232
249
  update = utils.promote(update)
233
250
  easq = utils.promote(exp_avg_sq)
234
251
  [utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
235
- utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
252
+ utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
253
+ group['eps'])
236
254
  raise SkipUpdate
237
255
 
238
256
  return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
239
257
 
240
258
 
241
- def _init_soap(state, group, update, grad, param):
242
- utils.init_preconditioner(grad, state, utils.get_beta2(group), group['max_precond_dim'], group['precondition_1d'])
259
+ def _init_soap(state, group, update, grad, param, inner: str = ''):
260
+ utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
243
261
 
244
262
 
245
263
  def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
@@ -295,6 +313,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
295
313
  return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
296
314
 
297
315
 
316
+ @zero_guard('momentum')
317
+ @no_state
318
+ def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
319
+ return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
320
+
321
+
322
+ def _store_std(state, group, update, grad, param):
323
+ state['init_std'] = torch.std(grad, dim=0)
324
+
325
+
326
+ @general_guard("init_std", init_fn=_store_std)
327
+ @no_state
328
+ def mup_approx(group, updates, grads, params, init_std):
329
+ _updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
330
+ _updates, _init_std = zip(*_updates)
331
+ utils.stochastic_multiply_(_updates, _init_std)
332
+ return updates
333
+
334
+
298
335
  @zero_guard("momentum")
299
336
  @no_state
300
337
  def heavyball_momentum(group, updates, grads, params, momentum):
@@ -308,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
308
345
  @general_guard("Q", "GG", init_fn=_init_soap)
309
346
  @no_state
310
347
  def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
311
- update = utils.promote(update)
348
+ update = utils.promote(update) # Promote to highest precision if needed
312
349
 
313
350
  grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
314
351
  fn = _optim_fns[inner]
315
- precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'])
352
+ precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
353
+ group['eps'])
316
354
  precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
317
355
 
318
- for u, q, gg, eas in zip(update, Q, GG, exp_avg_sq):
319
- utils.update_preconditioner(u, q, gg, eas, group['max_precond_dim'], group['precondition_1d'],
356
+ for u, q, gg, ea in zip(update, Q, GG, exp_avg):
357
+ utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
320
358
  utils.beta_debias(group['shampoo_beta'], group['step']),
321
359
  group['is_preconditioning'])
322
360
  return precond
@@ -414,6 +452,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
414
452
  raise SkipUpdate
415
453
 
416
454
 
455
+ @no_state
456
+ def sign(group, update, grad, param, graft: bool = True):
457
+ return utils.sign_(update, graft)
458
+
459
+
417
460
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
418
461
  @no_state_no_foreach
419
462
  def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
@@ -439,8 +482,7 @@ def apply_to_idx(fn, idx):
439
482
  return _fn
440
483
 
441
484
 
442
- def chain(state: Union[callable, dict], group, grad, param, *fns):
443
- update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
485
+ def _inner_chain(state, group, update, grad, param, *fns):
444
486
  skip_update = False
445
487
  for fn in fns:
446
488
  try:
@@ -450,10 +492,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
450
492
  continue
451
493
  if update is None:
452
494
  break
495
+ return update, skip_update
496
+
497
+
498
+ def chain(state: Union[callable, dict], group, grad, param, *fns):
499
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
500
+ update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
453
501
  if not skip_update and update is not None:
454
502
  utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
455
503
 
456
504
 
505
+ def create_branch(branches: List[List[callable]], merge_fn: callable):
506
+ def _branch(state, group, update, grad, param):
507
+ outputs = []
508
+ for branch in branches:
509
+ branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
510
+ branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
511
+ if skip_update:
512
+ raise ValueError("Branches should not skip updates")
513
+ outputs.append(branch_update)
514
+ return merge_fn(outputs)
515
+
516
+ return _branch
517
+
518
+
457
519
  class ChainOpt(utils.StatefulOptimizer):
458
520
  promote: bool = False
459
521
 
heavyball/utils.py CHANGED
@@ -1,24 +1,40 @@
1
+ import copy
1
2
  import functools
2
3
  import gc
4
+ import inspect
3
5
  import math
4
6
  import random
5
7
  import string
8
+ import sys
9
+ import time
6
10
  import warnings
11
+ from datetime import datetime
7
12
  from typing import List, Optional, Tuple, Callable, Union
13
+ from unittest.mock import patch
8
14
 
15
+ import hyperopt
9
16
  import numpy as np
10
17
  import torch
11
18
  from torch import Tensor
19
+ from torch._dynamo import config
12
20
  from torch._dynamo.exc import TorchDynamoException
13
21
  from torch.backends import cudnn, opt_einsum
14
22
  from torch.utils._pytree import tree_map
15
23
 
24
+ config.cache_size_limit = 2 ** 16
25
+
26
+ np.warnings = warnings
27
+
16
28
  compile_mode = "max-autotune-no-cudagraphs"
17
29
  dynamic = False
18
30
  compile_mode_recommended_to_none = None
19
31
  zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
20
32
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
21
33
 
34
+ base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
35
+ 'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
36
+ 'weight_decay': 1e-4}
37
+
22
38
 
23
39
  def decorator(func):
24
40
  compiled = None
@@ -51,7 +67,7 @@ def decorator_knowngood(func: Callable):
51
67
  return _fn
52
68
 
53
69
 
54
- einsum_base = string.ascii_lowercase + string.ascii_uppercase
70
+ einsum_base = string.ascii_lowercase
55
71
 
56
72
 
57
73
  @decorator_knowngood
@@ -103,29 +119,29 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
103
119
  but we want to merge conv kernels into fan-in or at least merge the kernel
104
120
  so, [128, 64, 3, 3] should result in [128, 576] or [128, 64, 9] instead of [73728] or [8192, 3, 3] the baseline
105
121
  would've done
122
+
123
+ By @francois-rozet (commit: 68cde41eaf7e73b4c46eacb6a944865dcc081f1d), re-commited due to faulty merge
106
124
  """
107
- shape = grad.shape
108
125
  new_shape = []
109
-
110
- curr_shape = 1
111
-
112
- for sh in shape[1:][::-1]:
113
- temp_shape = curr_shape * sh
114
- if temp_shape > max_precond_dim:
115
- if curr_shape > 1:
116
- new_shape.append(curr_shape)
117
- curr_shape = sh
126
+ cum_size = 1
127
+
128
+ for s in grad.shape[1:][::-1]:
129
+ temp_size = cum_size * s
130
+ if temp_size > max_precond_dim:
131
+ if cum_size > 1:
132
+ new_shape.append(cum_size)
133
+ cum_size = s
118
134
  else:
119
- new_shape.append(sh)
120
- curr_shape = 1
135
+ new_shape.append(s)
136
+ cum_size = 1
121
137
  else:
122
- curr_shape = temp_shape
123
- new_shape = [*shape[:1], *new_shape[::-1]]
138
+ cum_size = temp_size
124
139
 
125
- if curr_shape > 1 or len(new_shape) == 0:
126
- new_shape.append(curr_shape)
140
+ if cum_size > 1:
141
+ new_shape.append(cum_size)
127
142
 
128
- new_grad = grad.reshape(new_shape) # needs to be .reshape() due to channels_last
143
+ new_shape = [grad.shape[0], *new_shape[::-1]]
144
+ new_grad = grad.reshape(new_shape)
129
145
  if not split:
130
146
  return new_grad
131
147
 
@@ -153,12 +169,11 @@ def beta_debias(beta, step):
153
169
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
154
170
  out: List[Optional[Tensor]]):
155
171
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
156
- s32 = torch._foreach_mul(s32, beta2)
157
- s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
158
- denom = torch._foreach_sqrt(s32)
159
- denom = [d.clamp(min=eps) for d in denom]
172
+ s32 = [s * beta2 + g * g * (1 - beta2) for s, g in zip(s32, g32)]
160
173
  copy_stochastic_list_(state, s32)
161
174
 
175
+ denom = [d.sqrt().clamp(min=eps) for d in s32]
176
+
162
177
  if out[0] is None:
163
178
  return denom
164
179
 
@@ -174,7 +189,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
174
189
 
175
190
  @decorator_knowngood
176
191
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
177
- g32 = promote(grad)
192
+ g32 = list(map(promote, grad))
178
193
  denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
179
194
  out = torch._foreach_div(g32, denom)
180
195
  copy_stochastic_list_(grad, out)
@@ -189,7 +204,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
189
204
 
190
205
  @decorator_knowngood
191
206
  def _compilable_exp_avg_(state, grad, beta):
192
- lerped = _lerp32(state, grad, beta)
207
+ lerped = _lerp(state, grad, beta)
193
208
  copy_stochastic_list_(grad, lerped)
194
209
 
195
210
 
@@ -240,29 +255,31 @@ def clean():
240
255
  gc.collect()
241
256
 
242
257
 
243
- def set_torch():
258
+ def _ignore_warning(msg):
259
+ warnings.filterwarnings('ignore', f'.*{msg}.*')
260
+
261
+
262
+ def set_torch(benchmark_limit: int = 32):
244
263
  cudnn.benchmark = True
245
264
  cudnn.deterministic = False
265
+ cudnn.benchmark_limit = benchmark_limit
246
266
  torch.use_deterministic_algorithms(False)
247
267
  torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
248
268
  opt_einsum.enabled = True
249
- opt_einsum.strategy = "auto-hq"
269
+ opt_einsum.strategy = "dp"
270
+
271
+ # Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
272
+ _ignore_warning(
273
+ 'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
274
+ _ignore_warning(
275
+ 'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
250
276
 
251
277
 
252
278
  @decorator
253
279
  def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
254
- """
255
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
256
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
257
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
258
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
259
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
260
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
261
- performance at all relative to UV^T, where USV^T = G is the SVD.
262
- """
263
280
  assert len(G.shape) == 2
264
281
  a, b, c = (3.4445, -4.7750, 2.0315)
265
- X = G.bfloat16()
282
+ X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
266
283
  X /= (X.norm() + eps) # ensure top singular value <= 1
267
284
  if G.size(0) > G.size(1):
268
285
  X = X.T
@@ -317,12 +334,24 @@ def nesterov_momentum(state, grad, beta):
317
334
  return grad
318
335
 
319
336
 
337
+ @decorator_knowngood
338
+ def _compilable_nesterov_ema_(state, grad, beta):
339
+ ema32 = _lerp(state, grad, beta)
340
+ stochastic_add_(grad, ema32, 1)
341
+
342
+
343
+ def nesterov_ema(state, grad, beta):
344
+ state, grad = list_guard(state, grad)
345
+ beta = scalar_guard(beta, state[0])
346
+ _compilable_nesterov_ema_(state, grad, beta)
347
+ return grad
348
+
349
+
350
+ @decorator_knowngood
320
351
  def _compilable_grafting(magnitude, direction):
321
352
  return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
322
353
 
323
354
 
324
- # mode in ("newtonschulz", "qr", "svd")
325
- # scale_mode in ("none", "scale", "graft")
326
355
  @decorator_knowngood
327
356
  def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
328
357
  if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
@@ -350,74 +379,82 @@ def _compilable_scatter_set(target, source, index):
350
379
  target[:] = source.contiguous()[index].reshape_as(target)
351
380
 
352
381
 
353
- def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
382
+ @decorator_knowngood
383
+ def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
354
384
  """
355
385
  Computes the eigenbases of the preconditioner using one round of power iteration
356
- followed by torch.linalg.qr decomposition.
386
+ followed by torch.linalg.qr decomposition, and updates exp_avg in-place from old to new eigenspace.
387
+
388
+ :param GG: List of accumulated gradient outer products.
389
+ :param Q: List of current eigenbases (updated in-place to Q_new).
390
+ :param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
357
391
  """
358
- matrix = []
359
- orth_matrix = []
360
- for m, o in zip(GG, Q):
361
- if len(m) == 0:
362
- matrix.append([])
363
- orth_matrix.append([])
364
- continue
365
- if m.data.dtype != torch.float:
366
- matrix.append(promote(m.data))
367
- orth_matrix.append(promote(o.data))
368
- else:
369
- matrix.append(promote(m.data))
370
- orth_matrix.append(promote(o.data))
392
+ if isinstance(Q, list) and not Q:
393
+ return
394
+
395
+ if exp_avg is not None and exp_avg.dim() != len(Q):
396
+ raise ValueError(f"exp_avg dim {exp_avg.dim()} does not match Q length {len(Q)}")
371
397
 
372
- indices = []
398
+ new_qs = []
373
399
 
374
- for ind, (m, o, q) in enumerate(zip(matrix, orth_matrix, Q)):
400
+ for m, q in zip(GG, Q):
375
401
  if len(m) == 0:
376
- indices.append(None)
377
402
  continue
378
403
 
379
- tmp = m @ o
380
- est_eig = torch.einsum('ij,ij->j', o, tmp)
404
+ m = promote(m.data)
405
+ q_old = promote(q.data)
406
+
407
+ tmp = m @ q_old
408
+ est_eig = torch.einsum('ij,ij->j', q_old, tmp)
381
409
  sort_idx = torch.argsort(est_eig, descending=True)
382
- indices.append(sort_idx)
383
- inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
384
410
 
385
- indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
386
- for i, ind in enumerate(indices))
387
- _compilable_scatter_set(exp_avg_sq, exp_avg_sq, indices)
411
+ tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
412
+ new_qs.append(tmp)
413
+
414
+ if exp_avg is None:
415
+ for q, q_new in zip(Q, new_qs):
416
+ copy_stochastic_(q, q_new)
417
+ return
418
+
419
+ assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
420
+ in_str = einsum_base[:exp_avg.dim()]
421
+ out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
422
+
423
+ from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if len(m) > 0])
424
+ if not from_shampoo:
425
+ return
426
+
427
+ to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if len(m) > 0])
428
+ out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
429
+
430
+ subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
431
+ exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
432
+ copy_stochastic_(exp_avg, exp_avg_new)
433
+
434
+ for q, q_new in zip(Q, new_qs):
435
+ copy_stochastic_(q, q_new)
388
436
 
389
437
 
390
438
  def get_orthogonal_matrix(mat):
391
439
  """
392
440
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
393
441
  """
394
- matrix = []
395
- for m in mat:
396
- if len(m) == 0:
397
- matrix.append([])
398
- continue
399
- if m.data.dtype != torch.float:
400
- float_data = False
401
- original_type = m.data.dtype
402
- original_device = m.data.device
403
- matrix.append(promote(m.data))
404
- else:
405
- float_data = True
406
- matrix.append(m.data)
407
442
 
408
443
  final = []
409
- for m in matrix:
444
+ for m in mat:
410
445
  if len(m) == 0:
411
446
  final.append([])
412
447
  continue
413
448
 
449
+ m = promote(m.data)
450
+
414
451
  device, dtype = m.device, m.dtype
415
452
  for modifier in (None, torch.double, 'cpu'):
416
453
  if modifier is not None:
417
454
  m = m.to(modifier)
418
455
  try:
419
- Q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device))[1].to(device=device,
420
- dtype=dtype)
456
+ eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
457
+ eigvec = eigvec.to(device=device, dtype=dtype)
421
458
  break
422
459
  except torch.OutOfMemoryError:
423
460
  pass
@@ -427,9 +464,9 @@ def get_orthogonal_matrix(mat):
427
464
  else:
428
465
  raise RuntimeError("Failed to compute eigenvalues.")
429
466
 
430
- Q = torch.flip(Q, [1])
467
+ eigvec = torch.flip(eigvec, [1])
431
468
 
432
- final.append(Q)
469
+ final.append(eigvec)
433
470
 
434
471
  return final
435
472
 
@@ -454,7 +491,7 @@ def get_beta1(group):
454
491
 
455
492
 
456
493
  def get_beta2(group):
457
- if 'beta2_scale' in group:
494
+ if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
458
495
  step = max(group.get("step", 1), 1)
459
496
  return 1 - step ** -group['beta2_scale']
460
497
  if 'betas' in group:
@@ -509,21 +546,46 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
509
546
  _compilable_stochastic_add_(x, y, alpha)
510
547
 
511
548
 
549
+ @decorator_knowngood
550
+ def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
551
+ for x_, y_ in zip(x, y):
552
+ x32 = promote(x_)
553
+ y32 = promote(y_)
554
+ copy_stochastic_(x_, x32 * y32)
555
+
556
+
557
+ def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
558
+ x, y = list_guard(x, y)
559
+ _compilable_stochastic_multiply_(x, y)
560
+
561
+
512
562
  @decorator
513
- def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
563
+ def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
564
+ """
565
+ Simplified by @francois-rozet in commit 704ccc4bab52429f945df421647ec82c54cdd65f
566
+ Re-commited due to faulty merge
567
+ """
514
568
  if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
515
569
  return
516
570
 
517
- for idx, sh in enumerate(grad.shape):
518
- if sh > max_precond_dim:
571
+ for idx, m in enumerate(GG):
572
+ if not isinstance(m, Tensor):
519
573
  continue
520
574
  b = einsum_base[idx]
521
575
  g0 = einsum_base[:grad.dim()]
522
576
  g1 = g0.replace(b, b.upper())
523
577
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
524
- GG[idx].lerp_(outer_product, 1 - beta)
578
+ m.lerp_(outer_product, 1 - beta)
579
+
580
+
581
+ def tree_apply(fn):
582
+ def _fn(*args):
583
+ return tree_map(fn, *args)
584
+
585
+ return _fn
525
586
 
526
587
 
588
+ @tree_apply
527
589
  def promote(x):
528
590
  if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
529
591
  return torch.float32
@@ -540,45 +602,38 @@ def min_dtype(xs: List[Tensor]):
540
602
  return torch.float32
541
603
 
542
604
 
543
- def update_preconditioner(grad, Q, GG, exp_avg_sq, max_precond_dim, precondition_1d, beta, update_precond):
605
+ def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d, beta, update_precond):
544
606
  """
545
607
  Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
546
608
  """
547
- compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
609
+ update_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
548
610
  if update_precond:
549
- get_orthogonal_matrix_QR(GG, Q, exp_avg_sq)
611
+ get_orthogonal_matrix_QR(GG, Q, exp_avg)
550
612
 
551
613
 
552
- def init_preconditioner(grad, state, beta, max_precond_dim=10000, precondition_1d=False):
614
+ def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
553
615
  """
554
616
  Initializes the preconditioner matrices (L and R in the paper).
555
617
  """
556
618
  state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
557
- if grad.dim() == 1:
558
- if precondition_1d or grad.shape[0] > max_precond_dim:
559
- state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
560
- else:
561
- state['GG'].append([])
562
-
563
- else:
619
+ if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
564
620
  for sh in grad.shape:
565
621
  if sh > max_precond_dim:
566
- state['GG'].append([])
622
+ state['GG'].append(None)
567
623
  else:
568
624
  state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
625
+ else:
626
+ state['GG'].append(None)
569
627
 
570
- compute_ggt(grad, state['GG'], max_precond_dim, precondition_1d, beta)
628
+ update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
571
629
  state['Q'] = get_orthogonal_matrix(state['GG'])
572
630
 
573
631
 
574
632
  @decorator
575
633
  def project(grad, Q, back: bool):
576
634
  """
577
-
578
635
  :param grad:
579
636
  :param Q:
580
- :param merge_dims:
581
- :param max_precond_dim:
582
637
  :param back: whether to project to Shampoo eigenbases or back to original space
583
638
  :return:
584
639
  """
@@ -591,12 +646,40 @@ def project(grad, Q, back: bool):
591
646
  return grad
592
647
 
593
648
 
649
+ def modify_closure(closure):
650
+ """
651
+ Modifies the closure function to use create_graph=True in backward().
652
+
653
+ Args:
654
+ closure: The closure function passed to the optimizer.
655
+
656
+ Returns:
657
+ The return value of the modified closure.
658
+ """
659
+
660
+ def patched_backward(self, *args, **kwargs):
661
+ kwargs['create_graph'] = True
662
+ return original_backward(self, *args, **kwargs)
663
+
664
+ original_backward = torch.Tensor.backward
665
+
666
+ with patch.object(torch.Tensor, 'backward', patched_backward):
667
+ return closure()
668
+
669
+
594
670
  class StatefulOptimizer(torch.optim.Optimizer):
671
+ """
672
+ finite_differences saves memory, but needs more compute. (Alternative is true HVP)
673
+ Both `True` and `False` have some edge cases they don't support, so experiment with it.
674
+ The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
675
+ Further notice that both methods have different numerics outputs
676
+ """
595
677
  ema_decay: float = 0.001
596
678
  compile_step: bool = False
597
679
  hessian_approx: bool = False
598
680
  precond_schedule: Union[Callable, float, None] = None
599
681
  stochastic_schedule: bool = False
682
+ finite_differences: bool = False
600
683
 
601
684
  def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
602
685
  super().__init__(params, {**defaults, 'foreach': foreach})
@@ -709,38 +792,68 @@ class StatefulOptimizer(torch.optim.Optimizer):
709
792
  set_(self.state_(p)['param_ema'], p.data)
710
793
  set_(p.data, ema_clone)
711
794
 
712
- def step(self, closure: Optional[Callable] = None):
713
- if self.precond_schedule is None:
714
- self._is_preconditioning = False
715
- else:
716
- self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
795
+ def _handle_closure(self, closure):
717
796
  hessian_approx = self.hessian_approx and self._is_preconditioning
797
+
718
798
  if closure is None:
719
799
  if hessian_approx:
720
800
  raise ValueError("Hessian approximation requires a closure.")
721
- loss = None
722
- else:
801
+ return None
802
+
803
+ if not hessian_approx:
723
804
  with torch.enable_grad():
724
805
  loss = closure()
725
- if hessian_approx:
726
- grads = []
727
- for group in self.param_groups:
728
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
729
- grads.append(g)
730
- p.vector = torch.randn_like(p)
731
- p.orig = p.data.clone()
732
- stochastic_add_(p.data, p.vector, tiny_bf16)
733
-
734
- with torch.enable_grad():
735
- closure()
736
-
737
- for group in self.param_groups:
738
- for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
739
- p.grad = grads.pop(0)
740
- stochastic_add_(g, p.grad, -1)
741
- p.hessian_vector = g
742
- p.data.copy_(p.orig)
743
- del p.orig
806
+ return loss
807
+
808
+ if self.finite_differences:
809
+ with torch.enable_grad():
810
+ loss = closure() # closure without retain_graph=True
811
+
812
+ grads = []
813
+ for group in self.param_groups:
814
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
815
+ grads.append(g)
816
+ p.vector = torch.randn_like(p)
817
+ p.orig = p.data.clone()
818
+ stochastic_add_(p.data, p.vector, tiny_bf16)
819
+ else:
820
+ with torch.enable_grad():
821
+ loss = modify_closure(closure)
822
+
823
+ if self.finite_differences:
824
+ with torch.enable_grad():
825
+ closure()
826
+
827
+ for group in self.param_groups:
828
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
829
+ p.grad = grads.pop(0)
830
+ stochastic_add_(g, p.grad, -1)
831
+ p.hessian_vector = g
832
+ p.data.copy_(p.orig)
833
+ del p.orig
834
+ else:
835
+ for group in self.param_groups:
836
+ for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
837
+ p.grad = g
838
+ params, grads = zip(*[x for group in self.param_groups for x in
839
+ self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
840
+ vs = [torch.randn_like(p) for p in params]
841
+ with torch.enable_grad():
842
+ hvs = torch.autograd.grad(grads, params, vs)
843
+
844
+ for p, g, v, hv in zip(params, grads, vs, hvs):
845
+ p.hessian_vector = hv
846
+ p.grad = g
847
+ p.vector = v
848
+
849
+ return loss
850
+
851
+ def step(self, closure: Optional[Callable] = None):
852
+ if self.precond_schedule is None:
853
+ self._is_preconditioning = False
854
+ else:
855
+ self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
856
+ loss = self._handle_closure(closure)
744
857
 
745
858
  # we assume that parameters are constant and that there are no excessive recompiles
746
859
  with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
@@ -748,7 +861,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
748
861
  group['is_preconditioning'] = self._is_preconditioning
749
862
  self._step(group)
750
863
  if self.use_ema:
751
- self.ema_update(group)
864
+ self.ema_update()
752
865
 
753
866
  return loss
754
867
 
@@ -758,12 +871,12 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
758
871
  copy_stochastic_(t, s)
759
872
 
760
873
 
761
- def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
874
+ @decorator_knowngood
875
+ def _lerp(state: List[Tensor], grad: List[Tensor], beta):
762
876
  ea32 = list(map(promote, state))
763
877
  grad = list(map(promote, grad))
764
878
  beta = promote(beta)
765
-
766
- ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
879
+ ea32 = [e * beta + g * (1 - beta) for e, g in zip(ea32, grad)]
767
880
  copy_stochastic_list_(state, ea32)
768
881
  return ea32
769
882
 
@@ -775,15 +888,14 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
775
888
  beta2 = beta_debias(beta2, step)
776
889
 
777
890
  g32 = list(map(promote, grad))
778
-
779
- exp_avg32 = _lerp32(exp_avg, g32, beta1)
780
- denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
781
- u32 = torch._foreach_div(exp_avg32, denom)
891
+ exp_avg32 = _lerp(exp_avg, g32, beta1)
892
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
893
+ u32 = [ea / d for ea, d in zip(exp_avg32, denom)]
782
894
  copy_stochastic_list_(grad, u32)
783
895
 
784
896
 
785
897
  def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
786
- eps: float):
898
+ eps: float = 1e-8):
787
899
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
788
900
  beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
789
901
  _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
@@ -798,9 +910,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
798
910
  beta2 = beta_debias(beta2, step)
799
911
 
800
912
  u32, g32 = [list(map(promote, x)) for x in [update, grad]]
801
-
802
- exp_avg32 = _lerp32(exp_avg, u32, beta1)
803
- denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
913
+ exp_avg32 = _lerp(exp_avg, u32, beta1)
914
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
804
915
  u32 = torch._foreach_div(exp_avg32, denom)
805
916
  _compilable_update_(y, u32, decay, lr, caution, g32)
806
917
 
@@ -810,51 +921,50 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
810
921
  caution: bool):
811
922
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
812
923
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
813
- return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
924
+ _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
814
925
 
815
926
 
816
927
  @decorator_knowngood
817
928
  def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
818
- beta2: Tensor, step: Tensor):
929
+ beta2: Tensor, step: Tensor, eps: Tensor):
819
930
  beta1 = beta_debias(beta1, step)
820
931
  beta2 = beta_debias(beta2, step)
821
932
 
822
933
  gp32 = list(map(promote, grad))
823
-
824
- denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
934
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, gp32, beta2, eps, [None])
825
935
  gp32 = torch._foreach_div(gp32, denom)
826
- gp32 = _lerp32(exp_avg, gp32, beta1)
827
-
936
+ gp32 = _lerp(exp_avg, gp32, beta1)
828
937
  copy_stochastic_list_(grad, gp32)
829
938
 
830
939
 
831
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
940
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
941
+ eps: float = 1e-8):
832
942
  exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
833
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
834
- _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
943
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
944
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
835
945
  return grad
836
946
 
837
947
 
838
948
  @decorator_knowngood
839
949
  def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
840
950
  grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
841
- caution: bool):
951
+ caution: bool, eps: Tensor):
842
952
  beta1 = beta_debias(beta1, step)
843
953
  beta2 = beta_debias(beta2, step)
844
954
 
845
955
  u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
846
-
847
- denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
956
+ denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
848
957
  u32 = torch._foreach_div(u32, denom)
849
- u32 = _lerp32(exp_avg, u32, beta1)
958
+ u32 = _lerp(exp_avg, u32, beta1)
850
959
  _compilable_update_(y, u32, decay, lr, caution, gp32)
851
960
 
852
961
 
853
962
  def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
854
- grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
963
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
964
+ eps: float = 1e-8):
855
965
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
856
- beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
857
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
966
+ beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
967
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
858
968
 
859
969
 
860
970
  @decorator_knowngood
@@ -891,7 +1001,7 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
891
1001
  copy_stochastic_list_(exp_avg, exp_avg32)
892
1002
 
893
1003
  beta2 = beta_debias(beta2, step + 1)
894
- exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
1004
+ exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
895
1005
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
896
1006
 
897
1007
  copy_stochastic_list_(grad, update)
@@ -970,12 +1080,15 @@ def get_soap_precond_schedule(precond_scheduler):
970
1080
  return _inner
971
1081
 
972
1082
 
1083
+ def _max_idx(x: List[int]):
1084
+ return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
1085
+
1086
+
973
1087
  def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
974
1088
  """For a scalar or tensor t, we initialize its preconditioner Q and
975
1089
  reusable einsum expressions for updating Q and preconditioning gradient.
976
1090
  """
977
1091
  letters = string.ascii_lowercase + string.ascii_uppercase
978
-
979
1092
  dtype = dtype if dtype is not None else t.dtype
980
1093
  shape = t.shape
981
1094
 
@@ -992,17 +1105,20 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
992
1105
 
993
1106
  scale = scale ** (1 / len(shape))
994
1107
 
1108
+ dim_diag = [False for _ in shape]
995
1109
  if memory_save_mode is None:
996
- dim_diag = [False for _ in shape]
1110
+ pass
997
1111
  elif memory_save_mode == "one_diag":
998
- rev_sorted_dims = np.argsort(shape)[::-1]
999
- dim_diag = [False for _ in shape]
1000
- dim_diag[rev_sorted_dims[0]] = True
1112
+ dim_diag[_max_idx(shape)] = True
1113
+ elif memory_save_mode == "smart_one_diag":
1114
+ sorted_shape = sorted(shape)
1115
+ if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
1116
+ dim_diag[_max_idx(shape)] = True
1001
1117
  elif memory_save_mode == "all_diag":
1002
1118
  dim_diag = [True for _ in shape]
1003
1119
  else:
1004
1120
  raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
1005
- "[None, 'one_diag', 'all_diag']")
1121
+ "[None, 'one_diag', 'all_diag', 'smart_one_diag']")
1006
1122
 
1007
1123
  Q = []
1008
1124
  piece1A, piece2A, piece3A = ([], "", "")
@@ -1016,11 +1132,9 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1016
1132
  piece1A.append(letters[i])
1017
1133
  piece2A = piece2A + letters[i]
1018
1134
  piece3A = piece3A + letters[i]
1019
-
1020
1135
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1021
1136
  subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
1022
1137
  exprGs.append(subscripts)
1023
-
1024
1138
  piece1P.append(letters[i + 13])
1025
1139
  piece2P.append(letters[i + 13])
1026
1140
  piece3P = piece3P + letters[i + 13]
@@ -1028,16 +1142,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
1028
1142
  else:
1029
1143
  # use triangular matrix as preconditioner for this dim
1030
1144
  Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
1031
-
1032
1145
  piece1A.append(letters[i] + letters[i + 13])
1033
1146
  piece2A = piece2A + letters[i + 13]
1034
1147
  piece3A = piece3A + letters[i]
1035
-
1036
1148
  piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
1037
1149
  piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
1038
1150
  subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
1039
1151
  exprGs.append(subscripts)
1040
-
1041
1152
  a, b, c = (letters[i], letters[i + 13], letters[i + 26])
1042
1153
  piece1P.append(a + b)
1043
1154
  piece2P.append(a + c)
@@ -1058,7 +1169,7 @@ def psgd_balance_Q(Q_in):
1058
1169
 
1059
1170
 
1060
1171
  def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
1061
- eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1172
+ eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
1062
1173
  eps *= G.norm() / G.numel()
1063
1174
  G = G + torch.randn_like(G) * eps
1064
1175
  md = min_dtype(Q + [G])
@@ -1084,9 +1195,7 @@ def psgd_lb(A, max_abs):
1084
1195
  A /= max_abs
1085
1196
  a0 = torch.einsum('ij,ij->j', A, A)
1086
1197
  i = torch.argmax(a0)
1087
-
1088
1198
  x = torch.index_select(A, 1, i).flatten().contiguous()
1089
-
1090
1199
  x = torch.einsum('i,ij->j', x, A)
1091
1200
  x /= x.norm()
1092
1201
  x = torch.einsum('j,kj->k', x, A)
@@ -1099,15 +1208,12 @@ def psgd_lb(A, max_abs):
1099
1208
  def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
1100
1209
  """Update Kronecker product preconditioner Q with pair (V, G)."""
1101
1210
  exprA, exprGs, _ = exprs
1102
-
1103
1211
  A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
1104
1212
 
1105
1213
  for q, exprG, o in zip(Q, exprGs, oq):
1106
1214
  term1 = promote(torch.einsum(exprG, A, A))
1107
1215
  term2 = promote(torch.einsum(exprG, conjB, conjB))
1108
-
1109
1216
  term1, term2 = term1 - term2, term1 + term2
1110
-
1111
1217
  term1 *= precond_lr
1112
1218
  norm = term2.norm(float('inf'))
1113
1219
  if q.dim() < 2:
@@ -1221,6 +1327,48 @@ def identity(x):
1221
1327
  return x
1222
1328
 
1223
1329
 
1330
+ @decorator_knowngood
1331
+ def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1332
+ ema32 = _lerp(ema, p, ema_decay)
1333
+ _lerp(p, ema32, 1 - weight_decay)
1334
+
1335
+
1336
+ def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1337
+ p, ema = list_guard(p, ema)
1338
+ ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
1339
+ _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1340
+
1341
+
1342
+ @decorator_knowngood
1343
+ def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1344
+ ema32 = _lerp(ema, p, ema_decay)
1345
+ for p_, e_ in zip(p, ema32):
1346
+ p32 = promote(p_)
1347
+ p32 = p32 + (p32 - e_).sign() * weight_decay
1348
+ copy_stochastic_(p_, p32)
1349
+
1350
+
1351
+ def l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
1352
+ p, ema = list_guard(p, ema)
1353
+ ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
1354
+ _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
1355
+
1356
+
1357
+ @decorator_knowngood
1358
+ def _compilable_sign_(grad: List[Tensor], graft: bool):
1359
+ for g_ in grad:
1360
+ gs = g_.sign()
1361
+ if graft:
1362
+ gs = _compilable_grafting(g_, gs)
1363
+ copy_stochastic_(g_, gs)
1364
+
1365
+
1366
+ def sign_(grad: List[Tensor], graft: bool = True):
1367
+ grad = list_guard(grad)
1368
+ _compilable_sign_(grad, graft)
1369
+ return grad
1370
+
1371
+
1224
1372
  @decorator_knowngood
1225
1373
  def _compilable_trust_region_clip_(grad, lerp, scale):
1226
1374
  # (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
@@ -1372,7 +1520,6 @@ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps:
1372
1520
 
1373
1521
  if graft:
1374
1522
  out = _compilable_grafting(g, out)
1375
-
1376
1523
  copy_stochastic_(g, out)
1377
1524
 
1378
1525
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.2
3
+ Version: 1.6.0
4
4
  Summary: Efficient optimizers
5
- Home-page: https://github.com/clashluke/heavyball
6
- Author: Lucas Nestler
5
+ Home-page: https://github.com/HomebrewML/HeavyBall
6
+ Author: HeavyBall Authors
7
7
  Author-email: github.heavyball@nestler.sh
8
8
  License: BSD
9
9
  Classifier: Development Status :: 5 - Production/Stable
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
300
300
  def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
301
301
  weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
302
302
  merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
303
- data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
303
+ correct_bias: bool = True, warmup_steps: int = 1,
304
304
  split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
305
305
  mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
306
306
  beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
324
324
  * **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
325
325
  * **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
326
326
  * **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
327
- * **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
328
327
  * **`correct_bias`**: Enables/disables bias correction for the running averages.
329
328
  * **`warmup_steps`**: Number of steps for linear learning rate warmup.
330
329
  * **`split`**: Whether to split large dimensions when forming the preconditioner.
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
931
930
  * **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
932
931
  `compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
933
932
  configurations.
934
-
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
2
+ heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
3
+ heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
4
+ heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
6
+ heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.6.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=f0wWIjsibgA4_YwkPP8HFD7-snggYsAOFc84W0WnNMA,12049
2
- heavyball/chainable.py,sha256=ygeQU-t3RT0Q1BWrEQ_0b4SlXYy8aGDt0DCZAfbiNiw,25040
3
- heavyball/utils.py,sha256=D7ENwrIex_dgFiUHezymmsIdruoQ4_hYztIolCXo2KE,50636
4
- heavyball-1.5.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.2.dist-info/METADATA,sha256=n_2fW7Wcz_btxBRWFibTe8wnM10B2su100bJzW0bfZY,43584
6
- heavyball-1.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.2.dist-info/RECORD,,