heavyball 1.1.1__tar.gz → 1.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {heavyball-1.1.1 → heavyball-1.1.2}/PKG-INFO +1 -1
  2. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball/__init__.py +40 -35
  3. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball/chainable.py +17 -19
  4. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball/utils.py +91 -78
  5. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/PKG-INFO +1 -1
  6. {heavyball-1.1.1 → heavyball-1.1.2}/setup.py +1 -1
  7. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_channels_last.py +2 -1
  8. {heavyball-1.1.1 → heavyball-1.1.2}/LICENSE +0 -0
  9. {heavyball-1.1.1 → heavyball-1.1.2}/README.md +0 -0
  10. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
  11. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
  12. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/requires.txt +0 -0
  13. {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/top_level.txt +0 -0
  14. {heavyball-1.1.1 → heavyball-1.1.2}/setup.cfg +0 -0
  15. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_bf16_params.py +0 -0
  16. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_bf16_q.py +0 -0
  17. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_bf16_storage.py +0 -0
  18. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_caution.py +0 -0
  19. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_closure.py +0 -0
  20. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_ema.py +0 -0
  21. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_foreach.py +0 -0
  22. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_hook.py +0 -0
  23. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_mars.py +0 -0
  24. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_memory.py +0 -0
  25. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_merge.py +0 -0
  26. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_no_grad.py +0 -0
  27. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_psgd.py +0 -0
  28. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_soap.py +0 -0
  29. {heavyball-1.1.1 → heavyball-1.1.2}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.1
3
+ Version: 1.1.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,9 +10,9 @@ class ForeachAdamW(C.BaseOpt):
10
10
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
11
11
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
12
12
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
13
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
14
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
15
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
13
+ defaults = locals()
14
+ defaults.pop("self")
15
+ params = defaults.pop("params")
16
16
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
17
17
 
18
18
 
@@ -25,9 +25,9 @@ class ForeachRMSprop(C.BaseOpt):
25
25
  weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
26
26
  caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
27
27
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
28
- defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
29
- foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
30
- beta2_scale=beta2_scale)
28
+ defaults = locals()
29
+ defaults.pop("self")
30
+ params = defaults.pop("params")
31
31
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
32
32
 
33
33
 
@@ -36,10 +36,9 @@ class ForeachSFAdamW(C.ScheduleFree):
36
36
  weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
37
37
  caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
38
38
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
39
- defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
40
- weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
41
- foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
42
- beta2_scale=beta2_scale)
39
+ defaults = locals()
40
+ defaults.pop("self")
41
+ params = defaults.pop("params")
43
42
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
44
43
  C.update_by_schedule_free)
45
44
 
@@ -53,9 +52,9 @@ class ForeachADOPT(C.BaseOpt):
53
52
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
54
53
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
55
54
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
56
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
57
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
58
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
55
+ defaults = locals()
56
+ defaults.pop("self")
57
+ params = defaults.pop("params")
59
58
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
60
59
 
61
60
 
@@ -65,9 +64,9 @@ class ForeachMuon(C.BaseOpt):
65
64
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
66
65
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
67
66
  nesterov: bool = True):
68
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
69
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
70
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
67
+ defaults = locals()
68
+ defaults.pop("self")
69
+ params = defaults.pop("params")
71
70
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
72
71
  C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
73
72
 
@@ -77,12 +76,24 @@ class ForeachLaProp(C.BaseOpt):
77
76
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
78
77
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
79
78
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
80
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
81
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
82
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
79
+ defaults = locals()
80
+ defaults.pop("self")
81
+ params = defaults.pop("params")
83
82
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
84
83
 
85
84
 
85
+ class MuonLaProp(C.BaseOpt):
86
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
87
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
88
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
89
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
90
+ defaults = locals()
91
+ defaults.pop("self")
92
+ params = defaults.pop("params")
93
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
94
+ C.orthogonalize_update)
95
+
96
+
86
97
  class ForeachSOAP(C.BaseOpt):
87
98
  """
88
99
  ForeachSOAP
@@ -112,12 +123,10 @@ class ForeachSOAP(C.BaseOpt):
112
123
  gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
124
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
114
125
 
115
- defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
116
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
117
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
118
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
119
- 'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
120
- 'beta2_scale': beta2_scale}
126
+ defaults = locals()
127
+ defaults.pop("self")
128
+ params = defaults.pop("params")
129
+
121
130
  if use_precond_schedule:
122
131
  del defaults['precondition_frequency']
123
132
  else:
@@ -161,19 +170,15 @@ class ForeachPSGDKron(C.BaseOpt):
161
170
  gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
162
171
  # expert parameters
163
172
  precond_init_scale=1.0, precond_lr=0.1):
173
+ defaults = locals()
174
+ defaults.pop("self")
175
+ params = defaults.pop("params")
176
+
164
177
  delayed = C.default(delayed, self.delayed)
165
178
  cached = C.default(cached, self.cached)
166
179
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
167
180
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
168
181
 
169
- defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
170
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
171
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
172
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
173
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
174
- storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
175
- stochastic_schedule=stochastic_schedule)
176
-
177
182
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
178
183
  *(C.exp_avg,) * exp_avg_input, #
179
184
  functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
@@ -215,9 +220,9 @@ CachedPSGDKron = ForeachCachedPSGDKron
215
220
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
216
221
  Muon = ForeachMuon
217
222
 
218
- __all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
223
+ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
219
224
  "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
220
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
225
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
221
226
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
222
227
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
223
228
  "ForeachRMSprop", "ForeachMuon"]
@@ -160,22 +160,21 @@ def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
160
160
  @zero_guard("exp_avg", "exp_avg_sq")
161
161
  @no_state
162
162
  def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
163
- utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
164
- group['lr'], group['eps'], group['weight_decay'], group['caution'])
163
+ utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
164
+ group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution'])
165
165
  raise SkipUpdate
166
166
 
167
167
 
168
168
  @zero_guard("exp_avg", "exp_avg_sq")
169
169
  @no_state
170
170
  def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
171
- return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
172
- group['eps'])
171
+ return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'])
173
172
 
174
173
 
175
174
  @zero_guard("exp_avg", "exp_avg_sq")
176
175
  @no_state
177
176
  def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
178
- utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
177
+ utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
179
178
  group['step'], group['lr'], group['weight_decay'], group['caution'])
180
179
  raise SkipUpdate
181
180
 
@@ -203,7 +202,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
203
202
  utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
204
203
  raise SkipUpdate
205
204
 
206
- utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
205
+ utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
207
206
  group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
208
207
  raise SkipUpdate
209
208
 
@@ -262,13 +261,13 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
262
261
 
263
262
 
264
263
  @no_state_no_foreach
265
- def orthogonalize_update(group, update, grad, param):
264
+ def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
266
265
  if update.dim() == 1:
267
266
  return update
268
267
  original_shape = update.shape
269
268
  # doing it this way, as tmp and update are not guaranteed to share memory address or layout
270
269
  tmp = update.flatten(1, -1)
271
- utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
270
+ utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp, scale_mode)
272
271
  return tmp.reshape(original_shape)
273
272
 
274
273
 
@@ -331,7 +330,7 @@ def _update_psgd_cache(cached, Q_cache, q):
331
330
 
332
331
  def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
333
332
  if cached:
334
- return utils.precond_grad_cached_(cache_expr, update, *cache_expr)
333
+ return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
335
334
  return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
336
335
 
337
336
 
@@ -352,7 +351,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
352
351
  update = update.to(memory_format=torch.contiguous_format)
353
352
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
354
353
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
355
- out = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
354
+ out = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
356
355
  return torch.as_strided(out, old.shape, old.stride())
357
356
 
358
357
 
@@ -361,7 +360,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
361
360
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
362
361
  prob: Optional[callable] = None):
363
362
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
364
- precond = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
363
+ precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
365
364
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
366
365
  return precond
367
366
 
@@ -418,7 +417,6 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
418
417
  class ChainOpt(utils.StatefulOptimizer):
419
418
  def __init__(self, params, defaults, foreach: bool, *fns):
420
419
  super().__init__(params, defaults, foreach)
421
-
422
420
  self.fns = tuple(fns)
423
421
 
424
422
  def _step(self, group):
@@ -473,9 +471,10 @@ class BaseOpt(ChainOpt):
473
471
  update_clipping: str_or_fn = None
474
472
  palm: bool = False
475
473
  auto_fuse: bool = True
474
+ compile_step: bool = False
476
475
 
477
476
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
478
- palm: bool = None, *fns):
477
+ palm: bool = use_default, *fns):
479
478
  if default(update_clipping, self.update_clipping) is None:
480
479
  if fns and self.auto_fuse:
481
480
  args, kwargs = None, None
@@ -490,6 +489,7 @@ class BaseOpt(ChainOpt):
490
489
  else:
491
490
  if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
492
491
  raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
492
+
493
493
  fns = tuple(fns)
494
494
 
495
495
  if default(palm, self.palm):
@@ -505,9 +505,9 @@ class BaseOpt(ChainOpt):
505
505
  class ScheduleFree(BaseOpt):
506
506
  def eval(self):
507
507
  for group in self.param_groups:
508
- train_mode = group['train_mode']
508
+ group['train_mode'] = train_mode = not group.get('train_mode')
509
509
  beta1 = utils.get_beta1(group)
510
- if beta1 > 0 and train_mode:
510
+ if beta1 > 0 and not train_mode:
511
511
  for p in group['params']:
512
512
  state = self.state_(p)
513
513
  if 'z' in state:
@@ -516,13 +516,12 @@ class ScheduleFree(BaseOpt):
516
516
  p32 = utils.promote(p.data)
517
517
  p32.lerp_(end=z, weight=1 - 1 / beta1)
518
518
  utils.copy_stochastic_(p.data, p32)
519
- group['train_mode'] = False
520
519
 
521
520
  def train(self):
522
521
  for group in self.param_groups:
523
- train_mode = group['train_mode']
522
+ group['train_mode'] = train_mode = not group.get('train_mode')
524
523
  beta1 = utils.get_beta1(group)
525
- if beta1 > 0 and not train_mode:
524
+ if beta1 > 0 and train_mode:
526
525
  for p in group['params']:
527
526
  state = self.state_(p)
528
527
  if 'z' in state:
@@ -530,4 +529,3 @@ class ScheduleFree(BaseOpt):
530
529
  p32 = utils.promote(p.data)
531
530
  p32.lerp_(end=z, weight=1 - beta1)
532
531
  utils.copy_stochastic_(p.data, p32)
533
- group['train_mode'] = True
@@ -163,7 +163,7 @@ def beta_debias(beta, step):
163
163
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
164
164
  out: List[Optional[Tensor]]):
165
165
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
166
- torch._foreach_mul_(s32, beta2)
166
+ s32 = torch._foreach_mul(s32, beta2)
167
167
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
168
168
  denom = torch._foreach_sqrt(s32)
169
169
  [d.clamp_(min=eps) for d in denom]
@@ -185,11 +185,11 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
185
185
  @decorator_knowngood
186
186
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
187
187
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
188
- torch._foreach_mul_(s32, beta2)
188
+ s32 = torch._foreach_mul(s32, beta2)
189
189
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
190
190
  denom = torch._foreach_sqrt(s32)
191
191
  [d.clamp_(min=eps) for d in denom]
192
- out = torch._foreach_div_(g32, denom)
192
+ out = torch._foreach_div(g32, denom)
193
193
  copy_stochastic_list_(state, s32)
194
194
  copy_stochastic_list_(grad, out)
195
195
 
@@ -204,7 +204,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
204
204
  @decorator_knowngood
205
205
  def _compilable_exp_avg_(state, grad, beta):
206
206
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
- [s.lerp_(g, beta) for s, g in zip(s32, g32)]
207
+ s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
208
208
  copy_stochastic_list_(state, s32)
209
209
  copy_stochastic_list_(grad, s32)
210
210
 
@@ -225,7 +225,7 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
225
225
  torch._foreach_div_(p_norm, g_norm)
226
226
  torch._foreach_mul_(p_norm, clip_val)
227
227
  torch._foreach_minimum_(p_norm, 1)
228
- return torch._foreach_mul(gradients, p_norm)
228
+ torch._foreach_mul_(gradients, p_norm)
229
229
 
230
230
 
231
231
  def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
@@ -234,7 +234,8 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
234
234
  return gradients
235
235
  parameters, gradients = list_guard(parameters, gradients)
236
236
  clip_val = scalar_guard(clip_val, parameters[0])
237
- return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
237
+ _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
238
+ return gradients
238
239
 
239
240
 
240
241
  def is_compiling():
@@ -304,7 +305,7 @@ def ortho(x):
304
305
  @decorator_knowngood
305
306
  def _compilable_heavyball_momentum_(state, grad, beta):
306
307
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
307
- torch._foreach_mul_(s32, beta)
308
+ s32 = torch._foreach_mul(s32, beta)
308
309
  torch._foreach_add_(s32, g32)
309
310
  copy_stochastic_list_(state, s32)
310
311
  copy_stochastic_list_(grad, s32)
@@ -313,7 +314,7 @@ def _compilable_heavyball_momentum_(state, grad, beta):
313
314
  @decorator_knowngood
314
315
  def _compilable_nesterov_momentum_(state, grad, beta):
315
316
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
316
- torch._foreach_mul_(s32, beta)
317
+ s32 = torch._foreach_mul(s32, beta)
317
318
  torch._foreach_add_(s32, g32)
318
319
  [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
319
320
  copy_stochastic_list_(state, s32)
@@ -334,17 +335,27 @@ def nesterov_momentum(state, grad, beta):
334
335
  return grad
335
336
 
336
337
 
338
+ # mode in ("newtonschulz", "qr", "svd")
339
+ # scale_mode in ("none", "scale", "graft")
337
340
  @decorator_knowngood
338
- def inplace_orthogonal_(x, mode, out):
341
+ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
339
342
  if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
340
343
  y = zeropower_via_newtonschulz5(x, 5)
341
344
  elif mode == 'qr':
342
- y = torch.linalg.qr(x).Q
345
+ y = torch.linalg.qr(promote(x)).Q
343
346
  elif mode == 'svd':
344
- u, s, v = torch.linalg.svd(x)
347
+ u, s, v = torch.linalg.svd(promote(x))
345
348
  y = u @ v.T
346
349
  else:
347
350
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
351
+ if scale_mode == "none":
352
+ pass
353
+ elif scale_mode == "scale":
354
+ y *= max(1, x.size(0) / x.size(1)) ** 0.5
355
+ elif scale_mode == "graft":
356
+ y *= x.norm() / y.norm().clamp_(min=1e-6)
357
+ else:
358
+ raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
348
359
  set_(out, y)
349
360
 
350
361
 
@@ -378,7 +389,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
378
389
  est_eig = torch.einsum('ij,ij->j', o, tmp)
379
390
  sort_idx = torch.argsort(est_eig, descending=True)
380
391
  indices.append(sort_idx)
381
- inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q)
392
+ inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
382
393
 
383
394
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
384
395
  for i, ind in enumerate(indices))
@@ -437,8 +448,7 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
437
448
  for x_, y_ in zip(x, y):
438
449
  x32 = promote(x_)
439
450
  y32 = promote(y_)
440
- x32.lerp_(y32, a)
441
- copy_stochastic_(x_, x32)
451
+ copy_stochastic_(x_, x32.lerp(y32, a))
442
452
 
443
453
 
444
454
  def get_beta1(group):
@@ -499,7 +509,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
499
509
  for x_, y_ in zip(x, y):
500
510
  x32 = promote(x_)
501
511
  y32 = promote(y_)
502
- x32.add_(y32, alpha=alpha)
512
+ x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
503
513
  copy_stochastic_(x_, x32)
504
514
 
505
515
 
@@ -521,7 +531,7 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
521
531
  g0 = einsum_base[:grad.dim()]
522
532
  g1 = g0.replace(b, b.upper())
523
533
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
524
- GG[idx].lerp_(promote(outer_product), 1 - beta)
534
+ GG[idx].lerp_(outer_product, 1 - beta)
525
535
 
526
536
 
527
537
  def promote(x):
@@ -586,7 +596,8 @@ def project(grad, Q, back: bool):
586
596
  preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
587
597
  if preconditioners:
588
598
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
589
- grad = torch.einsum(f'{param},{preconditioners}->{out}', grad, *[q for q in Q if len(q) > 0])
599
+ out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
600
+ grad = out.to(grad.dtype)
590
601
  return grad
591
602
 
592
603
 
@@ -739,20 +750,26 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
739
750
  copy_stochastic_(t, s)
740
751
 
741
752
 
753
+ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
754
+ ea32 = list(map(promote, state))
755
+ grad = list(map(promote, grad))
756
+
757
+ ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
758
+ copy_stochastic_list_(state, ea32)
759
+ return ea32
760
+
761
+
742
762
  @decorator_knowngood
743
763
  def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
744
764
  step: Tensor):
745
765
  beta1 = beta_debias(beta1, step)
746
766
  beta2 = beta_debias(beta2, step)
747
767
 
748
- g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
768
+ g32 = list(map(promote, grad))
749
769
 
750
- [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
751
- denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
770
+ exp_avg32 = _lerp32(exp_avg, g32, beta1)
771
+ denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
752
772
  u32 = torch._foreach_div(exp_avg32, denom)
753
-
754
- copy_stochastic_list_(exp_avg, exp_avg32)
755
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
756
773
  copy_stochastic_list_(grad, u32)
757
774
 
758
775
 
@@ -764,28 +781,26 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
764
781
 
765
782
 
766
783
  @decorator_knowngood
767
- def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
768
- beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
769
- caution: bool):
784
+ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
785
+ grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
786
+ eps: Tensor, caution: bool):
770
787
  beta1 = beta_debias(beta1, step)
771
788
  beta2 = beta_debias(beta2, step)
772
789
 
773
- g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
790
+ u32, g32 = [list(map(promote, x)) for x in [update, grad]]
774
791
 
775
- [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
776
- denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
792
+ exp_avg32 = _lerp32(exp_avg, u32, beta1)
793
+ denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
777
794
  u32 = torch._foreach_div(exp_avg32, denom)
795
+ _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
778
796
 
779
- copy_stochastic_list_(exp_avg, exp_avg32)
780
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
781
- _compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
782
797
 
783
-
784
- def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
785
- beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
798
+ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
799
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
800
+ caution: bool):
786
801
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
787
802
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
788
- return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
803
+ return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
789
804
 
790
805
 
791
806
  @decorator_knowngood
@@ -794,14 +809,13 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
794
809
  beta1 = beta_debias(beta1, step)
795
810
  beta2 = beta_debias(beta2, step)
796
811
 
797
- gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
812
+ gp32 = list(map(promote, grad))
798
813
 
799
- denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
814
+ denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
800
815
  gp32 = torch._foreach_div(gp32, denom)
801
- stochastic_lerp_(exp_avg, gp32, 1 - beta1)
816
+ gp32 = _lerp32(exp_avg, gp32, beta1)
802
817
 
803
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
804
- copy_stochastic_list_(grad, exp_avg)
818
+ copy_stochastic_list_(grad, gp32)
805
819
 
806
820
 
807
821
  def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
@@ -812,52 +826,50 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
812
826
 
813
827
 
814
828
  @decorator_knowngood
815
- def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
816
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
817
- decay: Tensor, caution: bool):
829
+ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
830
+ grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
831
+ caution: bool):
818
832
  beta1 = beta_debias(beta1, step)
819
833
  beta2 = beta_debias(beta2, step)
820
834
 
821
- gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
822
-
823
- denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
824
- gp32 = torch._foreach_div(gp32, denom)
825
- stochastic_lerp_(exp_avg, gp32, 1 - beta1)
826
- update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
835
+ u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
827
836
 
828
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
837
+ denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
838
+ u32 = torch._foreach_div(u32, denom)
839
+ u32 = _lerp32(exp_avg, u32, beta1)
840
+ _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
829
841
 
830
842
 
831
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
832
- beta2: float, step: int, lr: float, decay: float, caution: bool):
843
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
844
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
833
845
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
834
846
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
835
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
847
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
836
848
 
837
849
 
838
850
  @decorator_knowngood
839
- def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
840
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
841
- update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
851
+ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
852
+ u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
853
+ _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
842
854
 
843
855
  beta1 = beta_debias(beta1, step)
844
856
  denom = torch._foreach_sqrt(exp_avg_sq32)
845
857
  [denom.clamp_(min=eps) for denom in denom]
846
- torch._foreach_mul_(exp_avg32, beta1)
847
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
858
+ exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
859
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
860
+ copy_stochastic_list_(exp_avg, exp_avg32)
848
861
 
849
862
  beta2 = beta_debias(beta2, step + 1)
850
- torch._foreach_mul_(exp_avg_sq32, beta2)
851
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
852
-
853
- copy_stochastic_list_(exp_avg, exp_avg32)
863
+ exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
864
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
854
865
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
855
866
 
856
867
 
857
- def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
868
+
869
+ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
858
870
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
859
871
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
860
- _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
872
+ _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
861
873
 
862
874
 
863
875
  @decorator_knowngood
@@ -868,21 +880,21 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
868
880
  beta1 = beta_debias(beta1, step)
869
881
  denom = torch._foreach_sqrt(exp_avg_sq32)
870
882
  [denom.clamp_(min=1e-8) for denom in denom]
871
- torch._foreach_mul_(exp_avg32, beta1)
883
+ exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
872
884
  [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
885
+ copy_stochastic_list_(exp_avg, exp_avg32)
873
886
 
874
887
  beta2 = beta_debias(beta2, step + 1)
875
- torch._foreach_mul_(exp_avg_sq32, beta2)
888
+ exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
876
889
  [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
877
-
878
- copy_stochastic_list_(exp_avg, exp_avg32)
879
890
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
891
+
880
892
  copy_stochastic_list_(grad, update)
881
893
 
882
894
 
883
895
  def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
884
- exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
885
- beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
896
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
897
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
886
898
  _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
887
899
  return grad
888
900
 
@@ -927,7 +939,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
927
939
 
928
940
  for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
929
941
  if caution:
930
- _compilable_cautioning_(promote(g_), u32_)
942
+ u32_ = _compilable_cautioning(promote(g_), u32_)
931
943
  add_fn(p32_, u32_, lr)
932
944
 
933
945
  copy_stochastic_list_(p, p32)
@@ -1243,7 +1255,7 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1243
1255
  prob = prob(group[f'{name}_prob_step'])
1244
1256
  if group['stochastic_schedule']:
1245
1257
  return rng.random() < prob
1246
- cumulative_prob = state.get(name, 0)
1258
+ cumulative_prob = group.get(name, 0)
1247
1259
  group[name] = cumulative_prob + prob
1248
1260
  return int(group[name]) > int(cumulative_prob)
1249
1261
 
@@ -1304,15 +1316,16 @@ def mars_correction(g, old_g, beta1, gamma):
1304
1316
 
1305
1317
 
1306
1318
  @decorator_knowngood
1307
- def _compilable_cautioning_(g: Tensor, update: Tensor):
1308
- mask = (g * update) > 0
1309
- update.masked_fill_(~mask, 0)
1310
- scale = mask.numel() / mask.sum().clamp(min=1)
1319
+ def _compilable_cautioning(g: Tensor, update: Tensor):
1320
+ mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
1321
+ update = update.masked_fill(mask, 0)
1322
+ scale = mask.numel() / (mask.numel() - mask.sum()).clamp(min=1)
1311
1323
  update.mul_(scale)
1324
+ return update
1312
1325
 
1313
1326
 
1314
1327
  def caution(g, update):
1315
- _compilable_cautioning_(g, update)
1328
+ return _compilable_cautioning(g, update)
1316
1329
 
1317
1330
 
1318
1331
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.1
3
+ Version: 1.1.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -10,7 +10,7 @@ setuptools.setup(
10
10
  name='heavyball',
11
11
  license='BSD',
12
12
  description='Efficient optimizers',
13
- version='1.1.1',
13
+ version='1.1.2',
14
14
  long_description=README,
15
15
  url='https://github.com/clashluke/heavyball',
16
16
  packages=setuptools.find_packages(),
@@ -11,6 +11,7 @@ from heavyball.utils import clean, set_torch
11
11
  from torch import nn
12
12
  from torch._dynamo import config
13
13
 
14
+ heavyball.utils.zeroth_power_mode = 'newtonschulz'
14
15
  heavyball.utils.compile_mode = 'default'
15
16
  config.cache_size_limit = 128
16
17
 
@@ -34,7 +35,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations
34
35
  if is_channels_last:
35
36
  model.to(memory_format=torch.channels_last)
36
37
 
37
- o = get_optim(opt, model.parameters(), lr=1e-5, weight_decay=1e-4, warmup_steps=16)
38
+ o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16)
38
39
 
39
40
  for _ in range(iterations):
40
41
  loss = model(torch.randn((1024, size, 4, 4), device='cuda')).square().mean()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes