ista-daslab-optimizers 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ista_daslab_optimizers/__init__.py +6 -0
- ista_daslab_optimizers/acdc/__init__.py +5 -0
- ista_daslab_optimizers/acdc/acdc.py +387 -0
- ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
- ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
- ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
- ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
- ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
- ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
- ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
- ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
- ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
- ista_daslab_optimizers/micro_adam/__init__.py +5 -0
- ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
- ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
- ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
- ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
- ista_daslab_optimizers/tools.py +218 -0
- ista_daslab_optimizers/utils/dct.py +45 -0
- ista_daslab_optimizers/utils/global_cache.py +45 -0
- ista_daslab_optimizers/utils/matrix_storage.py +58 -0
- ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
- ista_daslab_optimizers/utils/quantizers.py +71 -0
- ista_daslab_optimizers/utils/schedulers.py +41 -0
- ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
- ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
- ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
- ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
- ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _get_autotune_configs():
|
|
8
|
+
return [
|
|
9
|
+
triton.Config(
|
|
10
|
+
{
|
|
11
|
+
"BLOCK_SIZE_M": bm,
|
|
12
|
+
"BLOCK_SIZE_N": bn,
|
|
13
|
+
"BLOCK_SIZE_K": bk,
|
|
14
|
+
"GROUP_SIZE_M": 8,
|
|
15
|
+
"LOWER_UPPER": 1,
|
|
16
|
+
},
|
|
17
|
+
num_stages=stages,
|
|
18
|
+
num_warps=warps,
|
|
19
|
+
)
|
|
20
|
+
for bm in [64, 128]
|
|
21
|
+
for bn in [64, 128, 256]
|
|
22
|
+
for bk in [64, 128]
|
|
23
|
+
for stages, warps in [(3, 4), (3, 8), (4, 4)]
|
|
24
|
+
if bm // bn <= 2 and bn // bm <= 2
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@triton.jit
|
|
29
|
+
def _pid_to_block(
|
|
30
|
+
pid,
|
|
31
|
+
M,
|
|
32
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
33
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
34
|
+
GROUP_SIZE_M: tl.constexpr,
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Helper function to map Triton program ID to (batch, row, col) of the output matrix.
|
|
38
|
+
"""
|
|
39
|
+
# Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
|
40
|
+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
41
|
+
num_pid_n = tl.cdiv(M, BLOCK_SIZE_N)
|
|
42
|
+
|
|
43
|
+
# Map PID to a single matrix in batch
|
|
44
|
+
batch_idx = pid // (num_pid_m * num_pid_n)
|
|
45
|
+
pid = pid % (num_pid_m * num_pid_n)
|
|
46
|
+
|
|
47
|
+
# Map PID to 2D grid of blocks
|
|
48
|
+
pid_m = pid // num_pid_n
|
|
49
|
+
pid_n = pid % num_pid_n
|
|
50
|
+
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
|
|
51
|
+
|
|
52
|
+
m_idx = pid_m * BLOCK_SIZE_M
|
|
53
|
+
n_idx = pid_n * BLOCK_SIZE_N
|
|
54
|
+
|
|
55
|
+
return batch_idx, m_idx, n_idx
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@triton.autotune(
|
|
59
|
+
configs=_get_autotune_configs(),
|
|
60
|
+
key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
|
|
61
|
+
)
|
|
62
|
+
@triton.jit
|
|
63
|
+
def ns_line_1_kernel(
|
|
64
|
+
A_ptr,
|
|
65
|
+
C_ptr,
|
|
66
|
+
M,
|
|
67
|
+
K,
|
|
68
|
+
a_stride_b,
|
|
69
|
+
a_stride_r,
|
|
70
|
+
a_stride_c,
|
|
71
|
+
c_stride_b,
|
|
72
|
+
c_stride_r,
|
|
73
|
+
c_stride_c,
|
|
74
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
75
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
76
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
77
|
+
GROUP_SIZE_M: tl.constexpr,
|
|
78
|
+
LOWER_UPPER: tl.constexpr,
|
|
79
|
+
):
|
|
80
|
+
"""
|
|
81
|
+
Input A has shape (M, K)
|
|
82
|
+
Output C has shape (M, M)
|
|
83
|
+
Compute C = A @ A.T
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
pid = tl.program_id(axis=0)
|
|
87
|
+
batch_idx, m_idx, n_idx = _pid_to_block(
|
|
88
|
+
pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Skip blocks that don't need to be computed
|
|
92
|
+
skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
|
|
93
|
+
skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
|
|
94
|
+
if skip_block_below_diag or skip_block_above_diag:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# Index into one matrix of batch
|
|
98
|
+
A_ptr += batch_idx * a_stride_b
|
|
99
|
+
C_ptr += batch_idx * c_stride_b
|
|
100
|
+
|
|
101
|
+
# Create pointer arrays for A and A.T
|
|
102
|
+
offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
103
|
+
offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
|
|
104
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
105
|
+
a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
|
|
106
|
+
at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)
|
|
107
|
+
|
|
108
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
109
|
+
|
|
110
|
+
# Accumulate over blocks of K
|
|
111
|
+
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
112
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
113
|
+
at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
114
|
+
accumulator = tl.dot(a, at, accumulator)
|
|
115
|
+
a_ptrs += BLOCK_SIZE_K * a_stride_c
|
|
116
|
+
at_ptrs += BLOCK_SIZE_K * a_stride_c
|
|
117
|
+
|
|
118
|
+
out_dtype = C_ptr.dtype.element_ty
|
|
119
|
+
output = accumulator.to(out_dtype)
|
|
120
|
+
|
|
121
|
+
# Store block of C
|
|
122
|
+
offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
|
|
123
|
+
offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
|
|
124
|
+
c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
|
|
125
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
|
|
126
|
+
tl.store(c_ptrs, output, mask=c_mask)
|
|
127
|
+
|
|
128
|
+
# Store block of C mirrored across the diagonal
|
|
129
|
+
c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
|
|
130
|
+
c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
|
|
131
|
+
tl.store(c_ptrs_t, output.T, mask=c_mask_t)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def ns_line_1(A: Tensor, *, out: Tensor = None):
|
|
135
|
+
"""
|
|
136
|
+
Launch Triton kernel to compute C = A @ A.T
|
|
137
|
+
"""
|
|
138
|
+
if A.ndim > 3 or A.ndim < 2:
|
|
139
|
+
raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.")
|
|
140
|
+
|
|
141
|
+
M, K = A.shape[-2:]
|
|
142
|
+
|
|
143
|
+
if out is None:
|
|
144
|
+
out = torch.empty((*A.shape[:-1], M), device=A.device, dtype=A.dtype)
|
|
145
|
+
assert out.size(-2) == M, "Output matrix has incorrect shape"
|
|
146
|
+
assert out.size(-1) == M, "Output matrix has incorrect shape"
|
|
147
|
+
|
|
148
|
+
batch_size = A.size(0) if A.ndim == 3 else 1
|
|
149
|
+
input_batch_stride = A.stride(0) if A.ndim == 3 else 0
|
|
150
|
+
output_batch_stride = out.stride(0) if out.ndim == 3 else 0
|
|
151
|
+
|
|
152
|
+
grid = lambda meta: (
|
|
153
|
+
batch_size
|
|
154
|
+
* triton.cdiv(M, meta["BLOCK_SIZE_M"])
|
|
155
|
+
* triton.cdiv(M, meta["BLOCK_SIZE_N"]),
|
|
156
|
+
)
|
|
157
|
+
ns_line_1_kernel[grid](
|
|
158
|
+
A_ptr=A,
|
|
159
|
+
C_ptr=out,
|
|
160
|
+
M=M,
|
|
161
|
+
K=K,
|
|
162
|
+
a_stride_b=input_batch_stride,
|
|
163
|
+
a_stride_r=A.stride(-2),
|
|
164
|
+
a_stride_c=A.stride(-1),
|
|
165
|
+
c_stride_b=output_batch_stride,
|
|
166
|
+
c_stride_r=out.stride(-2),
|
|
167
|
+
c_stride_c=out.stride(-1),
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return out
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@triton.autotune(
|
|
174
|
+
configs=_get_autotune_configs(),
|
|
175
|
+
key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
|
|
176
|
+
)
|
|
177
|
+
@triton.jit
|
|
178
|
+
def ns_line_2_kernel(
|
|
179
|
+
A_ptr,
|
|
180
|
+
C_ptr,
|
|
181
|
+
M,
|
|
182
|
+
a_stride_b,
|
|
183
|
+
a_stride_r,
|
|
184
|
+
a_stride_c,
|
|
185
|
+
c_stride_b,
|
|
186
|
+
c_stride_r,
|
|
187
|
+
c_stride_c,
|
|
188
|
+
alpha,
|
|
189
|
+
beta,
|
|
190
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
191
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
192
|
+
BLOCK_SIZE_K: tl.constexpr,
|
|
193
|
+
GROUP_SIZE_M: tl.constexpr,
|
|
194
|
+
LOWER_UPPER: tl.constexpr,
|
|
195
|
+
):
|
|
196
|
+
"""
|
|
197
|
+
Input A is square matrix with shape (M, M)
|
|
198
|
+
Output C has shape (M, M)
|
|
199
|
+
Compute C = alpha * A @ A.T + beta * A
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
pid = tl.program_id(axis=0)
|
|
203
|
+
batch_idx, m_idx, n_idx = _pid_to_block(
|
|
204
|
+
pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Skip blocks that don't need to be computed
|
|
208
|
+
skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
|
|
209
|
+
skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
|
|
210
|
+
if skip_block_below_diag or skip_block_above_diag:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# Index into one matrix of batch
|
|
214
|
+
A_ptr += batch_idx * a_stride_b
|
|
215
|
+
C_ptr += batch_idx * c_stride_b
|
|
216
|
+
|
|
217
|
+
# Create pointer arrays for A and A.T
|
|
218
|
+
offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
219
|
+
offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
|
|
220
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
221
|
+
a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
|
|
222
|
+
at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)
|
|
223
|
+
|
|
224
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
225
|
+
|
|
226
|
+
# Accumulate over blocks of K
|
|
227
|
+
for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)):
|
|
228
|
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0)
|
|
229
|
+
at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0)
|
|
230
|
+
accumulator = tl.dot(a, at, accumulator)
|
|
231
|
+
a_ptrs += BLOCK_SIZE_K * a_stride_c
|
|
232
|
+
at_ptrs += BLOCK_SIZE_K * a_stride_c
|
|
233
|
+
|
|
234
|
+
# Load block of A to add (corresponds to the current block of C)
|
|
235
|
+
offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M)
|
|
236
|
+
offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N)
|
|
237
|
+
a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c)
|
|
238
|
+
a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M)
|
|
239
|
+
a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32)
|
|
240
|
+
|
|
241
|
+
# Apply alpha and beta
|
|
242
|
+
accumulator *= alpha
|
|
243
|
+
accumulator += a_add * beta
|
|
244
|
+
|
|
245
|
+
out_dtype = C_ptr.dtype.element_ty
|
|
246
|
+
output = accumulator.to(out_dtype)
|
|
247
|
+
|
|
248
|
+
# Store block of C
|
|
249
|
+
offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
|
|
250
|
+
offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
|
|
251
|
+
c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
|
|
252
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
|
|
253
|
+
tl.store(c_ptrs, output, mask=c_mask)
|
|
254
|
+
|
|
255
|
+
# Store block of C mirrored across the diagonal
|
|
256
|
+
c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
|
|
257
|
+
c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
|
|
258
|
+
tl.store(c_ptrs_t, output.T, mask=c_mask_t)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None):
|
|
262
|
+
"""
|
|
263
|
+
Launch Triton kernel to compute C = alpha * A @ A.T + beta * A
|
|
264
|
+
"""
|
|
265
|
+
if A.ndim > 3 or A.ndim < 2:
|
|
266
|
+
raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.")
|
|
267
|
+
|
|
268
|
+
M, K = A.shape[-2:]
|
|
269
|
+
if M != K:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"Input must be symmetric square matrix, but got shape {A.shape}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if out is None:
|
|
275
|
+
out = torch.empty((*A.shape[:-1], M), device=A.device, dtype=A.dtype)
|
|
276
|
+
assert out.size(-2) == M, "Output matrix has incorrect shape"
|
|
277
|
+
assert out.size(-1) == M, "Output matrix has incorrect shape"
|
|
278
|
+
|
|
279
|
+
batch_size = A.size(0) if A.ndim == 3 else 1
|
|
280
|
+
input_batch_stride = A.stride(0) if A.ndim == 3 else 0
|
|
281
|
+
output_batch_stride = out.stride(0) if out.ndim == 3 else 0
|
|
282
|
+
|
|
283
|
+
grid = lambda meta: (
|
|
284
|
+
batch_size
|
|
285
|
+
* triton.cdiv(M, meta["BLOCK_SIZE_M"])
|
|
286
|
+
* triton.cdiv(M, meta["BLOCK_SIZE_N"]),
|
|
287
|
+
)
|
|
288
|
+
ns_line_2_kernel[grid](
|
|
289
|
+
A_ptr=A,
|
|
290
|
+
C_ptr=out,
|
|
291
|
+
M=M,
|
|
292
|
+
a_stride_b=input_batch_stride,
|
|
293
|
+
a_stride_r=A.stride(-2),
|
|
294
|
+
a_stride_c=A.stride(-1),
|
|
295
|
+
c_stride_b=output_batch_stride,
|
|
296
|
+
c_stride_r=out.stride(-2),
|
|
297
|
+
c_stride_c=out.stride(-1),
|
|
298
|
+
alpha=alpha,
|
|
299
|
+
beta=beta,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
return out
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@torch.compile(dynamic=False, fullgraph=True)
|
|
306
|
+
def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7):
|
|
307
|
+
"""
|
|
308
|
+
Reference implementation of Newton-Schulz without Triton.
|
|
309
|
+
"""
|
|
310
|
+
# Newton-Schulz constants
|
|
311
|
+
ns_consts = [
|
|
312
|
+
(4.0848, -6.8946, 2.9270),
|
|
313
|
+
(3.9505, -6.3029, 2.6377),
|
|
314
|
+
(3.7418, -5.5913, 2.3037),
|
|
315
|
+
(2.8769, -3.1427, 1.2046),
|
|
316
|
+
(2.8366, -3.0525, 1.2012),
|
|
317
|
+
]
|
|
318
|
+
|
|
319
|
+
X = G.to(dtype=torch.bfloat16)
|
|
320
|
+
if G.size(-2) > G.size(-1):
|
|
321
|
+
X = X.mT
|
|
322
|
+
|
|
323
|
+
# Ensure spectral norm is at most 1
|
|
324
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon)
|
|
325
|
+
|
|
326
|
+
for a, b, c in ns_consts:
|
|
327
|
+
A = X @ X.mT
|
|
328
|
+
B = b * A + c * (A @ A)
|
|
329
|
+
X = a * X + B @ X
|
|
330
|
+
|
|
331
|
+
if G.size(-2) > G.size(-1):
|
|
332
|
+
X = X.mT
|
|
333
|
+
return X
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
@torch.compile(dynamic=False, fullgraph=True)
|
|
337
|
+
def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7):
|
|
338
|
+
"""
|
|
339
|
+
Triton implementation of Newton-Schulz iteration
|
|
340
|
+
"""
|
|
341
|
+
# Newton-Schulz constants
|
|
342
|
+
ns_consts = [
|
|
343
|
+
(4.0848, -6.8946, 2.9270),
|
|
344
|
+
(3.9505, -6.3029, 2.6377),
|
|
345
|
+
(3.7418, -5.5913, 2.3037),
|
|
346
|
+
(2.8769, -3.1427, 1.2046),
|
|
347
|
+
(2.8366, -3.0525, 1.2012),
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
X = G.to(dtype=torch.bfloat16)
|
|
351
|
+
if G.size(-2) > G.size(-1):
|
|
352
|
+
X = X.mT
|
|
353
|
+
|
|
354
|
+
# Ensure spectral norm is at most 1
|
|
355
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon)
|
|
356
|
+
|
|
357
|
+
# Allocate buffers
|
|
358
|
+
X = X.contiguous()
|
|
359
|
+
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
|
|
360
|
+
B = torch.empty_like(A)
|
|
361
|
+
C = torch.empty_like(X)
|
|
362
|
+
|
|
363
|
+
ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm
|
|
364
|
+
|
|
365
|
+
# Perform the NS iterations
|
|
366
|
+
for a, b, c in ns_consts:
|
|
367
|
+
ns_line_1(X, out=A) # A = X @ X.mT
|
|
368
|
+
ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A
|
|
369
|
+
ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X
|
|
370
|
+
X, C = C, X # Swap references to avoid unnecessary copies
|
|
371
|
+
|
|
372
|
+
if G.size(-2) > G.size(-1):
|
|
373
|
+
X = X.mT
|
|
374
|
+
return X
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
class Quantizer4bit:
|
|
5
|
+
def __init__(self, shape, device, dtype, bucket_size):
|
|
6
|
+
assert np.prod(shape) % bucket_size == 0
|
|
7
|
+
self.shape = shape
|
|
8
|
+
self.device = device
|
|
9
|
+
self.bucket_size = bucket_size
|
|
10
|
+
self.numel = np.prod(shape)
|
|
11
|
+
|
|
12
|
+
self.n_buckets = self.numel // self.bucket_size
|
|
13
|
+
|
|
14
|
+
self.xq = torch.zeros(self.numel // 2, dtype=torch.uint8, device=self.device)
|
|
15
|
+
self.x_min = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
|
|
16
|
+
self.x_max = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
|
|
17
|
+
|
|
18
|
+
def quantize(self, x):
|
|
19
|
+
N, B = self.n_buckets, self.bucket_size
|
|
20
|
+
N = self.numel // B
|
|
21
|
+
self.x_min.copy_(x.view(N, B).min(dim=1).values.view(-1, 1))
|
|
22
|
+
self.x_max.copy_(x.view(N, B).max(dim=1).values.view(-1, 1))
|
|
23
|
+
u = (self.x_max - self.x_min) / 15.
|
|
24
|
+
xq = ((x.view(N, B) - self.x_min) / u + 0.5).floor().to(torch.uint8).view(-1, 2)
|
|
25
|
+
byte_left = xq[:, 0] << 4
|
|
26
|
+
byte_right = xq[:, 1]
|
|
27
|
+
self.xq.copy_(byte_left | byte_right)
|
|
28
|
+
|
|
29
|
+
def quantize_inv(self):
|
|
30
|
+
N, B = self.n_buckets, self.bucket_size
|
|
31
|
+
u = (self.x_max - self.x_min) / 15.
|
|
32
|
+
byte_left = (self.xq & 0xF0) >> 4
|
|
33
|
+
byte_right = self.xq & 0x0F
|
|
34
|
+
xq = torch.hstack(
|
|
35
|
+
(
|
|
36
|
+
byte_left.view(-1),
|
|
37
|
+
byte_right.view(-1)
|
|
38
|
+
)
|
|
39
|
+
).view(N, B) # intercalate byte_left and byte_right
|
|
40
|
+
x = xq * u + self.x_min
|
|
41
|
+
return x.view(*self.shape)
|
|
42
|
+
|
|
43
|
+
class Quantizer8bit:
|
|
44
|
+
def __init__(self, shape, device, dtype, bucket_size):
|
|
45
|
+
assert np.prod(shape) % bucket_size == 0
|
|
46
|
+
self.shape = shape
|
|
47
|
+
self.device = device
|
|
48
|
+
self.bucket_size = bucket_size
|
|
49
|
+
self.numel = np.prod(shape)
|
|
50
|
+
|
|
51
|
+
self.n_buckets = self.numel // self.bucket_size
|
|
52
|
+
|
|
53
|
+
self.xq = torch.zeros(self.numel, dtype=torch.uint8, device=self.device)
|
|
54
|
+
self.x_min = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
|
|
55
|
+
self.x_max = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
|
|
56
|
+
|
|
57
|
+
def quantize(self, x):
|
|
58
|
+
N, B = self.n_buckets, self.bucket_size
|
|
59
|
+
N = self.numel // B
|
|
60
|
+
self.x_min.copy_(x.view(N, B).min(dim=1).values.view(-1, 1))
|
|
61
|
+
self.x_max.copy_(x.view(N, B).max(dim=1).values.view(-1, 1))
|
|
62
|
+
u = (self.x_max - self.x_min) / 15.
|
|
63
|
+
xq = ((x.view(N, B) - self.x_min) / u + 0.5).floor().to(torch.uint8)
|
|
64
|
+
self.xq.copy_(xq.view(-1))
|
|
65
|
+
del xq, u
|
|
66
|
+
|
|
67
|
+
def quantize_inv(self):
|
|
68
|
+
N, B = self.n_buckets, self.bucket_size
|
|
69
|
+
u = (self.x_max - self.x_min) / 15.
|
|
70
|
+
x = self.xq.view(N, B) * u + self.x_min
|
|
71
|
+
return x.view(*self.shape)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def ema_standard_schedule(m, g, beta):
|
|
4
|
+
"""
|
|
5
|
+
Implements the standard EMA: m_new = beta * m_old + (1 - beta) * g
|
|
6
|
+
:param m: momentum buffer
|
|
7
|
+
:param g: gradient
|
|
8
|
+
:param beta: EMA coefficient
|
|
9
|
+
"""
|
|
10
|
+
m.lerp_(g, 1 - beta)
|
|
11
|
+
|
|
12
|
+
def ema_delayed_decay_schedule(m, g, beta, beta_prev, t, T_decay, alpha):
|
|
13
|
+
"""
|
|
14
|
+
This version is proposed by Mher Safaryan in June 2025 while a postdoc @ ISTA:
|
|
15
|
+
|
|
16
|
+
beta_0 = 1 (tracks largest weight)
|
|
17
|
+
alpha >= 0 (sub-schedule slope)
|
|
18
|
+
|
|
19
|
+
if t == 1 or t % T_decay == 0:
|
|
20
|
+
m_t = beta * m_t-1 + (1 - beta) * g
|
|
21
|
+
beta_t = 1 - beta
|
|
22
|
+
else:
|
|
23
|
+
m_t = (1 / (1 + alpha + beta_t-1)) * m_t-1 + (alpha + beta_t-1) / (1 + alpha + beta_t-1) * g
|
|
24
|
+
beta_t = (alpha + beta_t-1) / (1 + alpha + beta_t-1)
|
|
25
|
+
|
|
26
|
+
:param m: momentum buffer
|
|
27
|
+
:param g: gradient buffer
|
|
28
|
+
:param beta: EMA coefficient
|
|
29
|
+
:param beta_prev: previous EMA coefficient
|
|
30
|
+
:param alpha: slope (use values between 0.001 and 0.007)
|
|
31
|
+
:param T_decay: decay interval
|
|
32
|
+
:return: returns beta_t
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
if t == 1 or t % T_decay == 0:
|
|
36
|
+
ema_standard_schedule(m, g, beta)
|
|
37
|
+
return 1 - beta
|
|
38
|
+
else:
|
|
39
|
+
beta_t = (alpha + beta_prev) / (1 + alpha + beta_prev)
|
|
40
|
+
ema_standard_schedule(m, g, 1-beta_t)
|
|
41
|
+
return beta_t
|