heavyball 0.15.0__tar.gz → 0.15.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {heavyball-0.15.0 → heavyball-0.15.1}/PKG-INFO +1 -1
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/__init__.py +24 -2
- heavyball-0.15.1/heavyball/cached_psgd_kron.py +141 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/delayed_psgd.py +8 -7
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/p_adam.py +11 -6
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/psgd_kron.py +8 -6
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/pure_psgd.py +8 -6
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/utils.py +6 -6
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/PKG-INFO +1 -1
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/SOURCES.txt +1 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/setup.py +1 -1
- {heavyball-0.15.0 → heavyball-0.15.1}/LICENSE +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/README.md +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_adamw.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_adopt.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_laprop.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_sfadamw.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/foreach_soap.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/palm_foreach_sfadamw.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/palm_foreach_soap.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/precond_schedule_foreach_soap.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/precond_schedule_palm_foreach_soap.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/precond_schedule_sfpsoap.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball/schedule_free_palm_foreach_soap.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/dependency_links.txt +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/requires.txt +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/heavyball.egg-info/top_level.txt +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/setup.cfg +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/test/test_closure.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/test/test_memory.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/test/test_merge.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/test/test_no_grad.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/test/test_psgd.py +0 -0
- {heavyball-0.15.0 → heavyball-0.15.1}/test/test_soap.py +0 -0
@@ -1,3 +1,4 @@
|
|
1
|
+
from .cached_psgd_kron import ForeachCachedPSGDKron
|
1
2
|
from .delayed_psgd import ForeachDelayedPSGD
|
2
3
|
from .foreach_adamw import ForeachAdamW
|
3
4
|
from .foreach_adopt import ForeachADOPT
|
@@ -16,7 +17,28 @@ from .schedule_free_palm_foreach_soap import SFPaLMForeachSOAP
|
|
16
17
|
|
17
18
|
PalmForEachSoap = PaLMForeachSOAP
|
18
19
|
|
20
|
+
PaLMSOAP = PaLMForeachSOAP
|
21
|
+
PaLMSFAdamW = PaLMForeachSFAdamW
|
22
|
+
PaLMSFSoap = SFPaLMForeachSOAP
|
23
|
+
PaLMForeachSOAP = PaLMForeachSOAP
|
24
|
+
PrecondScheduleSFPaLMSOAP = PrecondScheduleSFPaLMSOAP
|
25
|
+
SOAP = ForeachSOAP
|
26
|
+
SFAdamW = ForeachSFAdamW
|
27
|
+
LaProp = ForeachLaProp
|
28
|
+
ADOPT = ForeachADOPT
|
29
|
+
PrecondScheduleForeachSOAP = PrecondScheduleForeachSOAP
|
30
|
+
PrecondSchedulePaLMForeachSOAP = PrecondSchedulePaLMForeachSOAP
|
31
|
+
PSGDKron = ForeachPSGDKron
|
32
|
+
AdamW = ForeachAdamW
|
33
|
+
PurePSGD = ForeachPurePSGD
|
34
|
+
PaLMPAdam = ForeachPaLMPAdam
|
35
|
+
DelayedPSGD = ForeachDelayedPSGD
|
36
|
+
CachedPSGDKron = ForeachCachedPSGDKron
|
37
|
+
|
19
38
|
__all__ = ['PalmForEachSoap', 'PaLMForeachSFAdamW', 'PaLMForeachSOAP', 'SFPaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
20
39
|
'ForeachSOAP', 'ForeachSFAdamW', 'ForeachLaProp', 'ForeachADOPT', 'PrecondScheduleForeachSOAP',
|
21
|
-
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD',
|
22
|
-
'
|
40
|
+
'PrecondSchedulePaLMForeachSOAP', 'ForeachPSGDKron', 'ForeachAdamW', 'ForeachPurePSGD', 'ForeachPaLMPAdam',
|
41
|
+
'ForeachDelayedPSGD', 'ForeachCachedPSGDKron', #
|
42
|
+
'PaLMSOAP', 'PaLMSFAdamW', 'PaLMSFSoap', 'PaLMSFAdamW', 'PaLMForeachSOAP', 'PrecondScheduleSFPaLMSOAP',
|
43
|
+
'SOAP', 'SFAdamW', 'LaProp', 'ADOPT', 'PSGDKron', 'AdamW', 'PurePSGD', 'PaLMPAdam', 'DelayedPSGD',
|
44
|
+
'CachedPSGDKron']
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""
|
2
|
+
Originally from Evan Walters and Omead Pooladzandi, 2024
|
3
|
+
Modified under Creative Commons Attribution 4.0 International
|
4
|
+
Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5ee8a1a4c29e4780bbf6c521e545189eff9/kron_torch/kron.py
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from heavyball.utils import einsum_base
|
11
|
+
|
12
|
+
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, trust_region_clip_, PSGDBase, \
|
13
|
+
precond_update_prob_schedule, split_p_and_g_in_group, line_to_triu, triu_to_line, set_, einsum_base
|
14
|
+
|
15
|
+
|
16
|
+
class ForeachCachedPSGDKron(PSGDBase):
|
17
|
+
"""Implements PSGD Kron from https://github.com/lixilinx/psgd_torch with cached preconditioners.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
params (iterable): Iterable of parameters to optimize or dicts defining
|
21
|
+
parameter groups.
|
22
|
+
lr (float): Learning rate.
|
23
|
+
b1 (float): Momentum parameter.
|
24
|
+
weight_decay (float): Weight decay (L2 penalty).
|
25
|
+
preconditioner_update_probability (callable or float, optional): Probability of
|
26
|
+
updating the preconditioner. If None, defaults to a schedule that anneals
|
27
|
+
from 1.0 to 0.03 by 4000 steps.
|
28
|
+
max_size_triangular (int): Max size for dim's preconditioner to be triangular.
|
29
|
+
min_ndim_triangular (int): Minimum number of dimensions a layer needs
|
30
|
+
to have triangular preconditioners.
|
31
|
+
memory_save_mode: (string, optional), None, 'one_diag', or 'all_diag', None is default
|
32
|
+
to set all preconditioners to be triangular, 'one_diag' sets the largest
|
33
|
+
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
|
34
|
+
to be diagonal.
|
35
|
+
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
|
36
|
+
update instead of raw gradients.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
40
|
+
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
41
|
+
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
42
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
|
43
|
+
if not 0.0 <= lr:
|
44
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
45
|
+
if not 0.0 <= beta < 1.0:
|
46
|
+
raise ValueError(f"Invalid beta parameter: {beta}")
|
47
|
+
if not 0.0 <= weight_decay:
|
48
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
49
|
+
|
50
|
+
if preconditioner_update_probability is None:
|
51
|
+
preconditioner_update_probability = precond_update_prob_schedule()
|
52
|
+
if clip_fn is None:
|
53
|
+
clip_fn = lambda x: trust_region_clip_(x, 0.9, 1.5)
|
54
|
+
self.preconditioner_update_probability = preconditioner_update_probability
|
55
|
+
self.clip_fn = clip_fn
|
56
|
+
|
57
|
+
defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay, max_size_triangular=max_size_triangular,
|
58
|
+
min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode,
|
59
|
+
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
60
|
+
# precond lr hardcoded to 0.1
|
61
|
+
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
62
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
63
|
+
store_triu_as_line=store_triu_as_line)
|
64
|
+
super().__init__(params, defaults)
|
65
|
+
|
66
|
+
self._prob_step = 0
|
67
|
+
|
68
|
+
def _step(self, group):
|
69
|
+
# update preconditioners all together
|
70
|
+
update_prob = self.preconditioner_update_probability
|
71
|
+
if callable(update_prob):
|
72
|
+
update_prob = update_prob(self._prob_step)
|
73
|
+
do_update = self.rng.random() < update_prob
|
74
|
+
self._prob_step += 1
|
75
|
+
|
76
|
+
momentum_into_precond_update = group.get("momentum_into_precond_update", True)
|
77
|
+
precond_init_scale = group['precond_init_scale']
|
78
|
+
max_size_triangular = group['max_size_triangular']
|
79
|
+
min_ndim_triangular = group['min_ndim_triangular']
|
80
|
+
memory_save_mode = group['memory_save_mode']
|
81
|
+
precond_lr = group['precond_lr']
|
82
|
+
weight_decay = group['weight_decay']
|
83
|
+
lr = group['lr']
|
84
|
+
beta = group['beta']
|
85
|
+
store_triu_as_line = group['store_triu_as_line']
|
86
|
+
|
87
|
+
vals = []
|
88
|
+
|
89
|
+
for p, g in split_p_and_g_in_group(group):
|
90
|
+
state = self.state_(p)
|
91
|
+
|
92
|
+
if 'Q' not in state:
|
93
|
+
state["exp_avg"] = torch.zeros_like(g)
|
94
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
95
|
+
memory_save_mode, dtype=g.dtype)
|
96
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
97
|
+
state['Q_cache'] = [torch.empty_like(q) for q in Q]
|
98
|
+
|
99
|
+
expr = [f'{c.upper()}{c}' if q_.ndim == 2 else c for c, q_ in zip(einsum_base, Q)]
|
100
|
+
expr = ','.join(expr)
|
101
|
+
grad_expr = ''.join(c for c, _ in zip(einsum_base, g.shape))
|
102
|
+
out_expr = ''.join(c.upper() if c.upper() in expr else c for c in grad_expr)
|
103
|
+
expr = f'{expr},{grad_expr}->{out_expr}'
|
104
|
+
|
105
|
+
state['cache_expr'] = expr
|
106
|
+
|
107
|
+
vals.append((p, g, state["exp_avg"], state["Q"], state['Q_cache']))
|
108
|
+
|
109
|
+
if not vals:
|
110
|
+
return
|
111
|
+
|
112
|
+
p_list, grad_list, exp_avg_list, Q_list, Q_cache_list = zip(*vals)
|
113
|
+
del vals
|
114
|
+
|
115
|
+
group["step"] += 1
|
116
|
+
|
117
|
+
torch._foreach_lerp_(exp_avg_list, grad_list, (1 - beta) / (1 - beta ** group["step"]))
|
118
|
+
|
119
|
+
grad_list, Q_list, Q_cache_list, exp_avg_list = list(grad_list), list(Q_list), list(Q_cache_list), list(
|
120
|
+
exp_avg_list)
|
121
|
+
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
122
|
+
cached_q = Q_cache_list.pop(0)
|
123
|
+
q_orig = Q_list.pop(0)
|
124
|
+
ea = exp_avg_list.pop(0)
|
125
|
+
|
126
|
+
if do_update:
|
127
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
128
|
+
self.balance([g], [q])
|
129
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr,
|
130
|
+
[q_orig] if store_triu_as_line else None)
|
131
|
+
for c_, q_ in zip(cached_q, q):
|
132
|
+
if q_.ndim == 2:
|
133
|
+
torch.matmul(q_.T.conj(), q_, out=c_)
|
134
|
+
else:
|
135
|
+
torch.mul(q_.conj(), q_, out=c_)
|
136
|
+
|
137
|
+
set_(g, torch.einsum(self.state_(p)['cache_expr'], *cached_q, ea))
|
138
|
+
grad_list = self.clip_fn(grad_list)
|
139
|
+
|
140
|
+
lr = -warmup(lr, group['step'], group['warmup_steps'])
|
141
|
+
update_param_(p_list, grad_list, lr, weight_decay)
|
@@ -38,7 +38,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: callable = None):
|
41
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
|
42
42
|
if not 0.0 <= lr:
|
43
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
44
|
if not 0.0 <= beta < 1.0:
|
@@ -58,7 +58,8 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
58
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
59
59
|
# precond lr hardcoded to 0.1
|
60
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
61
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
+
store_triu_as_line=store_triu_as_line)
|
62
63
|
super().__init__(params, defaults)
|
63
64
|
|
64
65
|
self._prob_step = 0
|
@@ -80,6 +81,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
80
81
|
weight_decay = group['weight_decay']
|
81
82
|
lr = group['lr']
|
82
83
|
beta = group['beta']
|
84
|
+
store_triu_as_line = group['store_triu_as_line']
|
83
85
|
|
84
86
|
vals = []
|
85
87
|
|
@@ -90,7 +92,7 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
90
92
|
state["exp_avg"] = torch.zeros_like(g)
|
91
93
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
92
94
|
memory_save_mode, dtype=g.dtype)
|
93
|
-
state["Q"] = triu_to_line(Q)
|
95
|
+
state["Q"] = triu_to_line(Q) if store_triu_as_line else Q
|
94
96
|
|
95
97
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
96
98
|
|
@@ -108,12 +110,11 @@ class ForeachDelayedPSGD(PSGDBase):
|
|
108
110
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
109
111
|
q_orig = Q_list.pop(0)
|
110
112
|
ea = exp_avg_list.pop(0)
|
111
|
-
q = line_to_triu(q_orig)
|
112
|
-
self.balance(do_update, [g], [q])
|
113
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
113
114
|
new = psgd_precond_grad(q, self.state_(p)["exprs"], ea)
|
114
|
-
|
115
115
|
if do_update:
|
116
|
-
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig])
|
116
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
117
|
+
self.balance([g], [q])
|
117
118
|
set_(g, new)
|
118
119
|
|
119
120
|
grad_list = self.clip_fn(grad_list)
|
@@ -5,6 +5,7 @@ Source available at https://github.com/evanatyourservice/kron_torch/blob/97a2b5e
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import torch
|
8
|
+
from heavyball.utils import triu_to_line, line_to_triu
|
8
9
|
|
9
10
|
from .utils import update_param_, warmup, psgd_precond_grad, init_Q_exprs, PSGDBase, precond_update_prob_schedule, \
|
10
11
|
exp_avg_sq_, beta_debias, split_p_and_g_in_group
|
@@ -36,7 +37,8 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
36
37
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
38
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
39
|
momentum_into_precond_update=True, warmup_steps: int = 1, betas=(None, None), beta: float = 0.9,
|
39
|
-
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None
|
40
|
+
beta2_scale: float = 0.8, merge_dims: bool = False, split: bool = False, clip_fn: callable = None,
|
41
|
+
store_triu_as_line: bool = True):
|
40
42
|
if not 0.0 <= lr:
|
41
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
44
|
if not 0.0 <= weight_decay:
|
@@ -57,7 +59,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
57
59
|
# precond lr hardcoded to 0.1
|
58
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
59
61
|
step=0, warmup_steps=warmup_steps, beta=beta, beta2_scale=beta2_scale, merge_dims=merge_dims,
|
60
|
-
split=split)
|
62
|
+
split=split, store_triu_as_line=store_triu_as_line)
|
61
63
|
super().__init__(params, defaults)
|
62
64
|
|
63
65
|
self._prob_step = 0
|
@@ -77,6 +79,7 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
77
79
|
precond_lr = group['precond_lr']
|
78
80
|
weight_decay = group['weight_decay']
|
79
81
|
lr = group['lr']
|
82
|
+
store_triu_as_line = group['store_triu_as_line']
|
80
83
|
|
81
84
|
vals = []
|
82
85
|
|
@@ -86,8 +89,9 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
86
89
|
if 'Q' not in state:
|
87
90
|
state['exp_avg'] = torch.zeros_like(g)
|
88
91
|
state['exp_avg_sq'] = torch.zeros_like(g)
|
89
|
-
|
92
|
+
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular,
|
90
93
|
min_ndim_triangular, memory_save_mode, dtype=g.dtype)
|
94
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
91
95
|
|
92
96
|
vals.append((p, g, state["Q"], state['exp_avg'], state['exp_avg_sq']))
|
93
97
|
|
@@ -99,15 +103,16 @@ class ForeachPaLMPAdam(PSGDBase):
|
|
99
103
|
|
100
104
|
group["step"] += 1
|
101
105
|
|
102
|
-
|
106
|
+
Q_triu = [line_to_triu(q) if store_triu_as_line else q for q in Q_list]
|
103
107
|
if do_update:
|
104
|
-
self.
|
108
|
+
self.balance(grad_list, Q_triu)
|
109
|
+
self.do_update(p_list, grad_list, Q_triu, precond_lr, Q_list if store_triu_as_line else None)
|
105
110
|
|
106
111
|
torch._foreach_lerp_(exp_avg, grad_list, 1 - beta_debias(group['beta'], group['step']))
|
107
112
|
|
108
113
|
beta2 = 1 - group['step'] ** -group['beta2_scale']
|
109
114
|
|
110
|
-
for p, Q, g, ea, eas in zip(p_list,
|
115
|
+
for p, Q, g, ea, eas in zip(p_list, Q_triu, grad_list, exp_avg, exp_avg_sq):
|
111
116
|
psgd_precond_grad(Q, self.state_(p)["exprs"], g, inplace=True)
|
112
117
|
ea = psgd_precond_grad(Q, self.state_(p)["exprs"], ea)
|
113
118
|
exp_avg_sq_(eas, g, beta_debias(beta2, group['step']), 1e-8, out=g)
|
@@ -38,7 +38,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
38
38
|
def __init__(self, params, lr=0.001, beta=0.9, weight_decay=0.0, preconditioner_update_probability=None,
|
39
39
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
40
40
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
41
|
-
split: bool = False, clip_fn: Optional[callable] = None):
|
41
|
+
split: bool = False, clip_fn: Optional[callable] = None, store_triu_as_line: bool = True):
|
42
42
|
if not 0.0 <= lr:
|
43
43
|
raise ValueError(f"Invalid learning rate: {lr}")
|
44
44
|
if not 0.0 <= beta < 1.0:
|
@@ -58,7 +58,8 @@ class ForeachPSGDKron(PSGDBase):
|
|
58
58
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
59
59
|
# precond lr hardcoded to 0.1
|
60
60
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
61
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
61
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
62
|
+
store_triu_as_line=store_triu_as_line)
|
62
63
|
super().__init__(params, defaults)
|
63
64
|
|
64
65
|
self._prob_step = 0
|
@@ -80,6 +81,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
80
81
|
weight_decay = group['weight_decay']
|
81
82
|
lr = group['lr']
|
82
83
|
beta = group['beta']
|
84
|
+
store_triu_as_line = group['store_triu_as_line']
|
83
85
|
|
84
86
|
vals = []
|
85
87
|
|
@@ -90,7 +92,7 @@ class ForeachPSGDKron(PSGDBase):
|
|
90
92
|
state["exp_avg"] = torch.zeros_like(g)
|
91
93
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
92
94
|
memory_save_mode, dtype=g.dtype)
|
93
|
-
state['Q'] = triu_to_line(Q)
|
95
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
94
96
|
|
95
97
|
vals.append((p, g, state["exp_avg"], state["Q"]))
|
96
98
|
|
@@ -108,11 +110,11 @@ class ForeachPSGDKron(PSGDBase):
|
|
108
110
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
109
111
|
q_orig = Q_list.pop(0)
|
110
112
|
ea = exp_avg_list.pop(0)
|
111
|
-
q = line_to_triu(q_orig)
|
113
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
112
114
|
|
113
|
-
self.balance(do_update, [g], [q])
|
114
115
|
if do_update:
|
115
|
-
self.
|
116
|
+
self.balance([g], [q])
|
117
|
+
self.do_update([p], [ea if momentum_into_precond_update else g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
116
118
|
set_(g, psgd_precond_grad(q, self.state_(p)["exprs"], ea))
|
117
119
|
|
118
120
|
grad_list = self.clip_fn(grad_list)
|
@@ -36,7 +36,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
36
36
|
def __init__(self, params, lr=0.001, weight_decay=0.0, preconditioner_update_probability=None,
|
37
37
|
max_size_triangular=2048, min_ndim_triangular=2, memory_save_mode=None,
|
38
38
|
momentum_into_precond_update=True, warmup_steps: int = 1, merge_dims: bool = False,
|
39
|
-
split: bool = False, clip_fn: callable = None):
|
39
|
+
split: bool = False, clip_fn: callable = None, store_triu_as_line: bool = True):
|
40
40
|
if not 0.0 <= lr:
|
41
41
|
raise ValueError(f"Invalid learning rate: {lr}")
|
42
42
|
if not 0.0 <= weight_decay:
|
@@ -54,7 +54,8 @@ class ForeachPurePSGD(PSGDBase):
|
|
54
54
|
momentum_into_precond_update=momentum_into_precond_update, precond_lr=0.1,
|
55
55
|
# precond lr hardcoded to 0.1
|
56
56
|
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
|
57
|
-
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split
|
57
|
+
step=0, warmup_steps=warmup_steps, merge_dims=merge_dims, split=split,
|
58
|
+
store_triu_as_line=store_triu_as_line)
|
58
59
|
super().__init__(params, defaults)
|
59
60
|
|
60
61
|
self._prob_step = 0
|
@@ -74,6 +75,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
74
75
|
precond_lr = group['precond_lr']
|
75
76
|
weight_decay = group['weight_decay']
|
76
77
|
lr = group['lr']
|
78
|
+
store_triu_as_line = group['store_triu_as_line']
|
77
79
|
|
78
80
|
vals = []
|
79
81
|
|
@@ -83,7 +85,7 @@ class ForeachPurePSGD(PSGDBase):
|
|
83
85
|
if 'Q' not in state:
|
84
86
|
Q, state["exprs"] = init_Q_exprs(p, precond_init_scale, max_size_triangular, min_ndim_triangular,
|
85
87
|
memory_save_mode, dtype=g.dtype)
|
86
|
-
state['Q'] = triu_to_line(Q)
|
88
|
+
state['Q'] = triu_to_line(Q) if store_triu_as_line else Q
|
87
89
|
|
88
90
|
vals.append((p, g, state["Q"]))
|
89
91
|
|
@@ -98,11 +100,11 @@ class ForeachPurePSGD(PSGDBase):
|
|
98
100
|
Q_list = list(Q_list)
|
99
101
|
for i, (p, g) in enumerate(zip(p_list, grad_list)):
|
100
102
|
q_orig = Q_list.pop(0)
|
101
|
-
q = line_to_triu(q_orig)
|
103
|
+
q = line_to_triu(q_orig) if store_triu_as_line else q_orig
|
102
104
|
|
103
|
-
self.balance(do_update, [g], [q])
|
104
105
|
if do_update:
|
105
|
-
self.
|
106
|
+
self.balance([g], [q])
|
107
|
+
self.do_update([p], [g], [q], precond_lr, [q_orig] if store_triu_as_line else None)
|
106
108
|
psgd_precond_grad(q, self.state_(p)["exprs"], g, inplace=True)
|
107
109
|
|
108
110
|
grad_list = self.clip_fn(grad_list)
|
@@ -29,7 +29,7 @@ def decorator(func):
|
|
29
29
|
return _fn
|
30
30
|
|
31
31
|
|
32
|
-
|
32
|
+
einsum_base = string.ascii_lowercase + string.ascii_uppercase
|
33
33
|
|
34
34
|
|
35
35
|
def warmup(lr: float, step: int, warmup_steps: int):
|
@@ -317,8 +317,8 @@ def compute_ggt(grad, GG, max_precond_dim, precondition_1d, beta):
|
|
317
317
|
for idx, sh in enumerate(grad.shape):
|
318
318
|
if sh > max_precond_dim:
|
319
319
|
continue
|
320
|
-
b =
|
321
|
-
g0 =
|
320
|
+
b = einsum_base[idx]
|
321
|
+
g0 = einsum_base[:grad.dim()]
|
322
322
|
g1 = g0.replace(b, b.upper())
|
323
323
|
outer_product = torch.einsum(f'{g0},{g1}->{b + b.upper()}', grad, grad)
|
324
324
|
GG[idx].lerp_(promote(outer_product), 1 - beta)
|
@@ -374,7 +374,7 @@ def project(grad, Q, back: bool):
|
|
374
374
|
:param back: whether to project to Shampoo eigenbases or back to original space
|
375
375
|
:return:
|
376
376
|
"""
|
377
|
-
param =
|
377
|
+
param = einsum_base[:grad.dim()]
|
378
378
|
preconditioners = ",".join([(g + g.upper())[::-1 if back else 1] for m, g in zip(Q, param) if len(m) > 0])
|
379
379
|
if preconditioners:
|
380
380
|
out = ''.join([c.upper() if c.upper() in preconditioners else c for c in param])
|
@@ -759,8 +759,8 @@ class PSGDBase(StatefulOptimizer):
|
|
759
759
|
self.rng = random.Random(0x1923213)
|
760
760
|
self._tiny = torch.finfo(torch.bfloat16).tiny
|
761
761
|
|
762
|
-
def balance(self,
|
763
|
-
if
|
762
|
+
def balance(self, grad_list, Q_list):
|
763
|
+
if self.rng.random() > 0.01:
|
764
764
|
return
|
765
765
|
|
766
766
|
for g, q in zip(grad_list, Q_list):
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|