ista-daslab-optimizers 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ista_daslab_optimizers/__init__.py +6 -0
- ista_daslab_optimizers/acdc/__init__.py +5 -0
- ista_daslab_optimizers/acdc/acdc.py +387 -0
- ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
- ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
- ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
- ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
- ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
- ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
- ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
- ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
- ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
- ista_daslab_optimizers/micro_adam/__init__.py +5 -0
- ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
- ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
- ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
- ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
- ista_daslab_optimizers/tools.py +218 -0
- ista_daslab_optimizers/utils/dct.py +45 -0
- ista_daslab_optimizers/utils/global_cache.py +45 -0
- ista_daslab_optimizers/utils/matrix_storage.py +58 -0
- ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
- ista_daslab_optimizers/utils/quantizers.py +71 -0
- ista_daslab_optimizers/utils/schedulers.py +41 -0
- ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
- ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
- ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
- ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
- ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.distributed as dist
|
|
3
|
+
import math
|
|
4
|
+
import wandb
|
|
5
|
+
|
|
6
|
+
from ista_daslab_optimizers.utils.dct import dct3_matrix, dct_type2_makhoul
|
|
7
|
+
from ista_daslab_optimizers.utils.global_cache import GlobalCache
|
|
8
|
+
from ista_daslab_optimizers.utils.newton_schulz_triton import newton_schulz_triton
|
|
9
|
+
|
|
10
|
+
def adam_update(grad, buf1, buf2, step, betas, eps):
|
|
11
|
+
buf1.lerp_(grad, 1 - betas[0])
|
|
12
|
+
buf2.lerp_(grad.square(), 1 - betas[1])
|
|
13
|
+
buf1c = buf1 / (1 - betas[0] ** step)
|
|
14
|
+
buf2c = buf2 / (1 - betas[1] ** step)
|
|
15
|
+
return buf1c / (buf2c.sqrt() + eps)
|
|
16
|
+
|
|
17
|
+
def zeropower_via_newtonschulz5(G, steps: int):
|
|
18
|
+
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
19
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
20
|
+
X = G.bfloat16()
|
|
21
|
+
if G.size(-2) > G.size(-1):
|
|
22
|
+
X = X.mT
|
|
23
|
+
|
|
24
|
+
# Ensure spectral norm is at most 1
|
|
25
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
26
|
+
# Perform the NS iterations
|
|
27
|
+
for _ in range(steps):
|
|
28
|
+
A = X @ X.mT
|
|
29
|
+
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
30
|
+
X = a * X + B @ X
|
|
31
|
+
|
|
32
|
+
if G.size(-2) > G.size(-1):
|
|
33
|
+
X = X.mT
|
|
34
|
+
return X
|
|
35
|
+
|
|
36
|
+
def trion_update(G, M, rank, ns_type, mu, use_makhoul, ns_steps, out_ortho, out_indices):
|
|
37
|
+
# formerly called dct_dion_low_rank_muon_update
|
|
38
|
+
M.add_(G)
|
|
39
|
+
|
|
40
|
+
R, C = G.shape
|
|
41
|
+
is_right_proj = (R >= C)
|
|
42
|
+
# DCT = get_dct_matrix(size=min(R, C), key=(min(R, C), rank), device=G.device, dtype=G.dtype)
|
|
43
|
+
|
|
44
|
+
size = min(R, C)
|
|
45
|
+
key = (size, rank)
|
|
46
|
+
if GlobalCache.contains(category='ortho', key=key):
|
|
47
|
+
DCT = GlobalCache.get(category='ortho', key=key)
|
|
48
|
+
else:
|
|
49
|
+
DCT = dct3_matrix(min(R, C), device=device, dtype=dtype)
|
|
50
|
+
GlobalCache.add(category='ortho', key=key, item=DCT.T if use_makhoul else DCT)
|
|
51
|
+
|
|
52
|
+
if use_makhoul:
|
|
53
|
+
if is_right_proj: # R >= C, I have to flip it for Makhoul
|
|
54
|
+
inputM = M
|
|
55
|
+
else:
|
|
56
|
+
inputM = M.T
|
|
57
|
+
# force the input to have more columns than rows for Makhoul
|
|
58
|
+
# fatM = M if R <= C else M.T # input fat/wide matrix to Makhoul: R < C
|
|
59
|
+
S = dct_type2_makhoul(inputM)
|
|
60
|
+
|
|
61
|
+
if is_right_proj:
|
|
62
|
+
norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows) to rank columns
|
|
63
|
+
else:
|
|
64
|
+
### case 1: transpose S to be able to use dim=1
|
|
65
|
+
S = S.T # account for the transposition in inputM because Makhoul computes DCT per rows by default
|
|
66
|
+
norms = S.norm(p=1, dim=1) # dim = 1 computes norm of rows (over all columns) to rank rows
|
|
67
|
+
|
|
68
|
+
### case 2: to avoid transposing S, use dim=0 instead of dim=1 and it should be the same
|
|
69
|
+
# norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows) to rank columns
|
|
70
|
+
else: # use matmul
|
|
71
|
+
# ranking: compute similarities
|
|
72
|
+
if is_right_proj:
|
|
73
|
+
S = M @ DCT # (R, C) @ (C, C) = (R, C)
|
|
74
|
+
norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows)
|
|
75
|
+
else:
|
|
76
|
+
S = DCT @ M # (R, R) @ (R, C) = (R, C)
|
|
77
|
+
norms = S.norm(p=1, dim=1)
|
|
78
|
+
|
|
79
|
+
# ranking: determine indices of most significant rows/columns
|
|
80
|
+
indices = torch.topk(input=norms, k=rank, sorted=False).indices
|
|
81
|
+
|
|
82
|
+
# create Q_r
|
|
83
|
+
if is_right_proj:
|
|
84
|
+
Q = DCT[:, indices] # (C, r)
|
|
85
|
+
m = S[:, indices] # (R, r)
|
|
86
|
+
M.add_(m @ Q.T, alpha=-(1 - mu))
|
|
87
|
+
else:
|
|
88
|
+
Q = DCT[indices, :] # (r, R)
|
|
89
|
+
m = S[indices, :] # (r, C)
|
|
90
|
+
M.add_(Q.T @ m, alpha=-(1 - mu))
|
|
91
|
+
|
|
92
|
+
if ns_type == 'torch':
|
|
93
|
+
ortho_m = zeropower_via_newtonschulz5(m, steps=ns_steps).to(dtype=M.dtype)
|
|
94
|
+
elif ns_type == 'triton':
|
|
95
|
+
ortho_m = newton_schulz_triton(m).to(dtype=M.dtype)
|
|
96
|
+
else:
|
|
97
|
+
raise RuntimeError(f'Unknown ns_type: {ns_type}')
|
|
98
|
+
|
|
99
|
+
out_ortho.copy_(ortho_m)
|
|
100
|
+
out_indices.copy_(indices)
|
|
101
|
+
|
|
102
|
+
class Trion(torch.optim.Optimizer):
|
|
103
|
+
def __init__(self, param_groups):
|
|
104
|
+
for group in param_groups:
|
|
105
|
+
assert "use_muon" in group
|
|
106
|
+
if group["use_muon"]:
|
|
107
|
+
|
|
108
|
+
group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
|
|
109
|
+
# defaults
|
|
110
|
+
group["lr"] = group.get("lr", 0.02)
|
|
111
|
+
group["momentum"] = group.get("momentum", 0.95)
|
|
112
|
+
group["weight_decay"] = group.get("weight_decay", 0)
|
|
113
|
+
group["step"] = 0
|
|
114
|
+
# assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
|
|
115
|
+
else:
|
|
116
|
+
# defaults
|
|
117
|
+
group["lr"] = group.get("lr", 3e-4)
|
|
118
|
+
group["betas"] = group.get("betas", (0.9, 0.95))
|
|
119
|
+
group["eps"] = group.get("eps", 1e-10)
|
|
120
|
+
group["weight_decay"] = group.get("weight_decay", 0)
|
|
121
|
+
# assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
|
|
122
|
+
super().__init__(param_groups, dict())
|
|
123
|
+
|
|
124
|
+
@torch.no_grad()
|
|
125
|
+
def step(self, closure=None):
|
|
126
|
+
loss = None
|
|
127
|
+
if closure is not None:
|
|
128
|
+
with torch.enable_grad():
|
|
129
|
+
loss = closure()
|
|
130
|
+
|
|
131
|
+
muon_param_index = 0
|
|
132
|
+
for group in self.param_groups:
|
|
133
|
+
if group["use_muon"]:
|
|
134
|
+
group["step"] += 1
|
|
135
|
+
|
|
136
|
+
params = group["params"]
|
|
137
|
+
|
|
138
|
+
if ('lowrank_updates' not in group) and ('lowrank_indices' not in group):
|
|
139
|
+
group["lowrank_updates"] = []
|
|
140
|
+
group["lowrank_indices"] = []
|
|
141
|
+
for p in params:
|
|
142
|
+
R, C = p.shape
|
|
143
|
+
is_right_proj = (R >= C)
|
|
144
|
+
o_shape = (R, group["rank"]) if is_right_proj else (group["rank"], C)
|
|
145
|
+
group["lowrank_updates"].append(torch.zeros(o_shape, dtype=p.dtype, device=p.device))
|
|
146
|
+
group["lowrank_indices"].append(torch.zeros(group["rank"], dtype=torch.int32, device=p.device))
|
|
147
|
+
lowrank_updates = group["lowrank_updates"]
|
|
148
|
+
lowrank_indices = group["lowrank_indices"]
|
|
149
|
+
|
|
150
|
+
pad_size = dist.get_world_size() - len(params) % dist.get_world_size()
|
|
151
|
+
# params_pad = params + [torch.empty_like(params[-1])] * pad_size
|
|
152
|
+
lowrank_updates_pad = lowrank_updates + [torch.empty_like(lowrank_updates[-1])] * pad_size
|
|
153
|
+
lowrank_indices_pad = lowrank_indices + [torch.empty_like(lowrank_indices[-1])] * pad_size
|
|
154
|
+
|
|
155
|
+
##### compute low-rank updates only on one GPU
|
|
156
|
+
for pi in range(len(params))[::dist.get_world_size()]:
|
|
157
|
+
idx = pi + dist.get_rank()
|
|
158
|
+
if idx < len(params):
|
|
159
|
+
p = params[idx]
|
|
160
|
+
lowrank_u = lowrank_updates[idx] # low-rank update
|
|
161
|
+
lowrank_idx = lowrank_indices[idx] # row/column indices
|
|
162
|
+
|
|
163
|
+
if p.grad is None:
|
|
164
|
+
# continue
|
|
165
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
166
|
+
|
|
167
|
+
state = self.state[p]
|
|
168
|
+
|
|
169
|
+
if len(state) == 0:
|
|
170
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
171
|
+
state["param_id"] = muon_param_index
|
|
172
|
+
muon_param_index += 1
|
|
173
|
+
|
|
174
|
+
trion_update(
|
|
175
|
+
G=p.grad,
|
|
176
|
+
M=state["momentum_buffer"],
|
|
177
|
+
rank=group["rank"],
|
|
178
|
+
ns_type=group["ns_type"],
|
|
179
|
+
mu=group["momentum"],
|
|
180
|
+
use_makhoul=group.get("use_makhoul", False),
|
|
181
|
+
ns_steps=5,
|
|
182
|
+
out_ortho=lowrank_u,
|
|
183
|
+
out_indices=lowrank_idx)
|
|
184
|
+
# end if
|
|
185
|
+
|
|
186
|
+
# all-gather for low-rank updates
|
|
187
|
+
dist.all_gather(
|
|
188
|
+
tensor_list=lowrank_updates_pad[pi:pi + dist.get_world_size()],
|
|
189
|
+
tensor=lowrank_updates_pad[idx])
|
|
190
|
+
|
|
191
|
+
# all-gather for row/column indices
|
|
192
|
+
dist.all_gather(
|
|
193
|
+
tensor_list=lowrank_indices_pad[pi:pi + dist.get_world_size()],
|
|
194
|
+
tensor=lowrank_indices_pad[idx])
|
|
195
|
+
# end for pi
|
|
196
|
+
|
|
197
|
+
for pi in range(len(params)):
|
|
198
|
+
p = params[pi]
|
|
199
|
+
R, C = p.shape
|
|
200
|
+
indices = lowrank_indices[pi]
|
|
201
|
+
ot = lowrank_updates[pi]
|
|
202
|
+
DCT = get_dct_matrix(size=min(R, C), key=(min(R, C), group["rank"]), device=p.device, dtype=p.dtype)
|
|
203
|
+
|
|
204
|
+
is_right_proj = (R >= C)
|
|
205
|
+
# print(f'R: {R}, C:{C}, Q: {tuple(Q.shape)}, ot: {ot.shape}')
|
|
206
|
+
if is_right_proj:
|
|
207
|
+
Q = DCT[:, indices] # (C, r)
|
|
208
|
+
update = ot @ Q.T # (R, r) @ (r, C) = (R, C)
|
|
209
|
+
else:
|
|
210
|
+
Q = DCT[indices, :] # (R, r)
|
|
211
|
+
update = Q.T @ ot # (R, r) @ (r, C) = (R, C)
|
|
212
|
+
|
|
213
|
+
# R, C = p.shape
|
|
214
|
+
scaling_type = group["scaling_type"]
|
|
215
|
+
if scaling_type == 'kj':
|
|
216
|
+
scaling = max(1, R / C) ** 0.5
|
|
217
|
+
elif scaling_type == 'none':
|
|
218
|
+
scaling = 1
|
|
219
|
+
elif scaling_type == 'kimi':
|
|
220
|
+
scaling = 0.2 * math.sqrt(max(R, C))
|
|
221
|
+
elif scaling_type == 'dion':
|
|
222
|
+
scaling = (R / C) ** 0.5
|
|
223
|
+
|
|
224
|
+
p.mul_(1 - group["lr"] * group["weight_decay"]).add_(update.reshape(p.shape), alpha=-group["lr"] * scaling)
|
|
225
|
+
# end for pi
|
|
226
|
+
else:
|
|
227
|
+
for p in group["params"]:
|
|
228
|
+
if p.grad is None:
|
|
229
|
+
# continue
|
|
230
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
231
|
+
state = self.state[p]
|
|
232
|
+
if len(state) == 0:
|
|
233
|
+
state["exp_avg"] = torch.zeros_like(p)
|
|
234
|
+
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
235
|
+
state["step"] = 0
|
|
236
|
+
state["step"] += 1
|
|
237
|
+
update = adam_update(p.grad,state["exp_avg"], state["exp_avg_sq"],
|
|
238
|
+
state["step"], group["betas"], group["eps"])
|
|
239
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
240
|
+
p.add_(update, alpha=-group["lr"])
|
|
241
|
+
|
|
242
|
+
return loss
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
class ISTAOptimizer(torch.optim.Optimizer):
|
|
4
|
+
def __init__(self, params, lr, weight_decay):
|
|
5
|
+
super().__init__(params, dict(lr=lr, weight_decay=weight_decay))
|
|
6
|
+
self.lr = lr
|
|
7
|
+
self.weight_decay = weight_decay
|
|
8
|
+
self.optim_steps = 0
|
|
9
|
+
|
|
10
|
+
def loop_params(self, check_grad=True):
|
|
11
|
+
for group in self.param_groups:
|
|
12
|
+
for p in group['params']:
|
|
13
|
+
if check_grad:
|
|
14
|
+
if p.grad is None: continue
|
|
15
|
+
yield group, self.state[p], p
|
|
16
|
+
|
|
17
|
+
@torch.no_grad()
|
|
18
|
+
def init_optimizer_states(self):
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
@torch.no_grad()
|
|
22
|
+
def optimizer_step(self):
|
|
23
|
+
raise NotImplementedError
|
|
24
|
+
|
|
25
|
+
@torch.no_grad()
|
|
26
|
+
def step(self, closure=None):
|
|
27
|
+
self.optim_steps += 1
|
|
28
|
+
|
|
29
|
+
loss = None
|
|
30
|
+
if closure is not None:
|
|
31
|
+
with torch.enable_grad():
|
|
32
|
+
loss = closure()
|
|
33
|
+
|
|
34
|
+
self.optimizer_step()
|
|
35
|
+
|
|
36
|
+
return loss
|