heavyball 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -59,6 +59,19 @@ class ForeachADOPT(C.BaseOpt):
59
59
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
60
60
 
61
61
 
62
+ class ForeachMuon(C.BaseOpt):
63
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
64
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
65
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
66
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
67
+ nesterov: bool = True):
68
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
69
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
70
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
71
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
72
+ C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
73
+
74
+
62
75
  class ForeachLaProp(C.BaseOpt):
63
76
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
64
77
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
@@ -200,10 +213,11 @@ PurePSGD = ForeachPurePSGD
200
213
  DelayedPSGD = ForeachDelayedPSGD
201
214
  CachedPSGDKron = ForeachCachedPSGDKron
202
215
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
216
+ Muon = ForeachMuon
203
217
 
204
- __all__ = ["PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron", "CachedDelayedPSGDKron",
205
- "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT", "PrecondScheduleSOAP",
206
- "PrecondSchedulePaLMSOAP", 'RMSprop', #
218
+ __all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
219
+ "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
220
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
207
221
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
208
222
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
209
- "ForeachRMSprop"]
223
+ "ForeachRMSprop", "ForeachMuon"]
heavyball/chainable.py CHANGED
@@ -33,6 +33,23 @@ def _guard_in_state(state, key, template_fn):
33
33
  return state[key]
34
34
 
35
35
 
36
+ class FunctionTransform:
37
+ def __init__(self, fn):
38
+ self.fn = fn
39
+ self.fn_name = self.get_fn().__name__
40
+
41
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
42
+ raise NotImplementedError
43
+
44
+ def get_fn(self):
45
+ if hasattr(self.fn, 'get_fn'):
46
+ return self.fn.get_fn()
47
+ return self.fn
48
+
49
+ def val_name(self, name):
50
+ return f"{self.fn_name}_{name}"
51
+
52
+
36
53
  def _zero_guard(state, key, ref, dtype):
37
54
  return _guard_in_state(state, key,
38
55
  lambda: torch.zeros_like(ref, dtype=torch.float32, memory_format=torch.preserve_format))
@@ -43,61 +60,77 @@ def _storage_dtype(group):
43
60
  return getattr(torch, dtype)
44
61
 
45
62
 
46
- def zero_guard(*names):
47
- def _outer(fn):
48
- def _fn(state, group, update, grad, param, *args, **kwargs):
49
- vars = [[_zero_guard(state(p), name, p, _storage_dtype(group)) for p in param] for name in names]
50
- return fn(state, group, update, grad, param, *args, *vars, **kwargs)
63
+ class ZeroGuard(FunctionTransform):
64
+ def __init__(self, fn, names):
65
+ super().__init__(fn)
66
+ self.names = names
51
67
 
52
- return _fn
68
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
69
+ vars = [[_zero_guard(state(p), self.val_name(name), p, _storage_dtype(group)) for p in param] #
70
+ for name in self.names]
71
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
53
72
 
54
- return _outer
55
73
 
74
+ class CopyGuard(FunctionTransform):
75
+ def __init__(self, fn, index, names):
76
+ super().__init__(fn)
77
+ self.index = index
78
+ self.names = names
56
79
 
57
- def copy_guard(index, *names):
58
- def _outer(fn):
59
- def _fn(state, group, update, grad, param, *args, **kwargs):
60
- val = [update, grad, param, *args][index]
61
- vars = [[_guard_in_state(state(p), name, lambda: torch.clone(v)) for p, v in zip(param, val)] #
62
- for name in names]
63
- return fn(state, group, update, grad, param, *args, *vars, **kwargs)
80
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
81
+ val = [update, grad, param, *args][self.index]
82
+ vars = [[_guard_in_state(state(p), self.val_name(name), lambda: torch.clone(v)) for p, v in zip(param, val)] #
83
+ for name in self.names]
84
+ return self.fn(state, group, update, grad, param, *args, *vars, **kwargs)
64
85
 
65
- return _fn
66
86
 
67
- return _outer
87
+ class GeneralGuard(FunctionTransform): # We can't guard against reuse in the general case
88
+ def __init__(self, fn, names, init_fn):
89
+ super().__init__(fn)
90
+ self.names = names
91
+ self.init_fn = init_fn
68
92
 
93
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
94
+ vars = []
95
+ skip_update = False
96
+ for p, g, u in zip(param, grad, update):
97
+ st = state(p)
98
+ skip_update |= _inplace_guard_(st, self.names, lambda: self.init_fn(st, group, u, g, p, **kwargs))
99
+ vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in self.names])
100
+ if skip_update:
101
+ raise SkipUpdate
102
+ return self.fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
69
103
 
70
- def general_guard(*names, init_fn):
71
- def _outer(fn):
72
- def _fn(state, group, update, grad, param, *args, **kwargs):
73
- vars = []
74
- skip_update = False
75
- for p, g, u in zip(param, grad, update):
76
- st = state(p)
77
- skip_update |= _inplace_guard_(st, names, lambda: init_fn(st, group, u, g, p, **kwargs))
78
- vars.append([st[name] if isinstance(name, str) else st.get(name[0], name[1]) for name in names])
79
- if skip_update:
80
- raise SkipUpdate
81
- return fn(state, group, update, grad, param, *args, *zip(*vars), **kwargs)
82
104
 
83
- return _fn
105
+ class NoState(FunctionTransform):
106
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
107
+ return self.fn(group, update, grad, param, *args, **kwargs)
84
108
 
85
- return _outer
86
109
 
110
+ class NoStateNoForeach(FunctionTransform):
111
+ def __call__(self, state, group, update, grad, param, *args, **kwargs):
112
+ for a in zip(update, grad, param, *args):
113
+ return self.fn(group, *a, **kwargs)
87
114
 
88
- def no_state(fn):
89
- def _fn(state, *args, **kwargs):
90
- return fn(*args, **kwargs)
91
115
 
92
- return _fn
116
+ def zero_guard(*names):
117
+ return functools.partial(ZeroGuard, names=names)
93
118
 
94
119
 
95
- def no_state_no_foreach(fn):
96
- def _fn(state, group, *args, **kwargs):
97
- for a in zip(*args):
98
- return fn(group, *a, **kwargs)
120
+ def copy_guard(index, *names):
121
+ return functools.partial(CopyGuard, index=index, names=names)
122
+
123
+
124
+ def general_guard(*names, init_fn):
125
+ return functools.partial(GeneralGuard, names=names, init_fn=init_fn)
99
126
 
100
- return _fn
127
+
128
+ def no_state(fn):
129
+ return NoState(fn)
130
+
131
+
132
+ def no_state_no_foreach(fn):
133
+ return NoStateNoForeach(fn)
101
134
 
102
135
 
103
136
  class SkipUpdate(ValueError):
@@ -107,18 +140,14 @@ class SkipUpdate(ValueError):
107
140
  @zero_guard("exp_avg")
108
141
  @no_state
109
142
  def exp_avg(group, update, grad, param, exp_avg):
110
- utils.stochastic_lerp_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
111
- return exp_avg
143
+ return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
112
144
 
113
145
 
114
146
  @zero_guard("exp_avg_sq")
115
147
  @no_state
116
148
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
117
- out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
118
- group['eps'])
119
- if group['step'] == 1:
120
- raise SkipUpdate
121
- return out
149
+ return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
150
+ group['eps'])
122
151
 
123
152
 
124
153
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -232,6 +261,29 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
232
261
  raise ValueError("No preconditioner update schedule specified.")
233
262
 
234
263
 
264
+ @no_state_no_foreach
265
+ def orthogonalize_update(group, update, grad, param):
266
+ if update.dim() == 1:
267
+ return update
268
+ original_shape = update.shape
269
+ # doing it this way, as tmp and update are not guaranteed to share memory address or layout
270
+ tmp = update.flatten(1, -1)
271
+ utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
272
+ return tmp.reshape(original_shape)
273
+
274
+
275
+ @zero_guard("momentum")
276
+ @no_state
277
+ def nesterov_momentum(group, updates, grads, params, momentum):
278
+ return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
279
+
280
+
281
+ @zero_guard("momentum")
282
+ @no_state
283
+ def heavyball_momentum(group, updates, grads, params, momentum):
284
+ return utils.heavyball_momentum(momentum, updates, utils.get_beta1(group))
285
+
286
+
235
287
  @zero_guard("exp_avg", "exp_avg_sq")
236
288
  @general_guard("Q", "GG", init_fn=_init_soap)
237
289
  @no_state
@@ -296,9 +348,12 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
296
348
  @no_state_no_foreach
297
349
  def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
298
350
  prob: Optional[callable] = None):
351
+ old = update
352
+ update = update.to(memory_format=torch.contiguous_format)
299
353
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
300
354
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
301
- return _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
355
+ out = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
356
+ return torch.as_strided(out, old.shape, old.stride())
302
357
 
303
358
 
304
359
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
@@ -346,7 +401,7 @@ def apply_to_idx(fn, idx):
346
401
 
347
402
 
348
403
  def chain(state: Union[callable, dict], group, grad, param, *fns):
349
- update = grad
404
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
350
405
  skip_update = False
351
406
  for fn in fns:
352
407
  try:
@@ -375,7 +430,10 @@ class ChainOpt(utils.StatefulOptimizer):
375
430
  else:
376
431
  group['lr'] = -group['base_lr']
377
432
 
378
- p, g = zip(*list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group))))
433
+ vals = list(self.split_p_and_g_in_group(group, should_promote=False, beta1=utils.get_beta1(group)))
434
+ if not vals:
435
+ return
436
+ p, g = zip(*vals)
379
437
 
380
438
  if not group['foreach'] or len(p) == 1:
381
439
  for param, grad in zip(p, g):
heavyball/utils.py CHANGED
@@ -23,7 +23,7 @@ from torch.utils._pytree import tree_map
23
23
  compile_mode = "max-autotune-no-cudagraphs"
24
24
  dynamic = False
25
25
  compile_mode_recommended_to_none = None
26
- zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster, 'eigh' is perfect but slow
26
+ zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
27
27
  tiny_bf16 = torch.finfo(torch.bfloat16).tiny
28
28
 
29
29
 
@@ -91,11 +91,9 @@ def schedule_free_(lr: float, weight_lr_power: float, weight_sum: float, beta1:
91
91
  ckp1 = weight / weight_sum
92
92
  except ZeroDivisionError:
93
93
  ckp1 = 0
94
- ckp1 = 0
95
94
 
96
- # These operations update y in-place,
97
- # without computing x explicitly.
98
- lr, ckp1, beta1 = scalar_guard(lr, parameters[0]), scalar_guard(ckp1, parameters[0]), scalar_guard(beta1, parameters[0])
95
+ grad, parameters, z = list_guard(grad, parameters, z)
96
+ lr, ckp1, beta1 = scalar_guard(lr, ckp1, beta1, grad[0])
99
97
  _compilable_schedule_free_(parameters, z, ckp1, grad, lr, beta1, decay)
100
98
  return weight_sum
101
99
 
@@ -179,28 +177,43 @@ def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tens
179
177
 
180
178
 
181
179
  def exp_avg_sq_(state, grad, beta2, eps, out=None):
182
- state, grad, out = list_guard(state), list_guard(grad), list_guard(out)
183
- beta2, eps = scalar_guard(beta2, state[0]), scalar_guard(eps, state[0])
180
+ state, grad, out = list_guard(state, grad, out)
181
+ beta2, eps = scalar_guard(beta2, eps, state[0])
184
182
  return _compilable_exp_avg_sq_(state, grad, beta2, eps, out)
185
183
 
186
184
 
187
185
  @decorator_knowngood
188
- def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
189
- out: List[Optional[Tensor]]):
186
+ def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
190
187
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
191
188
  torch._foreach_mul_(s32, beta2)
192
189
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
193
190
  denom = torch._foreach_sqrt(s32)
194
191
  [d.clamp_(min=eps) for d in denom]
195
- out = torch._foreach_div(g32, denom)
192
+ out = torch._foreach_div_(g32, denom)
196
193
  copy_stochastic_list_(state, s32)
197
- return stochastic_round_list_(grad, out)
194
+ copy_stochastic_list_(grad, out)
198
195
 
199
196
 
200
- def scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps):
201
- grad, exp_avg_sq = list_guard(grad), list_guard(exp_avg_sq)
202
- beta2, eps = scalar_guard(beta2, grad[0]), scalar_guard(eps, grad[0])
203
- return _compilable_scale_by_exp_avg_sq_(grad, exp_avg_sq, beta2, eps, grad)
197
+ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
198
+ grad, exp_avg_sq = list_guard(grad, exp_avg_sq)
199
+ beta2, eps = scalar_guard(beta2, eps, grad[0])
200
+ _compilable_scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps)
201
+ return grad
202
+
203
+
204
+ @decorator_knowngood
205
+ def _compilable_exp_avg_(state, grad, beta):
206
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
+ [s.lerp_(g, beta) for s, g in zip(s32, g32)]
208
+ copy_stochastic_list_(state, s32)
209
+ copy_stochastic_list_(grad, s32)
210
+
211
+
212
+ def scale_by_exp_avg_(state, grad, beta):
213
+ state, grad = list_guard(state, grad)
214
+ beta = scalar_guard(beta, state[0])
215
+ _compilable_exp_avg_(state, grad, beta)
216
+ return grad
204
217
 
205
218
 
206
219
  @decorator_knowngood
@@ -219,7 +232,7 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
219
232
  minimum: float = 1e-3, eps: float = 1e-8):
220
233
  if clip_val <= 0:
221
234
  return gradients
222
- parameters, gradients = list_guard(parameters), list_guard(gradients)
235
+ parameters, gradients = list_guard(parameters, gradients)
223
236
  clip_val = scalar_guard(clip_val, parameters[0])
224
237
  return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
225
238
 
@@ -254,33 +267,29 @@ def set_torch():
254
267
 
255
268
 
256
269
  @decorator
257
- def zeropower_via_newtonschulz5(G, init, steps=2, eps=1e-7):
270
+ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
258
271
  """
259
- Modified from "modded-nanogpt" under the MIT license:
260
- Original: https://github.com/KellerJordan/modded-nanogpt/blob/a0dcbfdd9a0617d091d5123cfc354745428e40d3/train_gpt2.py
261
-
262
272
  Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
263
273
  quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
264
274
  of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
265
275
  zero even beyond the point where the iteration no longer converges all the way to one everywhere
266
276
  on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
267
- where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
277
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
268
278
  performance at all relative to UV^T, where USV^T = G is the SVD.
269
279
  """
270
280
  assert len(G.shape) == 2
271
281
  a, b, c = (3.4445, -4.7750, 2.0315)
272
- X = G.float()
273
- init = init / (init.norm() + eps) # ensure top singular value <= 1
274
- X = X / (X.norm() + eps) # ensure top singular value <= 1
282
+ X = G.bfloat16()
283
+ X /= (X.norm() + eps) # ensure top singular value <= 1
275
284
  if G.size(0) > G.size(1):
276
285
  X = X.T
277
286
  for _ in range(steps):
278
- A = X @ X.T # preconditioner
279
- B = A @ init
280
- init = X = a * init + b * B + c * A @ B
287
+ A = X @ X.T
288
+ B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
289
+ X = a * X + B @ X
281
290
  if G.size(0) > G.size(1):
282
291
  X = X.T
283
- return X
292
+ return X.to(G.dtype)
284
293
 
285
294
 
286
295
  def ortho(x):
@@ -292,6 +301,53 @@ def ortho(x):
292
301
  raise NotImplementedError(f"Unknown zeroth_power_mode: {zeroth_power_mode}")
293
302
 
294
303
 
304
+ @decorator_knowngood
305
+ def _compilable_heavyball_momentum_(state, grad, beta):
306
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
307
+ torch._foreach_mul_(s32, beta)
308
+ torch._foreach_add_(s32, g32)
309
+ copy_stochastic_list_(state, s32)
310
+ copy_stochastic_list_(grad, s32)
311
+
312
+
313
+ @decorator_knowngood
314
+ def _compilable_nesterov_momentum_(state, grad, beta):
315
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
316
+ torch._foreach_mul_(s32, beta)
317
+ torch._foreach_add_(s32, g32)
318
+ [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
319
+ copy_stochastic_list_(state, s32)
320
+ copy_stochastic_list_(grad, g32)
321
+
322
+
323
+ def heavyball_momentum(state, grad, beta):
324
+ state, grad = list_guard(state, grad)
325
+ beta = scalar_guard(beta, state[0])
326
+ _compilable_heavyball_momentum_(state, grad, beta)
327
+ return grad
328
+
329
+
330
+ def nesterov_momentum(state, grad, beta):
331
+ state, grad = list_guard(state, grad)
332
+ beta = scalar_guard(beta, state[0])
333
+ _compilable_nesterov_momentum_(state, grad, beta)
334
+ return grad
335
+
336
+
337
+ @decorator_knowngood
338
+ def inplace_orthogonal_(x, mode, out):
339
+ if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
340
+ y = zeropower_via_newtonschulz5(x, 5)
341
+ elif mode == 'qr':
342
+ y = torch.linalg.qr(x).Q
343
+ elif mode == 'svd':
344
+ u, s, v = torch.linalg.svd(x)
345
+ y = u @ v.T
346
+ else:
347
+ raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
348
+ set_(out, y)
349
+
350
+
295
351
  def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
296
352
  """
297
353
  Computes the eigenbases of the preconditioner using one round of power iteration
@@ -322,17 +378,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
322
378
  est_eig = torch.einsum('ij,ij->j', o, tmp)
323
379
  sort_idx = torch.argsort(est_eig, descending=True)
324
380
  indices.append(sort_idx)
325
- if zeroth_power_mode == 'eigh':
326
- set_(q, torch.linalg.eigh(m)[1])
327
- elif zeroth_power_mode.startswith('newtonschulz'):
328
- iterations = zeroth_power_mode[len('newtonschulz'):]
329
- if iterations == '':
330
- iterations = 10
331
- else:
332
- iterations = int(iterations)
333
- set_(q, zeropower_via_newtonschulz5(m, o[:, sort_idx], iterations))
334
- else:
335
- set_(q, ortho(tmp[:, sort_idx]))
381
+ inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q)
336
382
 
337
383
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
338
384
  for i, ind in enumerate(indices))
@@ -407,7 +453,6 @@ def get_beta1(group):
407
453
 
408
454
 
409
455
  def get_beta2(group):
410
- beta = None
411
456
  if 'beta2_scale' in group:
412
457
  step = max(group.get("step", 1), 1)
413
458
  return 1 - step ** -group['beta2_scale']
@@ -417,23 +462,36 @@ def get_beta2(group):
417
462
 
418
463
 
419
464
  def stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[float, int, Tensor]):
420
- x, y = list_guard(x), list_guard(y)
465
+ x, y = list_guard(x, y)
421
466
  a = scalar_guard(a, x[0])
422
467
  _compilable_stochastic_lerp_(x, y, a)
423
468
 
424
469
 
425
- def list_guard(x):
426
- if isinstance(x, (list, tuple)):
427
- return x
428
- return [x]
470
+ def list_guard(*xs):
471
+ out = []
472
+ for x in xs:
473
+ if isinstance(x, (list, tuple)):
474
+ out.append(x)
475
+ else:
476
+ out.append([x])
477
+ if len(xs) == 1:
478
+ return out[0]
479
+ return out
429
480
 
430
481
 
431
- def scalar_guard(x, ref):
432
- if isinstance(x, float):
433
- return torch.empty((), dtype=torch.float32, device=ref.device).fill_(x)
434
- if isinstance(x, int):
435
- return torch.empty((), dtype=torch.int64, device=ref.device).fill_(x)
436
- return x
482
+ def scalar_guard(*args):
483
+ *xs, ref = args
484
+ out = []
485
+ for x in xs:
486
+ if isinstance(x, float):
487
+ out.append(torch.empty((), dtype=torch.float32, device=ref.device).fill_(x))
488
+ elif isinstance(x, int):
489
+ out.append(torch.empty((), dtype=torch.int64, device=ref.device).fill_(x))
490
+ else:
491
+ out.append(x)
492
+ if len(xs) == 1:
493
+ return out[0]
494
+ return out
437
495
 
438
496
 
439
497
  @decorator_knowngood
@@ -446,7 +504,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
446
504
 
447
505
 
448
506
  def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, Tensor]):
449
- x, y = list_guard(x), list_guard(y)
507
+ x, y = list_guard(x, y)
450
508
  alpha = scalar_guard(alpha, x[0])
451
509
  _compilable_stochastic_add_(x, y, alpha)
452
510
 
@@ -695,14 +753,14 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
695
753
 
696
754
  copy_stochastic_list_(exp_avg, exp_avg32)
697
755
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
698
- return stochastic_round_list_(exp_avg, u32)
756
+ copy_stochastic_list_(grad, u32)
699
757
 
700
758
 
701
759
  def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
702
760
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
703
- beta1, beta2, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step,
704
- exp_avg[0])
705
- return _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
761
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
762
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
763
+ return grad
706
764
 
707
765
 
708
766
  @decorator_knowngood
@@ -725,32 +783,32 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
725
783
 
726
784
  def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
727
785
  beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
728
- y, exp_avg, exp_avg_sq, grad = map(list_guard, (y, exp_avg, exp_avg_sq, grad))
729
- beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
786
+ y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
787
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
730
788
  return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
731
789
 
732
790
 
733
791
  @decorator_knowngood
734
- def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: Tensor,
792
+ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
735
793
  beta2: Tensor, step: Tensor):
736
794
  beta1 = beta_debias(beta1, step)
737
795
  beta2 = beta_debias(beta2, step)
738
796
 
739
- gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
797
+ gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
740
798
 
741
799
  denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
742
800
  gp32 = torch._foreach_div(gp32, denom)
743
801
  stochastic_lerp_(exp_avg, gp32, 1 - beta1)
744
802
 
745
803
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
804
+ copy_stochastic_list_(grad, exp_avg)
746
805
 
747
806
 
748
- def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor], beta1: float, beta2: float,
749
- step: int):
750
- exp_avg, exp_avg_sq, grad_projected = list_guard(exp_avg), list_guard(exp_avg_sq), list_guard(grad_projected)
751
- beta1, beta, step = scalar_guard(beta1, exp_avg[0]), scalar_guard(beta2, exp_avg[0]), scalar_guard(step, exp_avg[0])
752
- _compilable_laprop_(exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step)
753
- return exp_avg
807
+ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
808
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
809
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
810
+ _compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
811
+ return grad
754
812
 
755
813
 
756
814
  @decorator_knowngood
@@ -770,11 +828,11 @@ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq
770
828
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
771
829
 
772
830
 
773
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad_projected: List[Tensor],
774
- beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
775
- y, exp_avg, exp_avg_sq, grad_projected = map(list_guard, (y, exp_avg, exp_avg_sq, grad_projected))
776
- beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
777
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad_projected, beta1, beta2, step, lr, decay, caution)
831
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
832
+ beta2: float, step: int, lr: float, decay: float, caution: bool):
833
+ exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
834
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
835
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
778
836
 
779
837
 
780
838
  @decorator_knowngood
@@ -797,8 +855,8 @@ def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, l
797
855
 
798
856
 
799
857
  def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
800
- y, grad, exp_avg_sq, exp_avg = list_guard(y), list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
801
- beta1, beta2, step, lr = [scalar_guard (x, y[0]) for x in (beta1, beta2, step, lr)]
858
+ exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
859
+ beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
802
860
  _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
803
861
 
804
862
 
@@ -819,14 +877,14 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
819
877
 
820
878
  copy_stochastic_list_(exp_avg, exp_avg32)
821
879
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
822
-
823
- return update
880
+ copy_stochastic_list_(grad, update)
824
881
 
825
882
 
826
883
  def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
827
- grad, exp_avg_sq, exp_avg = list_guard(grad), list_guard(exp_avg_sq), list_guard(exp_avg)
828
- beta1, beta2, step = scalar_guard(beta1, grad[0]), scalar_guard(beta2, grad[0]), scalar_guard(step, grad[0])
829
- return _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
884
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
885
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
886
+ _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
887
+ return grad
830
888
 
831
889
 
832
890
  def stochastic_round_list_(ref: List[Tensor], source: List[Tensor]):
@@ -877,7 +935,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
877
935
 
878
936
  def update_param_(param: List[Tensor], update: List[Tensor], lr: float, decay: float, add_fn: callable = None,
879
937
  caution: bool = False, grad: List[Tensor] = None):
880
- param, update, grad = list_guard(param), list_guard(update), list_guard(grad)
938
+ param, update, grad = list_guard(param, update, grad)
881
939
  lr = scalar_guard(lr, param[0])
882
940
  if not caution:
883
941
  grad = [None] * len(param)
@@ -983,11 +1041,15 @@ def psgd_balance_Q(Q_in):
983
1041
 
984
1042
 
985
1043
  def psgd_calc_A_and_conjB(exprA, G, Q):
1044
+ V = torch.randn(G.shape, dtype=G.dtype, device=G.device)
1045
+ eps = scalar_guard(math.sqrt(torch.finfo(torch.float32).eps), G)
1046
+ eps *= G.norm() / G.numel()
1047
+ G += V * eps
986
1048
  md = min_dtype(Q + [G])
987
1049
  A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
988
1050
  order = G.dim()
989
1051
  p = list(range(order))
990
- conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
1052
+ conjB = torch.permute(V, p[1:] + p[:1]).to(promote(G.dtype))
991
1053
  Q = [promote(q) for q in Q]
992
1054
  for i, q in enumerate(Q):
993
1055
  if q.dim() <= 1:
@@ -1134,7 +1196,7 @@ def _compilable_trust_region_clip_(grad, lerp: float = 0.9, scale: float = 1.5):
1134
1196
 
1135
1197
  def trust_region_clip_(grad, lerp=0.9, scale=1.5):
1136
1198
  grad = list_guard(grad)
1137
- lerp, scale = scalar_guard(lerp, grad[0]), scalar_guard(scale, grad[0])
1199
+ lerp, scale = scalar_guard(lerp, scale, grad[0])
1138
1200
  return _compilable_trust_region_clip_(grad, lerp, scale)
1139
1201
 
1140
1202
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.0.0
3
+ Version: 1.1.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -132,5 +132,5 @@ It has several handy functions:
132
132
 
133
133
  * `set_torch()` sets pytorch optimization settings (TF32, opt_einsum, benchmark, ...)
134
134
  * `compile_mode`, a string passed as-is to `torch.compile(mode=compile_mode)` in all compiled heavyball calls; `compile_mode=None` disables torch_compile
135
- * `zeroth_power_mode`, a string determining whether to use QR, newtonschulz{iterations}, or svd or eigh to approximate
136
- the eigenvectors. Eigh has the highest precision and cost
135
+ * `zeroth_power_mode`, a string determining whether to use QR, newtonschulz, or svd or eigh to approximate
136
+ the eigenvectors.
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=ZPadmK5Q9hGrBAeNYldMUw8vciy9dCDL3d7Zk9erC3E,12702
2
+ heavyball/chainable.py,sha256=_KVV_bA_WYbyaiGDOaoQMHv-IM9jbgZx_cwzmRiKxl8,20321
3
+ heavyball/utils.py,sha256=F4nwiyDCTGg3uU5XAaX9_p11qf5Uiw4WRseMbuLfq0Y,47223
4
+ heavyball-1.1.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.1.1.dist-info/METADATA,sha256=uLpyF5G1Stjxxj23qmxpnk3ViUd0M55Q0wIDGDY58qk,12022
6
+ heavyball-1.1.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.1.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.1.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=1QPYBIH8amnk3-_rKe6L9FJ0rkV5wVNRr7Yw9BXjIYI,11636
2
- heavyball/chainable.py,sha256=cp-tpetPr4CNN9xJ85JSo89JYC5BWUygoE6dnET6tmc,18141
3
- heavyball/utils.py,sha256=qUoB9EIxl7GUyLkV5a5JAKOD6TvPc1FNsqyUbJ-HY6o,46343
4
- heavyball-1.0.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.0.0.dist-info/METADATA,sha256=9C2btIxngp26TRCJFU6B8ftkWQt1rfZZC10rkAhaORw,12074
6
- heavyball-1.0.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.0.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.0.0.dist-info/RECORD,,