heavyball 0.24.2__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,7 +22,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
22
22
  params (iterable): Iterable of parameters to optimize or dicts defining
23
23
  parameter groups.
24
24
  lr (float): Learning rate.
25
- b1 (float): Momentum parameter.
25
+ beta (float): Momentum parameter.
26
26
  weight_decay (float): Weight decay (L2 penalty).
27
27
  preconditioner_update_probability (callable or float, optional): Probability of
28
28
  updating the preconditioner. If None, defaults to a schedule that anneals
@@ -19,7 +19,7 @@ class ForeachCachedPSGDKron(PSGDBase):
19
19
  params (iterable): Iterable of parameters to optimize or dicts defining
20
20
  parameter groups.
21
21
  lr (float): Learning rate.
22
- b1 (float): Momentum parameter.
22
+ beta (float): Momentum parameter.
23
23
  weight_decay (float): Weight decay (L2 penalty).
24
24
  preconditioner_update_probability (callable or float, optional): Probability of
25
25
  updating the preconditioner. If None, defaults to a schedule that anneals
@@ -41,6 +41,7 @@ class ForeachCachedPSGDKron(PSGDBase):
41
41
  split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
42
42
  foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
43
43
  storage_dtype: str = 'float32', mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025,
44
+ orthogonalize_output: bool = False,
44
45
  #
45
46
  # expert parameters
46
47
  precond_init_scale=1.0, precond_lr=0.1):
@@ -59,7 +60,8 @@ class ForeachCachedPSGDKron(PSGDBase):
59
60
  momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
60
61
  precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
61
62
  split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
62
- storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars)
63
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
64
+ orthogonalize_output=orthogonalize_output)
63
65
  super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
64
66
 
65
67
  def _step(self, group):
@@ -75,6 +77,7 @@ class ForeachCachedPSGDKron(PSGDBase):
75
77
  store_triu_as_line = group['store_triu_as_line']
76
78
  q_dtype = getattr(torch, group['q_dtype'])
77
79
  storage_dtype = getattr(torch, group['storage_dtype'])
80
+ orthogonalize_output = group['orthogonalize_output']
78
81
  should_update = self.should_update(group)
79
82
 
80
83
  vals = []
heavyball/delayed_psgd.py CHANGED
@@ -25,7 +25,7 @@ class ForeachDelayedPSGD(PSGDBase):
25
25
  params (iterable): Iterable of parameters to optimize or dicts defining
26
26
  parameter groups.
27
27
  lr (float): Learning rate.
28
- b1 (float): Momentum parameter.
28
+ beta (float): Momentum parameter.
29
29
  weight_decay (float): Weight decay (L2 penalty).
30
30
  preconditioner_update_probability (callable or float, optional): Probability of
31
31
  updating the preconditioner. If None, defaults to a schedule that anneals
heavyball/utils.py CHANGED
@@ -23,7 +23,8 @@ def decorator(func):
23
23
 
24
24
  @functools.wraps(func)
25
25
  def _fn(*args, **kwargs):
26
- if is_compiling() or compile_mode is None:
26
+ disable = compile_mode_recommended_to_none is None
27
+ if is_compiling() or compile_mode_recommended_to_none is None:
27
28
  return func(*args, **kwargs)
28
29
  nonlocal compiled
29
30
  if compiled is None:
@@ -874,7 +875,7 @@ def psgd_lb(A, max_abs):
874
875
  return x
875
876
 
876
877
 
877
- @decorator_knowngood
878
+ @decorator
878
879
  def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
879
880
  """Update Kronecker product preconditioner Q with pair (V, G)."""
880
881
  exprA, exprGs, _ = exprs
@@ -1130,11 +1131,11 @@ def merge_group(group, *tensors):
1130
1131
  'max_precond_dim'], group.get('split', False)))
1131
1132
  return out
1132
1133
 
1134
+
1133
1135
  def hook_optimizer_into_model(model, optimizer, *args, **kwargs):
1134
1136
  def _step(p: Tensor, o: torch.optim.Optimizer):
1135
1137
  o.step()
1136
1138
  o.zero_grad()
1137
1139
 
1138
-
1139
1140
  for p in model.parameters():
1140
- p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
1141
+ p.register_post_accumulate_grad_hook(functools.partial(_step, o=optimizer([p], *args, **kwargs)))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 0.24.2
3
+ Version: 0.24.3
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -1,7 +1,7 @@
1
1
  heavyball/__init__.py,sha256=icHYN-MGsmHkLUlHCMcZkOlwY7GT63_ayR_a5iPKmzM,2226
2
- heavyball/cached_delayed_psgd_kron.py,sha256=cHwVDq-_284_eMt09rAq26D_8fv3N0e0wdN1woCHU1M,6864
3
- heavyball/cached_psgd_kron.py,sha256=ttg6bemNDRpCJBV3aJg2DSyVfsfTMZAnhErgwC2jXlw,6815
4
- heavyball/delayed_psgd.py,sha256=yHy83YQ_PKWtwQq1R_OVyj3cjmcbsZAXX1M-hGyciss,6332
2
+ heavyball/cached_delayed_psgd_kron.py,sha256=HEyT6vW6Le6FmWpf-vAEzgbAkPH2mByqXcVZn07KCMk,6866
3
+ heavyball/cached_psgd_kron.py,sha256=rOgWAeVMENI7kdoBuRo3ywrCeatAnIqBdeYPHuVk2aU,6998
4
+ heavyball/delayed_psgd.py,sha256=L6qRLPxJmJ_1e0Mk2zLYUEVxkt8NGHq6v3HKawlgFcU,6334
5
5
  heavyball/foreach_adamw.py,sha256=K4xTes4drylAqaqWky8O_Bg_mmbAmcHZ5DEBs5vMD-s,2860
6
6
  heavyball/foreach_adopt.py,sha256=fHnbEqvKKc5IKPDWC9Qo9PiISSjj1MEViy0Jb3BRgZQ,3582
7
7
  heavyball/foreach_laprop.py,sha256=EXkwFQ-H7hHWLmiNUsxUcmXhzNNLMjieHjfOlY_6kmo,2868
@@ -16,9 +16,9 @@ heavyball/precond_schedule_sfpsoap.py,sha256=KUKdZzd336w24zPRcqwRatj7IVmd1Us0a_V
16
16
  heavyball/psgd_kron.py,sha256=PtTe6eR547Y-4CvgjpchgkQsr_kWr4AN-uY9L_JO_C8,6088
17
17
  heavyball/pure_psgd.py,sha256=344NdVNHwUFX3fU2R1S_Xh9SXAML3E4ryHr7xfMh9Cc,5076
18
18
  heavyball/schedule_free_palm_foreach_soap.py,sha256=KTQY37MZH7YnOSTLKY8uVySUXxWXbFVUA1QXN3iv8Ds,7244
19
- heavyball/utils.py,sha256=FglgQfiE206I07rql3qP-X2C1j0hY3N5VcQwKUh08aA,40025
20
- heavyball-0.24.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
- heavyball-0.24.2.dist-info/METADATA,sha256=lTThJQbW6qbnQqy9lGlTTOttJcX5vfQ_s6Cm0arqfC8,11926
22
- heavyball-0.24.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
- heavyball-0.24.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
- heavyball-0.24.2.dist-info/RECORD,,
19
+ heavyball/utils.py,sha256=AxhcHzbFAvhTgTFyIcdxs9TJkH4AgVEaNeBRjOLzoBM,40095
20
+ heavyball-0.24.3.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
21
+ heavyball-0.24.3.dist-info/METADATA,sha256=32T-Q-a4k096KjxoR-3DQt25XpO_h0zs7lWKTDQLugI,11926
22
+ heavyball-0.24.3.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
23
+ heavyball-0.24.3.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
24
+ heavyball-0.24.3.dist-info/RECORD,,