heavyball 1.1.1__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-1.1.1 → heavyball-1.1.2}/PKG-INFO +1 -1
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball/__init__.py +40 -35
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball/chainable.py +17 -19
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball/utils.py +91 -78
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-1.1.1 → heavyball-1.1.2}/setup.py +1 -1
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_channels_last.py +2 -1
- {heavyball-1.1.1 → heavyball-1.1.2}/LICENSE +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/README.md +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/SOURCES.txt +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/setup.cfg +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_bf16_params.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_bf16_q.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_bf16_storage.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_caution.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_closure.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_ema.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_foreach.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_hook.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_mars.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_memory.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_merge.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_no_grad.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_psgd.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_soap.py +0 -0
- {heavyball-1.1.1 → heavyball-1.1.2}/test/test_stochastic_updates.py +0 -0
@@ -10,9 +10,9 @@ class ForeachAdamW(C.BaseOpt):
|
|
10
10
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
11
11
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
12
12
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
13
|
-
defaults =
|
14
|
-
|
15
|
-
|
13
|
+
defaults = locals()
|
14
|
+
defaults.pop("self")
|
15
|
+
params = defaults.pop("params")
|
16
16
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adam)
|
17
17
|
|
18
18
|
|
@@ -25,9 +25,9 @@ class ForeachRMSprop(C.BaseOpt):
|
|
25
25
|
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
26
26
|
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
27
27
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
28
|
-
defaults =
|
29
|
-
|
30
|
-
|
28
|
+
defaults = locals()
|
29
|
+
defaults.pop("self")
|
30
|
+
params = defaults.pop("params")
|
31
31
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq)
|
32
32
|
|
33
33
|
|
@@ -36,10 +36,9 @@ class ForeachSFAdamW(C.ScheduleFree):
|
|
36
36
|
weight_lr_power=2.0, foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False,
|
37
37
|
caution: bool = False, mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
38
38
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
39
|
-
defaults =
|
40
|
-
|
41
|
-
|
42
|
-
beta2_scale=beta2_scale)
|
39
|
+
defaults = locals()
|
40
|
+
defaults.pop("self")
|
41
|
+
params = defaults.pop("params")
|
43
42
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_exp_avg_sq,
|
44
43
|
C.update_by_schedule_free)
|
45
44
|
|
@@ -53,9 +52,9 @@ class ForeachADOPT(C.BaseOpt):
|
|
53
52
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
54
53
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
55
54
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
56
|
-
defaults =
|
57
|
-
|
58
|
-
|
55
|
+
defaults = locals()
|
56
|
+
defaults.pop("self")
|
57
|
+
params = defaults.pop("params")
|
59
58
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_adopt)
|
60
59
|
|
61
60
|
|
@@ -65,9 +64,9 @@ class ForeachMuon(C.BaseOpt):
|
|
65
64
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
66
65
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8,
|
67
66
|
nesterov: bool = True):
|
68
|
-
defaults =
|
69
|
-
|
70
|
-
|
67
|
+
defaults = locals()
|
68
|
+
defaults.pop("self")
|
69
|
+
params = defaults.pop("params")
|
71
70
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm,
|
72
71
|
C.nesterov_momentum if nesterov else C.heavyball_momentum, C.orthogonalize_update)
|
73
72
|
|
@@ -77,12 +76,24 @@ class ForeachLaProp(C.BaseOpt):
|
|
77
76
|
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
78
77
|
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
79
78
|
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
80
|
-
defaults =
|
81
|
-
|
82
|
-
|
79
|
+
defaults = locals()
|
80
|
+
defaults.pop("self")
|
81
|
+
params = defaults.pop("params")
|
83
82
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.update_by_laprop)
|
84
83
|
|
85
84
|
|
85
|
+
class MuonLaProp(C.BaseOpt):
|
86
|
+
def __init__(self, params, lr=0.0025, betas=(0.9, 0.99), eps=1e-8, weight_decay=0, warmup_steps=0,
|
87
|
+
foreach: bool = True, storage_dtype: str = 'float32', mars: bool = False, caution: bool = False,
|
88
|
+
mars_gamma: float = 0.0025, gradient_clipping: C.str_or_fn = C.use_default,
|
89
|
+
update_clipping: C.str_or_fn = C.use_default, palm: bool = C.use_default, beta2_scale: float = 0.8):
|
90
|
+
defaults = locals()
|
91
|
+
defaults.pop("self")
|
92
|
+
params = defaults.pop("params")
|
93
|
+
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, palm, C.scale_by_laprop,
|
94
|
+
C.orthogonalize_update)
|
95
|
+
|
96
|
+
|
86
97
|
class ForeachSOAP(C.BaseOpt):
|
87
98
|
"""
|
88
99
|
ForeachSOAP
|
@@ -112,12 +123,10 @@ class ForeachSOAP(C.BaseOpt):
|
|
112
123
|
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default):
|
113
124
|
use_precond_schedule = C.default(use_precond_schedule, self.use_precond_schedule)
|
114
125
|
|
115
|
-
defaults =
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
'caution': caution, 'mars_gamma': mars_gamma, 'palm': palm, 'precond_scheduler': precond_scheduler,
|
120
|
-
'beta2_scale': beta2_scale}
|
126
|
+
defaults = locals()
|
127
|
+
defaults.pop("self")
|
128
|
+
params = defaults.pop("params")
|
129
|
+
|
121
130
|
if use_precond_schedule:
|
122
131
|
del defaults['precondition_frequency']
|
123
132
|
else:
|
@@ -161,19 +170,15 @@ class ForeachPSGDKron(C.BaseOpt):
|
|
161
170
|
gradient_clipping: C.str_or_fn = C.use_default, update_clipping: C.str_or_fn = C.use_default, #
|
162
171
|
# expert parameters
|
163
172
|
precond_init_scale=1.0, precond_lr=0.1):
|
173
|
+
defaults = locals()
|
174
|
+
defaults.pop("self")
|
175
|
+
params = defaults.pop("params")
|
176
|
+
|
164
177
|
delayed = C.default(delayed, self.delayed)
|
165
178
|
cached = C.default(cached, self.cached)
|
166
179
|
exp_avg_input = C.default(exp_avg_input, self.exp_avg_input)
|
167
180
|
update_clipping = C.default(update_clipping, utils.trust_region_clip_)
|
168
181
|
|
169
|
-
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
170
|
-
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
171
|
-
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
172
|
-
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
173
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
174
|
-
storage_dtype=storage_dtype, caution=caution, mars_gamma=mars_gamma, mars=mars,
|
175
|
-
stochastic_schedule=stochastic_schedule)
|
176
|
-
|
177
182
|
super().__init__(params, defaults, foreach, gradient_clipping, update_clipping, False, #
|
178
183
|
*(C.exp_avg,) * exp_avg_input, #
|
179
184
|
functools.partial(C.scale_by_delayed_psgd if delayed else C.scale_by_psgd, cached=cached,
|
@@ -215,9 +220,9 @@ CachedPSGDKron = ForeachCachedPSGDKron
|
|
215
220
|
CachedDelayedPSGDKron = ForeachCachedDelayedPSGDKron
|
216
221
|
Muon = ForeachMuon
|
217
222
|
|
218
|
-
__all__ = ["Muon","RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
223
|
+
__all__ = ["Muon", "RMSprop", "PrecondSchedulePaLMSOAP", "PSGDKron", "PurePSGD", "DelayedPSGD", "CachedPSGDKron",
|
219
224
|
"CachedDelayedPSGDKron", "PalmForEachSoap", "PaLMSOAP", "PaLMSFAdamW", "LaProp", "ADOPT",
|
220
|
-
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop',
|
225
|
+
"PrecondScheduleSOAP", "PrecondSchedulePaLMSOAP", 'RMSprop', 'MuonLaProp', #
|
221
226
|
"ForeachAdamW", "ForeachSFAdamW", "ForeachLaProp", "ForeachADOPT", "ForeachSOAP", "ForeachPSGDKron",
|
222
227
|
"ForeachPurePSGD", "ForeachDelayedPSGD", "ForeachCachedPSGDKron", "ForeachCachedDelayedPSGDKron",
|
223
228
|
"ForeachRMSprop", "ForeachMuon"]
|
@@ -160,22 +160,21 @@ def scale_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
160
160
|
@zero_guard("exp_avg", "exp_avg_sq")
|
161
161
|
@no_state
|
162
162
|
def update_by_adam(group, update, grad, param, exp_avg, exp_avg_sq):
|
163
|
-
utils.fused_adam_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
|
164
|
-
group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
163
|
+
utils.fused_adam_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
|
164
|
+
group['step'], group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
165
165
|
raise SkipUpdate
|
166
166
|
|
167
167
|
|
168
168
|
@zero_guard("exp_avg", "exp_avg_sq")
|
169
169
|
@no_state
|
170
170
|
def scale_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
171
|
-
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step']
|
172
|
-
group['eps'])
|
171
|
+
return utils.laprop_(exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group), group['step'])
|
173
172
|
|
174
173
|
|
175
174
|
@zero_guard("exp_avg", "exp_avg_sq")
|
176
175
|
@no_state
|
177
176
|
def update_by_laprop(group, update, grad, param, exp_avg, exp_avg_sq):
|
178
|
-
utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, utils.get_beta1(group), utils.get_beta2(group),
|
177
|
+
utils.fused_laprop_(param, exp_avg, exp_avg_sq, update, grad, utils.get_beta1(group), utils.get_beta2(group),
|
179
178
|
group['step'], group['lr'], group['weight_decay'], group['caution'])
|
180
179
|
raise SkipUpdate
|
181
180
|
|
@@ -203,7 +202,7 @@ def update_by_adopt(group, update, grad, param, exp_avg, exp_avg_sq):
|
|
203
202
|
utils.exp_avg_sq_(exp_avg_sq, update, utils.beta_debias(utils.get_beta2(group), group['step']), 1)
|
204
203
|
raise SkipUpdate
|
205
204
|
|
206
|
-
utils.fused_adopt_(param, update, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
205
|
+
utils.fused_adopt_(param, update, grad, exp_avg_sq, exp_avg, utils.get_beta1(group), utils.get_beta2(group),
|
207
206
|
group['step'] - 2, group['lr'], group['eps'], group['weight_decay'], group['caution'])
|
208
207
|
raise SkipUpdate
|
209
208
|
|
@@ -262,13 +261,13 @@ def precond_schedule(group, prob: Union[callable, float, None] = None, name: str
|
|
262
261
|
|
263
262
|
|
264
263
|
@no_state_no_foreach
|
265
|
-
def orthogonalize_update(group, update, grad, param):
|
264
|
+
def orthogonalize_update(group, update, grad, param, scale_mode: str = "scale"): # explore scale_mode="graft"
|
266
265
|
if update.dim() == 1:
|
267
266
|
return update
|
268
267
|
original_shape = update.shape
|
269
268
|
# doing it this way, as tmp and update are not guaranteed to share memory address or layout
|
270
269
|
tmp = update.flatten(1, -1)
|
271
|
-
utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp)
|
270
|
+
utils.inplace_orthogonal_(tmp, utils.zeroth_power_mode, tmp, scale_mode)
|
272
271
|
return tmp.reshape(original_shape)
|
273
272
|
|
274
273
|
|
@@ -331,7 +330,7 @@ def _update_psgd_cache(cached, Q_cache, q):
|
|
331
330
|
|
332
331
|
def _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache):
|
333
332
|
if cached:
|
334
|
-
return utils.precond_grad_cached_(cache_expr, update, *
|
333
|
+
return utils.precond_grad_cached_(cache_expr, update, *Q_cache)
|
335
334
|
return utils.psgd_precond_grad(exprs[-1], update, *Q_mat)
|
336
335
|
|
337
336
|
|
@@ -352,7 +351,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
352
351
|
update = update.to(memory_format=torch.contiguous_format)
|
353
352
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
354
353
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
355
|
-
out = _cached_psgd_precond_grad(
|
354
|
+
out = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
356
355
|
return torch.as_strided(out, old.shape, old.stride())
|
357
356
|
|
358
357
|
|
@@ -361,7 +360,7 @@ def scale_by_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str
|
|
361
360
|
def scale_by_delayed_psgd(group, update, grad, param, Q, exprs, Q_cache, cache_expr: str, cached: bool = False,
|
362
361
|
prob: Optional[callable] = None):
|
363
362
|
Q_mat = utils.line_to_triu(Q) if group['store_triu_as_line'] else Q
|
364
|
-
precond = _cached_psgd_precond_grad(
|
363
|
+
precond = _cached_psgd_precond_grad(cached, cache_expr, exprs, update, Q_mat, Q_cache)
|
365
364
|
_update_psgd_precond(group, param, update, Q_mat, Q, exprs, prob)
|
366
365
|
return precond
|
367
366
|
|
@@ -418,7 +417,6 @@ def chain(state: Union[callable, dict], group, grad, param, *fns):
|
|
418
417
|
class ChainOpt(utils.StatefulOptimizer):
|
419
418
|
def __init__(self, params, defaults, foreach: bool, *fns):
|
420
419
|
super().__init__(params, defaults, foreach)
|
421
|
-
|
422
420
|
self.fns = tuple(fns)
|
423
421
|
|
424
422
|
def _step(self, group):
|
@@ -473,9 +471,10 @@ class BaseOpt(ChainOpt):
|
|
473
471
|
update_clipping: str_or_fn = None
|
474
472
|
palm: bool = False
|
475
473
|
auto_fuse: bool = True
|
474
|
+
compile_step: bool = False
|
476
475
|
|
477
476
|
def __init__(self, params, defaults, foreach: bool, gradient_clipping: str_or_fn, update_clipping: str_or_fn,
|
478
|
-
palm: bool =
|
477
|
+
palm: bool = use_default, *fns):
|
479
478
|
if default(update_clipping, self.update_clipping) is None:
|
480
479
|
if fns and self.auto_fuse:
|
481
480
|
args, kwargs = None, None
|
@@ -490,6 +489,7 @@ class BaseOpt(ChainOpt):
|
|
490
489
|
else:
|
491
490
|
if any(fn in (update_by_adopt, update_by_adam, update_by_laprop, update_by_schedule_free) for fn in fns):
|
492
491
|
raise ValueError("`update_by` functions do not support update clipping. Use `scale_by`")
|
492
|
+
|
493
493
|
fns = tuple(fns)
|
494
494
|
|
495
495
|
if default(palm, self.palm):
|
@@ -505,9 +505,9 @@ class BaseOpt(ChainOpt):
|
|
505
505
|
class ScheduleFree(BaseOpt):
|
506
506
|
def eval(self):
|
507
507
|
for group in self.param_groups:
|
508
|
-
train_mode = group
|
508
|
+
group['train_mode'] = train_mode = not group.get('train_mode')
|
509
509
|
beta1 = utils.get_beta1(group)
|
510
|
-
if beta1 > 0 and train_mode:
|
510
|
+
if beta1 > 0 and not train_mode:
|
511
511
|
for p in group['params']:
|
512
512
|
state = self.state_(p)
|
513
513
|
if 'z' in state:
|
@@ -516,13 +516,12 @@ class ScheduleFree(BaseOpt):
|
|
516
516
|
p32 = utils.promote(p.data)
|
517
517
|
p32.lerp_(end=z, weight=1 - 1 / beta1)
|
518
518
|
utils.copy_stochastic_(p.data, p32)
|
519
|
-
group['train_mode'] = False
|
520
519
|
|
521
520
|
def train(self):
|
522
521
|
for group in self.param_groups:
|
523
|
-
train_mode = group
|
522
|
+
group['train_mode'] = train_mode = not group.get('train_mode')
|
524
523
|
beta1 = utils.get_beta1(group)
|
525
|
-
if beta1 > 0 and
|
524
|
+
if beta1 > 0 and train_mode:
|
526
525
|
for p in group['params']:
|
527
526
|
state = self.state_(p)
|
528
527
|
if 'z' in state:
|
@@ -530,4 +529,3 @@ class ScheduleFree(BaseOpt):
|
|
530
529
|
p32 = utils.promote(p.data)
|
531
530
|
p32.lerp_(end=z, weight=1 - beta1)
|
532
531
|
utils.copy_stochastic_(p.data, p32)
|
533
|
-
group['train_mode'] = True
|
@@ -163,7 +163,7 @@ def beta_debias(beta, step):
|
|
163
163
|
def _compilable_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor,
|
164
164
|
out: List[Optional[Tensor]]):
|
165
165
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
166
|
-
torch.
|
166
|
+
s32 = torch._foreach_mul(s32, beta2)
|
167
167
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
168
168
|
denom = torch._foreach_sqrt(s32)
|
169
169
|
[d.clamp_(min=eps) for d in denom]
|
@@ -185,11 +185,11 @@ def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
|
185
185
|
@decorator_knowngood
|
186
186
|
def _compilable_scale_by_exp_avg_sq_(state: List[Tensor], grad: List[Tensor], beta2: Tensor, eps: Tensor):
|
187
187
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
188
|
-
torch.
|
188
|
+
s32 = torch._foreach_mul(s32, beta2)
|
189
189
|
[s.addcmul_(g, g, value=1 - beta2) for s, g in zip(s32, g32)]
|
190
190
|
denom = torch._foreach_sqrt(s32)
|
191
191
|
[d.clamp_(min=eps) for d in denom]
|
192
|
-
out = torch.
|
192
|
+
out = torch._foreach_div(g32, denom)
|
193
193
|
copy_stochastic_list_(state, s32)
|
194
194
|
copy_stochastic_list_(grad, out)
|
195
195
|
|
@@ -204,7 +204,7 @@ def scale_by_exp_avg_sq_(exp_avg_sq, grad, beta2, eps):
|
|
204
204
|
@decorator_knowngood
|
205
205
|
def _compilable_exp_avg_(state, grad, beta):
|
206
206
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
207
|
-
[s.
|
207
|
+
s32 = [s.lerp(g, beta) for s, g in zip(s32, g32)]
|
208
208
|
copy_stochastic_list_(state, s32)
|
209
209
|
copy_stochastic_list_(grad, s32)
|
210
210
|
|
@@ -225,7 +225,7 @@ def _compilable_agc_(parameters: List[Tensor], gradients: List[Tensor], clip_val
|
|
225
225
|
torch._foreach_div_(p_norm, g_norm)
|
226
226
|
torch._foreach_mul_(p_norm, clip_val)
|
227
227
|
torch._foreach_minimum_(p_norm, 1)
|
228
|
-
|
228
|
+
torch._foreach_mul_(gradients, p_norm)
|
229
229
|
|
230
230
|
|
231
231
|
def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor], clip_val: float,
|
@@ -234,7 +234,8 @@ def adaptive_gradient_clipping_(parameters: List[Tensor], gradients: List[Tensor
|
|
234
234
|
return gradients
|
235
235
|
parameters, gradients = list_guard(parameters, gradients)
|
236
236
|
clip_val = scalar_guard(clip_val, parameters[0])
|
237
|
-
|
237
|
+
_compilable_agc_(parameters, gradients, clip_val, minimum, eps)
|
238
|
+
return gradients
|
238
239
|
|
239
240
|
|
240
241
|
def is_compiling():
|
@@ -304,7 +305,7 @@ def ortho(x):
|
|
304
305
|
@decorator_knowngood
|
305
306
|
def _compilable_heavyball_momentum_(state, grad, beta):
|
306
307
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
307
|
-
torch.
|
308
|
+
s32 = torch._foreach_mul(s32, beta)
|
308
309
|
torch._foreach_add_(s32, g32)
|
309
310
|
copy_stochastic_list_(state, s32)
|
310
311
|
copy_stochastic_list_(grad, s32)
|
@@ -313,7 +314,7 @@ def _compilable_heavyball_momentum_(state, grad, beta):
|
|
313
314
|
@decorator_knowngood
|
314
315
|
def _compilable_nesterov_momentum_(state, grad, beta):
|
315
316
|
s32, g32 = [list(map(promote, x)) for x in (state, grad)]
|
316
|
-
torch.
|
317
|
+
s32 = torch._foreach_mul(s32, beta)
|
317
318
|
torch._foreach_add_(s32, g32)
|
318
319
|
[g.add_(s, alpha=beta) for g, s in zip(g32, s32)]
|
319
320
|
copy_stochastic_list_(state, s32)
|
@@ -334,17 +335,27 @@ def nesterov_momentum(state, grad, beta):
|
|
334
335
|
return grad
|
335
336
|
|
336
337
|
|
338
|
+
# mode in ("newtonschulz", "qr", "svd")
|
339
|
+
# scale_mode in ("none", "scale", "graft")
|
337
340
|
@decorator_knowngood
|
338
|
-
def inplace_orthogonal_(x, mode, out):
|
341
|
+
def inplace_orthogonal_(x: Tensor, mode: str, out: Tensor, scale_mode: str):
|
339
342
|
if mode == 'newtonschulz' or x.shape[0] != x.shape[1]:
|
340
343
|
y = zeropower_via_newtonschulz5(x, 5)
|
341
344
|
elif mode == 'qr':
|
342
|
-
y = torch.linalg.qr(x).Q
|
345
|
+
y = torch.linalg.qr(promote(x)).Q
|
343
346
|
elif mode == 'svd':
|
344
|
-
u, s, v = torch.linalg.svd(x)
|
347
|
+
u, s, v = torch.linalg.svd(promote(x))
|
345
348
|
y = u @ v.T
|
346
349
|
else:
|
347
350
|
raise NotImplementedError(f"Unknown zeroth_power_mode: {mode}")
|
351
|
+
if scale_mode == "none":
|
352
|
+
pass
|
353
|
+
elif scale_mode == "scale":
|
354
|
+
y *= max(1, x.size(0) / x.size(1)) ** 0.5
|
355
|
+
elif scale_mode == "graft":
|
356
|
+
y *= x.norm() / y.norm().clamp_(min=1e-6)
|
357
|
+
else:
|
358
|
+
raise NotImplementedError(f"Unknown scale_mode: {scale_mode}")
|
348
359
|
set_(out, y)
|
349
360
|
|
350
361
|
|
@@ -378,7 +389,7 @@ def get_orthogonal_matrix_QR(GG, Q, exp_avg_sq):
|
|
378
389
|
est_eig = torch.einsum('ij,ij->j', o, tmp)
|
379
390
|
sort_idx = torch.argsort(est_eig, descending=True)
|
380
391
|
indices.append(sort_idx)
|
381
|
-
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q)
|
392
|
+
inplace_orthogonal_(tmp[:, sort_idx], zeroth_power_mode, q, "none")
|
382
393
|
|
383
394
|
indices = tuple(slice(None) if ind is None else ind.view(*(1,) * i, -1, *(1,) * (exp_avg_sq.dim() - i - 1)) #
|
384
395
|
for i, ind in enumerate(indices))
|
@@ -437,8 +448,7 @@ def _compilable_stochastic_lerp_(x: List[Tensor], y: List[Tensor], a: Union[floa
|
|
437
448
|
for x_, y_ in zip(x, y):
|
438
449
|
x32 = promote(x_)
|
439
450
|
y32 = promote(y_)
|
440
|
-
x32.
|
441
|
-
copy_stochastic_(x_, x32)
|
451
|
+
copy_stochastic_(x_, x32.lerp(y32, a))
|
442
452
|
|
443
453
|
|
444
454
|
def get_beta1(group):
|
@@ -499,7 +509,7 @@ def _compilable_stochastic_add_(x: List[Tensor], y: List[Tensor], alpha: Union[f
|
|
499
509
|
for x_, y_ in zip(x, y):
|
500
510
|
x32 = promote(x_)
|
501
511
|
y32 = promote(y_)
|
502
|
-
x32.add_(y32, alpha=alpha)
|
512
|
+
x32.add_(y32, alpha=alpha) # can't use out-of-place here; torch.compile doesn't handle data-dependent inputs
|
503
513
|
copy_stochastic_(x_, x32)
|
504
514
|
|
505
515
|
|
@@ -521,7 +531,7 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
521
531
|
g0 = einsum_base[:grad.dim()]
|
522
532
|
g1 = g0.replace(b, b.upper())
|
523
533
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
524
|
-
GG[idx].lerp_(
|
534
|
+
GG[idx].lerp_(outer_product, 1 - beta)
|
525
535
|
|
526
536
|
|
527
537
|
def promote(x):
|
@@ -586,7 +596,8 @@ def project(grad, Q, back: bool):
|
|
586
596
|
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
|
587
597
|
if preconditioners:
|
588
598
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
589
|
-
|
599
|
+
out = torch.einsum(f'{param},{preconditioners}->{out}', promote(grad), *[q for q in Q if len(q) > 0])
|
600
|
+
grad = out.to(grad.dtype)
|
590
601
|
return grad
|
591
602
|
|
592
603
|
|
@@ -739,20 +750,26 @@ def copy_stochastic_list_(target: List[Tensor], source: List[Tensor]):
|
|
739
750
|
copy_stochastic_(t, s)
|
740
751
|
|
741
752
|
|
753
|
+
def _lerp32(state: List[Tensor], grad: List[Tensor], beta):
|
754
|
+
ea32 = list(map(promote, state))
|
755
|
+
grad = list(map(promote, grad))
|
756
|
+
|
757
|
+
ea32 = [e.lerp(g, 1 - beta) for e, g in zip(ea32, grad)]
|
758
|
+
copy_stochastic_list_(state, ea32)
|
759
|
+
return ea32
|
760
|
+
|
761
|
+
|
742
762
|
@decorator_knowngood
|
743
763
|
def _compilable_adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: Tensor, beta2: Tensor,
|
744
764
|
step: Tensor):
|
745
765
|
beta1 = beta_debias(beta1, step)
|
746
766
|
beta2 = beta_debias(beta2, step)
|
747
767
|
|
748
|
-
g32
|
768
|
+
g32 = list(map(promote, grad))
|
749
769
|
|
750
|
-
|
751
|
-
denom = exp_avg_sq_(
|
770
|
+
exp_avg32 = _lerp32(exp_avg, g32, beta1)
|
771
|
+
denom = exp_avg_sq_(exp_avg_sq, g32, beta2, 1e-8)
|
752
772
|
u32 = torch._foreach_div(exp_avg32, denom)
|
753
|
-
|
754
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
755
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
756
773
|
copy_stochastic_list_(grad, u32)
|
757
774
|
|
758
775
|
|
@@ -764,28 +781,26 @@ def adam_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], b
|
|
764
781
|
|
765
782
|
|
766
783
|
@decorator_knowngood
|
767
|
-
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
768
|
-
beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
|
769
|
-
caution: bool):
|
784
|
+
def _fused_compilable_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
785
|
+
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, decay: Tensor, lr: Tensor,
|
786
|
+
eps: Tensor, caution: bool):
|
770
787
|
beta1 = beta_debias(beta1, step)
|
771
788
|
beta2 = beta_debias(beta2, step)
|
772
789
|
|
773
|
-
|
790
|
+
u32, g32 = [list(map(promote, x)) for x in [update, grad]]
|
774
791
|
|
775
|
-
|
776
|
-
denom = exp_avg_sq_(
|
792
|
+
exp_avg32 = _lerp32(exp_avg, u32, beta1)
|
793
|
+
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
777
794
|
u32 = torch._foreach_div(exp_avg32, denom)
|
795
|
+
_compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
|
778
796
|
|
779
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
780
|
-
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
781
|
-
_compilable_update_(y, u32, decay, lambda a, b, c: a.add_(b, alpha=c), lr, caution, g32)
|
782
797
|
|
783
|
-
|
784
|
-
|
785
|
-
|
798
|
+
def fused_adam_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
799
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, eps: float, decay: float,
|
800
|
+
caution: bool):
|
786
801
|
y, exp_avg, exp_avg_sq, grad = list_guard(y, exp_avg, exp_avg_sq, grad)
|
787
802
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, y[0])
|
788
|
-
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, decay, lr, eps, caution)
|
803
|
+
return _fused_compilable_adam_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, decay, lr, eps, caution)
|
789
804
|
|
790
805
|
|
791
806
|
@decorator_knowngood
|
@@ -794,14 +809,13 @@ def _compilable_laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: L
|
|
794
809
|
beta1 = beta_debias(beta1, step)
|
795
810
|
beta2 = beta_debias(beta2, step)
|
796
811
|
|
797
|
-
gp32
|
812
|
+
gp32 = list(map(promote, grad))
|
798
813
|
|
799
|
-
denom = exp_avg_sq_(
|
814
|
+
denom = exp_avg_sq_(exp_avg_sq, gp32, beta2, 1e-8)
|
800
815
|
gp32 = torch._foreach_div(gp32, denom)
|
801
|
-
|
816
|
+
gp32 = _lerp32(exp_avg, gp32, beta1)
|
802
817
|
|
803
|
-
copy_stochastic_list_(
|
804
|
-
copy_stochastic_list_(grad, exp_avg)
|
818
|
+
copy_stochastic_list_(grad, gp32)
|
805
819
|
|
806
820
|
|
807
821
|
def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor], beta1: float, beta2: float, step: int):
|
@@ -812,52 +826,50 @@ def laprop_(exp_avg: List[Tensor], exp_avg_sq: List[Tensor], grad: List[Tensor],
|
|
812
826
|
|
813
827
|
|
814
828
|
@decorator_knowngood
|
815
|
-
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
816
|
-
|
817
|
-
|
829
|
+
def _fused_compilable_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
830
|
+
grad: List[Tensor], beta1: Tensor, beta2: Tensor, step: Tensor, lr: Tensor, decay: Tensor,
|
831
|
+
caution: bool):
|
818
832
|
beta1 = beta_debias(beta1, step)
|
819
833
|
beta2 = beta_debias(beta2, step)
|
820
834
|
|
821
|
-
|
822
|
-
|
823
|
-
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
824
|
-
gp32 = torch._foreach_div(gp32, denom)
|
825
|
-
stochastic_lerp_(exp_avg, gp32, 1 - beta1)
|
826
|
-
update_param_(y, gp32, lr, decay, caution=caution, grad=gp32)
|
835
|
+
u32, gp32 = [list(map(promote, x)) for x in [update, grad]]
|
827
836
|
|
828
|
-
|
837
|
+
denom = exp_avg_sq_(exp_avg_sq, u32, beta2, 1e-8)
|
838
|
+
u32 = torch._foreach_div(u32, denom)
|
839
|
+
u32 = _lerp32(exp_avg, u32, beta1)
|
840
|
+
_compilable_update_(y, u32, decay, stochastic_add_, lr, caution, gp32)
|
829
841
|
|
830
842
|
|
831
|
-
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor],
|
832
|
-
beta2: float, step: int, lr: float, decay: float, caution: bool):
|
843
|
+
def fused_laprop_(y: List[Tensor], exp_avg: List[Tensor], exp_avg_sq: List[Tensor], update: List[Tensor],
|
844
|
+
grad: List[Tensor], beta1: float, beta2: float, step: int, lr: float, decay: float, caution: bool):
|
833
845
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
834
846
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
835
|
-
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, grad, beta1, beta2, step, lr, decay, caution)
|
847
|
+
_fused_compilable_laprop_(y, exp_avg, exp_avg_sq, update, grad, beta1, beta2, step, lr, decay, caution)
|
836
848
|
|
837
849
|
|
838
850
|
@decorator_knowngood
|
839
|
-
def _fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
840
|
-
|
841
|
-
|
851
|
+
def _fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
852
|
+
u32, g32, exp_avg_sq32, exp_avg32 = [list(map(promote, x)) for x in [update, grad, exp_avg_sq, exp_avg]]
|
853
|
+
_compilable_update_(y, u32, decay, stochastic_add_, lr, caution, g32)
|
842
854
|
|
843
855
|
beta1 = beta_debias(beta1, step)
|
844
856
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
845
857
|
[denom.clamp_(min=eps) for denom in denom]
|
846
|
-
torch.
|
847
|
-
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32,
|
858
|
+
exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
|
859
|
+
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, u32, denom)]
|
860
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
848
861
|
|
849
862
|
beta2 = beta_debias(beta2, step + 1)
|
850
|
-
torch.
|
851
|
-
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32,
|
852
|
-
|
853
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
863
|
+
exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
|
864
|
+
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, u32)]
|
854
865
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
855
866
|
|
856
867
|
|
857
|
-
|
868
|
+
|
869
|
+
def fused_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution):
|
858
870
|
exp_avg, exp_avg_sq, grad, y = list_guard(exp_avg, exp_avg_sq, grad, y)
|
859
871
|
beta1, beta2, step, lr = scalar_guard(beta1, beta2, step, lr, exp_avg[0])
|
860
|
-
_fused_compilable_adopt_(y, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
872
|
+
_fused_compilable_adopt_(y, update, grad, exp_avg_sq, exp_avg, beta1, beta2, step, lr, eps, decay, caution)
|
861
873
|
|
862
874
|
|
863
875
|
@decorator_knowngood
|
@@ -868,21 +880,21 @@ def _compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
|
868
880
|
beta1 = beta_debias(beta1, step)
|
869
881
|
denom = torch._foreach_sqrt(exp_avg_sq32)
|
870
882
|
[denom.clamp_(min=1e-8) for denom in denom]
|
871
|
-
torch.
|
883
|
+
exp_avg32 = torch._foreach_mul(exp_avg32, beta1)
|
872
884
|
[ea32.addcdiv_(g, d, value=1 - beta1) for ea32, g, d in zip(exp_avg32, g32, denom)]
|
885
|
+
copy_stochastic_list_(exp_avg, exp_avg32)
|
873
886
|
|
874
887
|
beta2 = beta_debias(beta2, step + 1)
|
875
|
-
torch.
|
888
|
+
exp_avg_sq32 = torch._foreach_mul(exp_avg_sq32, beta2)
|
876
889
|
[eas32.addcmul_(g, g, value=1 - beta2) for eas32, g in zip(exp_avg_sq32, g32)]
|
877
|
-
|
878
|
-
copy_stochastic_list_(exp_avg, exp_avg32)
|
879
890
|
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
891
|
+
|
880
892
|
copy_stochastic_list_(grad, update)
|
881
893
|
|
882
894
|
|
883
895
|
def adopt(grad, exp_avg_sq, exp_avg, beta1, beta2, step):
|
884
|
-
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad
|
885
|
-
beta1, beta2, step = scalar_guard(beta1, beta2, step,
|
896
|
+
exp_avg, exp_avg_sq, grad = list_guard(exp_avg, exp_avg_sq, grad)
|
897
|
+
beta1, beta2, step = scalar_guard(beta1, beta2, step, exp_avg[0])
|
886
898
|
_compilable_adopt_(grad, exp_avg_sq, exp_avg, beta1, beta2, step)
|
887
899
|
return grad
|
888
900
|
|
@@ -927,7 +939,7 @@ def _compilable_update_(p: List[Tensor], u: List[Tensor], decay: Tensor, add_fn:
|
|
927
939
|
|
928
940
|
for p32_, u32_, g_ in zip(p32, u32, g): # lr is data-dependent -> can't compile a foreach
|
929
941
|
if caution:
|
930
|
-
|
942
|
+
u32_ = _compilable_cautioning(promote(g_), u32_)
|
931
943
|
add_fn(p32_, u32_, lr)
|
932
944
|
|
933
945
|
copy_stochastic_list_(p, p32)
|
@@ -1243,7 +1255,7 @@ def psgd_should_update(group, prob: Union[float, callable], rng: Optional[random
|
|
1243
1255
|
prob = prob(group[f'{name}_prob_step'])
|
1244
1256
|
if group['stochastic_schedule']:
|
1245
1257
|
return rng.random() < prob
|
1246
|
-
cumulative_prob =
|
1258
|
+
cumulative_prob = group.get(name, 0)
|
1247
1259
|
group[name] = cumulative_prob + prob
|
1248
1260
|
return int(group[name]) > int(cumulative_prob)
|
1249
1261
|
|
@@ -1304,15 +1316,16 @@ def mars_correction(g, old_g, beta1, gamma):
|
|
1304
1316
|
|
1305
1317
|
|
1306
1318
|
@decorator_knowngood
|
1307
|
-
def
|
1308
|
-
mask = (
|
1309
|
-
update.
|
1310
|
-
scale = mask.numel() / mask.sum().clamp(min=1)
|
1319
|
+
def _compilable_cautioning(g: Tensor, update: Tensor):
|
1320
|
+
mask = g.signbit() ^ update.signbit() # "Mask if they point in different directions"
|
1321
|
+
update = update.masked_fill(mask, 0)
|
1322
|
+
scale = mask.numel() / (mask.numel() - mask.sum()).clamp(min=1)
|
1311
1323
|
update.mul_(scale)
|
1324
|
+
return update
|
1312
1325
|
|
1313
1326
|
|
1314
1327
|
def caution(g, update):
|
1315
|
-
|
1328
|
+
return _compilable_cautioning(g, update)
|
1316
1329
|
|
1317
1330
|
|
1318
1331
|
def precond_update_prob_schedule(max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500):
|
@@ -11,6 +11,7 @@ from heavyball.utils import clean, set_torch
|
|
11
11
|
from torch import nn
|
12
12
|
from torch._dynamo import config
|
13
13
|
|
14
|
+
heavyball.utils.zeroth_power_mode = 'newtonschulz'
|
14
15
|
heavyball.utils.compile_mode = 'default'
|
15
16
|
config.cache_size_limit = 128
|
16
17
|
|
@@ -34,7 +35,7 @@ def test_foreach(opt, size, depth: int, iterations: int = 1024, outer_iterations
|
|
34
35
|
if is_channels_last:
|
35
36
|
model.to(memory_format=torch.channels_last)
|
36
37
|
|
37
|
-
o = get_optim(opt, model.parameters(), lr=1e-
|
38
|
+
o = get_optim(opt, model.parameters(), lr=1e-3, weight_decay=1e-4, warmup_steps=16)
|
38
39
|
|
39
40
|
for _ in range(iterations):
|
40
41
|
loss = model(torch.randn((1024, size, 4, 4), device='cuda')).square().mean()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|