heavyball 0.25.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,56 +1,223 @@
1
- from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
2
- from .cached_psgd_kron import ForeachCachedPSGDKron
3
- from .delayed_psgd import ForeachDelayedPSGD
4
- from .foreach_adamw import ForeachAdamW
5
- from .foreach_adopt import ForeachADOPT
6
- from .foreach_laprop import ForeachLaProp
7
- from .foreach_sfadamw import ForeachSFAdamW
8
- from .foreach_soap import ForeachSOAP
9
- from .p_adam import ForeachPaLMPAdam
10
- from .palm_foreach_sfadamw import PaLMForeachSFAdamW
11
- from .palm_foreach_soap import PaLMForeachSOAP
12
- from .precond_schedule_foreach_soap import PrecondScheduleForeachSOAP
13
- from .precond_schedule_palm_foreach_soap import PrecondSchedulePaLMForeachSOAP
14
- from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
15
- from .psgd_kron import ForeachPSGDKron
16
- from .pure_psgd import ForeachPurePSGD
17
- from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
18
- from .foreach_solp import ForeachSOLP
19
- from .palm_foreach_solp import PaLMForeachSOLP
20
- from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
21
- from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
1
+ import functools
2
+ from typing import Optional
3
+
4
+ from . import chainable as C
5
+ from . import utils
6
+
7
+
8
+ class ForeachAdamW(C.BaseOpt):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
11
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
12
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
13
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
14
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
15
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
16
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
17
+
18
+
19
+ class ForeachRMSprop(C.BaseOpt):
20
+ """
21
+ Debiased RMSprop (not torch.optim.RMSprop)
22
+ """
23
+
24
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
25
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
26
+ caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
27
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
28
+ defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
29
+ foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
30
+ beta2_scale=beta2_scale)
31
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
32
+
33
+
34
+ class ForeachSFAdamW(C.ScheduleFree):
35
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
36
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
37
+ caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
38
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
39
+ defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
40
+ weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
41
+ foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
42
+ beta2_scale=beta2_scale)
43
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
44
+ C.update_by_schedule_free)
45
+
46
+
47
+ class PaLMForeachSFAdamW(ForeachSFAdamW):
48
+ palm: bool = True
49
+
50
+
51
+ class ForeachADOPT(C.BaseOpt):
52
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
53
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
54
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
55
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
56
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
57
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
58
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
59
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
60
+
61
+
62
+ class ForeachMuon(C.BaseOpt):
63
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
64
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
65
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
66
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
67
+ nesterov: bool = True):
68
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
69
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
70
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
71
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
72
+ C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
73
+
74
+
75
+ class ForeachLaProp(C.BaseOpt):
76
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
77
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
78
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
79
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
80
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
81
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
82
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
83
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
84
+
85
+
86
+ class ForeachSOAP(C.BaseOpt):
87
+ """
88
+ ForeachSOAP
89
+
90
+ Sources:
91
+ Baseline SOAP:
92
+ SOAP: Improving and Stabilizing Shampoo using Adam
93
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
94
+ https://arxiv.org/abs/2409.11321
95
+ https://github.com/nikhilvyas/SOAP
96
+
97
+ ScheduleFree:
98
+ The Road Less Scheduled
99
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
100
+ https://arxiv.org/abs/2405.15682
101
+ https://github.com/facebookresearch/schedule_free
102
+ """
103
+ use_precond_schedule: bool = False
104
+
105
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
106
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
107
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
108
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
109
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
110
+ mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
111
+ beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
112
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
113
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
114
+
115
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
116
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
117
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
118
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
119
+ 'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
120
+ 'beta2_scale': beta2_scale}
121
+ if use_precond_schedule:
122
+ del defaults['precondition_frequency']
123
+ else:
124
+ del defaults['precond_scheduler']
125
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
126
+ C.scale_by_soap)
22
127
 
23
- PalmForEachSoap = PaLMForeachSOAP
24
128
 
129
+ class PaLMForeachSOAP(ForeachSOAP):
130
+ use_precond_schedule: bool = False
131
+ palm: bool = True
132
+
133
+
134
+ class PrecondScheduleForeachSOAP(ForeachSOAP):
135
+ use_precond_schedule: bool = True
136
+
137
+
138
+ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
139
+ use_precond_schedule: bool = True
140
+ palm: bool = True
141
+
142
+
143
+ class ForeachPSGDKron(C.BaseOpt):
144
+ """
145
+ Originally from Evan Walters and Omead Pooladzandi, 2024
146
+ Modified under Creative Commons Attribution 4.0 International
147
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
148
+ """
149
+
150
+ delayed: bool = False
151
+ cached: bool = False
152
+ exp_avg_input: bool = True
153
+
154
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
155
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
156
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
157
+ split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
158
+ stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
159
+ caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
160
+ cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
161
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
162
+ # expert parameters
163
+ precond_init_scale=1.0, precond_lr=0.1):
164
+ delayed = C.default(delayed, self.delayed)
165
+ cached = C.default(cached, self.cached)
166
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
167
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
168
+
169
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
170
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
171
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
172
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
173
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
174
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
175
+ stochastic_schedule=stochastic_schedule)
176
+
177
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
178
+ *(C.exp_avg,) * exp_avg_input, #
179
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
180
+ prob=preconditioner_update_probability))
181
+
182
+
183
+ class ForeachPurePSGD(ForeachPSGDKron):
184
+ exp_avg_input: bool = False
185
+
186
+
187
+ class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
188
+ delayed: bool = True
189
+ cached: bool = True
190
+
191
+
192
+ class ForeachCachedPSGDKron(ForeachPSGDKron):
193
+ cached: bool = True
194
+
195
+
196
+ class ForeachDelayedPSGD(ForeachPSGDKron):
197
+ delayed: bool = True
198
+
199
+
200
+ PalmForEachSoap = PaLMForeachSOAP
25
201
  PaLMSOAP = PaLMForeachSOAP
26
202
  PaLMSFAdamW = PaLMForeachSFAdamW
27
- PaLMSFSoap = SFPaLMForeachSOAP
28
- PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
29
203
  SOAP = ForeachSOAP
30
204
  SFAdamW = ForeachSFAdamW
31
205
  LaProp = ForeachLaProp
32
206
  ADOPT = ForeachADOPT
207
+ RMSprop = ForeachRMSprop
33
208
  PrecondScheduleSOAP = PrecondScheduleForeachSOAP
34
209
  PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
35
210
  PSGDKron = ForeachPSGDKron
36
211
  AdamW = ForeachAdamW
37
212
  PurePSGD = ForeachPurePSGD
38
- PaLMPAdam = ForeachPaLMPAdam
39
213
  DelayedPSGD = ForeachDelayedPSGD
40
214
  CachedPSGDKron = ForeachCachedPSGDKron
41
215
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
42
- SOLP = ForeachSOLP
43
- PaLMSOLP = PaLMForeachSOLP
44
- PrecondSchedulePaLMSOLP = PrecondSchedulePaLMForeachSOLP
45
- PrecondScheduleSOLP = PrecondScheduleForeachSOLP
46
-
47
- __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
48
- 'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
49
- 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
50
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
51
- 'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
52
- #
53
- 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
54
- 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
55
- 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
56
- 'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
216
+ Muon = ForeachMuon
217
+
218
+ __all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
219
+ "CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
220
+ "PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', #
221
+ "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
222
+ "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
223
+ "ForeachRMSprop", "ForeachMuon"]