heavyball 0.25.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heavyball/__init__.py CHANGED
@@ -1,56 +1,209 @@
1
- from .cached_delayed_psgd_kron import ForeachCachedDelayedPSGDKron
2
- from .cached_psgd_kron import ForeachCachedPSGDKron
3
- from .delayed_psgd import ForeachDelayedPSGD
4
- from .foreach_adamw import ForeachAdamW
5
- from .foreach_adopt import ForeachADOPT
6
- from .foreach_laprop import ForeachLaProp
7
- from .foreach_sfadamw import ForeachSFAdamW
8
- from .foreach_soap import ForeachSOAP
9
- from .p_adam import ForeachPaLMPAdam
10
- from .palm_foreach_sfadamw import PaLMForeachSFAdamW
11
- from .palm_foreach_soap import PaLMForeachSOAP
12
- from .precond_schedule_foreach_soap import PrecondScheduleForeachSOAP
13
- from .precond_schedule_palm_foreach_soap import PrecondSchedulePaLMForeachSOAP
14
- from .precond_schedule_sfpsoap import PrecondScheduleSFPaLMSOAP
15
- from .psgd_kron import ForeachPSGDKron
16
- from .pure_psgd import ForeachPurePSGD
17
- from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
18
- from .foreach_solp import ForeachSOLP
19
- from .palm_foreach_solp import PaLMForeachSOLP
20
- from .precond_schedule_palm_foreach_solp import PrecondSchedulePaLMForeachSOLP
21
- from .precond_schedule_foreach_solp import PrecondScheduleForeachSOLP
1
+ import functools
2
+ from typing import Optional
22
3
 
23
- PalmForEachSoap = PaLMForeachSOAP
4
+ from . import chainable as C
5
+ from . import utils
6
+
7
+
8
+ class ForeachAdamW(C.BaseOpt):
9
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
10
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
11
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
12
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
13
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
14
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
15
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
16
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
17
+
18
+
19
+ class ForeachRMSprop(C.BaseOpt):
20
+ """
21
+ Debiased RMSprop (not torch.optim.RMSprop)
22
+ """
23
+
24
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
25
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
26
+ caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
27
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
28
+ defaults = dict(lr=lr, betas=betas, eps=eps, warmup_steps=warmup_steps, weight_decay=weight_decay,
29
+ foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
30
+ beta2_scale=beta2_scale)
31
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
32
+
33
+
34
+ class ForeachSFAdamW(C.ScheduleFree):
35
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-6, weight_decay=0, warmup_steps=0, r=0.0,
36
+ weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
37
+ caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
38
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
39
+ defaults = dict(lr=lr, betas=betas, eps=eps, r=r, k=0, warmup_steps=warmup_steps, train_mode=True,
40
+ weight_sum=0.0, lr_max=-1.0, weight_lr_power=weight_lr_power, weight_decay=weight_decay,
41
+ foreach=foreach, storage_dtype=storage_dtype, mars=mars, caution=caution, mars_gamma=mars_gamma,
42
+ beta2_scale=beta2_scale)
43
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
44
+ C.update_by_schedule_free)
45
+
46
+
47
+ class PaLMForeachSFAdamW(ForeachSFAdamW):
48
+ palm: bool = True
49
+
50
+
51
+ class ForeachADOPT(C.BaseOpt):
52
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
53
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
54
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
55
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
56
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
57
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
58
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
59
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
60
+
61
+
62
+ class ForeachLaProp(C.BaseOpt):
63
+ def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
64
+ foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
65
+ mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
66
+ update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
67
+ defaults = dict(lr=lr, betas=betas, eps=eps, k=0, warmup_steps=warmup_steps, train_mode=True, weight_sum=0.0,
68
+ lr_max=-1.0, weight_decay=weight_decay, storage_dtype=storage_dtype, mars=mars, caution=caution,
69
+ mars_gamma=mars_gamma, beta2_scale=beta2_scale)
70
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
71
+
72
+
73
+ class ForeachSOAP(C.BaseOpt):
74
+ """
75
+ ForeachSOAP
76
+
77
+ Sources:
78
+ Baseline SOAP:
79
+ SOAP: Improving and Stabilizing Shampoo using Adam
80
+ Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
81
+ https://arxiv.org/abs/2409.11321
82
+ https://github.com/nikhilvyas/SOAP
83
+
84
+ ScheduleFree:
85
+ The Road Less Scheduled
86
+ Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
87
+ https://arxiv.org/abs/2405.15682
88
+ https://github.com/facebookresearch/schedule_free
89
+ """
90
+ use_precond_schedule: bool = False
91
+
92
+ def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
93
+ weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
94
+ merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
95
+ data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
96
+ split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
97
+ mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
98
+ beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
99
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
100
+ use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
101
+
102
+ defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
103
+ "precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
104
+ "merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
105
+ "correct_bias": correct_bias, 'warmup_steps': warmup_steps, 'split': split, 'mars': mars,
106
+ 'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
107
+ 'beta2_scale': beta2_scale}
108
+ if use_precond_schedule:
109
+ del defaults['precondition_frequency']
110
+ else:
111
+ del defaults['precond_scheduler']
112
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
113
+ C.scale_by_soap)
114
+
115
+
116
+ class PaLMForeachSOAP(ForeachSOAP):
117
+ use_precond_schedule: bool = False
118
+ palm: bool = True
24
119
 
120
+
121
+ class PrecondScheduleForeachSOAP(ForeachSOAP):
122
+ use_precond_schedule: bool = True
123
+
124
+
125
+ class PrecondSchedulePaLMForeachSOAP(ForeachSOAP):
126
+ use_precond_schedule: bool = True
127
+ palm: bool = True
128
+
129
+
130
+ class ForeachPSGDKron(C.BaseOpt):
131
+ """
132
+ Originally from Evan Walters and Omead Pooladzandi, 2024
133
+ Modified under Creative Commons Attribution 4.0 International
134
+ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
135
+ """
136
+
137
+ delayed: bool = False
138
+ cached: bool = False
139
+ exp_avg_input: bool = True
140
+
141
+ def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
142
+ max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
143
+ momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
144
+ split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
145
+ stochastic_schedule: bool = True, storage_dtype: str = 'float32', mars: bool = False,
146
+ caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
147
+ cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
148
+ gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
149
+ # expert parameters
150
+ precond_init_scale=1.0, precond_lr=0.1):
151
+ delayed = C.default(delayed, self.delayed)
152
+ cached = C.default(cached, self.cached)
153
+ exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
154
+ update_clipping = C.default(update_clipping, utils.trust_region_clip_)
155
+
156
+ defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
157
+ min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
158
+ momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
159
+ precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
160
+ split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
161
+ storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
162
+ stochastic_schedule=stochastic_schedule)
163
+
164
+ super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
165
+ *(C.exp_avg,) * exp_avg_input, #
166
+ functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
167
+ prob=preconditioner_update_probability))
168
+
169
+
170
+ class ForeachPurePSGD(ForeachPSGDKron):
171
+ exp_avg_input: bool = False
172
+
173
+
174
+ class ForeachCachedDelayedPSGDKron(ForeachPSGDKron):
175
+ delayed: bool = True
176
+ cached: bool = True
177
+
178
+
179
+ class ForeachCachedPSGDKron(ForeachPSGDKron):
180
+ cached: bool = True
181
+
182
+
183
+ class ForeachDelayedPSGD(ForeachPSGDKron):
184
+ delayed: bool = True
185
+
186
+
187
+ PalmForEachSoap = PaLMForeachSOAP
25
188
  PaLMSOAP = PaLMForeachSOAP
26
189
  PaLMSFAdamW = PaLMForeachSFAdamW
27
- PaLMSFSoap = SFPaLMForeachSOAP
28
- PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
29
190
  SOAP = ForeachSOAP
30
191
  SFAdamW = ForeachSFAdamW
31
192
  LaProp = ForeachLaProp
32
193
  ADOPT = ForeachADOPT
194
+ RMSprop = ForeachRMSprop
33
195
  PrecondScheduleSOAP = PrecondScheduleForeachSOAP
34
196
  PrecondSchedulePaLMSOAP = PrecondSchedulePaLMForeachSOAP
35
197
  PSGDKron = ForeachPSGDKron
36
198
  AdamW = ForeachAdamW
37
199
  PurePSGD = ForeachPurePSGD
38
- PaLMPAdam = ForeachPaLMPAdam
39
200
  DelayedPSGD = ForeachDelayedPSGD
40
201
  CachedPSGDKron = ForeachCachedPSGDKron
41
202
  CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
42
- SOLP = ForeachSOLP
43
- PaLMSOLP = PaLMForeachSOLP
44
- PrecondSchedulePaLMSOLP = PrecondSchedulePaLMForeachSOLP
45
- PrecondScheduleSOLP = PrecondScheduleForeachSOLP
46
-
47
- __all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
48
- 'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
49
- 'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
50
- 'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', 'ForeachCachedDelayedPSGDKron', 'ForeachSOLP',
51
- 'PaLMForeachSOLP', 'PrecondSchedulePaLMForeachSOLP', 'PrecondScheduleForeachSOLP',
52
- #
53
- 'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PrecondScheduleSFPaLMSOAP', 'SOAP', 'SFAdamW',
54
- 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD', 'CachedPSGDKron',
55
- 'CachedDelayedPSGDKron', 'PrecondScheduleSOAP', 'PrecondSchedulePaLMSOAP', 'SOLP', 'PrecondScheduleSOLP',
56
- 'PrecondSchedulePaLMSOLP', 'PrecondScheduleSOLP']
203
+
204
+ __all__ = ["PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron", "CachedDelayedPSGDKron",
205
+ "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT", "PrecondScheduleSOAP",
206
+ "PrecondSchedulePaLMSOAP", 'RMSprop', #
207
+ "ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
208
+ "ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
209
+ "ForeachRMSprop"]