ista-daslab-optimizers 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ista_daslab_optimizers/__init__.py +6 -0
  2. ista_daslab_optimizers/acdc/__init__.py +5 -0
  3. ista_daslab_optimizers/acdc/acdc.py +387 -0
  4. ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
  5. ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
  6. ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
  7. ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
  8. ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
  9. ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
  10. ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
  11. ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
  12. ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
  13. ista_daslab_optimizers/micro_adam/__init__.py +5 -0
  14. ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
  15. ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
  16. ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
  17. ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
  18. ista_daslab_optimizers/tools.py +218 -0
  19. ista_daslab_optimizers/utils/dct.py +45 -0
  20. ista_daslab_optimizers/utils/global_cache.py +45 -0
  21. ista_daslab_optimizers/utils/matrix_storage.py +58 -0
  22. ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
  23. ista_daslab_optimizers/utils/quantizers.py +71 -0
  24. ista_daslab_optimizers/utils/schedulers.py +41 -0
  25. ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
  26. ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
  27. ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
  28. ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
  29. ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
@@ -0,0 +1,242 @@
1
+ import torch
2
+ import torch.distributed as dist
3
+ import math
4
+ import wandb
5
+
6
+ from ista_daslab_optimizers.utils.dct import dct3_matrix, dct_type2_makhoul
7
+ from ista_daslab_optimizers.utils.global_cache import GlobalCache
8
+ from ista_daslab_optimizers.utils.newton_schulz_triton import newton_schulz_triton
9
+
10
+ def adam_update(grad, buf1, buf2, step, betas, eps):
11
+ buf1.lerp_(grad, 1 - betas[0])
12
+ buf2.lerp_(grad.square(), 1 - betas[1])
13
+ buf1c = buf1 / (1 - betas[0] ** step)
14
+ buf2c = buf2 / (1 - betas[1] ** step)
15
+ return buf1c / (buf2c.sqrt() + eps)
16
+
17
+ def zeropower_via_newtonschulz5(G, steps: int):
18
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
19
+ a, b, c = (3.4445, -4.7750, 2.0315)
20
+ X = G.bfloat16()
21
+ if G.size(-2) > G.size(-1):
22
+ X = X.mT
23
+
24
+ # Ensure spectral norm is at most 1
25
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
26
+ # Perform the NS iterations
27
+ for _ in range(steps):
28
+ A = X @ X.mT
29
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
30
+ X = a * X + B @ X
31
+
32
+ if G.size(-2) > G.size(-1):
33
+ X = X.mT
34
+ return X
35
+
36
+ def trion_update(G, M, rank, ns_type, mu, use_makhoul, ns_steps, out_ortho, out_indices):
37
+ # formerly called dct_dion_low_rank_muon_update
38
+ M.add_(G)
39
+
40
+ R, C = G.shape
41
+ is_right_proj = (R >= C)
42
+ # DCT = get_dct_matrix(size=min(R, C), key=(min(R, C), rank), device=G.device, dtype=G.dtype)
43
+
44
+ size = min(R, C)
45
+ key = (size, rank)
46
+ if GlobalCache.contains(category='ortho', key=key):
47
+ DCT = GlobalCache.get(category='ortho', key=key)
48
+ else:
49
+ DCT = dct3_matrix(min(R, C), device=device, dtype=dtype)
50
+ GlobalCache.add(category='ortho', key=key, item=DCT.T if use_makhoul else DCT)
51
+
52
+ if use_makhoul:
53
+ if is_right_proj: # R >= C, I have to flip it for Makhoul
54
+ inputM = M
55
+ else:
56
+ inputM = M.T
57
+ # force the input to have more columns than rows for Makhoul
58
+ # fatM = M if R <= C else M.T # input fat/wide matrix to Makhoul: R < C
59
+ S = dct_type2_makhoul(inputM)
60
+
61
+ if is_right_proj:
62
+ norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows) to rank columns
63
+ else:
64
+ ### case 1: transpose S to be able to use dim=1
65
+ S = S.T # account for the transposition in inputM because Makhoul computes DCT per rows by default
66
+ norms = S.norm(p=1, dim=1) # dim = 1 computes norm of rows (over all columns) to rank rows
67
+
68
+ ### case 2: to avoid transposing S, use dim=0 instead of dim=1 and it should be the same
69
+ # norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows) to rank columns
70
+ else: # use matmul
71
+ # ranking: compute similarities
72
+ if is_right_proj:
73
+ S = M @ DCT # (R, C) @ (C, C) = (R, C)
74
+ norms = S.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows)
75
+ else:
76
+ S = DCT @ M # (R, R) @ (R, C) = (R, C)
77
+ norms = S.norm(p=1, dim=1)
78
+
79
+ # ranking: determine indices of most significant rows/columns
80
+ indices = torch.topk(input=norms, k=rank, sorted=False).indices
81
+
82
+ # create Q_r
83
+ if is_right_proj:
84
+ Q = DCT[:, indices] # (C, r)
85
+ m = S[:, indices] # (R, r)
86
+ M.add_(m @ Q.T, alpha=-(1 - mu))
87
+ else:
88
+ Q = DCT[indices, :] # (r, R)
89
+ m = S[indices, :] # (r, C)
90
+ M.add_(Q.T @ m, alpha=-(1 - mu))
91
+
92
+ if ns_type == 'torch':
93
+ ortho_m = zeropower_via_newtonschulz5(m, steps=ns_steps).to(dtype=M.dtype)
94
+ elif ns_type == 'triton':
95
+ ortho_m = newton_schulz_triton(m).to(dtype=M.dtype)
96
+ else:
97
+ raise RuntimeError(f'Unknown ns_type: {ns_type}')
98
+
99
+ out_ortho.copy_(ortho_m)
100
+ out_indices.copy_(indices)
101
+
102
+ class Trion(torch.optim.Optimizer):
103
+ def __init__(self, param_groups):
104
+ for group in param_groups:
105
+ assert "use_muon" in group
106
+ if group["use_muon"]:
107
+
108
+ group["params"] = sorted(group["params"], key=lambda x: x.size(), reverse=True)
109
+ # defaults
110
+ group["lr"] = group.get("lr", 0.02)
111
+ group["momentum"] = group.get("momentum", 0.95)
112
+ group["weight_decay"] = group.get("weight_decay", 0)
113
+ group["step"] = 0
114
+ # assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
115
+ else:
116
+ # defaults
117
+ group["lr"] = group.get("lr", 3e-4)
118
+ group["betas"] = group.get("betas", (0.9, 0.95))
119
+ group["eps"] = group.get("eps", 1e-10)
120
+ group["weight_decay"] = group.get("weight_decay", 0)
121
+ # assert set(group.keys()) == set(["params", "lr", "betas", "eps", "weight_decay", "use_muon"])
122
+ super().__init__(param_groups, dict())
123
+
124
+ @torch.no_grad()
125
+ def step(self, closure=None):
126
+ loss = None
127
+ if closure is not None:
128
+ with torch.enable_grad():
129
+ loss = closure()
130
+
131
+ muon_param_index = 0
132
+ for group in self.param_groups:
133
+ if group["use_muon"]:
134
+ group["step"] += 1
135
+
136
+ params = group["params"]
137
+
138
+ if ('lowrank_updates' not in group) and ('lowrank_indices' not in group):
139
+ group["lowrank_updates"] = []
140
+ group["lowrank_indices"] = []
141
+ for p in params:
142
+ R, C = p.shape
143
+ is_right_proj = (R >= C)
144
+ o_shape = (R, group["rank"]) if is_right_proj else (group["rank"], C)
145
+ group["lowrank_updates"].append(torch.zeros(o_shape, dtype=p.dtype, device=p.device))
146
+ group["lowrank_indices"].append(torch.zeros(group["rank"], dtype=torch.int32, device=p.device))
147
+ lowrank_updates = group["lowrank_updates"]
148
+ lowrank_indices = group["lowrank_indices"]
149
+
150
+ pad_size = dist.get_world_size() - len(params) % dist.get_world_size()
151
+ # params_pad = params + [torch.empty_like(params[-1])] * pad_size
152
+ lowrank_updates_pad = lowrank_updates + [torch.empty_like(lowrank_updates[-1])] * pad_size
153
+ lowrank_indices_pad = lowrank_indices + [torch.empty_like(lowrank_indices[-1])] * pad_size
154
+
155
+ ##### compute low-rank updates only on one GPU
156
+ for pi in range(len(params))[::dist.get_world_size()]:
157
+ idx = pi + dist.get_rank()
158
+ if idx < len(params):
159
+ p = params[idx]
160
+ lowrank_u = lowrank_updates[idx] # low-rank update
161
+ lowrank_idx = lowrank_indices[idx] # row/column indices
162
+
163
+ if p.grad is None:
164
+ # continue
165
+ p.grad = torch.zeros_like(p) # Force synchronization
166
+
167
+ state = self.state[p]
168
+
169
+ if len(state) == 0:
170
+ state["momentum_buffer"] = torch.zeros_like(p)
171
+ state["param_id"] = muon_param_index
172
+ muon_param_index += 1
173
+
174
+ trion_update(
175
+ G=p.grad,
176
+ M=state["momentum_buffer"],
177
+ rank=group["rank"],
178
+ ns_type=group["ns_type"],
179
+ mu=group["momentum"],
180
+ use_makhoul=group.get("use_makhoul", False),
181
+ ns_steps=5,
182
+ out_ortho=lowrank_u,
183
+ out_indices=lowrank_idx)
184
+ # end if
185
+
186
+ # all-gather for low-rank updates
187
+ dist.all_gather(
188
+ tensor_list=lowrank_updates_pad[pi:pi + dist.get_world_size()],
189
+ tensor=lowrank_updates_pad[idx])
190
+
191
+ # all-gather for row/column indices
192
+ dist.all_gather(
193
+ tensor_list=lowrank_indices_pad[pi:pi + dist.get_world_size()],
194
+ tensor=lowrank_indices_pad[idx])
195
+ # end for pi
196
+
197
+ for pi in range(len(params)):
198
+ p = params[pi]
199
+ R, C = p.shape
200
+ indices = lowrank_indices[pi]
201
+ ot = lowrank_updates[pi]
202
+ DCT = get_dct_matrix(size=min(R, C), key=(min(R, C), group["rank"]), device=p.device, dtype=p.dtype)
203
+
204
+ is_right_proj = (R >= C)
205
+ # print(f'R: {R}, C:{C}, Q: {tuple(Q.shape)}, ot: {ot.shape}')
206
+ if is_right_proj:
207
+ Q = DCT[:, indices] # (C, r)
208
+ update = ot @ Q.T # (R, r) @ (r, C) = (R, C)
209
+ else:
210
+ Q = DCT[indices, :] # (R, r)
211
+ update = Q.T @ ot # (R, r) @ (r, C) = (R, C)
212
+
213
+ # R, C = p.shape
214
+ scaling_type = group["scaling_type"]
215
+ if scaling_type == 'kj':
216
+ scaling = max(1, R / C) ** 0.5
217
+ elif scaling_type == 'none':
218
+ scaling = 1
219
+ elif scaling_type == 'kimi':
220
+ scaling = 0.2 * math.sqrt(max(R, C))
221
+ elif scaling_type == 'dion':
222
+ scaling = (R / C) ** 0.5
223
+
224
+ p.mul_(1 - group["lr"] * group["weight_decay"]).add_(update.reshape(p.shape), alpha=-group["lr"] * scaling)
225
+ # end for pi
226
+ else:
227
+ for p in group["params"]:
228
+ if p.grad is None:
229
+ # continue
230
+ p.grad = torch.zeros_like(p) # Force synchronization
231
+ state = self.state[p]
232
+ if len(state) == 0:
233
+ state["exp_avg"] = torch.zeros_like(p)
234
+ state["exp_avg_sq"] = torch.zeros_like(p)
235
+ state["step"] = 0
236
+ state["step"] += 1
237
+ update = adam_update(p.grad,state["exp_avg"], state["exp_avg_sq"],
238
+ state["step"], group["betas"], group["eps"])
239
+ p.mul_(1 - group["lr"] * group["weight_decay"])
240
+ p.add_(update, alpha=-group["lr"])
241
+
242
+ return loss
@@ -0,0 +1,5 @@
1
+ from .ista_optimizer import ISTAOptimizer
2
+
3
+ __all__ = [
4
+ 'ISTAOptimizer'
5
+ ]
@@ -0,0 +1,36 @@
1
+ import torch
2
+
3
+ class ISTAOptimizer(torch.optim.Optimizer):
4
+ def __init__(self, params, lr, weight_decay):
5
+ super().__init__(params, dict(lr=lr, weight_decay=weight_decay))
6
+ self.lr = lr
7
+ self.weight_decay = weight_decay
8
+ self.optim_steps = 0
9
+
10
+ def loop_params(self, check_grad=True):
11
+ for group in self.param_groups:
12
+ for p in group['params']:
13
+ if check_grad:
14
+ if p.grad is None: continue
15
+ yield group, self.state[p], p
16
+
17
+ @torch.no_grad()
18
+ def init_optimizer_states(self):
19
+ raise NotImplementedError
20
+
21
+ @torch.no_grad()
22
+ def optimizer_step(self):
23
+ raise NotImplementedError
24
+
25
+ @torch.no_grad()
26
+ def step(self, closure=None):
27
+ self.optim_steps += 1
28
+
29
+ loss = None
30
+ if closure is not None:
31
+ with torch.enable_grad():
32
+ loss = closure()
33
+
34
+ self.optimizer_step()
35
+
36
+ return loss
@@ -0,0 +1,5 @@
1
+ from .micro_adam import MicroAdam
2
+
3
+ __all__ = [
4
+ 'MicroAdam',
5
+ ]