heavyball 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -163,6 +163,17 @@ class OrthoLaProp(C.BaseOpt):
163
163
  C.orthogonalize_grad_to_param, C.scale_by_laprop)
164
164
 
165
165
 
166
+ class ForeachAdamW(C.BaseOpt):
167
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
168
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
169
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
170
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
171
+ defaults = locals()
172
+ defaults.pop("self")
173
+ params = defaults.pop("params")
174
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
175
+
176
+
166
177
  class OrthoAdamW(C.BaseOpt):
167
178
  def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
168
179
  foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
@@ -175,6 +186,18 @@ class OrthoAdamW(C.BaseOpt):
175
186
  C.orthogonalize_grad_to_param, C.scale_by_adam)
176
187
 
177
188
 
189
+ class AdamWOrtho(C.BaseOpt):
190
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
191
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
192
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
193
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
194
+ defaults = locals()
195
+ defaults.pop("self")
196
+ params = defaults.pop("params")
197
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
198
+ C.orthogonalize_grad_to_param)
199
+
200
+
178
201
  class ForeachPSGDKron(C.BaseOpt):
179
202
  """
180
203
  Originally from Evan Walters and Omead Pooladzandi, 2024
heavyball/utils.py CHANGED
@@ -770,22 +770,23 @@ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
770
770
 
771
771
  @decorator_knowngood
772
772
  def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
773
- step: Tensor):
773
+ step: Tensor, eps: Tensor):
774
774
  beta1 = beta_debias(beta1, step)
775
775
  beta2 = beta_debias(beta2, step)
776
776
 
777
777
  g32 = list(map(promote, grad))
778
778
 
779
779
  exp_avg32 = _lerp32(exp_avg, g32, beta1)
780
- denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
780
+ denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
781
781
  u32 = torch._foreach_div(exp_avg32, denom)
782
782
  copy_stochastic_list_(grad, u32)
783
783
 
784
784
 
785
- def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
785
+ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
786
+ eps: float):
786
787
  exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
787
- beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
788
- _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
788
+ beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
789
+ _compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
789
790
  return grad
790
791
 
791
792
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: heavyball
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: Efficient optimizers
5
5
  Home-page: https://github.com/clashluke/heavyball
6
6
  Author: Lucas Nestler
@@ -0,0 +1,8 @@
1
+ heavyball/__init__.py,sha256=fz-jC7m7XIYNf4PRaJ0rkSnWPYzMWEK5JQl4vp_yw_w,14166
2
+ heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
3
+ heavyball/utils.py,sha256=hae6gPVONG5lZiKm-Wqk0Sjjq3prfZIjCP5UoWcpptA,50338
4
+ heavyball-1.5.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
+ heavyball-1.5.1.dist-info/METADATA,sha256=ww9KSe8MJDnjz1blmtnubpE20bkuXJ8NeMOeDK40OJk,43584
6
+ heavyball-1.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
+ heavyball-1.5.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
+ heavyball-1.5.1.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- heavyball/__init__.py,sha256=AL3oSbNB1HQ0cwEG6aPZGVMbpCXXCYOxREX7JwK4Byc,12773
2
- heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
3
- heavyball/utils.py,sha256=NFvQcQemNOugH1vAi_UH3jnnttPSgVopmS1q6jbhxkQ,50289
4
- heavyball-1.5.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
5
- heavyball-1.5.0.dist-info/METADATA,sha256=fUOCJvDcBQ5280TCLhUCuIRwNVMvp3ysp4qrDuJCUeI,43584
6
- heavyball-1.5.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
7
- heavyball-1.5.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
8
- heavyball-1.5.0.dist-info/RECORD,,