heavyball 1.5.2__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heavyball/__init__.py +73 -17
- heavyball/chainable.py +76 -14
- heavyball/utils.py +322 -175
- {heavyball-1.5.2.dist-info → heavyball-1.6.0.dist-info}/METADATA +4 -6
- heavyball-1.6.0.dist-info/RECORD +8 -0
- heavyball-1.5.2.dist-info/RECORD +0 -8
- {heavyball-1.5.2.dist-info → heavyball-1.6.0.dist-info}/LICENSE +0 -0
- {heavyball-1.5.2.dist-info → heavyball-1.6.0.dist-info}/WHEEL +0 -0
- {heavyball-1.5.2.dist-info → heavyball-1.6.0.dist-info}/top_level.txt +0 -0
heavyball/__init__.py
CHANGED
@@ -104,23 +104,17 @@ class ForeachSOAP(C.BaseOpt):
|
|
104
104
|
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
105
105
|
https://arxiv.org/abs/2409.11321
|
106
106
|
https://github.com/nikhilvyas/SOAP
|
107
|
-
|
108
|
-
ScheduleFree:
|
109
|
-
The Road Less Scheduled
|
110
|
-
Aaron Defazio, Xingyu Alice Yang, Harsh Mehta, Konstantin Mishchenko, Ahmed Khaled, Ashok Cutkosky
|
111
|
-
https://arxiv.org/abs/2405.15682
|
112
|
-
https://github.com/facebookresearch/schedule_free
|
113
107
|
"""
|
114
108
|
use_precond_schedule: bool = False
|
115
109
|
|
116
110
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
117
111
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
118
112
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
113
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
114
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
115
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
116
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
117
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
124
118
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
125
119
|
|
126
120
|
defaults = locals()
|
@@ -137,6 +131,54 @@ class ForeachSOAP(C.BaseOpt):
|
|
137
131
|
C.scale_by_soap)
|
138
132
|
|
139
133
|
|
134
|
+
class ForeachSignLaProp(C.BaseOpt):
|
135
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
136
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
137
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
138
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
139
|
+
defaults = locals()
|
140
|
+
defaults.pop("self")
|
141
|
+
params = defaults.pop("params")
|
142
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop, C.sign)
|
143
|
+
|
144
|
+
|
145
|
+
class ForeachSOLP(C.BaseOpt):
|
146
|
+
"""
|
147
|
+
ForeachSOLP
|
148
|
+
|
149
|
+
Sources:
|
150
|
+
Baseline SOAP:
|
151
|
+
SOAP: Improving and Stabilizing Shampoo using Adam
|
152
|
+
Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, Sham Kakade
|
153
|
+
https://arxiv.org/abs/2409.11321
|
154
|
+
https://github.com/nikhilvyas/SOAP
|
155
|
+
"""
|
156
|
+
use_precond_schedule: bool = False
|
157
|
+
|
158
|
+
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
159
|
+
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
160
|
+
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
161
|
+
correct_bias: bool = True, warmup_steps: int = 0, split: bool = False, foreach: bool = True,
|
162
|
+
mars: bool = False, caution: bool = False, mars_gamma: float = 0.0025, palm: bool = C.use_default,
|
163
|
+
precond_scheduler=(1 / 3, 9), beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
164
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
165
|
+
storage_dtype: str = 'float32', stochastic_schedule: bool = False):
|
166
|
+
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
167
|
+
|
168
|
+
defaults = locals()
|
169
|
+
defaults.pop("self")
|
170
|
+
params = defaults.pop("params")
|
171
|
+
|
172
|
+
if use_precond_schedule:
|
173
|
+
del defaults['precondition_frequency']
|
174
|
+
self.precond_schedule = utils.get_soap_precond_schedule(defaults.pop("precond_scheduler"))
|
175
|
+
else:
|
176
|
+
del defaults['precond_scheduler']
|
177
|
+
self.precond_schedule = 1 / defaults.pop("precondition_frequency")
|
178
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, #
|
179
|
+
functools.partial(C.scale_by_soap, inner='laprop'))
|
180
|
+
|
181
|
+
|
140
182
|
class PaLMForeachSOAP(ForeachSOAP):
|
141
183
|
use_precond_schedule: bool = False
|
142
184
|
palm: bool = True
|
@@ -163,6 +205,18 @@ class OrthoLaProp(C.BaseOpt):
|
|
163
205
|
C.orthogonalize_grad_to_param, C.scale_by_laprop)
|
164
206
|
|
165
207
|
|
208
|
+
class LaPropOrtho(C.BaseOpt):
|
209
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
210
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
211
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
212
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
213
|
+
defaults = locals()
|
214
|
+
defaults.pop("self")
|
215
|
+
params = defaults.pop("params")
|
216
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
|
217
|
+
C.orthogonalize_grad_to_param)
|
218
|
+
|
219
|
+
|
166
220
|
class ForeachPSGDKron(C.BaseOpt):
|
167
221
|
"""
|
168
222
|
Originally from Evan Walters and Omead Pooladzandi, 2024
|
@@ -178,10 +232,10 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
178
232
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
179
233
|
momentum_into_precond_update=True, warmup_steps: int = 0, merge_dims: bool = False,
|
180
234
|
split: bool = False, store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
181
|
-
stochastic_schedule: bool =
|
235
|
+
stochastic_schedule: bool = False, storage_dtype: str = 'float32', mars: bool = False,
|
182
236
|
caution: bool = False, mars_gamma: float = 0.0025, delayed: Optional[bool] = C.use_default,
|
183
237
|
cached: Optional[bool] = C.use_default, exp_avg_input: Optional[bool] = C.use_default,
|
184
|
-
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default,
|
238
|
+
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
185
239
|
# expert parameters
|
186
240
|
precond_init_scale=1.0, precond_lr=0.1):
|
187
241
|
defaults = locals()
|
@@ -238,10 +292,12 @@ DelayedPSGD = ForeachDelayedPSGD
|
|
238
292
|
CachedPSGDKron = ForeachCachedPSGDKron
|
239
293
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
240
294
|
Muon = ForeachMuon
|
295
|
+
SignLaProp = ForeachSignLaProp
|
241
296
|
|
242
297
|
__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
243
298
|
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
244
|
-
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
245
|
-
|
246
|
-
"
|
247
|
-
"ForeachRMSprop", "ForeachMuon",
|
299
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', 'ForeachSignLaProp' #
|
300
|
+
"ForeachAdamW", "ForeachSFAdamW",
|
301
|
+
"ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron", "ForeachPurePSGD", "ForeachDelayedPSGD",
|
302
|
+
"ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron", "ForeachRMSprop", "ForeachMuon",
|
303
|
+
'ForeachCachedNewtonPSGD', 'OrthoLaProp', 'LaPropOrtho', 'SignLaProp']
|
heavyball/chainable.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import functools
|
2
2
|
import random
|
3
|
-
from typing import Optional, Union, Literal
|
3
|
+
from typing import Optional, Union, Literal, List
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
@@ -127,7 +127,7 @@ def zero_guard(*names):
|
|
127
127
|
|
128
128
|
|
129
129
|
def copy_guard(index, *names):
|
130
|
-
return functools.partial(CopyGuard, index=index, names=names
|
130
|
+
return functools.partial(CopyGuard, index=index, names=names)
|
131
131
|
|
132
132
|
|
133
133
|
def general_guard(*names, init_fn, skip_first: bool = True):
|
@@ -152,6 +152,22 @@ def exp_avg(group, update, grad, param, exp_avg):
|
|
152
152
|
return utils.scale_by_exp_avg_(exp_avg, update, utils.beta_debias(utils.get_beta1(group), group["step"]))
|
153
153
|
|
154
154
|
|
155
|
+
@zero_guard('exp_avg')
|
156
|
+
@no_state
|
157
|
+
def weight_decay_to_ema(group, update, grad, param, exp_avg):
|
158
|
+
utils.weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
159
|
+
group['weight_decay_to_ema'] * group['lr'])
|
160
|
+
return update
|
161
|
+
|
162
|
+
|
163
|
+
@zero_guard('exp_avg')
|
164
|
+
@no_state
|
165
|
+
def l1_weight_decay_to_ema(group, update, grad, param, exp_avg):
|
166
|
+
utils.l1_weight_decay_to_ema_(exp_avg, update, utils.beta_debias(group['ema_beta'], group['step']),
|
167
|
+
group['weight_decay_to_ema'] * group['lr'])
|
168
|
+
return update
|
169
|
+
|
170
|
+
|
155
171
|
@zero_guard("exp_avg_sq")
|
156
172
|
@no_state
|
157
173
|
def scale_by_exp_avg_sq(group, update, grad, param, exp_avg_sq):
|
@@ -206,14 +222,15 @@ def update_by_schedule_free(group, update, grad, param, z):
|
|
206
222
|
@no_state
|
207
223
|
def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
208
224
|
if group['step'] == 1:
|
209
|
-
utils.
|
225
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
210
226
|
raise SkipUpdate
|
211
227
|
|
212
228
|
if group['step'] == 2:
|
213
229
|
update = utils.promote(update)
|
214
230
|
easq = utils.promote(exp_avg_sq)
|
215
231
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
216
|
-
utils.
|
232
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
233
|
+
group['eps'])
|
217
234
|
raise SkipUpdate
|
218
235
|
|
219
236
|
utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
@@ -225,21 +242,22 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
225
242
|
@no_state
|
226
243
|
def scale_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
227
244
|
if group['step'] == 1:
|
228
|
-
utils.
|
245
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, 0, group['eps'])
|
229
246
|
raise SkipUpdate
|
230
247
|
|
231
248
|
if group['step'] == 2:
|
232
249
|
update = utils.promote(update)
|
233
250
|
easq = utils.promote(exp_avg_sq)
|
234
251
|
[utils.set_(ea, u / easq_.sqrt().clamp_(min=group['eps'])) for ea, u, easq_ in zip(exp_avg, update, easq)]
|
235
|
-
utils.
|
252
|
+
utils.scale_by_exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']),
|
253
|
+
group['eps'])
|
236
254
|
raise SkipUpdate
|
237
255
|
|
238
256
|
return utils.adopt(update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 2)
|
239
257
|
|
240
258
|
|
241
|
-
def _init_soap(state, group, update, grad, param):
|
242
|
-
utils.init_preconditioner(grad, state,
|
259
|
+
def _init_soap(state, group, update, grad, param, inner: str = ''):
|
260
|
+
utils.init_preconditioner(grad, state, group['max_precond_dim'], group['precondition_1d'])
|
243
261
|
|
244
262
|
|
245
263
|
def _init_psgd(state, group, update, grad, param, cached: bool = False, prob: Optional[callable] = None):
|
@@ -295,6 +313,25 @@ def nesterov_momentum(group, updates, grads, params, momentum):
|
|
295
313
|
return utils.nesterov_momentum(momentum, updates, utils.get_beta1(group))
|
296
314
|
|
297
315
|
|
316
|
+
@zero_guard('momentum')
|
317
|
+
@no_state
|
318
|
+
def nesterov_ema(group, updates, grads, params, momentum): # equivalent to Grokfast
|
319
|
+
return utils.nesterov_ema(momentum, updates, utils.get_beta1(group))
|
320
|
+
|
321
|
+
|
322
|
+
def _store_std(state, group, update, grad, param):
|
323
|
+
state['init_std'] = torch.std(grad, dim=0)
|
324
|
+
|
325
|
+
|
326
|
+
@general_guard("init_std", init_fn=_store_std)
|
327
|
+
@no_state
|
328
|
+
def mup_approx(group, updates, grads, params, init_std):
|
329
|
+
_updates = [(u, i) for u, i in zip(updates, init_std) if u.ndim > 1]
|
330
|
+
_updates, _init_std = zip(*_updates)
|
331
|
+
utils.stochastic_multiply_(_updates, _init_std)
|
332
|
+
return updates
|
333
|
+
|
334
|
+
|
298
335
|
@zero_guard("momentum")
|
299
336
|
@no_state
|
300
337
|
def heavyball_momentum(group, updates, grads, params, momentum):
|
@@ -308,15 +345,16 @@ _optim_fns = {'adam': utils.adam_, 'laprop': utils.laprop_}
|
|
308
345
|
@general_guard("Q", "GG", init_fn=_init_soap)
|
309
346
|
@no_state
|
310
347
|
def scale_by_soap(group, update, grad, param, exp_avg, exp_avg_sq, Q, GG, inner: str = 'adam'):
|
311
|
-
update = utils.promote(update)
|
348
|
+
update = utils.promote(update) # Promote to highest precision if needed
|
312
349
|
|
313
350
|
grad_projected = [utils.project(u, q, False) for u, q in zip(update, Q)]
|
314
351
|
fn = _optim_fns[inner]
|
315
|
-
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step']
|
352
|
+
precond = fn(exp_avg, exp_avg_sq, grad_projected, utils.get_beta1(group), utils.get_beta2(group), group['step'] - 1,
|
353
|
+
group['eps'])
|
316
354
|
precond = [utils.project(p, q, True) for p, q in zip(precond, Q)]
|
317
355
|
|
318
|
-
for u, q, gg,
|
319
|
-
utils.update_preconditioner(u, q, gg,
|
356
|
+
for u, q, gg, ea in zip(update, Q, GG, exp_avg):
|
357
|
+
utils.update_preconditioner(u, q, gg, ea, group['max_precond_dim'], group['precondition_1d'],
|
320
358
|
utils.beta_debias(group['shampoo_beta'], group['step']),
|
321
359
|
group['is_preconditioning'])
|
322
360
|
return precond
|
@@ -414,6 +452,11 @@ def update_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: st
|
|
414
452
|
raise SkipUpdate
|
415
453
|
|
416
454
|
|
455
|
+
@no_state
|
456
|
+
def sign(group, update, grad, param, graft: bool = True):
|
457
|
+
return utils.sign_(update, graft)
|
458
|
+
|
459
|
+
|
417
460
|
@general_guard("Q", "exprs", ("Q_cache", None), ("cache_expr", None), init_fn=_init_psgd, skip_first=False)
|
418
461
|
@no_state_no_foreach
|
419
462
|
def update_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
@@ -439,8 +482,7 @@ def apply_to_idx(fn, idx):
|
|
439
482
|
return _fn
|
440
483
|
|
441
484
|
|
442
|
-
def
|
443
|
-
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
485
|
+
def _inner_chain(state, group, update, grad, param, *fns):
|
444
486
|
skip_update = False
|
445
487
|
for fn in fns:
|
446
488
|
try:
|
@@ -450,10 +492,30 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
450
492
|
continue
|
451
493
|
if update is None:
|
452
494
|
break
|
495
|
+
return update, skip_update
|
496
|
+
|
497
|
+
|
498
|
+
def chain(state: Union[callable, dict], group, grad, param, *fns):
|
499
|
+
update = [torch.clone(g, memory_format=torch.preserve_format) for g in grad]
|
500
|
+
update, skip_update = _inner_chain(state, group, update, grad, param, *fns)
|
453
501
|
if not skip_update and update is not None:
|
454
502
|
utils.update_param_(param, update, group['lr'], group['weight_decay'], caution=group['caution'], grad=grad)
|
455
503
|
|
456
504
|
|
505
|
+
def create_branch(branches: List[List[callable]], merge_fn: callable):
|
506
|
+
def _branch(state, group, update, grad, param):
|
507
|
+
outputs = []
|
508
|
+
for branch in branches:
|
509
|
+
branch_update = [torch.clone(u, memory_format=torch.preserve_format) for u in update]
|
510
|
+
branch_update, skip_update = _inner_chain(state, group, branch_update, grad, param, *branch)
|
511
|
+
if skip_update:
|
512
|
+
raise ValueError("Branches should not skip updates")
|
513
|
+
outputs.append(branch_update)
|
514
|
+
return merge_fn(outputs)
|
515
|
+
|
516
|
+
return _branch
|
517
|
+
|
518
|
+
|
457
519
|
class ChainOpt(utils.StatefulOptimizer):
|
458
520
|
promote: bool = False
|
459
521
|
|
heavyball/utils.py
CHANGED
@@ -1,24 +1,40 @@
|
|
1
|
+
import copy
|
1
2
|
import functools
|
2
3
|
import gc
|
4
|
+
import inspect
|
3
5
|
import math
|
4
6
|
import random
|
5
7
|
import string
|
8
|
+
import sys
|
9
|
+
import time
|
6
10
|
import warnings
|
11
|
+
from datetime import datetime
|
7
12
|
from typing import List, Optional, Tuple, Callable, Union
|
13
|
+
from unittest.mock import patch
|
8
14
|
|
15
|
+
import hyperopt
|
9
16
|
import numpy as np
|
10
17
|
import torch
|
11
18
|
from torch import Tensor
|
19
|
+
from torch._dynamo import config
|
12
20
|
from torch._dynamo.exc import TorchDynamoException
|
13
21
|
from torch.backends import cudnn, opt_einsum
|
14
22
|
from torch.utils._pytree import tree_map
|
15
23
|
|
24
|
+
config.cache_size_limit = 2 ** 16
|
25
|
+
|
26
|
+
np.warnings = warnings
|
27
|
+
|
16
28
|
compile_mode = "max-autotune-no-cudagraphs"
|
17
29
|
dynamic = False
|
18
30
|
compile_mode_recommended_to_none = None
|
19
31
|
zeroth_power_mode = 'qr' # 'qr' is baseline, 'newtonschulz' converges better and faster
|
20
32
|
tiny_bf16 = torch.finfo(torch.bfloat16).tiny
|
21
33
|
|
34
|
+
base_args = {'betas': (0.9, 0.999), 'precondition_frequency': 1, 'merge_dims': False, 'warmup_steps': 100,
|
35
|
+
'max_precond_dim': 2 ** 16, 'beta': 0.9, 'max_size_triangular': 2 ** 16, 'split': False, 'eps': 1e-8,
|
36
|
+
'weight_decay': 1e-4}
|
37
|
+
|
22
38
|
|
23
39
|
def decorator(func):
|
24
40
|
compiled = None
|
@@ -51,7 +67,7 @@ def decorator_knowngood(func: Callable):
|
|
51
67
|
return _fn
|
52
68
|
|
53
69
|
|
54
|
-
einsum_base = string.ascii_lowercase
|
70
|
+
einsum_base = string.ascii_lowercase
|
55
71
|
|
56
72
|
|
57
73
|
@decorator_knowngood
|
@@ -103,29 +119,29 @@ def dim_merger(grad, max_precond_dim, split: bool = False):
|
|
103
119
|
but we want to merge conv kernels into fan-in or at least merge the kernel
|
104
120
|
so, [128, 64, 3, 3] should result in [128, 576] or [128, 64, 9] instead of [73728] or [8192, 3, 3] the baseline
|
105
121
|
would've done
|
122
|
+
|
123
|
+
By @francois-rozet (commit: 68cde41eaf7e73b4c46eacb6a944865dcc081f1d), re-commited due to faulty merge
|
106
124
|
"""
|
107
|
-
shape = grad.shape
|
108
125
|
new_shape = []
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
curr_shape = sh
|
126
|
+
cum_size = 1
|
127
|
+
|
128
|
+
for s in grad.shape[1:][::-1]:
|
129
|
+
temp_size = cum_size * s
|
130
|
+
if temp_size > max_precond_dim:
|
131
|
+
if cum_size > 1:
|
132
|
+
new_shape.append(cum_size)
|
133
|
+
cum_size = s
|
118
134
|
else:
|
119
|
-
new_shape.append(
|
120
|
-
|
135
|
+
new_shape.append(s)
|
136
|
+
cum_size = 1
|
121
137
|
else:
|
122
|
-
|
123
|
-
new_shape = [*shape[:1], *new_shape[::-1]]
|
138
|
+
cum_size = temp_size
|
124
139
|
|
125
|
-
if
|
126
|
-
new_shape.append(
|
140
|
+
if cum_size > 1:
|
141
|
+
new_shape.append(cum_size)
|
127
142
|
|
128
|
-
|
143
|
+
new_shape = [grad.shape[0], *new_shape[::-1]]
|
144
|
+
new_grad = grad.reshape(new_shape)
|
129
145
|
if not split:
|
130
146
|
return new_grad
|
131
147
|
|
@@ -153,12 +169,11 @@ def beta_debias(beta, step):
|
|
153
169
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
154
170
|
out: List[Optional[Tensor]]):
|
155
171
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
156
|
-
s32 =
|
157
|
-
s32 = [s + g * g * (1 - beta2) for s, g in zip(s32, g32)]
|
158
|
-
denom = torch._foreach_sqrt(s32)
|
159
|
-
denom = [d.clamp(min=eps) for d in denom]
|
172
|
+
s32 = [s * beta2 + g * g * (1 - beta2) for s, g in zip(s32, g32)]
|
160
173
|
copy_stochastic_list_(state, s32)
|
161
174
|
|
175
|
+
denom = [d.sqrt().clamp(min=eps) for d in s32]
|
176
|
+
|
162
177
|
if out[0] is None:
|
163
178
|
return denom
|
164
179
|
|
@@ -174,7 +189,7 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
174
189
|
|
175
190
|
@decorator_knowngood
|
176
191
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
177
|
-
g32 = promote
|
192
|
+
g32 = list(map(promote, grad))
|
178
193
|
denom = _compilable_exp_avg_sq_(state, g32, beta2, eps, [None])
|
179
194
|
out = torch._foreach_div(g32, denom)
|
180
195
|
copy_stochastic_list_(grad, out)
|
@@ -189,7 +204,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
189
204
|
|
190
205
|
@decorator_knowngood
|
191
206
|
def _compilable_exp_avg_(state, grad, beta):
|
192
|
-
lerped =
|
207
|
+
lerped = _lerp(state, grad, beta)
|
193
208
|
copy_stochastic_list_(grad, lerped)
|
194
209
|
|
195
210
|
|
@@ -240,29 +255,31 @@ def clean():
|
|
240
255
|
gc.collect()
|
241
256
|
|
242
257
|
|
243
|
-
def
|
258
|
+
def _ignore_warning(msg):
|
259
|
+
warnings.filterwarnings('ignore', f'.*{msg}.*')
|
260
|
+
|
261
|
+
|
262
|
+
def set_torch(benchmark_limit: int = 32):
|
244
263
|
cudnn.benchmark = True
|
245
264
|
cudnn.deterministic = False
|
265
|
+
cudnn.benchmark_limit = benchmark_limit
|
246
266
|
torch.use_deterministic_algorithms(False)
|
247
267
|
torch.set_float32_matmul_precision("high") # highest: FP32, high: TF32, medium: bf16
|
248
268
|
opt_einsum.enabled = True
|
249
|
-
opt_einsum.strategy = "
|
269
|
+
opt_einsum.strategy = "dp"
|
270
|
+
|
271
|
+
# Torch calls these for 2nd-order optimization in HeavyBall, but they are explicitly handled.
|
272
|
+
_ignore_warning(
|
273
|
+
'Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak')
|
274
|
+
_ignore_warning(
|
275
|
+
'We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak')
|
250
276
|
|
251
277
|
|
252
278
|
@decorator
|
253
279
|
def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
|
254
|
-
"""
|
255
|
-
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
256
|
-
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
257
|
-
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
258
|
-
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
259
|
-
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
260
|
-
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
261
|
-
performance at all relative to UV^T, where USV^T = G is the SVD.
|
262
|
-
"""
|
263
280
|
assert len(G.shape) == 2
|
264
281
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
265
|
-
X = G.bfloat16
|
282
|
+
X = G.to(torch.bfloat16 if G.dtype != torch.float64 else G.dtype) # Preserve float64 if present
|
266
283
|
X /= (X.norm() + eps) # ensure top singular value <= 1
|
267
284
|
if G.size(0) > G.size(1):
|
268
285
|
X = X.T
|
@@ -317,12 +334,24 @@ def nesterov_momentum(state, grad, beta):
|
|
317
334
|
return grad
|
318
335
|
|
319
336
|
|
337
|
+
@decorator_knowngood
|
338
|
+
def _compilable_nesterov_ema_(state, grad, beta):
|
339
|
+
ema32 = _lerp(state, grad, beta)
|
340
|
+
stochastic_add_(grad, ema32, 1)
|
341
|
+
|
342
|
+
|
343
|
+
def nesterov_ema(state, grad, beta):
|
344
|
+
state, grad = list_guard(state, grad)
|
345
|
+
beta = scalar_guard(beta, state[0])
|
346
|
+
_compilable_nesterov_ema_(state, grad, beta)
|
347
|
+
return grad
|
348
|
+
|
349
|
+
|
350
|
+
@decorator_knowngood
|
320
351
|
def _compilable_grafting(magnitude, direction):
|
321
352
|
return direction * (magnitude.norm() / direction.norm().clamp(min=1e-6))
|
322
353
|
|
323
354
|
|
324
|
-
# mode in ("newtonschulz", "qr", "svd")
|
325
|
-
# scale_mode in ("none", "scale", "graft")
|
326
355
|
@decorator_knowngood
|
327
356
|
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
328
357
|
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
@@ -350,74 +379,82 @@ def _compilable_scatter_set(target, source, index):
|
|
350
379
|
target[:] = source.contiguous()[index].reshape_as(target)
|
351
380
|
|
352
381
|
|
353
|
-
|
382
|
+
@decorator_knowngood
|
383
|
+
def get_orthogonal_matrix_QR(GG: List[Tensor], Q: List[Tensor], exp_avg: Optional[Tensor] = None):
|
354
384
|
"""
|
355
385
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
356
|
-
followed by torch.linalg.qr decomposition.
|
386
|
+
followed by torch.linalg.qr decomposition, and updates exp_avg in-place from old to new eigenspace.
|
387
|
+
|
388
|
+
:param GG: List of accumulated gradient outer products.
|
389
|
+
:param Q: List of current eigenbases (updated in-place to Q_new).
|
390
|
+
:param exp_avg: Exponential moving average in the old eigenspace (updated in-place if provided).
|
357
391
|
"""
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
orth_matrix.append([])
|
364
|
-
continue
|
365
|
-
if m.data.dtype != torch.float:
|
366
|
-
matrix.append(promote(m.data))
|
367
|
-
orth_matrix.append(promote(o.data))
|
368
|
-
else:
|
369
|
-
matrix.append(promote(m.data))
|
370
|
-
orth_matrix.append(promote(o.data))
|
392
|
+
if isinstance(Q, list) and not Q:
|
393
|
+
return
|
394
|
+
|
395
|
+
if exp_avg is not None and exp_avg.dim() != len(Q):
|
396
|
+
raise ValueError(f"exp_avg dim {exp_avg.dim()} does not match Q length {len(Q)}")
|
371
397
|
|
372
|
-
|
398
|
+
new_qs = []
|
373
399
|
|
374
|
-
for
|
400
|
+
for m, q in zip(GG, Q):
|
375
401
|
if len(m) == 0:
|
376
|
-
indices.append(None)
|
377
402
|
continue
|
378
403
|
|
379
|
-
|
380
|
-
|
404
|
+
m = promote(m.data)
|
405
|
+
q_old = promote(q.data)
|
406
|
+
|
407
|
+
tmp = m @ q_old
|
408
|
+
est_eig = torch.einsum('ij,ij->j', q_old, tmp)
|
381
409
|
sort_idx = torch.argsort(est_eig, descending=True)
|
382
|
-
indices.append(sort_idx)
|
383
|
-
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
|
384
410
|
|
385
|
-
|
386
|
-
|
387
|
-
|
411
|
+
tmp[:, sort_idx], _ = torch.linalg.qr(tmp[:, sort_idx])
|
412
|
+
new_qs.append(tmp)
|
413
|
+
|
414
|
+
if exp_avg is None:
|
415
|
+
for q, q_new in zip(Q, new_qs):
|
416
|
+
copy_stochastic_(q, q_new)
|
417
|
+
return
|
418
|
+
|
419
|
+
assert exp_avg.ndim < 13, "exp_avg.ndim must be less than 13"
|
420
|
+
in_str = einsum_base[:exp_avg.dim()]
|
421
|
+
out_str = einsum_base[exp_avg.dim():2 * exp_avg.dim()]
|
422
|
+
|
423
|
+
from_shampoo = ",".join([o + i for m, i, o in zip(Q, in_str, in_str.upper()) if len(m) > 0])
|
424
|
+
if not from_shampoo:
|
425
|
+
return
|
426
|
+
|
427
|
+
to_shampoo = ','.join([i + o for m, i, o in zip(new_qs, in_str.upper(), out_str) if len(m) > 0])
|
428
|
+
out_str = ''.join([o if o in to_shampoo else i for i, o in zip(in_str, out_str)])
|
429
|
+
|
430
|
+
subscripts = f'{in_str},{from_shampoo},{to_shampoo}->{out_str}'
|
431
|
+
exp_avg_new = torch.einsum(subscripts, exp_avg, *[q for q in Q], *[q for q in new_qs])
|
432
|
+
copy_stochastic_(exp_avg, exp_avg_new)
|
433
|
+
|
434
|
+
for q, q_new in zip(Q, new_qs):
|
435
|
+
copy_stochastic_(q, q_new)
|
388
436
|
|
389
437
|
|
390
438
|
def get_orthogonal_matrix(mat):
|
391
439
|
"""
|
392
440
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
393
441
|
"""
|
394
|
-
matrix = []
|
395
|
-
for m in mat:
|
396
|
-
if len(m) == 0:
|
397
|
-
matrix.append([])
|
398
|
-
continue
|
399
|
-
if m.data.dtype != torch.float:
|
400
|
-
float_data = False
|
401
|
-
original_type = m.data.dtype
|
402
|
-
original_device = m.data.device
|
403
|
-
matrix.append(promote(m.data))
|
404
|
-
else:
|
405
|
-
float_data = True
|
406
|
-
matrix.append(m.data)
|
407
442
|
|
408
443
|
final = []
|
409
|
-
for m in
|
444
|
+
for m in mat:
|
410
445
|
if len(m) == 0:
|
411
446
|
final.append([])
|
412
447
|
continue
|
413
448
|
|
449
|
+
m = promote(m.data)
|
450
|
+
|
414
451
|
device, dtype = m.device, m.dtype
|
415
452
|
for modifier in (None, torch.double, 'cpu'):
|
416
453
|
if modifier is not None:
|
417
454
|
m = m.to(modifier)
|
418
455
|
try:
|
419
|
-
|
420
|
-
|
456
|
+
eigval, eigvec = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
|
457
|
+
eigvec = eigvec.to(device=device, dtype=dtype)
|
421
458
|
break
|
422
459
|
except torch.OutOfMemoryError:
|
423
460
|
pass
|
@@ -427,9 +464,9 @@ def get_orthogonal_matrix(mat):
|
|
427
464
|
else:
|
428
465
|
raise RuntimeError("Failed to compute eigenvalues.")
|
429
466
|
|
430
|
-
|
467
|
+
eigvec = torch.flip(eigvec, [1])
|
431
468
|
|
432
|
-
final.append(
|
469
|
+
final.append(eigvec)
|
433
470
|
|
434
471
|
return final
|
435
472
|
|
@@ -454,7 +491,7 @@ def get_beta1(group):
|
|
454
491
|
|
455
492
|
|
456
493
|
def get_beta2(group):
|
457
|
-
if 'beta2_scale' in group:
|
494
|
+
if 'palm' in group and group['palm'] is True and 'beta2_scale' in group:
|
458
495
|
step = max(group.get("step", 1), 1)
|
459
496
|
return 1 - step ** -group['beta2_scale']
|
460
497
|
if 'betas' in group:
|
@@ -509,21 +546,46 @@ def stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[float, int, T
|
|
509
546
|
_compilable_stochastic_add_(x, y, alpha)
|
510
547
|
|
511
548
|
|
549
|
+
@decorator_knowngood
|
550
|
+
def _compilable_stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
551
|
+
for x_, y_ in zip(x, y):
|
552
|
+
x32 = promote(x_)
|
553
|
+
y32 = promote(y_)
|
554
|
+
copy_stochastic_(x_, x32 * y32)
|
555
|
+
|
556
|
+
|
557
|
+
def stochastic_multiply_(x: List[Tensor], y: List[Tensor]):
|
558
|
+
x, y = list_guard(x, y)
|
559
|
+
_compilable_stochastic_multiply_(x, y)
|
560
|
+
|
561
|
+
|
512
562
|
@decorator
|
513
|
-
def
|
563
|
+
def update_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
564
|
+
"""
|
565
|
+
Simplified by @francois-rozet in commit 704ccc4bab52429f945df421647ec82c54cdd65f
|
566
|
+
Re-commited due to faulty merge
|
567
|
+
"""
|
514
568
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
515
569
|
return
|
516
570
|
|
517
|
-
for idx,
|
518
|
-
if
|
571
|
+
for idx, m in enumerate(GG):
|
572
|
+
if not isinstance(m, Tensor):
|
519
573
|
continue
|
520
574
|
b = einsum_base[idx]
|
521
575
|
g0 = einsum_base[:grad.dim()]
|
522
576
|
g1 = g0.replace(b, b.upper())
|
523
577
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
524
|
-
|
578
|
+
m.lerp_(outer_product, 1 - beta)
|
579
|
+
|
580
|
+
|
581
|
+
def tree_apply(fn):
|
582
|
+
def _fn(*args):
|
583
|
+
return tree_map(fn, *args)
|
584
|
+
|
585
|
+
return _fn
|
525
586
|
|
526
587
|
|
588
|
+
@tree_apply
|
527
589
|
def promote(x):
|
528
590
|
if isinstance(x, torch.dtype) and x in (torch.bfloat16, torch.float16):
|
529
591
|
return torch.float32
|
@@ -540,45 +602,38 @@ def min_dtype(xs: List[Tensor]):
|
|
540
602
|
return torch.float32
|
541
603
|
|
542
604
|
|
543
|
-
def update_preconditioner(grad, Q, GG,
|
605
|
+
def update_preconditioner(grad, Q, GG, exp_avg, max_precond_dim, precondition_1d, beta, update_precond):
|
544
606
|
"""
|
545
607
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
546
608
|
"""
|
547
|
-
|
609
|
+
update_ggt(grad, GG, max_precond_dim, precondition_1d, beta)
|
548
610
|
if update_precond:
|
549
|
-
get_orthogonal_matrix_QR(GG, Q,
|
611
|
+
get_orthogonal_matrix_QR(GG, Q, exp_avg)
|
550
612
|
|
551
613
|
|
552
|
-
def init_preconditioner(grad, state,
|
614
|
+
def init_preconditioner(grad, state, max_precond_dim, precondition_1d):
|
553
615
|
"""
|
554
616
|
Initializes the preconditioner matrices (L and R in the paper).
|
555
617
|
"""
|
556
618
|
state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
|
557
|
-
if grad.
|
558
|
-
if precondition_1d or grad.shape[0] > max_precond_dim:
|
559
|
-
state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
|
560
|
-
else:
|
561
|
-
state['GG'].append([])
|
562
|
-
|
563
|
-
else:
|
619
|
+
if grad.numel() > 1 and (grad.ndim > 1 or precondition_1d):
|
564
620
|
for sh in grad.shape:
|
565
621
|
if sh > max_precond_dim:
|
566
|
-
state['GG'].append(
|
622
|
+
state['GG'].append(None)
|
567
623
|
else:
|
568
624
|
state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))
|
625
|
+
else:
|
626
|
+
state['GG'].append(None)
|
569
627
|
|
570
|
-
|
628
|
+
update_ggt(grad, state['GG'], max_precond_dim, precondition_1d, 0)
|
571
629
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
572
630
|
|
573
631
|
|
574
632
|
@decorator
|
575
633
|
def project(grad, Q, back: bool):
|
576
634
|
"""
|
577
|
-
|
578
635
|
:param grad:
|
579
636
|
:param Q:
|
580
|
-
:param merge_dims:
|
581
|
-
:param max_precond_dim:
|
582
637
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
583
638
|
:return:
|
584
639
|
"""
|
@@ -591,12 +646,40 @@ def project(grad, Q, back: bool):
|
|
591
646
|
return grad
|
592
647
|
|
593
648
|
|
649
|
+
def modify_closure(closure):
|
650
|
+
"""
|
651
|
+
Modifies the closure function to use create_graph=True in backward().
|
652
|
+
|
653
|
+
Args:
|
654
|
+
closure: The closure function passed to the optimizer.
|
655
|
+
|
656
|
+
Returns:
|
657
|
+
The return value of the modified closure.
|
658
|
+
"""
|
659
|
+
|
660
|
+
def patched_backward(self, *args, **kwargs):
|
661
|
+
kwargs['create_graph'] = True
|
662
|
+
return original_backward(self, *args, **kwargs)
|
663
|
+
|
664
|
+
original_backward = torch.Tensor.backward
|
665
|
+
|
666
|
+
with patch.object(torch.Tensor, 'backward', patched_backward):
|
667
|
+
return closure()
|
668
|
+
|
669
|
+
|
594
670
|
class StatefulOptimizer(torch.optim.Optimizer):
|
671
|
+
"""
|
672
|
+
finite_differences saves memory, but needs more compute. (Alternative is true HVP)
|
673
|
+
Both `True` and `False` have some edge cases they don't support, so experiment with it.
|
674
|
+
The previous (heavyball<=1.5.3) default was `True`, which is incompatible with some benchmarks but works better with RevNet
|
675
|
+
Further notice that both methods have different numerics outputs
|
676
|
+
"""
|
595
677
|
ema_decay: float = 0.001
|
596
678
|
compile_step: bool = False
|
597
679
|
hessian_approx: bool = False
|
598
680
|
precond_schedule: Union[Callable, float, None] = None
|
599
681
|
stochastic_schedule: bool = False
|
682
|
+
finite_differences: bool = False
|
600
683
|
|
601
684
|
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
602
685
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
@@ -709,38 +792,68 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
709
792
|
set_(self.state_(p)['param_ema'], p.data)
|
710
793
|
set_(p.data, ema_clone)
|
711
794
|
|
712
|
-
def
|
713
|
-
if self.precond_schedule is None:
|
714
|
-
self._is_preconditioning = False
|
715
|
-
else:
|
716
|
-
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
795
|
+
def _handle_closure(self, closure):
|
717
796
|
hessian_approx = self.hessian_approx and self._is_preconditioning
|
797
|
+
|
718
798
|
if closure is None:
|
719
799
|
if hessian_approx:
|
720
800
|
raise ValueError("Hessian approximation requires a closure.")
|
721
|
-
|
722
|
-
|
801
|
+
return None
|
802
|
+
|
803
|
+
if not hessian_approx:
|
723
804
|
with torch.enable_grad():
|
724
805
|
loss = closure()
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
806
|
+
return loss
|
807
|
+
|
808
|
+
if self.finite_differences:
|
809
|
+
with torch.enable_grad():
|
810
|
+
loss = closure() # closure without retain_graph=True
|
811
|
+
|
812
|
+
grads = []
|
813
|
+
for group in self.param_groups:
|
814
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
815
|
+
grads.append(g)
|
816
|
+
p.vector = torch.randn_like(p)
|
817
|
+
p.orig = p.data.clone()
|
818
|
+
stochastic_add_(p.data, p.vector, tiny_bf16)
|
819
|
+
else:
|
820
|
+
with torch.enable_grad():
|
821
|
+
loss = modify_closure(closure)
|
822
|
+
|
823
|
+
if self.finite_differences:
|
824
|
+
with torch.enable_grad():
|
825
|
+
closure()
|
826
|
+
|
827
|
+
for group in self.param_groups:
|
828
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
829
|
+
p.grad = grads.pop(0)
|
830
|
+
stochastic_add_(g, p.grad, -1)
|
831
|
+
p.hessian_vector = g
|
832
|
+
p.data.copy_(p.orig)
|
833
|
+
del p.orig
|
834
|
+
else:
|
835
|
+
for group in self.param_groups:
|
836
|
+
for p, g in self.split_p_and_g_in_group(group, skip_none=True, should_promote=False):
|
837
|
+
p.grad = g
|
838
|
+
params, grads = zip(*[x for group in self.param_groups for x in
|
839
|
+
self.split_p_and_g_in_group(group, skip_none=True, should_promote=False)])
|
840
|
+
vs = [torch.randn_like(p) for p in params]
|
841
|
+
with torch.enable_grad():
|
842
|
+
hvs = torch.autograd.grad(grads, params, vs)
|
843
|
+
|
844
|
+
for p, g, v, hv in zip(params, grads, vs, hvs):
|
845
|
+
p.hessian_vector = hv
|
846
|
+
p.grad = g
|
847
|
+
p.vector = v
|
848
|
+
|
849
|
+
return loss
|
850
|
+
|
851
|
+
def step(self, closure: Optional[Callable] = None):
|
852
|
+
if self.precond_schedule is None:
|
853
|
+
self._is_preconditioning = False
|
854
|
+
else:
|
855
|
+
self._is_preconditioning = psgd_should_update(self._inner_group, self.precond_schedule, self._precond_rng)
|
856
|
+
loss = self._handle_closure(closure)
|
744
857
|
|
745
858
|
# we assume that parameters are constant and that there are no excessive recompiles
|
746
859
|
with torch.no_grad(), torch._dynamo.utils.disable_cache_limit():
|
@@ -748,7 +861,7 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
748
861
|
group['is_preconditioning'] = self._is_preconditioning
|
749
862
|
self._step(group)
|
750
863
|
if self.use_ema:
|
751
|
-
self.ema_update(
|
864
|
+
self.ema_update()
|
752
865
|
|
753
866
|
return loss
|
754
867
|
|
@@ -758,12 +871,12 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
|
758
871
|
copy_stochastic_(t, s)
|
759
872
|
|
760
873
|
|
761
|
-
|
874
|
+
@decorator_knowngood
|
875
|
+
def _lerp(state: List[Tensor], grad: List[Tensor], beta):
|
762
876
|
ea32 = list(map(promote, state))
|
763
877
|
grad = list(map(promote, grad))
|
764
878
|
beta = promote(beta)
|
765
|
-
|
766
|
-
ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
|
879
|
+
ea32 = [e * beta + g * (1 - beta) for e, g in zip(ea32, grad)]
|
767
880
|
copy_stochastic_list_(state, ea32)
|
768
881
|
return ea32
|
769
882
|
|
@@ -775,15 +888,14 @@ def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: Lis
|
|
775
888
|
beta2 = beta_debias(beta2, step)
|
776
889
|
|
777
890
|
g32 = list(map(promote, grad))
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
u32 = torch._foreach_div(exp_avg32, denom)
|
891
|
+
exp_avg32 = _lerp(exp_avg, g32, beta1)
|
892
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, g32, beta2, eps, [None])
|
893
|
+
u32 = [ea / d for ea, d in zip(exp_avg32, denom)]
|
782
894
|
copy_stochastic_list_(grad, u32)
|
783
895
|
|
784
896
|
|
785
897
|
def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
786
|
-
eps: float):
|
898
|
+
eps: float = 1e-8):
|
787
899
|
exp_avg, exp_avg_sq, grad = map(list_guard, (exp_avg, exp_avg_sq, grad))
|
788
900
|
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
789
901
|
_compilable_adam_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
@@ -798,9 +910,8 @@ def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq:
|
|
798
910
|
beta2 = beta_debias(beta2, step)
|
799
911
|
|
800
912
|
u32, g32 = [list(map(promote, x)) for x in [update, grad]]
|
801
|
-
|
802
|
-
|
803
|
-
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
913
|
+
exp_avg32 = _lerp(exp_avg, u32, beta1)
|
914
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
|
804
915
|
u32 = torch._foreach_div(exp_avg32, denom)
|
805
916
|
_compilable_update_(y, u32, decay, lr, caution, g32)
|
806
917
|
|
@@ -810,51 +921,50 @@ def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor]
|
|
810
921
|
caution: bool):
|
811
922
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
812
923
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
813
|
-
|
924
|
+
_fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
814
925
|
|
815
926
|
|
816
927
|
@decorator_knowngood
|
817
928
|
def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor,
|
818
|
-
beta2: Tensor, step: Tensor):
|
929
|
+
beta2: Tensor, step: Tensor, eps: Tensor):
|
819
930
|
beta1 = beta_debias(beta1, step)
|
820
931
|
beta2 = beta_debias(beta2, step)
|
821
932
|
|
822
933
|
gp32 = list(map(promote, grad))
|
823
|
-
|
824
|
-
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
|
934
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, gp32, beta2, eps, [None])
|
825
935
|
gp32 = torch._foreach_div(gp32, denom)
|
826
|
-
gp32 =
|
827
|
-
|
936
|
+
gp32 = _lerp(exp_avg, gp32, beta1)
|
828
937
|
copy_stochastic_list_(grad, gp32)
|
829
938
|
|
830
939
|
|
831
|
-
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int
|
940
|
+
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int,
|
941
|
+
eps: float = 1e-8):
|
832
942
|
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
833
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
834
|
-
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step)
|
943
|
+
beta1, beta2, step, eps = scalar_guard(beta1, beta2, step, eps, exp_avg[0])
|
944
|
+
_compilable_laprop_(exp_avg, exp_avg_sq, grad, beta1, beta2, step, eps)
|
835
945
|
return grad
|
836
946
|
|
837
947
|
|
838
948
|
@decorator_knowngood
|
839
949
|
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
840
950
|
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
|
841
|
-
caution: bool):
|
951
|
+
caution: bool, eps: Tensor):
|
842
952
|
beta1 = beta_debias(beta1, step)
|
843
953
|
beta2 = beta_debias(beta2, step)
|
844
954
|
|
845
955
|
u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
|
846
|
-
|
847
|
-
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
956
|
+
denom = _compilable_exp_avg_sq_(exp_avg_sq, u32, beta2, eps, [None])
|
848
957
|
u32 = torch._foreach_div(u32, denom)
|
849
|
-
u32 =
|
958
|
+
u32 = _lerp(exp_avg, u32, beta1)
|
850
959
|
_compilable_update_(y, u32, decay, lr, caution, gp32)
|
851
960
|
|
852
961
|
|
853
962
|
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
854
|
-
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool
|
963
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool,
|
964
|
+
eps: float = 1e-8):
|
855
965
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
856
|
-
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
857
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
|
966
|
+
beta1, beta2, step, lr, eps = scalar_guard(beta1, beta2, step, lr, eps, exp_avg[0])
|
967
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution, eps)
|
858
968
|
|
859
969
|
|
860
970
|
@decorator_knowngood
|
@@ -891,7 +1001,7 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
891
1001
|
copy_stochastic_list_(exp_avg, exp_avg32)
|
892
1002
|
|
893
1003
|
beta2 = beta_debias(beta2, step + 1)
|
894
|
-
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32,
|
1004
|
+
exp_avg_sq32 = [eas32.lerp(g * g, 1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
895
1005
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
896
1006
|
|
897
1007
|
copy_stochastic_list_(grad, update)
|
@@ -970,12 +1080,15 @@ def get_soap_precond_schedule(precond_scheduler):
|
|
970
1080
|
return _inner
|
971
1081
|
|
972
1082
|
|
1083
|
+
def _max_idx(x: List[int]):
|
1084
|
+
return len(x) - 1 - np.argmax(x[::-1]) # we want to start counting from the back, as torch is fan-out/fan-in
|
1085
|
+
|
1086
|
+
|
973
1087
|
def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype=None):
|
974
1088
|
"""For a scalar or tensor t, we initialize its preconditioner Q and
|
975
1089
|
reusable einsum expressions for updating Q and preconditioning gradient.
|
976
1090
|
"""
|
977
1091
|
letters = string.ascii_lowercase + string.ascii_uppercase
|
978
|
-
|
979
1092
|
dtype = dtype if dtype is not None else t.dtype
|
980
1093
|
shape = t.shape
|
981
1094
|
|
@@ -992,17 +1105,20 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
992
1105
|
|
993
1106
|
scale = scale ** (1 / len(shape))
|
994
1107
|
|
1108
|
+
dim_diag = [False for _ in shape]
|
995
1109
|
if memory_save_mode is None:
|
996
|
-
|
1110
|
+
pass
|
997
1111
|
elif memory_save_mode == "one_diag":
|
998
|
-
|
999
|
-
|
1000
|
-
|
1112
|
+
dim_diag[_max_idx(shape)] = True
|
1113
|
+
elif memory_save_mode == "smart_one_diag":
|
1114
|
+
sorted_shape = sorted(shape)
|
1115
|
+
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
|
1116
|
+
dim_diag[_max_idx(shape)] = True
|
1001
1117
|
elif memory_save_mode == "all_diag":
|
1002
1118
|
dim_diag = [True for _ in shape]
|
1003
1119
|
else:
|
1004
1120
|
raise ValueError(f"Invalid memory_save_mode: {memory_save_mode}, must be one of "
|
1005
|
-
"[None, 'one_diag', 'all_diag']")
|
1121
|
+
"[None, 'one_diag', 'all_diag', 'smart_one_diag']")
|
1006
1122
|
|
1007
1123
|
Q = []
|
1008
1124
|
piece1A, piece2A, piece3A = ([], "", "")
|
@@ -1016,11 +1132,9 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1016
1132
|
piece1A.append(letters[i])
|
1017
1133
|
piece2A = piece2A + letters[i]
|
1018
1134
|
piece3A = piece3A + letters[i]
|
1019
|
-
|
1020
1135
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1021
1136
|
subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
|
1022
1137
|
exprGs.append(subscripts)
|
1023
|
-
|
1024
1138
|
piece1P.append(letters[i + 13])
|
1025
1139
|
piece2P.append(letters[i + 13])
|
1026
1140
|
piece3P = piece3P + letters[i + 13]
|
@@ -1028,16 +1142,13 @@ def init_Q_exprs(t, scale, max_size, min_ndim_triangular, memory_save_mode, dtyp
|
|
1028
1142
|
else:
|
1029
1143
|
# use triangular matrix as preconditioner for this dim
|
1030
1144
|
Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
|
1031
|
-
|
1032
1145
|
piece1A.append(letters[i] + letters[i + 13])
|
1033
1146
|
piece2A = piece2A + letters[i + 13]
|
1034
1147
|
piece3A = piece3A + letters[i]
|
1035
|
-
|
1036
1148
|
piece1 = "".join([(letters[i + 13] if j == i else letters[j]) for j in range(len(shape))])
|
1037
1149
|
piece2 = "".join([(letters[i + 26] if j == i else letters[j]) for j in range(len(shape))])
|
1038
1150
|
subscripts = (piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26])
|
1039
1151
|
exprGs.append(subscripts)
|
1040
|
-
|
1041
1152
|
a, b, c = (letters[i], letters[i + 13], letters[i + 26])
|
1042
1153
|
piece1P.append(a + b)
|
1043
1154
|
piece2P.append(a + c)
|
@@ -1058,7 +1169,7 @@ def psgd_balance_Q(Q_in):
|
|
1058
1169
|
|
1059
1170
|
|
1060
1171
|
def psgd_calc_A_and_conjB(exprA, G, Q, V=None):
|
1061
|
-
eps = scalar_guard(math.sqrt(torch.finfo(
|
1172
|
+
eps = scalar_guard(math.sqrt(torch.finfo(G.dtype).eps), G)
|
1062
1173
|
eps *= G.norm() / G.numel()
|
1063
1174
|
G = G + torch.randn_like(G) * eps
|
1064
1175
|
md = min_dtype(Q + [G])
|
@@ -1084,9 +1195,7 @@ def psgd_lb(A, max_abs):
|
|
1084
1195
|
A /= max_abs
|
1085
1196
|
a0 = torch.einsum('ij,ij->j', A, A)
|
1086
1197
|
i = torch.argmax(a0)
|
1087
|
-
|
1088
1198
|
x = torch.index_select(A, 1, i).flatten().contiguous()
|
1089
|
-
|
1090
1199
|
x = torch.einsum('i,ij->j', x, A)
|
1091
1200
|
x /= x.norm()
|
1092
1201
|
x = torch.einsum('j,kj->k', x, A)
|
@@ -1099,15 +1208,12 @@ def psgd_lb(A, max_abs):
|
|
1099
1208
|
def psgd_update_precond(Q, exprs, G, precond_lr, oq, store_triu_as_line, V):
|
1100
1209
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
1101
1210
|
exprA, exprGs, _ = exprs
|
1102
|
-
|
1103
1211
|
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q, V)
|
1104
1212
|
|
1105
1213
|
for q, exprG, o in zip(Q, exprGs, oq):
|
1106
1214
|
term1 = promote(torch.einsum(exprG, A, A))
|
1107
1215
|
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
1108
|
-
|
1109
1216
|
term1, term2 = term1 - term2, term1 + term2
|
1110
|
-
|
1111
1217
|
term1 *= precond_lr
|
1112
1218
|
norm = term2.norm(float('inf'))
|
1113
1219
|
if q.dim() < 2:
|
@@ -1221,6 +1327,48 @@ def identity(x):
|
|
1221
1327
|
return x
|
1222
1328
|
|
1223
1329
|
|
1330
|
+
@decorator_knowngood
|
1331
|
+
def _compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1332
|
+
ema32 = _lerp(ema, p, ema_decay)
|
1333
|
+
_lerp(p, ema32, 1 - weight_decay)
|
1334
|
+
|
1335
|
+
|
1336
|
+
def weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1337
|
+
p, ema = list_guard(p, ema)
|
1338
|
+
ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
|
1339
|
+
_compilable_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1340
|
+
|
1341
|
+
|
1342
|
+
@decorator_knowngood
|
1343
|
+
def _compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1344
|
+
ema32 = _lerp(ema, p, ema_decay)
|
1345
|
+
for p_, e_ in zip(p, ema32):
|
1346
|
+
p32 = promote(p_)
|
1347
|
+
p32 = p32 + (p32 - e_).sign() * weight_decay
|
1348
|
+
copy_stochastic_(p_, p32)
|
1349
|
+
|
1350
|
+
|
1351
|
+
def l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay):
|
1352
|
+
p, ema = list_guard(p, ema)
|
1353
|
+
ema_decay, weight_decay = scalar_guard(ema_decay, weight_decay, p[0])
|
1354
|
+
_compilable_l1_weight_decay_to_ema_(p, ema, ema_decay, weight_decay)
|
1355
|
+
|
1356
|
+
|
1357
|
+
@decorator_knowngood
|
1358
|
+
def _compilable_sign_(grad: List[Tensor], graft: bool):
|
1359
|
+
for g_ in grad:
|
1360
|
+
gs = g_.sign()
|
1361
|
+
if graft:
|
1362
|
+
gs = _compilable_grafting(g_, gs)
|
1363
|
+
copy_stochastic_(g_, gs)
|
1364
|
+
|
1365
|
+
|
1366
|
+
def sign_(grad: List[Tensor], graft: bool = True):
|
1367
|
+
grad = list_guard(grad)
|
1368
|
+
_compilable_sign_(grad, graft)
|
1369
|
+
return grad
|
1370
|
+
|
1371
|
+
|
1224
1372
|
@decorator_knowngood
|
1225
1373
|
def _compilable_trust_region_clip_(grad, lerp, scale):
|
1226
1374
|
# (sgn(x) * log(1 + |x|) * 0.1 + tanh(x) * 0.9).clamp_(min=-2, max=2)
|
@@ -1372,7 +1520,6 @@ def _compilable_orthogonalization(weight: List[Tensor], grad: List[Tensor], eps:
|
|
1372
1520
|
|
1373
1521
|
if graft:
|
1374
1522
|
out = _compilable_grafting(g, out)
|
1375
|
-
|
1376
1523
|
copy_stochastic_(g, out)
|
1377
1524
|
|
1378
1525
|
|
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.6.0
|
4
4
|
Summary: Efficient optimizers
|
5
|
-
Home-page: https://github.com/
|
6
|
-
Author:
|
5
|
+
Home-page: https://github.com/HomebrewML/HeavyBall
|
6
|
+
Author: HeavyBall Authors
|
7
7
|
Author-email: github.heavyball@nestler.sh
|
8
8
|
License: BSD
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
@@ -300,7 +300,7 @@ class ForeachSOAP(C.BaseOpt):
|
|
300
300
|
def __init__(self, params, lr: float = 3e-3, betas=(0.9, 0.95), shampoo_beta: float = 0.95, eps: float = 1e-8,
|
301
301
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
302
302
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
303
|
-
|
303
|
+
correct_bias: bool = True, warmup_steps: int = 1,
|
304
304
|
split: bool = False, foreach: bool = True, mars: bool = False, caution: bool = False,
|
305
305
|
mars_gamma: float = 0.0025, palm: bool = C.use_default, precond_scheduler=(1 / 3, 9),
|
306
306
|
beta2_scale: float = 0.8, use_precond_schedule: bool = C.use_default,
|
@@ -324,7 +324,6 @@ the second-order statistics of the gradients to accelerate convergence.
|
|
324
324
|
* **`merge_dims`**: Whether to merge dimensions when forming the preconditioner.
|
325
325
|
* **`precondition_1d`**: Whether to use a 1D preconditioner for 1D parameters.
|
326
326
|
* **`normalize_grads`**: Whether to normalize gradients before applying SOAP.
|
327
|
-
* **`data_format`**: `"channels_first"` or `"channels_last"`. Specifies the data format of the input tensors.
|
328
327
|
* **`correct_bias`**: Enables/disables bias correction for the running averages.
|
329
328
|
* **`warmup_steps`**: Number of steps for linear learning rate warmup.
|
330
329
|
* **`split`**: Whether to split large dimensions when forming the preconditioner.
|
@@ -931,4 +930,3 @@ tasks. However, the best choice always depends on your specific model, dataset,
|
|
931
930
|
* **`heavyball.utils`:** Remember to utilize the settings and functions in `heavyball.utils` (e.g., `set_torch`,
|
932
931
|
`compile_mode`, `zeroth_power_mode`, clipping functions) to optimize performance and experiment with different
|
933
932
|
configurations.
|
934
|
-
|
@@ -0,0 +1,8 @@
|
|
1
|
+
heavyball/__init__.py,sha256=DKp8rEAf7mt2-j9XRVlgjaLjyfuwUsyl_uXJoOKWAHg,15362
|
2
|
+
heavyball/chainable.py,sha256=n_u0QS92WitbtnENvNQ0m4dZTHuJ5ObQ88XA3cmhCfo,27298
|
3
|
+
heavyball/utils.py,sha256=Nk0q_sfv47F-QC9Wwi5KCt-C_71OhuzM98XHlYGvl24,55905
|
4
|
+
heavyball-1.6.0.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
+
heavyball-1.6.0.dist-info/METADATA,sha256=5suezTlZCOBwCgHeFgkLaywYwjAWN1SPg6yhvAv1WgE,43441
|
6
|
+
heavyball-1.6.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
+
heavyball-1.6.0.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
+
heavyball-1.6.0.dist-info/RECORD,,
|
heavyball-1.5.2.dist-info/RECORD
DELETED
@@ -1,8 +0,0 @@
|
|
1
|
-
heavyball/__init__.py,sha256=f0wWIjsibgA4_YwkPP8HFD7-snggYsAOFc84W0WnNMA,12049
|
2
|
-
heavyball/chainable.py,sha256=ygeQU-t3RT0Q1BWrEQ_0b4SlXYy8aGDt0DCZAfbiNiw,25040
|
3
|
-
heavyball/utils.py,sha256=D7ENwrIex_dgFiUHezymmsIdruoQ4_hYztIolCXo2KE,50636
|
4
|
-
heavyball-1.5.2.dist-info/LICENSE,sha256=CGdGJim64YifGmUVPaeyRsxkvyExtClswhRNIp8FY_U,1322
|
5
|
-
heavyball-1.5.2.dist-info/METADATA,sha256=n_2fW7Wcz_btxBRWFibTe8wnM10B2su100bJzW0bfZY,43584
|
6
|
-
heavyball-1.5.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
7
|
-
heavyball-1.5.2.dist-info/top_level.txt,sha256=SzCxSVg_qCUPA4kZObW3Zyo4v-d_mMOD-p7a-WXTl2E,10
|
8
|
-
heavyball-1.5.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|