heavyball 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +23 -0
- heavyball/utils.py +6 -5
- {heavyball-1.5.0.dist-info → heavyball-1.5.1.dist-info}/METADATA +1 -1
- heavyball-1.5.1.dist-info/RECORD +8 -0
- heavyball-1.5.0.dist-info/RECORD +0 -8
- {heavyball-1.5.0.dist-info → heavyball-1.5.1.dist-info}/LICENSE +0 -0
- {heavyball-1.5.0.dist-info → heavyball-1.5.1.dist-info}/WHEEL +0 -0
- {heavyball-1.5.0.dist-info → heavyball-1.5.1.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -163,6 +163,17 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
163
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
164
|
|
165
165
|
|
166
|
+
class ForeachAdamW(C.BaseOpt):
|
167
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
168
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
169
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
170
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
171
|
+
defaults = locals()
|
172
|
+
defaults.pop("self")
|
173
|
+
params = defaults.pop("params")
|
174
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
175
|
+
|
176
|
+
|
166
177
|
class OrthoAdamW(C.BaseOpt):
|
167
178
|
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
168
179
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
@@ -175,6 +186,18 @@ class OrthoAdamW(C.BaseOpt):
|
|
175
186
|
C.orthogonalize_grad_to_param, C.scale_by_adam)
|
176
187
|
|
177
188
|
|
189
|
+
class AdamWOrtho(C.BaseOpt):
|
190
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
191
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
192
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
193
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
194
|
+
defaults = locals()
|
195
|
+
defaults.pop("self")
|
196
|
+
params = defaults.pop("params")
|
197
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_adam,
|
198
|
+
C.orthogonalize_grad_to_param)
|
199
|
+
|
200
|
+
|
178
201
|
class ForeachPSGDKron(C.BaseOpt):
|
179
202
|
"""
|
180
203
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
heavyball/utils.py
CHANGED
@@ -770,22 +770,23 @@ def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
|
|
770
770
|
|
771
771
|
@decorator_knowngood
|
772
772
|
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
773
|
-
step: Tensor):
|
773
|
+
step: Tensor, eps: Tensor):
|
774
774
|
beta1 = beta_debias(beta1, step)
|
775
775
|
beta2 = beta_debias(beta2, step)
|
776
776
|
|
777
777
|
g32 = list(map(promote, grad))
|
778
778
|
|
779
779
|
exp_avg32 = _lerp32(exp_avg, g32, beta1)
|
780
|
-
denom = exp_avg_sq_(exp_avg_sq, g32, beta2,
|
780
|
+
denom = exp_avg_sq_(exp_avg_sq, g32, beta2, eps)
|
781
781
|
u32 = torch._foreach_div(exp_avg32, denom)
|
782
782
|
copy_stochastic_list_(grad, u32)
|
783
783
|
|
784
784
|
|
785
|
-
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int
|
785
|
+
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
786
|
+
eps: float):
|
786
787
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
787
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
788
|
-
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
788
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
789
|
+
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
789
790
|
return grad
|
790
791
|
|
791
792
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=fz-jC7m7XIYNf4PRaJ0rkSnWPYzMWEK5JQl4vp_yw_w,14166
|
2
|
+
heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
|
3
|
+
heavyball/utils.py,sha256=hae6gPVONG5lZiKm-Wqk0Sjjq3prfZIjCP5UoWcpptA,50338
|
4
|
+
heavyball-1.5.1.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.5.1.dist-info/METADATA,sha256=ww9KSe8MJDnjz1blmtnubpE20bkuXJ8NeMOeDK40OJk,43584
|
6
|
+
heavyball-1.5.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.5.1.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.5.1.dist-info/RECORD,,
|
heavyball-1.5.0.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=AL3oSbNB1HQ0cwEG6aPZGVMbpCXXCYOxREX7JwK4Byc,12773
|
2
|
-
heavyball/chainable.py,sha256=4xIaufYcIMgrasSIm9ZHwqRXD2vvUbHsW0FJqGB68EM,24782
|
3
|
-
heavyball/utils.py,sha256=NFvQcQemNOugH1vAi_UH3jnnttPSgVopmS1q6jbhxkQ,50289
|
4
|
-
heavyball-1.5.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.0.dist-info/METADATA,sha256=fUOCJvDcBQ5280TCLhUCuIRwNVMvp3ysp4qrDuJCUeI,43584
|
6
|
-
heavyball-1.5.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|