heavyball 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -10,9 +10,9 @@ class ForeachAdamW(C.BaseOpt):
10
10
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
11
11
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
12
12
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
13
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
14
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
15
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
13
+ defaults = locals()
14
+ defaults.pop("self")
15
+ params = defaults.pop("params")
16
16
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
17
17
 
18
18
 
@@ -25,9 +25,9 @@ class ForeachRMSprop(C.BaseOpt):
25
25
  weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
26
26
  caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
27
27
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
28
- defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
29
- foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
30
- beta2_scale=beta2_scale)
28
+ defaults = locals()
29
+ defaults.pop("self")
30
+ params = defaults.pop("params")
31
31
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
32
32
 
33
33
 
@@ -36,10 +36,9 @@ class ForeachSFAdamW(C.ScheduleFree):
36
36
  weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
37
37
  caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
38
38
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
39
- defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
40
- weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
41
- foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
42
- beta2_scale=beta2_scale)
39
+ defaults = locals()
40
+ defaults.pop("self")
41
+ params = defaults.pop("params")
43
42
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
44
43
  C.update_by_schedule_free)
45
44
 
@@ -53,9 +52,9 @@ class ForeachADOPT(C.BaseOpt):
53
52
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
54
53
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
55
54
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
56
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
57
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
58
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
55
+ defaults = locals()
56
+ defaults.pop("self")
57
+ params = defaults.pop("params")
59
58
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
60
59
 
61
60
 
@@ -65,9 +64,9 @@ class ForeachMuon(C.BaseOpt):
65
64
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
66
65
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
67
66
  nesterov: bool = True):
68
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
69
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
70
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
67
+ defaults = locals()
68
+ defaults.pop("self")
69
+ params = defaults.pop("params")
71
70
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
72
71
  C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
73
72
 
@@ -77,12 +76,24 @@ class ForeachLaProp(C.BaseOpt):
77
76
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
78
77
  mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
79
78
  update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
80
- defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
81
- lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
82
- mars_gamma=mars_gamma, beta2_scale=beta2_scale)
79
+ defaults = locals()
80
+ defaults.pop("self")
81
+ params = defaults.pop("params")
83
82
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
84
83
 
85
84
 
85
+ class MuonLaProp(C.BaseOpt):
86
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
87
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
88
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
89
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
90
+ defaults = locals()
91
+ defaults.pop("self")
92
+ params = defaults.pop("params")
93
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
94
+ C.orthogonalize_update)
95
+
96
+
86
97
  class ForeachSOAP(C.BaseOpt):
87
98
  """
88
99
  ForeachSOAP
@@ -112,12 +123,10 @@ class ForeachSOAP(C.BaseOpt):
112
123
  gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
124
  use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
114
125
 
115
- defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
116
- "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
117
- "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
118
- "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
119
- 'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
120
- 'beta2_scale': beta2_scale}
126
+ defaults = locals()
127
+ defaults.pop("self")
128
+ params = defaults.pop("params")
129
+
121
130
  if use_precond_schedule:
122
131
  del defaults['precondition_frequency']
123
132
  else:
@@ -161,19 +170,15 @@ class ForeachPSGDKron(C.BaseOpt):
161
170
  gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
162
171
  # expert parameters
163
172
  precond_init_scale=1.0, precond_lr=0.1):
173
+ defaults = locals()
174
+ defaults.pop("self")
175
+ params = defaults.pop("params")
176
+
164
177
  delayed = C.default(delayed, self.delayed)
165
178
  cached = C.default(cached, self.cached)
166
179
  exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
167
180
  update_clipping = C.default(update_clipping, utils.trust_region_clip_)
168
181
 
169
- defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
170
- min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
171
- momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
172
- precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
173
- split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
174
- storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
175
- stochastic_schedule=stochastic_schedule)
176
-
177
182
  super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
178
183
  *(C.exp_avg,) * exp_avg_input, #
179
184
  functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
@@ -215,9 +220,9 @@ CachedPSGDKron = ForeachCachedPSGDKron
215
220
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
216
221
  Muon = ForeachMuon
217
222
 
218
- __all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
223
+ __all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
219
224
  "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
220
- "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
225
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
221
226
  "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
222
227
  "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
223
228
  "ForeachRMSprop", "ForeachMuon"]
heavyball/chainable.py CHANGED
@@ -140,16 +140,14 @@ class SkipUpdate(ValueError):
140
140
  @zero_guard("exp_avg")
141
141
  @no_state
142
142
  def exp_avg(group, update, grad, param, exp_avg):
143
- utils.stochastic_lerp_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
144
- return exp_avg
143
+ return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
145
144
 
146
145
 
147
146
  @zero_guard("exp_avg_sq")
148
147
  @no_state
149
148
  def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
150
- out = utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
151
- group['eps'])
152
- return out
149
+ return utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group["step"]),
150
+ group['eps'])
153
151
 
154
152
 
155
153
  @zero_guard("exp_avg", "exp_avg_sq")
@@ -162,22 +160,21 @@ def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
162
160
  @zero_guard("exp_avg", "exp_avg_sq")
163
161
  @no_state
164
162
  def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
165
- utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
166
- group['lr'], group['eps'], group['weight_decay'], group['caution'])
163
+ utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
164
+ group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution'])
167
165
  raise SkipUpdate
168
166
 
169
167
 
170
168
  @zero_guard("exp_avg", "exp_avg_sq")
171
169
  @no_state
172
170
  def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
173
- return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'],
174
- group['eps'])
171
+ return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'])
175
172
 
176
173
 
177
174
  @zero_guard("exp_avg", "exp_avg_sq")
178
175
  @no_state
179
176
  def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
180
- utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
177
+ utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
181
178
  group['step'], group['lr'], group['weight_decay'], group['caution'])
182
179
  raise SkipUpdate
183
180
 
@@ -205,7 +202,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
205
202
  utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
206
203
  raise SkipUpdate
207
204
 
208
- utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
205
+ utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
209
206
  group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
210
207
  raise SkipUpdate
211
208
 
@@ -264,13 +261,13 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
264
261
 
265
262
 
266
263
  @no_state_no_foreach
267
- def orthogonalize_update(group, update, grad, param):
264
+ def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
268
265
  if update.dim() == 1:
269
266
  return update
270
267
  original_shape = update.shape
271
268
  # doing it this way, as tmp and update are not guaranteed to share memory address or layout
272
269
  tmp = update.flatten(1, -1)
273
- utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
270
+ utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp, scale_mode)
274
271
  return tmp.reshape(original_shape)
275
272
 
276
273
 
@@ -333,7 +330,7 @@ def _update_psgd_cache(cached, Q_cache, q):
333
330
 
334
331
  def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
335
332
  if cached:
336
- return utils.precond_grad_cached_(cache_expr, update, *cache_expr)
333
+ return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
337
334
  return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
338
335
 
339
336
 
@@ -350,9 +347,12 @@ def _fused_cached_psgd_precond_grad(group, grad, param, cached, cache_expr, expr
350
347
  @no_state_no_foreach
351
348
  def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
352
349
  prob: Optional[callable] = None):
350
+ old = update
351
+ update = update.to(memory_format=torch.contiguous_format)
353
352
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
354
353
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
355
- return _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
354
+ out = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
355
+ return torch.as_strided(out, old.shape, old.stride())
356
356
 
357
357
 
358
358
  @general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd)
@@ -360,7 +360,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
360
360
  def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
361
361
  prob: Optional[callable] = None):
362
362
  Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
363
- precond = _cached_psgd_precond_grad(False, cache_expr, exprs, update, Q_mat, Q_cache)
363
+ precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
364
364
  _update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
365
365
  return precond
366
366
 
@@ -400,7 +400,7 @@ def apply_to_idx(fn, idx):
400
400
 
401
401
 
402
402
  def chain(state: Union[callable, dict], group, grad, param, *fns):
403
- update = [torch.clone(g) for g in grad]
403
+ update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
404
404
  skip_update = False
405
405
  for fn in fns:
406
406
  try:
@@ -417,7 +417,6 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
417
417
  class ChainOpt(utils.StatefulOptimizer):
418
418
  def __init__(self, params, defaults, foreach: bool, *fns):
419
419
  super().__init__(params, defaults, foreach)
420
-
421
420
  self.fns = tuple(fns)
422
421
 
423
422
  def _step(self, group):
@@ -472,9 +471,10 @@ class BaseOpt(ChainOpt):
472
471
  update_clipping: str_or_fn = None
473
472
  palm: bool = False
474
473
  auto_fuse: bool = True
474
+ compile_step: bool = False
475
475
 
476
476
  def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
477
- palm: bool = None, *fns):
477
+ palm: bool = use_default, *fns):
478
478
  if default(update_clipping, self.update_clipping) is None:
479
479
  if fns and self.auto_fuse:
480
480
  args, kwargs = None, None
@@ -489,6 +489,7 @@ class BaseOpt(ChainOpt):
489
489
  else:
490
490
  if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
491
491
  raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
492
+
492
493
  fns = tuple(fns)
493
494
 
494
495
  if default(palm, self.palm):
@@ -504,9 +505,9 @@ class BaseOpt(ChainOpt):
504
505
  class ScheduleFree(BaseOpt):
505
506
  def eval(self):
506
507
  for group in self.param_groups:
507
- train_mode = group['train_mode']
508
+ group['train_mode'] = train_mode = not group.get('train_mode')
508
509
  beta1 = utils.get_beta1(group)
509
- if beta1 > 0 and train_mode:
510
+ if beta1 > 0 and not train_mode:
510
511
  for p in group['params']:
511
512
  state = self.state_(p)
512
513
  if 'z' in state:
@@ -515,13 +516,12 @@ class ScheduleFree(BaseOpt):
515
516
  p32 = utils.promote(p.data)
516
517
  p32.lerp_(end=z, weight=1 - 1 / beta1)
517
518
  utils.copy_stochastic_(p.data, p32)
518
- group['train_mode'] = False
519
519
 
520
520
  def train(self):
521
521
  for group in self.param_groups:
522
- train_mode = group['train_mode']
522
+ group['train_mode'] = train_mode = not group.get('train_mode')
523
523
  beta1 = utils.get_beta1(group)
524
- if beta1 > 0 and not train_mode:
524
+ if beta1 > 0 and train_mode:
525
525
  for p in group['params']:
526
526
  state = self.state_(p)
527
527
  if 'z' in state:
@@ -529,4 +529,3 @@ class ScheduleFree(BaseOpt):
529
529
  p32 = utils.promote(p.data)
530
530
  p32.lerp_(end=z, weight=1 - beta1)
531
531
  utils.copy_stochastic_(p.data, p32)
532
- group['train_mode'] = True
heavyball/utils.py CHANGED
@@ -163,7 +163,7 @@ def beta_debias(beta, step):
163
163
  def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
164
164
  out: List[Optional[Tensor]]):
165
165
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
166
- torch._foreach_mul_(s32, beta2)
166
+ s32 = torch._foreach_mul(s32, beta2)
167
167
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
168
168
  denom = torch._foreach_sqrt(s32)
169
169
  [d.clamp_(min=eps) for d in denom]
@@ -185,11 +185,11 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
185
185
  @decorator_knowngood
186
186
  def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
187
187
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
188
- torch._foreach_mul_(s32, beta2)
188
+ s32 = torch._foreach_mul(s32, beta2)
189
189
  [s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
190
190
  denom = torch._foreach_sqrt(s32)
191
191
  [d.clamp_(min=eps) for d in denom]
192
- out = torch._foreach_div_(g32, denom)
192
+ out = torch._foreach_div(g32, denom)
193
193
  copy_stochastic_list_(state, s32)
194
194
  copy_stochastic_list_(grad, out)
195
195
 
@@ -201,6 +201,21 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
201
201
  return grad
202
202
 
203
203
 
204
+ @decorator_knowngood
205
+ def _compilable_exp_avg_(state, grad, beta):
206
+ s32, g32 = [list(map(promote, x)) for x in (state, grad)]
207
+ s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
208
+ copy_stochastic_list_(state, s32)
209
+ copy_stochastic_list_(grad, s32)
210
+
211
+
212
+ def scale_by_exp_avg_(state, grad, beta):
213
+ state, grad = list_guard(state, grad)
214
+ beta = scalar_guard(beta, state[0])
215
+ _compilable_exp_avg_(state, grad, beta)
216
+ return grad
217
+
218
+
204
219
  @decorator_knowngood
205
220
  def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float, minimum: float, eps: float):
206
221
  p_norm = torch._foreach_norm(parameters)
@@ -210,7 +225,7 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
210
225
  torch._foreach_div_(p_norm, g_norm)
211
226
  torch._foreach_mul_(p_norm, clip_val)
212
227
  torch._foreach_minimum_(p_norm, 1)
213
- return torch._foreach_mul(gradients, p_norm)
228
+ torch._foreach_mul_(gradients, p_norm)
214
229
 
215
230
 
216
231
  def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
@@ -219,7 +234,8 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
219
234
  return gradients
220
235
  parameters, gradients = list_guard(parameters, gradients)
221
236
  clip_val = scalar_guard(clip_val, parameters[0])
222
- return _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
237
+ _compilable_agc_(parameters, gradients, clip_val, minimum, eps)
238
+ return gradients
223
239
 
224
240
 
225
241
  def is_compiling():
@@ -289,7 +305,7 @@ def ortho(x):
289
305
  @decorator_knowngood
290
306
  def _compilable_heavyball_momentum_(state, grad, beta):
291
307
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
292
- torch._foreach_mul_(s32, beta)
308
+ s32 = torch._foreach_mul(s32, beta)
293
309
  torch._foreach_add_(s32, g32)
294
310
  copy_stochastic_list_(state, s32)
295
311
  copy_stochastic_list_(grad, s32)
@@ -298,7 +314,7 @@ def _compilable_heavyball_momentum_(state, grad, beta):
298
314
  @decorator_knowngood
299
315
  def _compilable_nesterov_momentum_(state, grad, beta):
300
316
  s32, g32 = [list(map(promote, x)) for x in (state, grad)]
301
- torch._foreach_mul_(s32, beta)
317
+ s32 = torch._foreach_mul(s32, beta)
302
318
  torch._foreach_add_(s32, g32)
303
319
  [g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
304
320
  copy_stochastic_list_(state, s32)
@@ -319,17 +335,27 @@ def nesterov_momentum(state, grad, beta):
319
335
  return grad
320
336
 
321
337
 
338
+ # mode in ("newtonschulz", "qr", "svd")
339
+ # scale_mode in ("none", "scale", "graft")
322
340
  @decorator_knowngood
323
- def inplace_orthogonal_(x, mode, out):
324
- if mode == 'qr':
325
- y = torch.linalg.qr(x).Q
341
+ def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
342
+ if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
343
+ y = zeropower_via_newtonschulz5(x, 5)
344
+ elif mode == 'qr':
345
+ y = torch.linalg.qr(promote(x)).Q
326
346
  elif mode == 'svd':
327
- u, s, v = torch.linalg.svd(x)
347
+ u, s, v = torch.linalg.svd(promote(x))
328
348
  y = u @ v.T
329
- elif mode == 'newtonschulz':
330
- y = zeropower_via_newtonschulz5(x, 5)
331
349
  else:
332
350
  raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
351
+ if scale_mode == "none":
352
+ pass
353
+ elif scale_mode == "scale":
354
+ y *= max(1, x.size(0) / x.size(1)) ** 0.5
355
+ elif scale_mode == "graft":
356
+ y *= x.norm() / y.norm().clamp_(min=1e-6)
357
+ else:
358
+ raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
333
359
  set_(out, y)
334
360
 
335
361
 
@@ -363,7 +389,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
363
389
  est_eig = torch.einsum('ij,ij->j', o, tmp)
364
390
  sort_idx = torch.argsort(est_eig, descending=True)
365
391
  indices.append(sort_idx)
366
- inplace_orthogonal_(tmp[:, sort_idx], q)
392
+ inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
367
393
 
368
394
  indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
369
395
  for i, ind in enumerate(indices))
@@ -422,8 +448,7 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
422
448
  for x_, y_ in zip(x, y):
423
449
  x32 = promote(x_)
424
450
  y32 = promote(y_)
425
- x32.lerp_(y32, a)
426
- copy_stochastic_(x_, x32)
451
+ copy_stochastic_(x_, x32.lerp(y32, a))
427
452
 
428
453
 
429
454
  def get_beta1(group):
@@ -484,7 +509,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
484
509
  for x_, y_ in zip(x, y):
485
510
  x32 = promote(x_)
486
511
  y32 = promote(y_)
487
- x32.add_(y32, alpha=alpha)
512
+ x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
488
513
  copy_stochastic_(x_, x32)
489
514
 
490
515
 
@@ -506,7 +531,7 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
506
531
  g0 = einsum_base[:grad.dim()]
507
532
  g1 = g0.replace(b, b.upper())
508
533
  outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
509
- GG[idx].lerp_(promote(outer_product), 1 - beta)
534
+ GG[idx].lerp_(outer_product, 1 - beta)
510
535
 
511
536
 
512
537
  def promote(x):
@@ -571,7 +596,8 @@ def project(grad, Q, back: bool):
571
596
  preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
572
597
  if preconditioners:
573
598
  out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
574
- grad = torch.einsum(f'{param},{preconditioners}->{out}', grad, *[q for q in Q if len(q) > 0])
599
+ out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
600
+ grad = out.to(grad.dtype)
575
601
  return grad
576
602
 
577
603
 
@@ -724,20 +750,26 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
724
750
  copy_stochastic_(t, s)
725
751
 
726
752
 
753
+ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
754
+ ea32 = list(map(promote, state))
755
+ grad = list(map(promote, grad))
756
+
757
+ ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
758
+ copy_stochastic_list_(state, ea32)
759
+ return ea32
760
+
761
+
727
762
  @decorator_knowngood
728
763
  def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
729
764
  step: Tensor):
730
765
  beta1 = beta_debias(beta1, step)
731
766
  beta2 = beta_debias(beta2, step)
732
767
 
733
- g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
768
+ g32 = list(map(promote, grad))
734
769
 
735
- [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
736
- denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
770
+ exp_avg32 = _lerp32(exp_avg, g32, beta1)
771
+ denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
737
772
  u32 = torch._foreach_div(exp_avg32, denom)
738
-
739
- copy_stochastic_list_(exp_avg, exp_avg32)
740
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
741
773
  copy_stochastic_list_(grad, u32)
742
774
 
743
775
 
@@ -749,28 +781,26 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
749
781
 
750
782
 
751
783
  @decorator_knowngood
752
- def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
753
- beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor, eps: Tensor,
754
- caution: bool):
784
+ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
785
+ grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
786
+ eps: Tensor, caution: bool):
755
787
  beta1 = beta_debias(beta1, step)
756
788
  beta2 = beta_debias(beta2, step)
757
789
 
758
- g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [grad, exp_avg_sq, exp_avg]]
790
+ u32, g32 = [list(map(promote, x)) for x in [update, grad]]
759
791
 
760
- [ea32.lerp_(g, 1 - beta1) for ea32, g in zip(exp_avg32, g32)]
761
- denom = exp_avg_sq_(exp_avg_sq32, g32, beta2, 1e-8)
792
+ exp_avg32 = _lerp32(exp_avg, u32, beta1)
793
+ denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
762
794
  u32 = torch._foreach_div(exp_avg32, denom)
763
-
764
- copy_stochastic_list_(exp_avg, exp_avg32)
765
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
766
- _compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
795
+ _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
767
796
 
768
797
 
769
- def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
770
- beta2: float, step: int, lr: float, eps: float, decay: float, caution: bool):
798
+ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
799
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
800
+ caution: bool):
771
801
  y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
772
802
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
773
- return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
803
+ return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
774
804
 
775
805
 
776
806
  @decorator_knowngood
@@ -779,14 +809,13 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
779
809
  beta1 = beta_debias(beta1, step)
780
810
  beta2 = beta_debias(beta2, step)
781
811
 
782
- gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg_sq]]
812
+ gp32 = list(map(promote, grad))
783
813
 
784
- denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
814
+ denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
785
815
  gp32 = torch._foreach_div(gp32, denom)
786
- stochastic_lerp_(exp_avg, gp32, 1 - beta1)
816
+ gp32 = _lerp32(exp_avg, gp32, beta1)
787
817
 
788
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
789
- copy_stochastic_list_(grad, exp_avg)
818
+ copy_stochastic_list_(grad, gp32)
790
819
 
791
820
 
792
821
  def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
@@ -797,52 +826,50 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
797
826
 
798
827
 
799
828
  @decorator_knowngood
800
- def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
801
- grad_projected: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor,
802
- decay: Tensor, caution: bool):
829
+ def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
830
+ grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
831
+ caution: bool):
803
832
  beta1 = beta_debias(beta1, step)
804
833
  beta2 = beta_debias(beta2, step)
805
834
 
806
- gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad_projected, exp_avg_sq]]
835
+ u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
807
836
 
808
- denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
809
- gp32 = torch._foreach_div(gp32, denom)
810
- stochastic_lerp_(exp_avg, gp32, 1 - beta1)
811
- update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
812
-
813
- copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
837
+ denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
838
+ u32 = torch._foreach_div(u32, denom)
839
+ u32 = _lerp32(exp_avg, u32, beta1)
840
+ _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
814
841
 
815
842
 
816
- def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float,
817
- beta2: float, step: int, lr: float, decay: float, caution: bool):
843
+ def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
844
+ grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
818
845
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
819
846
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
820
- _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
847
+ _fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
821
848
 
822
849
 
823
850
  @decorator_knowngood
824
- def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
825
- g32, exp_avg32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, exp_avg, exp_avg_sq]]
826
- update_param_(y, exp_avg, lr, decay, caution=caution, grad=g32)
851
+ def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
852
+ u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
853
+ _compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
827
854
 
828
855
  beta1 = beta_debias(beta1, step)
829
856
  denom = torch._foreach_sqrt(exp_avg_sq32)
830
857
  [denom.clamp_(min=eps) for denom in denom]
831
- torch._foreach_mul_(exp_avg32, beta1)
832
- [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
858
+ exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
859
+ [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
860
+ copy_stochastic_list_(exp_avg, exp_avg32)
833
861
 
834
862
  beta2 = beta_debias(beta2, step + 1)
835
- torch._foreach_mul_(exp_avg_sq32, beta2)
836
- [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
837
-
838
- copy_stochastic_list_(exp_avg, exp_avg32)
863
+ exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
864
+ [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
839
865
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
840
866
 
841
867
 
842
- def fused_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
868
+
869
+ def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
843
870
  exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
844
871
  beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
845
- _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
872
+ _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
846
873
 
847
874
 
848
875
  @decorator_knowngood
@@ -853,21 +880,21 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
853
880
  beta1 = beta_debias(beta1, step)
854
881
  denom = torch._foreach_sqrt(exp_avg_sq32)
855
882
  [denom.clamp_(min=1e-8) for denom in denom]
856
- torch._foreach_mul_(exp_avg32, beta1)
883
+ exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
857
884
  [ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
885
+ copy_stochastic_list_(exp_avg, exp_avg32)
858
886
 
859
887
  beta2 = beta_debias(beta2, step + 1)
860
- torch._foreach_mul_(exp_avg_sq32, beta2)
888
+ exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
861
889
  [eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
862
-
863
- copy_stochastic_list_(exp_avg, exp_avg32)
864
890
  copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
891
+
865
892
  copy_stochastic_list_(grad, update)
866
893
 
867
894
 
868
895
  def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
869
- exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad, y)
870
- beta1, beta2, step = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
896
+ exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
897
+ beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
871
898
  _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
872
899
  return grad
873
900
 
@@ -912,7 +939,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
912
939
 
913
940
  for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
914
941
  if caution:
915
- _compilable_cautioning_(promote(g_), u32_)
942
+ u32_ = _compilable_cautioning(promote(g_), u32_)
916
943
  add_fn(p32_, u32_, lr)
917
944
 
918
945
  copy_stochastic_list_(p, p32)
@@ -1228,7 +1255,7 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
1228
1255
  prob = prob(group[f'{name}_prob_step'])
1229
1256
  if group['stochastic_schedule']:
1230
1257
  return rng.random() < prob
1231
- cumulative_prob = state.get(name, 0)
1258
+ cumulative_prob = group.get(name, 0)
1232
1259
  group[name] = cumulative_prob + prob
1233
1260
  return int(group[name]) > int(cumulative_prob)
1234
1261
 
@@ -1289,15 +1316,16 @@ def mars_correction(g, old_g, beta1, gamma):
1289
1316
 
1290
1317
 
1291
1318
  @decorator_knowngood
1292
- def _compilable_cautioning_(g: Tensor, update: Tensor):
1293
- mask = (g * update) > 0
1294
- update.masked_fill_(~mask, 0)
1295
- scale = mask.numel() / mask.sum().clamp(min=1)
1319
+ def _compilable_cautioning(g: Tensor, update: Tensor):
1320
+ mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
1321
+ update = update.masked_fill(mask, 0)
1322
+ scale = mask.numel() / (mask.numel() - mask.sum()).clamp(min=1)
1296
1323
  update.mul_(scale)
1324
+ return update
1297
1325
 
1298
1326
 
1299
1327
  def caution(g, update):
1300
- _compilable_cautioning_(g, update)
1328
+ return _compilable_cautioning(g, update)
1301
1329
 
1302
1330
 
1303
1331
  def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.1.0
3
+ Version: 1.1.2
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=miRgcXlzLWTNzojeRF5hEcg-x_GqfMHjRzOaiR_zO3U,10981
2
+ heavyball/chainable.py,sha256=Zp7q6RHYU4RgdZ_ezgc8NWPwsNfyFjRvhEK-IEqr4b4,20379
3
+ heavyball/utils.py,sha256=0j5wRDYeI9Elz9m8tcP7CZNhj_9OIWEF_uQpb0LTrYM,47814
4
+ heavyball-1.1.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.1.2.dist-info/METADATA,sha256=bhXVJpcuwNZaOKFydknhtqqYx0ZZsQp2wkEdUAoDfN4,12022
6
+ heavyball-1.1.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.1.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.1.2.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=ZPadmK5Q9hGrBAeNYldMUw8vciy9dCDL3d7Zk9erC3E,12702
2
- heavyball/chainable.py,sha256=NLyJZS_vuQxzLE3RjQy0isQGCY5xGQBCs1wirT9BsQY,20172
3
- heavyball/utils.py,sha256=BEbvwWswWMyEa4zozlnLaQIKEMH_k7OUGu84jlcj5t0,46736
4
- heavyball-1.1.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.1.0.dist-info/METADATA,sha256=hVz8TPY0A9SqkjVpLX6LVXXlBVmAC8my96DHO8jtSb8,12022
6
- heavyball-1.1.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.1.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.1.0.dist-info/RECORD,,