heavyball 0.19.0__tar.gz → 0.21.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.19.0 → heavyball-0.21.0}/PKG-INFO +2 -2
- {heavyball-0.19.0 → heavyball-0.21.0}/README.md +1 -1
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/cached_delayed_psgd_kron.py +11 -11
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/cached_psgd_kron.py +13 -12
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/delayed_psgd.py +15 -18
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_soap.py +4 -7
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/p_adam.py +9 -9
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/palm_foreach_soap.py +6 -6
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/precond_schedule_foreach_soap.py +6 -10
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/precond_schedule_palm_foreach_soap.py +4 -4
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/precond_schedule_sfpsoap.py +20 -10
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/psgd_kron.py +15 -12
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/pure_psgd.py +3 -6
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/schedule_free_palm_foreach_soap.py +17 -8
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/utils.py +169 -58
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/PKG-INFO +2 -2
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/setup.py +1 -1
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_bf16_params.py +2 -1
- heavyball-0.21.0/test/test_ema.py +61 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/LICENSE +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/__init__.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/setup.cfg +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_bf16_q.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_bf16_storage.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_closure.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_foreach.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_memory.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_merge.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_no_grad.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_psgd.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_soap.py +0 -0
- {heavyball-0.19.0 → heavyball-0.21.0}/test/test_stochastic_updates.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.21.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-22, 0.
|
35
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -8,7 +8,7 @@ A simple package of efficient optimizers
|
|
8
8
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
9
9
|
largely static alternative to `torch.optim` with more and better optimizers.
|
10
10
|
|
11
|
-
Currently (2024-11-22, 0.
|
11
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
12
12
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
13
13
|
|
14
14
|
## Features
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line,
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedDelayedPSGDKron(PSGDBase):
|
@@ -41,7 +41,8 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
41
41
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
42
42
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
43
43
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
44
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
44
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
45
|
+
storage_dtype: str = 'float32', #
|
45
46
|
# expert parameters
|
46
47
|
precond_init_scale=1.0, precond_lr=0.1):
|
47
48
|
if not 0.0 <= lr:
|
@@ -58,7 +59,7 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
58
59
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
60
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
60
61
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
61
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
62
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
62
63
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
63
64
|
|
64
65
|
def _step(self, group):
|
@@ -74,14 +75,15 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
74
75
|
beta = group['beta']
|
75
76
|
store_triu_as_line = group['store_triu_as_line']
|
76
77
|
q_dtype = getattr(torch, group['q_dtype'])
|
78
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
79
|
|
78
80
|
vals = []
|
79
81
|
|
80
|
-
for p, g in split_p_and_g_in_group(group):
|
82
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
81
83
|
state = self.state_(p)
|
82
84
|
|
83
85
|
if 'Q' not in state:
|
84
|
-
state["exp_avg"] = torch.zeros_like(g)
|
86
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
85
87
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
86
88
|
memory_save_mode, dtype=q_dtype)
|
87
89
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -105,7 +107,9 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
105
107
|
|
106
108
|
group["step"] += 1
|
107
109
|
|
108
|
-
|
110
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
111
|
+
|
112
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
109
113
|
|
110
114
|
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
111
115
|
exp_avg_list)
|
@@ -127,8 +131,4 @@ class ForeachCachedDelayedPSGDKron(PSGDBase):
|
|
127
131
|
else:
|
128
132
|
torch.mul(q_.conj(), q_, out=c_)
|
129
133
|
|
130
|
-
|
131
|
-
grad_list = self.clip_fn(grad_list)
|
132
|
-
|
133
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
134
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
134
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, init_Q_exprs, trust_region_clip_, PSGDBase, split_p_and_g_in_group, \
|
12
|
-
line_to_triu, triu_to_line,
|
12
|
+
line_to_triu, triu_to_line, einsum_base, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachCachedPSGDKron(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
+
storage_dtype: str = 'float32', #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -56,7 +57,8 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
56
57
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype
|
60
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype,
|
61
|
+
storage_dtype=storage_dtype)
|
60
62
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
63
|
|
62
64
|
def _step(self, group):
|
@@ -71,15 +73,16 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
71
73
|
beta = group['beta']
|
72
74
|
store_triu_as_line = group['store_triu_as_line']
|
73
75
|
q_dtype = getattr(torch, group['q_dtype'])
|
76
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
74
77
|
should_update = self.should_update(group)
|
75
78
|
|
76
79
|
vals = []
|
77
80
|
|
78
|
-
for p, g in split_p_and_g_in_group(group):
|
81
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
79
82
|
state = self.state_(p)
|
80
83
|
|
81
84
|
if 'Q' not in state:
|
82
|
-
state["exp_avg"] = torch.zeros_like(g)
|
85
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
83
86
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
87
|
memory_save_mode, dtype=q_dtype)
|
85
88
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -103,7 +106,9 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
103
106
|
|
104
107
|
group["step"] += 1
|
105
108
|
|
106
|
-
|
109
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta_debias(beta, group['step']))
|
110
|
+
|
111
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
107
112
|
|
108
113
|
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
109
114
|
exp_avg_list)
|
@@ -123,9 +128,5 @@ class ForeachCachedPSGDKron(PSGDBase):
|
|
123
128
|
else:
|
124
129
|
torch.mul(q_.conj(), q_, out=c_)
|
125
130
|
|
126
|
-
|
127
|
-
|
128
|
-
grad_list = self.clip_fn(grad_list)
|
129
|
-
|
130
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
131
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
131
|
+
g = torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea)
|
132
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
@@ -5,10 +5,10 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from heavyball.utils import copy_stochastic_list_
|
9
8
|
|
9
|
+
from heavyball.utils import stochastic_lerp_, beta_debias
|
10
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
11
|
-
|
11
|
+
split_p_and_g_in_group, triu_to_line, line_to_triu, promote
|
12
12
|
|
13
13
|
|
14
14
|
class ForeachDelayedPSGD(PSGDBase):
|
@@ -38,8 +38,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True,
|
42
|
-
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True, foreach: bool = True,
|
42
|
+
q_dtype='float32', stochastic_schedule: bool = True, storage_dtype:str='float32', #
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -55,12 +55,10 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
55
55
|
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
56
56
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
|
-
precond_init_scale=precond_init_scale,
|
59
|
-
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
58
|
+
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
61
60
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
61
|
|
63
|
-
|
64
62
|
def _step(self, group):
|
65
63
|
should_update = self.should_update(group)
|
66
64
|
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
@@ -74,14 +72,15 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
74
72
|
beta = group['beta']
|
75
73
|
store_triu_as_line = group['store_triu_as_line']
|
76
74
|
q_dtype = getattr(torch, group['q_dtype'])
|
75
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
77
76
|
|
78
77
|
vals = []
|
79
78
|
|
80
|
-
for p, g in split_p_and_g_in_group(group):
|
79
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
81
80
|
state = self.state_(p)
|
82
81
|
|
83
82
|
if 'Q' not in state:
|
84
|
-
state["exp_avg"] = torch.zeros_like(g)
|
83
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
85
84
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
86
85
|
memory_save_mode, dtype=q_dtype)
|
87
86
|
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -96,7 +95,9 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
96
95
|
|
97
96
|
group["step"] += 1
|
98
97
|
|
99
|
-
|
98
|
+
stochastic_lerp_(exp_avg_list, grad_list, beta_debias(beta, group["step"]))
|
99
|
+
|
100
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
100
101
|
|
101
102
|
Q_list, exp_avg_list = list(Q_list), list(exp_avg_list)
|
102
103
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
@@ -106,10 +107,6 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
106
107
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
107
108
|
if should_update:
|
108
109
|
q32 = [promote(q_) for q_ in q]
|
109
|
-
self.do_update(group,[p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
110
|
-
|
111
|
-
|
112
|
-
grad_list = self.clip_fn(grad_list)
|
113
|
-
|
114
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
115
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
110
|
+
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
111
|
+
store_triu_as_line)
|
112
|
+
update_param_([p], self.clip_fn([new]), lr, weight_decay)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
split_p_and_g_in_group, StatefulOptimizer
|
4
|
+
split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
5
5
|
|
6
6
|
|
7
7
|
class ForeachSOAP(StatefulOptimizer):
|
@@ -26,8 +26,7 @@ class ForeachSOAP(StatefulOptimizer):
|
|
26
26
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
27
27
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
28
28
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
29
|
-
split: bool = False,
|
30
|
-
foreach: bool = True):
|
29
|
+
split: bool = False, foreach: bool = True):
|
31
30
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
32
31
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
33
32
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
@@ -65,14 +64,12 @@ class ForeachSOAP(StatefulOptimizer):
|
|
65
64
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
66
65
|
beta1, beta2 = group["betas"]
|
67
66
|
|
68
|
-
old_debiased1 = beta_debias(beta1, step)
|
69
67
|
old_debiased2 = beta_debias(beta2, step)
|
70
68
|
|
71
69
|
# Decay the first and second moment running average coefficient
|
72
70
|
# In-place operations to update the averages at the same time
|
73
|
-
torch.
|
74
|
-
|
75
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
71
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
72
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
76
73
|
|
77
74
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
78
75
|
state = self.state_(p)
|
@@ -39,7 +39,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
39
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
40
40
|
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
41
|
store_triu_as_line: bool = True, foreach: bool = True, q_dtype='float32',
|
42
|
-
stochastic_schedule: bool = True,
|
42
|
+
stochastic_schedule: bool = True, storage_dtype:str ='float32',#
|
43
43
|
# expert parameters
|
44
44
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
45
|
if not 0.0 <= lr:
|
@@ -57,7 +57,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
57
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
58
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, beta=beta,
|
59
59
|
beta2_scale=beta2_scale, merge_dims=merge_dims, split=split,
|
60
|
-
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
61
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
62
62
|
|
63
63
|
def _step(self, group):
|
@@ -71,15 +71,16 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
71
71
|
lr = group['lr']
|
72
72
|
store_triu_as_line = group['store_triu_as_line']
|
73
73
|
q_dtype = getattr(torch, group['q_dtype'])
|
74
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
74
75
|
|
75
76
|
vals = []
|
76
77
|
|
77
|
-
for p, g in split_p_and_g_in_group(group):
|
78
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
78
79
|
state = self.state_(p)
|
79
80
|
|
80
81
|
if 'Q' not in state:
|
81
|
-
state['exp_avg'] = torch.zeros_like(g)
|
82
|
-
state['exp_avg_sq'] = torch.zeros_like(g)
|
82
|
+
state['exp_avg'] = torch.zeros_like(g, dtype=storage_dtype)
|
83
|
+
state['exp_avg_sq'] = torch.zeros_like(g, dtype=storage_dtype)
|
83
84
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
85
|
memory_save_mode, dtype=q_dtype)
|
85
86
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -103,6 +104,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
103
104
|
|
104
105
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
105
106
|
|
107
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
108
|
+
|
106
109
|
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
107
110
|
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
108
111
|
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
@@ -112,8 +115,5 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
112
115
|
divide by g here, because g == denom (from exp_avg_sq_(out=g)), avoids denom allocation
|
113
116
|
divide into g so we can deallocate ea, avoids one allocation (-> less memory than equivalent foreach)
|
114
117
|
"""
|
118
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
115
119
|
|
116
|
-
grad_list = self.clip_fn(grad_list)
|
117
|
-
|
118
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
119
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
|
3
3
|
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_sq_, update_param_, set_, \
|
4
|
-
split_p_and_g_in_group, StatefulOptimizer
|
4
|
+
split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
5
|
+
|
5
6
|
|
6
7
|
|
7
8
|
class PaLMForeachSOAP(StatefulOptimizer):
|
@@ -32,8 +33,7 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
32
33
|
max_precond_dim: int = 2048, #
|
33
34
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
34
35
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
35
|
-
beta2_scale: float = 0.8, split: bool = False,
|
36
|
-
foreach: bool = True):
|
36
|
+
beta2_scale: float = 0.8, split: bool = False, foreach: bool = True):
|
37
37
|
if betas[0] is not None:
|
38
38
|
beta = betas[0]
|
39
39
|
defaults = {"lr": lr, "beta": beta, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
@@ -75,13 +75,13 @@ class PaLMForeachSOAP(StatefulOptimizer):
|
|
75
75
|
beta1 = group["beta"]
|
76
76
|
|
77
77
|
beta2 = 1 - step ** -group['beta2_scale']
|
78
|
-
old_debiased1 = beta_debias(beta1, step)
|
79
78
|
old_debiased2 = beta_debias(beta2, step)
|
80
79
|
|
81
80
|
# Decay the first and second moment running average coefficient
|
82
81
|
# In-place operations to update the averages at the same time
|
83
|
-
torch.
|
84
|
-
|
82
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
83
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
84
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
85
85
|
|
86
86
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
87
87
|
state = self.state_(p)
|
@@ -2,8 +2,8 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias,
|
6
|
-
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, update_param_, \
|
6
|
+
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer, exp_avg_
|
7
7
|
|
8
8
|
|
9
9
|
class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
@@ -27,8 +27,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
27
27
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
28
28
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
29
29
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1,
|
30
|
-
precond_scheduler=(1 / 3, 9), split: bool = False,
|
31
|
-
foreach: bool = True):
|
30
|
+
precond_scheduler=(1 / 3, 9), split: bool = False, foreach: bool = True):
|
32
31
|
defaults = {"lr": lr, "betas": betas, "shampoo_beta": shampoo_beta, "eps": eps, "weight_decay": weight_decay,
|
33
32
|
"precondition_frequency": precondition_frequency, "max_precond_dim": max_precond_dim,
|
34
33
|
"merge_dims": merge_dims, "precondition_1d": precondition_1d, "normalize_grads": normalize_grads,
|
@@ -68,14 +67,12 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
68
67
|
p_list, grad, grad_projected, exp_avg, exp_avg_sq = zip(*vals)
|
69
68
|
beta1, beta2 = group["betas"]
|
70
69
|
|
71
|
-
old_debiased1 = beta_debias(beta1, step)
|
72
70
|
old_debiased2 = beta_debias(beta2, step)
|
73
71
|
|
74
72
|
# Decay the first and second moment running average coefficient
|
75
73
|
# In-place operations to update the averages at the same time
|
76
|
-
torch.
|
77
|
-
|
78
|
-
denom = exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, group['eps'])
|
74
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
75
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
79
76
|
|
80
77
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
81
78
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
@@ -89,8 +86,7 @@ class PrecondScheduleForeachSOAP(StatefulOptimizer):
|
|
89
86
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
90
87
|
set_(d, project(exp_avg_projected / d, state['Q'], True))
|
91
88
|
|
92
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
93
|
-
update_precond)
|
89
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
94
90
|
|
95
91
|
# Why does this have to be rebiased here?
|
96
92
|
step_size = -group["lr"] * min(step / group['warmup_steps'], 1)
|
@@ -2,7 +2,7 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, beta_debias,
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, beta_debias, exp_avg_, update_param_, \
|
6
6
|
precond_schedule, set_, split_p_and_g_in_group, StatefulOptimizer
|
7
7
|
|
8
8
|
|
@@ -81,9 +81,9 @@ class PrecondSchedulePaLMForeachSOAP(StatefulOptimizer):
|
|
81
81
|
|
82
82
|
# Decay the first and second moment running average coefficient
|
83
83
|
# In-place operations to update the averages at the same time
|
84
|
-
torch.
|
85
|
-
torch.
|
86
|
-
denom =
|
84
|
+
beta2 = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(beta2)
|
85
|
+
step_tensor = torch.empty((), dtype=torch.int32, device=p_list[0].device).fill_(step)
|
86
|
+
denom = exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step_tensor)
|
87
87
|
|
88
88
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
89
89
|
for p, g, ea, d in zip(p_list, grad, exp_avg, denom):
|
@@ -2,8 +2,19 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
|
6
|
-
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, precond_schedule, split_p_and_g_in_group, copy_stochastic_list_, \
|
7
|
+
promote
|
8
|
+
|
9
|
+
|
10
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
11
|
+
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
12
|
+
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
13
|
+
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
14
|
+
torch._foreach_div_(gp32, denom)
|
15
|
+
|
16
|
+
copy_stochastic_list_(exp_avg_sq, eas32)
|
17
|
+
copy_stochastic_list_(grad_projected, gp32)
|
7
18
|
|
8
19
|
|
9
20
|
class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
@@ -40,8 +51,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
40
51
|
weight_decay: float = 0.01, precondition_frequency: int = 2, max_precond_dim: int = 2048, #
|
41
52
|
merge_dims: bool = True, precondition_1d: bool = False, normalize_grads: bool = False,
|
42
53
|
data_format: str = "channels_first", correct_bias: bool = True, warmup_steps: int = 1, r=0.0,
|
43
|
-
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9),
|
44
|
-
|
54
|
+
weight_lr_power=2.0, gradient_clip_val: float = 0.1, precond_scheduler=(1 / 3, 9), betas=(None, None),
|
55
|
+
split: bool = False, foreach: bool = True):
|
45
56
|
if betas[0] is not None:
|
46
57
|
beta = betas[0]
|
47
58
|
defaults = {"lr": lr, "beta": beta, "beta2_scale": beta2_scale, "eps": eps, "weight_decay": weight_decay,
|
@@ -103,8 +114,8 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
103
114
|
|
104
115
|
# Decay the first and second moment running average coefficient
|
105
116
|
# In-place operations to update the averages at the same time
|
106
|
-
|
107
|
-
|
117
|
+
old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(old_debiased2)
|
118
|
+
_compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
|
108
119
|
|
109
120
|
update_precond = precond_schedule(step, group['precond_scheduler'], self.rng)
|
110
121
|
|
@@ -114,13 +125,12 @@ class PrecondScheduleSFPaLMSOAP(ScheduleFree):
|
|
114
125
|
# to the original space
|
115
126
|
set_(gp, project(gp, state['Q'], back=True))
|
116
127
|
|
117
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2,
|
118
|
-
update_precond)
|
128
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
|
119
129
|
|
120
130
|
# Weight decay calculated at y
|
121
131
|
if group["weight_decay"] > 0:
|
122
132
|
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
123
133
|
|
124
134
|
lr = warmup(group['lr'], step, group['warmup_steps'])
|
125
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
126
|
-
|
135
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
|
136
|
+
z, grad_projected, group['r'], step)
|
@@ -9,7 +9,7 @@ from typing import Optional
|
|
9
9
|
import torch
|
10
10
|
|
11
11
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
12
|
-
split_p_and_g_in_group, line_to_triu, triu_to_line,
|
12
|
+
split_p_and_g_in_group, line_to_triu, triu_to_line, promote, stochastic_lerp_, beta_debias
|
13
13
|
|
14
14
|
|
15
15
|
class ForeachPSGDKron(PSGDBase):
|
@@ -39,7 +39,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
41
|
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True,
|
42
|
-
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
42
|
+
foreach: bool = True, q_dtype='float32', stochastic_schedule: bool = True,
|
43
|
+
storage_dtype: str = 'float32', #
|
43
44
|
# expert parameters
|
44
45
|
precond_init_scale=1.0, precond_lr=0.1):
|
45
46
|
if not 0.0 <= lr:
|
@@ -56,7 +57,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
56
57
|
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
57
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=precond_lr,
|
58
59
|
precond_init_scale=precond_init_scale, step=0, warmup_steps=warmup_steps, merge_dims=merge_dims,
|
59
|
-
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype)
|
60
|
+
split=split, store_triu_as_line=store_triu_as_line, q_dtype=q_dtype, storage_dtype=storage_dtype)
|
60
61
|
super().__init__(params, defaults, foreach, stochastic_schedule, clip_fn, preconditioner_update_probability)
|
61
62
|
|
62
63
|
def _step(self, group):
|
@@ -72,14 +73,15 @@ class ForeachPSGDKron(PSGDBase):
|
|
72
73
|
beta = group['beta']
|
73
74
|
store_triu_as_line = group['store_triu_as_line']
|
74
75
|
q_dtype = getattr(torch, group['q_dtype'])
|
76
|
+
storage_dtype = getattr(torch, group['storage_dtype'])
|
75
77
|
|
76
78
|
vals = []
|
77
79
|
|
78
|
-
for p, g in split_p_and_g_in_group(group):
|
80
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
79
81
|
state = self.state_(p)
|
80
82
|
|
81
83
|
if 'Q' not in state:
|
82
|
-
state["exp_avg"] = torch.zeros_like(g)
|
84
|
+
state["exp_avg"] = torch.zeros_like(g, dtype=storage_dtype)
|
83
85
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
84
86
|
memory_save_mode, dtype=q_dtype)
|
85
87
|
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
@@ -94,9 +96,14 @@ class ForeachPSGDKron(PSGDBase):
|
|
94
96
|
|
95
97
|
group["step"] += 1
|
96
98
|
|
97
|
-
|
99
|
+
beta = beta_debias(beta, group["step"])
|
100
|
+
beta = torch.empty((), dtype=torch.float32, device=grad_list[0].device).fill_(1 - beta)
|
101
|
+
stochastic_lerp_(exp_avg_list, grad_list, 1 - beta)
|
98
102
|
|
99
103
|
grad_list, Q_list, exp_avg_list = list(grad_list), list(Q_list), list(exp_avg_list)
|
104
|
+
|
105
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
106
|
+
|
100
107
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
101
108
|
q_orig = Q_list.pop(0)
|
102
109
|
ea = exp_avg_list.pop(0)
|
@@ -106,9 +113,5 @@ class ForeachPSGDKron(PSGDBase):
|
|
106
113
|
q32 = [promote(q_) for q_ in q]
|
107
114
|
self.do_update(group, [p], [ea if momentum_into_precond_update else g], [q32], precond_lr, [q_orig],
|
108
115
|
store_triu_as_line)
|
109
|
-
|
110
|
-
|
111
|
-
grad_list = self.clip_fn(grad_list)
|
112
|
-
|
113
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
114
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
116
|
+
g = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
117
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
@@ -70,7 +70,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
70
70
|
|
71
71
|
vals = []
|
72
72
|
|
73
|
-
for p, g in split_p_and_g_in_group(group):
|
73
|
+
for p, g in split_p_and_g_in_group(group, should_promote=False):
|
74
74
|
state = self.state_(p)
|
75
75
|
|
76
76
|
if 'Q' not in state:
|
@@ -89,6 +89,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
89
89
|
group["step"] += 1
|
90
90
|
|
91
91
|
Q_list = list(Q_list)
|
92
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
92
93
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
93
94
|
q_orig = Q_list.pop(0)
|
94
95
|
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
@@ -97,8 +98,4 @@ class ForeachPurePSGD(PSGDBase):
|
|
97
98
|
q32 = [promote(q_) for q_ in q]
|
98
99
|
self.do_update(group, [p], [g], [q32], precond_lr, [q_orig], store_triu_as_line)
|
99
100
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
100
|
-
|
101
|
-
grad_list = self.clip_fn(grad_list)
|
102
|
-
|
103
|
-
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
104
|
-
update_param_(p_list, grad_list, lr, weight_decay)
|
101
|
+
update_param_([p], self.clip_fn([g]), lr, weight_decay)
|
@@ -2,8 +2,18 @@ import random
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, \
|
6
|
-
|
5
|
+
from .utils import init_preconditioner, update_preconditioner, project, set_, adaptive_gradient_clipping_, exp_avg_sq_, \
|
6
|
+
beta_debias, schedule_free_, warmup, ScheduleFree, split_p_and_g_in_group, copy_stochastic_list_, promote
|
7
|
+
|
8
|
+
|
9
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
10
|
+
def _compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased2, eps):
|
11
|
+
eas32, gp32 = [list(map(promote, x)) for x in (exp_avg_sq, grad_projected)]
|
12
|
+
denom = exp_avg_sq_(eas32, gp32, old_debiased2, eps)
|
13
|
+
torch._foreach_div_(gp32, denom)
|
14
|
+
|
15
|
+
copy_stochastic_list_(exp_avg_sq, eas32)
|
16
|
+
copy_stochastic_list_(grad_projected, gp32)
|
7
17
|
|
8
18
|
|
9
19
|
class SFPaLMForeachSOAP(ScheduleFree):
|
@@ -95,8 +105,8 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
95
105
|
|
96
106
|
# Decay the first and second moment running average coefficient
|
97
107
|
# In-place operations to update the averages at the same time
|
98
|
-
|
99
|
-
|
108
|
+
old_debiased_tensor = torch.empty((), dtype=torch.float32, device=p_list[0].device).fill_(new_debiased2)
|
109
|
+
_compilable_exp_avg_sq_(exp_avg_sq, grad_projected, old_debiased_tensor, group["eps"])
|
100
110
|
|
101
111
|
update_precond = group['step'] > 0 and group['step'] % group['precondition_frequency'] == 0
|
102
112
|
|
@@ -107,13 +117,12 @@ class SFPaLMForeachSOAP(ScheduleFree):
|
|
107
117
|
# CANT DO /= HERE AS EXP_AVG MAY POINT TO THE BUFFER
|
108
118
|
set_(gp, project(gp, state['Q'], back=True))
|
109
119
|
|
110
|
-
update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2,
|
111
|
-
update_precond)
|
120
|
+
update_preconditioner(g, state, max_precond_dim, precondition_1d, 1 - new_debiased2, update_precond)
|
112
121
|
|
113
122
|
# Weight decay calculated at y
|
114
123
|
if group["weight_decay"] > 0:
|
115
124
|
torch._foreach_add_(grad, p_list, alpha=group["weight_decay"])
|
116
125
|
|
117
126
|
lr = warmup(group['lr'], step, group['warmup_steps'])
|
118
|
-
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'],
|
119
|
-
|
127
|
+
group['weight_sum'] = schedule_free_(lr, group['weight_lr_power'], group['weight_sum'], group['beta'], p_list,
|
128
|
+
z, grad_projected, group['r'], step)
|
@@ -3,7 +3,7 @@ import gc
|
|
3
3
|
import math
|
4
4
|
import random
|
5
5
|
import string
|
6
|
-
from typing import List, Optional, Tuple, Callable
|
6
|
+
from typing import List, Optional, Tuple, Callable, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import torch
|
@@ -141,6 +141,7 @@ def beta_debias(beta, step):
|
|
141
141
|
return 1 - (1 - beta) / (1 - beta ** step)
|
142
142
|
|
143
143
|
|
144
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
144
145
|
def exp_avg_sq_(state, grad, beta2, eps, out=None):
|
145
146
|
if isinstance(state, torch.Tensor):
|
146
147
|
state.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
@@ -327,6 +328,36 @@ def get_orthogonal_matrix(mat):
|
|
327
328
|
return final
|
328
329
|
|
329
330
|
|
331
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
332
|
+
def _compilable_stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
333
|
+
for x_, y_ in zip(x, y):
|
334
|
+
x32 = promote(x_)
|
335
|
+
y32 = promote(y_)
|
336
|
+
x32.lerp_(y32, a)
|
337
|
+
copy_stochastic_(x_, x32)
|
338
|
+
|
339
|
+
|
340
|
+
def stochastic_lerp_(x: List[torch.Tensor], y: List[torch.Tensor], a: Union[float, int, torch.Tensor]):
|
341
|
+
if not isinstance(a, torch.Tensor):
|
342
|
+
a = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(a)
|
343
|
+
_compilable_stochastic_lerp_(x, y, a)
|
344
|
+
|
345
|
+
|
346
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
347
|
+
def _compilable_stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
348
|
+
for x_, y_ in zip(x, y):
|
349
|
+
x32 = promote(x_)
|
350
|
+
y32 = promote(y_)
|
351
|
+
x32.add_(y32, alpha=alpha)
|
352
|
+
copy_stochastic_(x_, x32)
|
353
|
+
|
354
|
+
|
355
|
+
def stochastic_add_(x: List[torch.Tensor], y: List[torch.Tensor], alpha: Union[float, int, torch.Tensor]):
|
356
|
+
if not isinstance(alpha, torch.Tensor):
|
357
|
+
alpha = torch.empty((), dtype=torch.float32, device=x[0].device).fill_(alpha)
|
358
|
+
_compilable_stochastic_add_(x, y, alpha)
|
359
|
+
|
360
|
+
|
330
361
|
@decorator
|
331
362
|
def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
332
363
|
if grad.dim() == 1 and (not precondition_1d or grad.shape[0] > max_precond_dim):
|
@@ -409,9 +440,12 @@ def project(grad, Q, back: bool):
|
|
409
440
|
|
410
441
|
|
411
442
|
class StatefulOptimizer(torch.optim.Optimizer):
|
412
|
-
|
443
|
+
ema_decay: float = 0.001
|
444
|
+
|
445
|
+
def __init__(self, params, defaults, foreach: bool = True, use_ema: bool = False):
|
413
446
|
super().__init__(params, {**defaults, 'foreach': foreach})
|
414
447
|
self.fake_groups = {}
|
448
|
+
self.use_ema = use_ema
|
415
449
|
|
416
450
|
def key(self, param: torch.Tensor):
|
417
451
|
return (param.data_ptr(), tuple(param.shape))
|
@@ -445,6 +479,54 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
445
479
|
def _step(self, group):
|
446
480
|
raise NotImplementedError
|
447
481
|
|
482
|
+
def ema_update(self):
|
483
|
+
with torch.no_grad():
|
484
|
+
for top_group in self.param_groups:
|
485
|
+
for group in self.get_groups(top_group):
|
486
|
+
active_p = [p for p in group['params']]
|
487
|
+
|
488
|
+
if not active_p:
|
489
|
+
return
|
490
|
+
|
491
|
+
k = group['ema_step'] = group.get('ema_step', -1) + 1
|
492
|
+
|
493
|
+
for p in active_p:
|
494
|
+
if 'param_ema' not in self.state_(p):
|
495
|
+
self.state_(p)['param_ema'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
|
496
|
+
|
497
|
+
y, param_ema = zip(*[(p.data, self.state_(p)['param_ema']) for p in active_p])
|
498
|
+
torch._foreach_lerp_(param_ema, y, weight=beta_debias(1 - self.ema_decay, k + 1))
|
499
|
+
|
500
|
+
def copy_emas_to_params(self):
|
501
|
+
with torch.no_grad():
|
502
|
+
for top_group in self.param_groups:
|
503
|
+
for group in self.get_groups(top_group):
|
504
|
+
active_p = [p for p in group['params']]
|
505
|
+
|
506
|
+
if not active_p:
|
507
|
+
return
|
508
|
+
|
509
|
+
for p in active_p:
|
510
|
+
if 'param_ema' in self.state_(p):
|
511
|
+
p_clone = p.data.clone()
|
512
|
+
set_(p.data, self.state_(p)['param_ema'])
|
513
|
+
set_(self.state_(p)['param_ema'], p_clone)
|
514
|
+
|
515
|
+
def copy_params_to_emas(self):
|
516
|
+
with torch.no_grad():
|
517
|
+
for top_group in self.param_groups:
|
518
|
+
for group in self.get_groups(top_group):
|
519
|
+
active_p = [p for p in group['params']]
|
520
|
+
|
521
|
+
if not active_p:
|
522
|
+
return
|
523
|
+
|
524
|
+
for p in active_p:
|
525
|
+
if 'param_ema' in self.state_(p):
|
526
|
+
ema_clone = self.state_(p)['param_ema'].data.clone()
|
527
|
+
set_(self.state_(p)['param_ema'], p.data)
|
528
|
+
set_(p.data, ema_clone)
|
529
|
+
|
448
530
|
def step(self, closure: Optional[Callable] = None):
|
449
531
|
if closure is None:
|
450
532
|
loss = None
|
@@ -455,6 +537,8 @@ class StatefulOptimizer(torch.optim.Optimizer):
|
|
455
537
|
for top_group in self.param_groups:
|
456
538
|
for group in self.get_groups(top_group):
|
457
539
|
self._step(group)
|
540
|
+
if self.use_ema:
|
541
|
+
self.ema_update(group)
|
458
542
|
return loss
|
459
543
|
|
460
544
|
|
@@ -497,6 +581,32 @@ def copy_stochastic_list_(target: List[torch.Tensor], source: List[torch.Tensor]
|
|
497
581
|
copy_stochastic_(t, s)
|
498
582
|
|
499
583
|
|
584
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
585
|
+
def _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step):
|
586
|
+
beta1 = beta_debias(beta1, step)
|
587
|
+
beta2 = beta_debias(beta2, step)
|
588
|
+
|
589
|
+
g32, gp32, exp_avg_sq32 = [list(map(promote, x)) for x in [grad, grad_projected, exp_avg_sq]]
|
590
|
+
|
591
|
+
stochastic_lerp_(exp_avg, g32, 1 - beta1)
|
592
|
+
denom = exp_avg_sq_(exp_avg_sq32, gp32, beta2, 1e-8)
|
593
|
+
|
594
|
+
copy_stochastic_list_(exp_avg_sq, exp_avg_sq32)
|
595
|
+
return denom
|
596
|
+
|
597
|
+
|
598
|
+
def exp_avg_(exp_avg: List[torch.Tensor], exp_avg_sq: List[torch.Tensor], grad: List[torch.Tensor],
|
599
|
+
grad_projected: List[torch.Tensor], beta1: float, beta2: float, step: int):
|
600
|
+
if isinstance(beta1, float):
|
601
|
+
beta1 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta1)
|
602
|
+
if isinstance(beta2, float):
|
603
|
+
beta2 = torch.empty((), dtype=torch.float32, device=exp_avg[0].device).fill_(beta2)
|
604
|
+
if isinstance(step, int):
|
605
|
+
step = torch.empty((), dtype=torch.int32, device=exp_avg[0].device).fill_(step)
|
606
|
+
denom = _compilable_exp_avg_(exp_avg, exp_avg_sq, grad, grad_projected, beta1, beta2, step)
|
607
|
+
return denom
|
608
|
+
|
609
|
+
|
500
610
|
# this can be dynamic for most optimizers - just not for PSGD. So, it's disabled for all
|
501
611
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True)
|
502
612
|
def _compilable_copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
@@ -523,23 +633,26 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
|
523
633
|
|
524
634
|
|
525
635
|
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
526
|
-
def
|
527
|
-
|
528
|
-
u32 =
|
636
|
+
def _compilable_update_(p, u, decay, add_fn, lr):
|
637
|
+
u = [u_.view_as(p_) for u_, p_ in zip(u, p)]
|
638
|
+
p32, u32 = [list(map(promote, x)) for x in [p, u]]
|
639
|
+
|
529
640
|
if decay > 0:
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
641
|
+
torch._foreach_mul_(p32, 1 - decay * lr)
|
642
|
+
|
643
|
+
for p32_, u32_ in zip(p32, u32): # lr is data-dependent -> can't compile a foreach
|
644
|
+
if add_fn is None:
|
645
|
+
p32_.add_(u32_, alpha=lr)
|
646
|
+
else:
|
647
|
+
add_fn(p32_, u32_, lr)
|
648
|
+
|
649
|
+
copy_stochastic_list_(p, p32)
|
536
650
|
|
537
651
|
|
538
652
|
def update_param_(param: List[torch.Tensor], update: List[torch.Tensor], lr: float, decay: float,
|
539
653
|
add_fn: callable = None):
|
540
654
|
lr_tensor = torch.empty((), dtype=torch.float32, device=param[0].device).fill_(lr)
|
541
|
-
|
542
|
-
_compilable_update_one_(p, u, decay, add_fn, lr_tensor)
|
655
|
+
_compilable_update_(param, update, decay, add_fn, lr_tensor)
|
543
656
|
|
544
657
|
|
545
658
|
def precond_schedule(step, precond_scheduler, rng):
|
@@ -638,12 +751,13 @@ def psgd_balance_Q(Q_in):
|
|
638
751
|
torch._foreach_mul_(Q_in, list(norms))
|
639
752
|
|
640
753
|
|
641
|
-
def psgd_calc_A_and_conjB(exprA, G, Q
|
642
|
-
md = min_dtype(Q)
|
643
|
-
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md))
|
754
|
+
def psgd_calc_A_and_conjB(exprA, G, Q):
|
755
|
+
md = min_dtype(Q + [G])
|
756
|
+
A = torch.einsum(exprA, *[q.to(md) for q in Q], G.to(md)).to(G.dtype)
|
644
757
|
order = G.dim()
|
645
758
|
p = list(range(order))
|
646
|
-
conjB = torch.
|
759
|
+
conjB = torch.randn(G.shape[1:] + G.shape[:1], dtype=promote(G.dtype), device=G.device)
|
760
|
+
Q = [promote(q) for q in Q]
|
647
761
|
for i, q in enumerate(Q):
|
648
762
|
if q.dim() <= 1:
|
649
763
|
conjB /= q
|
@@ -651,7 +765,7 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
|
651
765
|
unsqueeze = conjB.dim() <= 1
|
652
766
|
if unsqueeze:
|
653
767
|
conjB = conjB.unsqueeze(0)
|
654
|
-
conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False
|
768
|
+
conjB = torch.linalg.solve_triangular(q, conjB, upper=True, left=False)
|
655
769
|
if unsqueeze:
|
656
770
|
conjB = conjB.squeeze(0)
|
657
771
|
if i < order - 1:
|
@@ -661,33 +775,29 @@ def psgd_calc_A_and_conjB(exprA, G, Q, V):
|
|
661
775
|
|
662
776
|
def psgd_lb(A, max_abs):
|
663
777
|
A /= max_abs
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
x
|
671
|
-
x =
|
672
|
-
|
673
|
-
x = torch.where(comp, x, x.T)
|
674
|
-
torch.matmul(x, torch.where(comp, A, A.T), out=x.view(1, -1))
|
675
|
-
x /= torch.linalg.vector_norm(x)
|
676
|
-
torch.matmul(x, torch.where(comp, ah, ah.T), out=x.view(1, -1))
|
677
|
-
x = torch.linalg.vector_norm(x)
|
778
|
+
a0 = torch.einsum('ij,ij->j', A, A)
|
779
|
+
i = torch.argmax(a0)
|
780
|
+
|
781
|
+
x = torch.index_select(A, 1, i).flatten().contiguous()
|
782
|
+
|
783
|
+
x = torch.einsum('i,ij->j', x, A)
|
784
|
+
x /= x.norm()
|
785
|
+
x = torch.einsum('j,kj->k', x, A)
|
786
|
+
x = x.norm()
|
678
787
|
x *= max_abs
|
679
788
|
return x
|
680
789
|
|
681
790
|
|
682
|
-
|
791
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
792
|
+
def psgd_update_precond(Q, exprs, G, precond_lr, tiny, oq, store_triu_as_line):
|
683
793
|
"""Update Kronecker product preconditioner Q with pair (V, G)."""
|
684
794
|
exprA, exprGs, _ = exprs
|
685
795
|
|
686
|
-
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q
|
796
|
+
A, conjB = psgd_calc_A_and_conjB(exprA, G, Q)
|
687
797
|
|
688
|
-
for q, exprG in zip(Q, exprGs):
|
689
|
-
term1 = torch.einsum(exprG, A, A
|
690
|
-
term2 = torch.einsum(exprG, conjB
|
798
|
+
for q, exprG, o in zip(Q, exprGs, oq):
|
799
|
+
term1 = promote(torch.einsum(exprG, A, A))
|
800
|
+
term2 = promote(torch.einsum(exprG, conjB, conjB))
|
691
801
|
|
692
802
|
term2 += term1 # a + b
|
693
803
|
term1 *= 2 # 2a
|
@@ -696,18 +806,22 @@ def psgd_update_precond(Q, exprs, V, G, step, tiny):
|
|
696
806
|
else:
|
697
807
|
term1 = term1 - term2
|
698
808
|
|
699
|
-
term1 *=
|
809
|
+
term1 *= precond_lr
|
700
810
|
norm = term2.norm(float('inf'))
|
701
811
|
if q.dim() < 2:
|
702
|
-
term1 *= q
|
703
|
-
|
812
|
+
term1 *= q.to(term1.dtype)
|
813
|
+
term1 /= norm.clamp_(min=tiny)
|
704
814
|
else:
|
705
815
|
torch.triu(term1, out=term1)
|
706
|
-
term1 /=
|
707
|
-
|
816
|
+
term1 /= psgd_lb(term2, norm).clamp_(tiny)
|
817
|
+
torch.matmul(term1, q, out=term1)
|
818
|
+
if store_triu_as_line:
|
819
|
+
term1 = triu_to_line([term1])[0][1]
|
820
|
+
o = o[1]
|
821
|
+
stochastic_add_([o], [term1], -1)
|
708
822
|
|
709
823
|
|
710
|
-
@
|
824
|
+
@torch.compile(mode='max-autotune-no-cudagraphs', fullgraph=True, dynamic=True)
|
711
825
|
def psgd_precond_grad(Q, exprs, G, inplace: bool = False):
|
712
826
|
"""Precondition gradient G with preconditioner Q."""
|
713
827
|
md = min_dtype(Q)
|
@@ -838,18 +952,9 @@ class PSGDBase(StatefulOptimizer):
|
|
838
952
|
group[name] = cumulative_prob + prob
|
839
953
|
return int(group[name]) > int(cumulative_prob)
|
840
954
|
|
841
|
-
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q:
|
842
|
-
|
843
|
-
|
844
|
-
if store_triu_as_line:
|
845
|
-
update_fn = update_triu_
|
846
|
-
else:
|
847
|
-
update_fn = copy_stochastic_list_
|
848
|
-
else:
|
849
|
-
update_fn = lambda x, y: None
|
850
|
-
for i, (p, grad, Q, oq) in enumerate(zip(p_list, grad_list, q_list, original_q)):
|
851
|
-
psgd_update_precond(Q, self.state_(p)["exprs"], torch.randn_like(grad), grad, precond_lr, self._tiny)
|
852
|
-
update_fn(oq, Q)
|
955
|
+
def do_update(self, group, p_list, grad_list, q_list, precond_lr, original_q: List, store_triu_as_line=False):
|
956
|
+
for p, grad, Q, oq in zip(p_list, grad_list, q_list, original_q):
|
957
|
+
psgd_update_precond(Q, self.state_(p)["exprs"], grad, precond_lr, self._tiny, oq, store_triu_as_line)
|
853
958
|
|
854
959
|
if self.should_update(group, self.balance_probability, "balance_prob"):
|
855
960
|
for g, q in zip(grad_list, original_q if original_q else q_list):
|
@@ -896,13 +1001,19 @@ def merge_group(group, *tensors):
|
|
896
1001
|
return out
|
897
1002
|
|
898
1003
|
|
899
|
-
def split_p_and_g_in_group(group: dict, skip_none: bool = True):
|
1004
|
+
def split_p_and_g_in_group(group: dict, skip_none: bool = True, should_promote: bool = True):
|
900
1005
|
for p in group["params"]:
|
901
1006
|
if skip_none and p.grad is None:
|
902
1007
|
continue
|
903
1008
|
|
904
|
-
|
905
|
-
|
1009
|
+
if p.grad is None:
|
1010
|
+
grad = None
|
1011
|
+
else:
|
1012
|
+
if should_promote:
|
1013
|
+
grad = promote(p.grad)
|
1014
|
+
else:
|
1015
|
+
grad = p.grad
|
1016
|
+
p.grad = None
|
906
1017
|
|
907
1018
|
p_views = merge_group(group, p)
|
908
1019
|
if grad is not None:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: heavyball
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.21.0
|
4
4
|
Summary: Efficient optimizers
|
5
5
|
Home-page: https://github.com/clashluke/heavyball
|
6
6
|
Author: Lucas Nestler
|
@@ -32,7 +32,7 @@ A simple package of efficient optimizers
|
|
32
32
|
The goal is not to thrive for completeness, full maintenance or abstraction, but instead to provide a simple
|
33
33
|
largely static alternative to `torch.optim` with more and better optimizers.
|
34
34
|
|
35
|
-
Currently (2024-11-22, 0.
|
35
|
+
Currently (2024-11-22, 0.21.0), the recommended stable optimizer is `PrecondSchedulePaLMSOAP` (see below). The
|
36
36
|
recommended experimental optimizer is `DelayedPSGDKron` ([tuning guide](docs/psgd_efficiency.md)).
|
37
37
|
|
38
38
|
## Features
|
@@ -20,10 +20,11 @@ def get_memory():
|
|
20
20
|
|
21
21
|
@pytest.mark.parametrize("opt", heavyball.__all__)
|
22
22
|
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
23
|
-
def test_foreach(opt, size, depth: int, iterations: int =
|
23
|
+
def test_foreach(opt, size, depth: int, iterations: int = 512, outer_iterations: int = 3):
|
24
24
|
set_torch()
|
25
25
|
opt = getattr(heavyball, opt)
|
26
26
|
|
27
|
+
|
27
28
|
peaks = []
|
28
29
|
losses = []
|
29
30
|
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import pytest
|
2
|
+
import torch
|
3
|
+
from torch import nn
|
4
|
+
from torch._dynamo import config
|
5
|
+
|
6
|
+
import heavyball
|
7
|
+
import heavyball.utils
|
8
|
+
from benchmark.utils import get_optim
|
9
|
+
from heavyball.utils import clean, set_torch
|
10
|
+
|
11
|
+
config.cache_size_limit = 128
|
12
|
+
|
13
|
+
|
14
|
+
def get_memory():
|
15
|
+
clean()
|
16
|
+
torch.cuda.synchronize()
|
17
|
+
clean()
|
18
|
+
torch.cuda.synchronize()
|
19
|
+
return torch.cuda.memory_allocated()
|
20
|
+
|
21
|
+
|
22
|
+
@pytest.mark.parametrize("opt", heavyball.__all__)
|
23
|
+
@pytest.mark.parametrize("size,depth", [(256, 2)])
|
24
|
+
def test_foreach(opt, size, depth: int, iterations: int = 128, outer_iterations: int = 3):
|
25
|
+
set_torch()
|
26
|
+
opt = getattr(heavyball, opt)
|
27
|
+
|
28
|
+
peaks = []
|
29
|
+
losses = []
|
30
|
+
|
31
|
+
for do_ema in [True, False]:
|
32
|
+
torch.manual_seed(0x2131290)
|
33
|
+
peaks.append([])
|
34
|
+
losses.append([])
|
35
|
+
|
36
|
+
for i in range(outer_iterations):
|
37
|
+
model = nn.Sequential(*[nn.Linear(size, size) for _ in range(depth)]).cuda()
|
38
|
+
o = get_optim(opt, model.parameters(), lr=1e-3)
|
39
|
+
|
40
|
+
for _ in range(iterations):
|
41
|
+
loss = model(torch.randn((1024, size), device='cuda')).square().mean()
|
42
|
+
loss.backward()
|
43
|
+
o.step()
|
44
|
+
o.zero_grad()
|
45
|
+
if do_ema:
|
46
|
+
o.ema_update()
|
47
|
+
o.copy_emas_to_params()
|
48
|
+
o.copy_params_to_emas()
|
49
|
+
losses[-1].append(loss.detach())
|
50
|
+
|
51
|
+
if do_ema:
|
52
|
+
o.copy_emas_to_params()
|
53
|
+
loss = model(torch.randn((1024, size), device='cuda')).square().mean()
|
54
|
+
losses[-1].append(loss.detach())
|
55
|
+
|
56
|
+
del model, o
|
57
|
+
clean()
|
58
|
+
|
59
|
+
for i, (l0, l1) in enumerate(zip(*losses)):
|
60
|
+
print(i, l0.item(), l1.item())
|
61
|
+
assert l0.float() <= l1.float()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|