heavyball 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -59,6 +59,19 @@ class ForeachADOPT(C.BaseOpt):
59
59
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
60
60
 
61
61
 
62
+ class ForeachMuon(C.BaseOpt):
63
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
64
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
65
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
66
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
67
+ nesterov: bool = True):
68
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
69
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
70
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
71
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
72
+ C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
73
+
74
+
62
75
  class ForeachLaProp(C.BaseOpt):
63
76
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
64
77
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
@@ -200,10 +213,11 @@ PurePSGD = ForeachPurePSGD
200
213
  DelayedPSGD = ForeachDelayedPSGD
201
214
  CachedPSGDKron = ForeachCachedPSGDKron
202
215
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
216
+ Muon = ForeachMuon
203
217
 
204
- __all__ = ["PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron", "CachedDelayedPSGDKron",
205
- "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT", "PrecondScheduleSOAP",
206
- "PrecondSchedulePaLMSOAP", 'RMSprop', #
218
+ __all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
219
+ "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
220
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
207
221
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
208
222
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
209
- "ForeachRMSprop"]
223
+ "ForeachRMSprop", "ForeachMuon"]
heavyball/chainable.py CHANGED
@@ -33,6 +33,23 @@ def _guard_in_state(state, key, template_fn):
33
33
  return state[key]
34
34
 
35
35
 
36
+ class FunctionTransform:
37
+ def __init__(self, fn):
38
+ self.fn = fn
39
+ self.fn_name = self.get_fn().__name__
40
+
41
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
42
+ raise NotImplementedError
43
+
44
+ def get_fn(self):
45
+ if hasattr(self.fn, 'get_fn'):
46
+ return self.fn.get_fn()
47
+ return self.fn
48
+
49
+ def val_name(self, name):
50
+ return f"{self.fn_name}_{name}"
51
+
52
+
36
53
  def _zero_guard(state, key, ref, dtype):
37
54
  return _guard_in_state(state, key,
38
55
  lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
@@ -43,61 +60,77 @@ def _storage_dtype(group):
43
60
  return getattr(torch, dtype)
44
61
 
45
62
 
46
- def zero_guard(*names):
47
- def _outer(fn):
48
- def _fn(state, group, update, grad, param, *args, **kwargs):
49
- vars = [[_zero_guard(state(p), name, p, _storage_dtype(group)) for p in param] for name in names]
50
- return fn(state, group, update, grad, param, *args, *vars, **kwargs)
63
+ class ZeroGuard(FunctionTransform):
64
+ def __init__(self, fn, names):
65
+ super().__init__(fn)
66
+ self.names = names
51
67
 
52
- return _fn
68
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
69
+ vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
70
+ for name in self.names]
71
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
53
72
 
54
- return _outer
55
73
 
74
+ class CopyGuard(FunctionTransform):
75
+ def __init__(self, fn, index, names):
76
+ super().__init__(fn)
77
+ self.index = index
78
+ self.names = names
56
79
 
57
- def copy_guard(index, *names):
58
- def _outer(fn):
59
- def _fn(state, group, update, grad, param, *args, **kwargs):
60
- val = [update, grad, param, *args][index]
61
- vars = [[_guard_in_state(state(p), name, lambda: torch.clone(v)) for p, v in zip(param, val)] #
62
- for name in names]
63
- return fn(state, group, update, grad, param, *args, *vars, **kwargs)
80
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
81
+ val = [update, grad, param, *args][self.index]
82
+ vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
83
+ for name in self.names]
84
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
64
85
 
65
- return _fn
66
86
 
67
- return _outer
87
+ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
88
+ def __init__(self, fn, names, init_fn):
89
+ super().__init__(fn)
90
+ self.names = names
91
+ self.init_fn = init_fn
68
92
 
93
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
94
+ vars = []
95
+ skip_update = False
96
+ for p, g, u in zip(param, grad, update):
97
+ st = state(p)
98
+ skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
99
+ vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
100
+ if skip_update:
101
+ raise SkipUpdate
102
+ return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
69
103
 
70
- def general_guard(*names, init_fn):
71
- def _outer(fn):
72
- def _fn(state, group, update, grad, param, *args, **kwargs):
73
- vars = []
74
- skip_update = False
75
- for p, g, u in zip(param, grad, update):
76
- st = state(p)
77
- skip_update |= _inplace_guard_(st, names, lambda: init_fn(st, group, u, g, p, **kwargs))
78
- vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in names])
79
- if skip_update:
80
- raise SkipUpdate
81
- return fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
82
104
 
83
- return _fn
105
+ class NoState(FunctionTransform):
106
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
107
+ return self.fn(group, update, grad, param, *args, **kwargs)
84
108
 
85
- return _outer
86
109
 
110
+ class NoStateNoForeach(FunctionTransform):
111
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
112
+ for a in zip(update, grad, param, *args):
113
+ return self.fn(group, *a, **kwargs)
87
114
 
88
- def no_state(fn):
89
- def _fn(state, *args, **kwargs):
90
- return fn(*args, **kwargs)
91
115
 
92
- return _fn
116
+ def zero_guard(*names):
117
+ return functools.partial(ZeroGuard, names=names)
93
118
 
94
119
 
95
- def no_state_no_foreach(fn):
96
- def _fn(state, group, *args, **kwargs):
97
- for a in zip(*args):
98
- return fn(group, *a, **kwargs)
120
+ def copy_guard(index, *names):
121
+ return functools.partial(CopyGuard, index=index, names=names)
122
+
123
+
124
+ def general_guard(*names, init_fn):
125
+ return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
99
126
 
100
- return _fn
127
+
128
+ def no_state(fn):
129
+ return NoState(fn)
130
+
131
+
132
+ def no_state_no_foreach(fn):
133
+ return NoStateNoForeach(fn)
101
134
 
102
135
 
103
136
  class SkipUpdate(ValueError):
@@ -116,8 +149,6 @@ def exp_avg(group, update, grad, param, exp_avg):
116
149
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
117
150
  out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
118
151
  group['eps'])
119
- if group['step'] == 1:
120
- raise SkipUpdate
121
152
  return out
122
153
 
123
154
 
@@ -232,6 +263,29 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
232
263
  raise ValueError("No preconditioner update schedule specified.")
233
264
 
234
265
 
266
+ @no_state_no_foreach
267
+ def orthogonalize_update(group, update, grad, param):
268
+ if update.dim() == 1:
269
+ return update
270
+ original_shape = update.shape
271
+ # doing it this way, as tmp and update are not guaranteed to share memory address or layout
272
+ tmp = update.flatten(1, -1)
273
+ utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
274
+ return tmp.reshape(original_shape)
275
+
276
+
277
+ @zero_guard("momentum")
278
+ @no_state
279
+ def nesterov_momentum(group, updates, grads, params, momentum):
280
+ return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
281
+
282
+
283
+ @zero_guard("momentum")
284
+ @no_state
285
+ def heavyball_momentum(group, updates, grads, params, momentum):
286
+ return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
287
+
288
+
235
289
  @zero_guard("exp_avg", "exp_avg_sq")
236
290
  @general_guard("Q", "GG", init_fn=_init_soap)
237
291
  @no_state
@@ -346,7 +400,7 @@ def apply_to_idx(fn, idx):
346
400
 
347
401
 
348
402
  def chain(state: Union[callable, dict], group, grad, param, *fns):
349
- update = grad
403
+ update = [torch.clone(g) for g in grad]
350
404
  skip_update = False
351
405
  for fn in fns:
352
406
  try:
@@ -375,7 +429,10 @@ class ChainOpt(utils.StatefulOptimizer):
375
429
  else:
376
430
  group['lr'] = -group['base_lr']
377
431
 
378
- p, g = zip(*list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group))))
432
+ vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
433
+ if not vals:
434
+ return
435
+ p, g = zip(*vals)
379
436
 
380
437
  if not group['foreach'] or len(p) == 1:
381
438
  for param, grad in zip(p, g):
heavyball/utils.py CHANGED
@@ -23,7 +23,7 @@ from torch.utils._pytree import tree_map
23
23
  compile_mode = "max-autotune-no-cudagraphs"
24
24
  dynamic = False
25
25
  compile_mode_recommended_to_none = None
26
- zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
26
+ zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
27
27
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
28
28
 
29
29
 
@@ -91,11 +91,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
91
91
  ckp1 = weight / weight_sum
92
92
  except ZeroDivisionError:
93
93
  ckp1 = 0
94
- ckp1 = 0
95
94
 
96
- # These operations update y in-place,
97
- # without computing x explicitly.
98
- lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
95
+ grad, parameters, z = list_guard(grad, parameters, z)
96
+ lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
99
97
  _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
100
98
  return weight_sum
101
99
 
@@ -179,28 +177,28 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
179
177
 
180
178
 
181
179
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
182
- state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
183
- beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
180
+ state, grad, out = list_guard(state, grad, out)
181
+ beta2, eps = scalar_guard(beta2, eps, state[0])
184
182
  return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
185
183
 
186
184
 
187
185
  @decorator_knowngood
188
- def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
189
- out: List[Optional[Tensor]]):
186
+ def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
190
187
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
191
188
  torch._foreach_mul_(s32, beta2)
192
189
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
193
190
  denom = torch._foreach_sqrt(s32)
194
191
  [d.clamp_(min=eps) for d in denom]
195
- out = torch._foreach_div(g32, denom)
192
+ out = torch._foreach_div_(g32, denom)
196
193
  copy_stochastic_list_(state, s32)
197
- return stochastic_round_list_(grad, out)
194
+ copy_stochastic_list_(grad, out)
198
195
 
199
196
 
200
- def scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps):
201
- grad, exp_avg_sq = list_guard(grad), list_guard(exp_avg_sq)
202
- beta2, eps = scalar_guard(beta2, grad[0]), scalar_guard(eps, grad[0])
203
- return _compilable_scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps, grad)
197
+ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
198
+ grad, exp_avg_sq = list_guard(grad, exp_avg_sq)
199
+ beta2, eps = scalar_guard(beta2, eps, grad[0])
200
+ _compilable_scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps)
201
+ return grad
204
202
 
205
203
 
206
204
  @decorator_knowngood
@@ -219,7 +217,7 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
219
217
  minimum: float = 1e-3, eps: float = 1e-8):
220
218
  if clip_val <= 0:
221
219
  return gradients
222
- parameters, gradients = list_guard(parameters), list_guard(gradients)
220
+ parameters, gradients = list_guard(parameters, gradients)
223
221
  clip_val = scalar_guard(clip_val, parameters[0])
224
222
  return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
225
223
 
@@ -254,33 +252,29 @@ def set_torch():
254
252
 
255
253
 
256
254
  @decorator
257
- def zeropower_via_newtonschulz5(G, init, steps=2, eps=1e-7):
255
+ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
258
256
  """
259
- Modified from "modded-nanogpt" under the MIT license:
260
- Original: https://github.com/KellerJordan/modded-nanogpt/blob/a0dcbfdd9a0617d091d5123cfc354745428e40d3/train_gpt2.py
261
-
262
257
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
263
258
  quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
264
259
  of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
265
260
  zero even beyond the point where the iteration no longer converges all the way to one everywhere
266
261
  on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
267
- where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
262
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
268
263
  performance at all relative to UV^T, where USV^T = G is the SVD.
269
264
  """
270
265
  assert len(G.shape) == 2
271
266
  a, b, c = (3.4445, -4.7750, 2.0315)
272
- X = G.float()
273
- init = init / (init.norm() + eps) # ensure top singular value <= 1
274
- X = X / (X.norm() + eps) # ensure top singular value <= 1
267
+ X = G.bfloat16()
268
+ X /= (X.norm() + eps) # ensure top singular value <= 1
275
269
  if G.size(0) > G.size(1):
276
270
  X = X.T
277
271
  for _ in range(steps):
278
- A = X @ X.T # preconditioner
279
- B = A @ init
280
- init = X = a * init + b * B + c * A @ B
272
+ A = X @ X.T
273
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
274
+ X = a * X + B @ X
281
275
  if G.size(0) > G.size(1):
282
276
  X = X.T
283
- return X
277
+ return X.to(G.dtype)
284
278
 
285
279
 
286
280
  def ortho(x):
@@ -292,6 +286,53 @@ def ortho(x):
292
286
  raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
293
287
 
294
288
 
289
+ @decorator_knowngood
290
+ def _compilable_heavyball_momentum_(state, grad, beta):
291
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
292
+ torch._foreach_mul_(s32, beta)
293
+ torch._foreach_add_(s32, g32)
294
+ copy_stochastic_list_(state, s32)
295
+ copy_stochastic_list_(grad, s32)
296
+
297
+
298
+ @decorator_knowngood
299
+ def _compilable_nesterov_momentum_(state, grad, beta):
300
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
301
+ torch._foreach_mul_(s32, beta)
302
+ torch._foreach_add_(s32, g32)
303
+ [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
304
+ copy_stochastic_list_(state, s32)
305
+ copy_stochastic_list_(grad, g32)
306
+
307
+
308
+ def heavyball_momentum(state, grad, beta):
309
+ state, grad = list_guard(state, grad)
310
+ beta = scalar_guard(beta, state[0])
311
+ _compilable_heavyball_momentum_(state, grad, beta)
312
+ return grad
313
+
314
+
315
+ def nesterov_momentum(state, grad, beta):
316
+ state, grad = list_guard(state, grad)
317
+ beta = scalar_guard(beta, state[0])
318
+ _compilable_nesterov_momentum_(state, grad, beta)
319
+ return grad
320
+
321
+
322
+ @decorator_knowngood
323
+ def inplace_orthogonal_(x, mode, out):
324
+ if mode == 'qr':
325
+ y = torch.linalg.qr(x).Q
326
+ elif mode == 'svd':
327
+ u, s, v = torch.linalg.svd(x)
328
+ y = u @ v.T
329
+ elif mode == 'newtonschulz':
330
+ y = zeropower_via_newtonschulz5(x, 5)
331
+ else:
332
+ raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
333
+ set_(out, y)
334
+
335
+
295
336
  def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
296
337
  """
297
338
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -322,17 +363,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
322
363
  est_eig = torch.einsum('ij,ij->j', o, tmp)
323
364
  sort_idx = torch.argsort(est_eig, descending=True)
324
365
  indices.append(sort_idx)
325
- if zeroth_power_mode == 'eigh':
326
- set_(q, torch.linalg.eigh(m)[1])
327
- elif zeroth_power_mode.startswith('newtonschulz'):
328
- iterations = zeroth_power_mode[len('newtonschulz'):]
329
- if iterations == '':
330
- iterations = 10
331
- else:
332
- iterations = int(iterations)
333
- set_(q, zeropower_via_newtonschulz5(m, o[:, sort_idx], iterations))
334
- else:
335
- set_(q, ortho(tmp[:, sort_idx]))
366
+ inplace_orthogonal_(tmp[:, sort_idx], q)
336
367
 
337
368
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
338
369
  for i, ind in enumerate(indices))
@@ -407,7 +438,6 @@ def get_beta1(group):
407
438
 
408
439
 
409
440
  def get_beta2(group):
410
- beta = None
411
441
  if 'beta2_scale' in group:
412
442
  step = max(group.get("step", 1), 1)
413
443
  return 1 - step ** -group['beta2_scale']
@@ -417,23 +447,36 @@ def get_beta2(group):
417
447
 
418
448
 
419
449
  def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
420
- x, y = list_guard(x), list_guard(y)
450
+ x, y = list_guard(x, y)
421
451
  a = scalar_guard(a, x[0])
422
452
  _compilable_stochastic_lerp_(x, y, a)
423
453
 
424
454
 
425
- def list_guard(x):
426
- if isinstance(x, (list, tuple)):
427
- return x
428
- return [x]
455
+ def list_guard(*xs):
456
+ out = []
457
+ for x in xs:
458
+ if isinstance(x, (list, tuple)):
459
+ out.append(x)
460
+ else:
461
+ out.append([x])
462
+ if len(xs) == 1:
463
+ return out[0]
464
+ return out
429
465
 
430
466
 
431
- def scalar_guard(x, ref):
432
- if isinstance(x, float):
433
- return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
434
- if isinstance(x, int):
435
- return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
436
- return x
467
+ def scalar_guard(*args):
468
+ *xs, ref = args
469
+ out = []
470
+ for x in xs:
471
+ if isinstance(x, float):
472
+ out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
473
+ elif isinstance(x, int):
474
+ out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
475
+ else:
476
+ out.append(x)
477
+ if len(xs) == 1:
478
+ return out[0]
479
+ return out
437
480
 
438
481
 
439
482
  @decorator_knowngood
@@ -446,7 +489,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
446
489
 
447
490
 
448
491
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
449
- x, y = list_guard(x), list_guard(y)
492
+ x, y = list_guard(x, y)
450
493
  alpha = scalar_guard(alpha, x[0])
451
494
  _compilable_stochastic_add_(x, y, alpha)
452
495
 
@@ -695,14 +738,14 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
695
738
 
696
739
  copy_stochastic_list_(exp_avg, exp_avg32)
697
740
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
698
- return stochastic_round_list_(exp_avg, u32)
741
+ copy_stochastic_list_(grad, u32)
699
742
 
700
743
 
701
744
  def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
702
745
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
703
- beta1, beta2, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step,
704
- exp_avg[0])
705
- return _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
746
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
747
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
748
+ return grad
706
749
 
707
750
 
708
751
  @decorator_knowngood
@@ -725,32 +768,32 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
725
768
 
726
769
  def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
727
770
  beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
728
- y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad))
729
- beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
771
+ y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
772
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
730
773
  return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
731
774
 
732
775
 
733
776
  @decorator_knowngood
734
- def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: Tensor,
777
+ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
735
778
  beta2: Tensor, step: Tensor):
736
779
  beta1 = beta_debias(beta1, step)
737
780
  beta2 = beta_debias(beta2, step)
738
781
 
739
- gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
782
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
740
783
 
741
784
  denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
742
785
  gp32 = torch._foreach_div(gp32, denom)
743
786
  stochastic_lerp_(exp_avg, gp32, 1 - beta1)
744
787
 
745
788
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
789
+ copy_stochastic_list_(grad, exp_avg)
746
790
 
747
791
 
748
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: float, beta2: float,
749
- step: int):
750
- exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
751
- beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
752
- _compilable_laprop_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
753
- return exp_avg
792
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
793
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
794
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
795
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
796
+ return grad
754
797
 
755
798
 
756
799
  @decorator_knowngood
@@ -770,11 +813,11 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
770
813
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
771
814
 
772
815
 
773
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
774
- beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
775
- y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected))
776
- beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
777
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution)
816
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
817
+ beta2: float, step: int, lr: float, decay: float, caution: bool):
818
+ exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
819
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
820
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
778
821
 
779
822
 
780
823
  @decorator_knowngood
@@ -797,8 +840,8 @@ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, l
797
840
 
798
841
 
799
842
  def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
800
- y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
801
- beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
843
+ exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
844
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
802
845
  _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
803
846
 
804
847
 
@@ -819,14 +862,14 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
819
862
 
820
863
  copy_stochastic_list_(exp_avg, exp_avg32)
821
864
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
822
-
823
- return update
865
+ copy_stochastic_list_(grad, update)
824
866
 
825
867
 
826
868
  def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
827
- grad, exp_avg_sq, exp_avg = list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
828
- beta1, beta2, step = scalar_guard(beta1, grad[0]), scalar_guard(beta2, grad[0]), scalar_guard(step, grad[0])
829
- return _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
869
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
870
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
871
+ _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
872
+ return grad
830
873
 
831
874
 
832
875
  def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
@@ -877,7 +920,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
877
920
 
878
921
  def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
879
922
  caution: bool = False, grad: List[Tensor] = None):
880
- param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
923
+ param, update, grad = list_guard(param, update, grad)
881
924
  lr = scalar_guard(lr, param[0])
882
925
  if not caution:
883
926
  grad = [None] * len(param)
@@ -983,11 +1026,15 @@ def psgd_balance_Q(Q_in):
983
1026
 
984
1027
 
985
1028
  def psgd_calc_A_and_conjB(exprA, G, Q):
1029
+ V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
1030
+ eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1031
+ eps *= G.norm() / G.numel()
1032
+ G += V * eps
986
1033
  md = min_dtype(Q + [G])
987
1034
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
988
1035
  order = G.dim()
989
1036
  p = list(range(order))
990
- conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1037
+ conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
991
1038
  Q = [promote(q) for q in Q]
992
1039
  for i, q in enumerate(Q):
993
1040
  if q.dim() <= 1:
@@ -1134,7 +1181,7 @@ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1134
1181
 
1135
1182
  def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1136
1183
  grad = list_guard(grad)
1137
- lerp, scale = scalar_guard(lerp, grad[0]), scalar_guard(scale, grad[0])
1184
+ lerp, scale = scalar_guard(lerp, scale, grad[0])
1138
1185
  return _compilable_trust_region_clip_(grad, lerp, scale)
1139
1186
 
1140
1187
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -132,5 +132,5 @@ It has several handy functions:
132
132
 
133
133
  * `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
134
134
  * `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
135
- * `zeroth_power_mode`, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate
136
- the eigenvectors. Eigh has the highest precision and cost
135
+ * `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
136
+ the eigenvectors.
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=ZPadmK5Q9hGrBAeNYldMUw8vciy9dCDL3d7Zk9erC3E,12702
2
+ heavyball/chainable.py,sha256=NLyJZS_vuQxzLE3RjQy0isQGCY5xGQBCs1wirT9BsQY,20172
3
+ heavyball/utils.py,sha256=BEbvwWswWMyEa4zozlnLaQIKEMH_k7OUGu84jlcj5t0,46736
4
+ heavyball-1.1.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.1.0.dist-info/METADATA,sha256=hVz8TPY0A9SqkjVpLX6LVXXlBVmAC8my96DHO8jtSb8,12022
6
+ heavyball-1.1.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.1.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.1.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=1QPYBIH8amnk3-_rKe6L9FJ0rkV5wVNRr7Yw9BXjIYI,11636
2
- heavyball/chainable.py,sha256=cp-tpetPr4CNN9xJ85JSo89JYC5BWUygoE6dnET6tmc,18141
3
- heavyball/utils.py,sha256=qUoB9EIxl7GUyLkV5a5JAKOD6TvPc1FNsqyUbJ-HY6o,46343
4
- heavyball-1.0.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.0.0.dist-info/METADATA,sha256=9C2btIxngp26TRCJFU6B8ftkWQt1rfZZC10rkAhaORw,12074
6
- heavyball-1.0.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.0.0.dist-info/RECORD,,