turboadam 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- turboadam/__init__.py +10 -0
- turboadam/costate.py +464 -0
- turboadam/oneq.py +77 -0
- turboadam/optimizer.py +292 -0
- turboadam/quantize.py +299 -0
- turboadam/triton_kernels.py +516 -0
- turboadam/utils.py +66 -0
- turboadam-0.1.0.dist-info/METADATA +272 -0
- turboadam-0.1.0.dist-info/RECORD +12 -0
- turboadam-0.1.0.dist-info/WHEEL +5 -0
- turboadam-0.1.0.dist-info/licenses/LICENSE +21 -0
- turboadam-0.1.0.dist-info/top_level.txt +1 -0
turboadam/__init__.py
ADDED
turboadam/costate.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""CoState — first moment (m) compression.
|
|
2
|
+
|
|
3
|
+
Gradient-residual decomposition: m = α·g + δ
|
|
4
|
+
- α = (m·g) / (g·g) — scalar per parameter tensor
|
|
5
|
+
- δ = m - α·g — residual orthogonal to current gradient
|
|
6
|
+
|
|
7
|
+
Residual δ is partitioned into 128-element blocks and classified:
|
|
8
|
+
- Null costate (r < τ₀): store 1 bit in bitmap
|
|
9
|
+
- Phase costate (τ₀ ≤ r < τ₁): store 1-bit sign per element
|
|
10
|
+
- Amplitude costate (r ≥ τ₁): store 1-bit sign + fp16 block scale
|
|
11
|
+
|
|
12
|
+
Adaptive thresholds: τ₀ = P_10(r), τ₁ = P_90(r) per parameter tensor per step.
|
|
13
|
+
No warmup required — EMA error-washing handles cold-start.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from turboadam.utils import pad_to_blocks, BLOCK_SIZE
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def decompose(m: torch.Tensor, g: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
23
|
+
"""Decompose momentum into gradient-aligned component and residual.
|
|
24
|
+
|
|
25
|
+
m = α·g + δ where α = (m·g) / (g·g)
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
m: First moment tensor (any shape, will be treated as flat).
|
|
29
|
+
g: Gradient tensor (same shape as m).
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
(alpha, delta) where alpha is a scalar tensor (0-dim) and delta has
|
|
33
|
+
the same shape as m.
|
|
34
|
+
"""
|
|
35
|
+
m_flat = m.reshape(-1)
|
|
36
|
+
g_flat = g.reshape(-1)
|
|
37
|
+
g_dot_g = g_flat.dot(g_flat)
|
|
38
|
+
# Keep alpha as a GPU scalar tensor — no .item() sync
|
|
39
|
+
alpha = torch.where(
|
|
40
|
+
g_dot_g > 0, m_flat.dot(g_flat) / g_dot_g, g_dot_g.new_zeros(())
|
|
41
|
+
)
|
|
42
|
+
delta = m - alpha * g
|
|
43
|
+
return alpha, delta
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
# Block ratio computation
|
|
48
|
+
# ---------------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def compute_block_ratios(
|
|
52
|
+
delta: torch.Tensor,
|
|
53
|
+
m: torch.Tensor,
|
|
54
|
+
block_size: int = BLOCK_SIZE,
|
|
55
|
+
) -> torch.Tensor:
|
|
56
|
+
"""Compute per-block ratio r = norm(delta_block) / norm(m_block).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
delta: Residual tensor (any shape).
|
|
60
|
+
m: First moment tensor (same shape as delta).
|
|
61
|
+
block_size: Elements per block (default: BLOCK_SIZE = 128).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
1-D float32 tensor of length num_blocks, each entry in [0, ∞).
|
|
65
|
+
Blocks where norm(m_block) == 0 get ratio 0.
|
|
66
|
+
"""
|
|
67
|
+
delta_flat = delta.reshape(-1).float()
|
|
68
|
+
m_flat = m.reshape(-1).float()
|
|
69
|
+
|
|
70
|
+
delta_padded, orig_len = pad_to_blocks(delta_flat, block_size)
|
|
71
|
+
m_padded, _ = pad_to_blocks(m_flat, block_size)
|
|
72
|
+
|
|
73
|
+
num_blocks = delta_padded.shape[0] // block_size
|
|
74
|
+
delta_blocks = delta_padded.reshape(num_blocks, block_size)
|
|
75
|
+
m_blocks = m_padded.reshape(num_blocks, block_size)
|
|
76
|
+
|
|
77
|
+
delta_norms = delta_blocks.norm(dim=1) # (num_blocks,)
|
|
78
|
+
m_norms = m_blocks.norm(dim=1) # (num_blocks,)
|
|
79
|
+
|
|
80
|
+
# Guard: where m_norm is zero, ratio is 0
|
|
81
|
+
safe_m_norms = m_norms.clone()
|
|
82
|
+
safe_m_norms[safe_m_norms == 0.0] = 1.0 # avoid division by zero
|
|
83
|
+
ratios = delta_norms / safe_m_norms
|
|
84
|
+
ratios[m_norms == 0.0] = 0.0
|
|
85
|
+
|
|
86
|
+
return ratios
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ---------------------------------------------------------------------------
|
|
90
|
+
# Threshold computation
|
|
91
|
+
# ---------------------------------------------------------------------------
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def compute_thresholds(
|
|
95
|
+
ratios: torch.Tensor,
|
|
96
|
+
null_pct: float = 0.10,
|
|
97
|
+
amp_pct: float = 0.90,
|
|
98
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
99
|
+
"""Compute adaptive thresholds as percentiles of ratios.
|
|
100
|
+
|
|
101
|
+
Default P10/P90 gives 10% null, 80% phase, 10% amplitude.
|
|
102
|
+
|
|
103
|
+
Uses sort+index instead of torch.quantile for ~4x speedup on GPU.
|
|
104
|
+
Returns scalar tensors (no .item() sync).
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
ratios: 1-D float tensor of per-block ratios.
|
|
108
|
+
null_pct: Percentile for null/phase boundary. Default: 0.10 (P10).
|
|
109
|
+
amp_pct: Percentile for phase/amplitude boundary. Default: 0.90 (P90).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
(tau0, tau1) as scalar tensors on the same device as ratios.
|
|
113
|
+
"""
|
|
114
|
+
sorted_r = ratios.sort().values
|
|
115
|
+
n = sorted_r.shape[0]
|
|
116
|
+
idx_lo = max(0, int(null_pct * n) - 1)
|
|
117
|
+
idx_hi = min(n - 1, int(amp_pct * n))
|
|
118
|
+
return sorted_r[idx_lo], sorted_r[idx_hi]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# ---------------------------------------------------------------------------
|
|
122
|
+
# Block classification
|
|
123
|
+
# ---------------------------------------------------------------------------
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def classify_blocks(
|
|
127
|
+
ratios: torch.Tensor,
|
|
128
|
+
tau0: float,
|
|
129
|
+
tau1: float,
|
|
130
|
+
) -> torch.Tensor:
|
|
131
|
+
"""Assign costate label to each block based on its ratio.
|
|
132
|
+
|
|
133
|
+
Labels:
|
|
134
|
+
0 (null) : r < tau0
|
|
135
|
+
1 (phase) : tau0 <= r < tau1
|
|
136
|
+
2 (amplitude) : r >= tau1
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
ratios: 1-D float tensor of per-block ratios.
|
|
140
|
+
tau0: Lower threshold (10th percentile).
|
|
141
|
+
tau1: Upper threshold (90th percentile).
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
uint8 tensor of labels, same length as ratios.
|
|
145
|
+
"""
|
|
146
|
+
labels = torch.zeros(ratios.shape[0], dtype=torch.uint8, device=ratios.device)
|
|
147
|
+
labels[ratios >= tau0] = 1 # phase (will be overwritten for amplitude)
|
|
148
|
+
labels[ratios >= tau1] = 2 # amplitude
|
|
149
|
+
return labels
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ---------------------------------------------------------------------------
|
|
153
|
+
# Encoding and decoding
|
|
154
|
+
# ---------------------------------------------------------------------------
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _pack_signs(values: torch.Tensor) -> torch.Tensor:
|
|
158
|
+
"""Pack sign bits (1 if negative, 0 if non-negative) into uint8, 8 bits/byte.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
values: 1-D float tensor of arbitrary length.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
uint8 tensor of length ceil(len(values) / 8).
|
|
165
|
+
"""
|
|
166
|
+
n = values.shape[0]
|
|
167
|
+
pad = (8 - n % 8) % 8
|
|
168
|
+
if pad > 0:
|
|
169
|
+
values = torch.cat([values, values.new_zeros(pad)])
|
|
170
|
+
sign_bits = (values < 0).to(torch.uint8).reshape(-1, 8)
|
|
171
|
+
# Pack via bitwise shifts (avoids creating multiplier tensor each call)
|
|
172
|
+
packed = (
|
|
173
|
+
(sign_bits[:, 0] << 7)
|
|
174
|
+
| (sign_bits[:, 1] << 6)
|
|
175
|
+
| (sign_bits[:, 2] << 5)
|
|
176
|
+
| (sign_bits[:, 3] << 4)
|
|
177
|
+
| (sign_bits[:, 4] << 3)
|
|
178
|
+
| (sign_bits[:, 5] << 2)
|
|
179
|
+
| (sign_bits[:, 6] << 1)
|
|
180
|
+
| sign_bits[:, 7]
|
|
181
|
+
)
|
|
182
|
+
return packed
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _unpack_signs(packed: torch.Tensor, n: int) -> torch.Tensor:
|
|
186
|
+
"""Unpack sign bits from uint8 bytes back to +1/-1 float values.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
packed: uint8 tensor of packed sign bytes.
|
|
190
|
+
n: Number of elements to unpack (may be less than len(packed)*8).
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
float32 tensor of length n with values +1 or -1.
|
|
194
|
+
"""
|
|
195
|
+
# Ensure uint8 (PyTorch load_state_dict _cast may change dtype).
|
|
196
|
+
# Use vectorized bit extraction — no Python loop, stays on the original device.
|
|
197
|
+
# MPS doesn't support integer bitwise ops, so fall back to CPU for MPS only.
|
|
198
|
+
orig_device = packed.device
|
|
199
|
+
is_mps = orig_device.type == "mps"
|
|
200
|
+
work_device = torch.device("cpu") if is_mps else orig_device
|
|
201
|
+
|
|
202
|
+
packed_int = packed.to(dtype=torch.int32, device=work_device)
|
|
203
|
+
shifts = torch.tensor(
|
|
204
|
+
[7, 6, 5, 4, 3, 2, 1, 0], dtype=torch.int32, device=work_device
|
|
205
|
+
)
|
|
206
|
+
# (num_bytes, 8): extract all 8 bits per byte in one shot
|
|
207
|
+
bits = ((packed_int.unsqueeze(1) >> shifts.unsqueeze(0)) & 1).float()
|
|
208
|
+
bits = bits.reshape(-1)[:n]
|
|
209
|
+
signs = 1.0 - 2.0 * bits
|
|
210
|
+
return signs.to(orig_device)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def encode_blocks(
|
|
214
|
+
delta: torch.Tensor,
|
|
215
|
+
labels: torch.Tensor,
|
|
216
|
+
block_size: int = BLOCK_SIZE,
|
|
217
|
+
) -> dict:
|
|
218
|
+
"""Encode delta into per-block compressed representation.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
delta: Residual tensor (any shape, flattened internally).
|
|
222
|
+
labels: uint8 costate labels, one per block.
|
|
223
|
+
block_size: Elements per block.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
dict with:
|
|
227
|
+
labels (uint8) : costate label per block
|
|
228
|
+
sign_packed (uint8) : packed sign bits, ceil(numel/8) bytes
|
|
229
|
+
block_norms (float32): L2 norm of each delta block
|
|
230
|
+
scales (float16): per-block amplitude scale = block_norm / sqrt(block_size),
|
|
231
|
+
stored as fp16. This is the uniform per-element magnitude
|
|
232
|
+
that sign-only reconstruction (amplitude costate) uses.
|
|
233
|
+
"""
|
|
234
|
+
delta_flat = delta.reshape(-1).float()
|
|
235
|
+
|
|
236
|
+
delta_padded, _ = pad_to_blocks(delta_flat, block_size)
|
|
237
|
+
num_blocks = delta_padded.shape[0] // block_size
|
|
238
|
+
delta_blocks = delta_padded.reshape(num_blocks, block_size)
|
|
239
|
+
|
|
240
|
+
# Per-block L2 norms
|
|
241
|
+
block_norms = delta_blocks.norm(dim=1).float() # (num_blocks,)
|
|
242
|
+
|
|
243
|
+
# Per-block fp16 amplitude scales: block_norm / sqrt(block_size).
|
|
244
|
+
# Storing the per-element uniform magnitude (rather than the block L2 norm) means
|
|
245
|
+
# amplitude decode is: scale * sign(delta_block), yielding the correct element magnitudes.
|
|
246
|
+
# Phase decode computes the same value on the fly from block_norms; amplitude stores it
|
|
247
|
+
# explicitly in fp16 so the scale is preserved with full fp16 precision.
|
|
248
|
+
scales = (block_norms / math.sqrt(block_size)).to(torch.float16) # (num_blocks,)
|
|
249
|
+
|
|
250
|
+
# Pack sign bits for the original (un-padded) elements
|
|
251
|
+
sign_packed = _pack_signs(delta_flat)
|
|
252
|
+
|
|
253
|
+
return {
|
|
254
|
+
"labels": labels,
|
|
255
|
+
"sign_packed": sign_packed,
|
|
256
|
+
"block_norms": block_norms,
|
|
257
|
+
"scales": scales,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def decode_blocks(
|
|
262
|
+
encoded: dict,
|
|
263
|
+
alpha,
|
|
264
|
+
g: torch.Tensor,
|
|
265
|
+
block_size: int = BLOCK_SIZE,
|
|
266
|
+
original_numel: int = None,
|
|
267
|
+
) -> torch.Tensor:
|
|
268
|
+
"""Reconstruct approximated m = alpha*g + delta_hat from encoded representation.
|
|
269
|
+
|
|
270
|
+
Per-costate delta_hat reconstruction:
|
|
271
|
+
Null (0) : delta_hat_block = 0
|
|
272
|
+
Phase (1) : delta_hat_block = (norm(delta_block)/sqrt(block_size)) * sign(delta_block)
|
|
273
|
+
Amplitude (2) : delta_hat_block = fp16_scale * sign(delta_block)
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
encoded: dict returned by encode_blocks.
|
|
277
|
+
alpha: Scalar float from decompose().
|
|
278
|
+
g: Gradient tensor (same original shape as delta).
|
|
279
|
+
block_size: Elements per block.
|
|
280
|
+
original_numel: Number of elements in the original delta (before padding).
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Reconstructed m tensor with the same shape as g.
|
|
284
|
+
"""
|
|
285
|
+
g_flat = g.reshape(-1).float()
|
|
286
|
+
if original_numel is None:
|
|
287
|
+
original_numel = g_flat.shape[0]
|
|
288
|
+
|
|
289
|
+
device = g_flat.device
|
|
290
|
+
labels = encoded["labels"].to(dtype=torch.uint8, device=device)
|
|
291
|
+
sign_packed = encoded["sign_packed"]
|
|
292
|
+
block_norms = encoded["block_norms"].to(dtype=torch.float32, device=device)
|
|
293
|
+
scales = encoded["scales"].to(dtype=torch.float32, device=device)
|
|
294
|
+
num_blocks = labels.shape[0]
|
|
295
|
+
|
|
296
|
+
# Unpack sign bits and reshape to block layout
|
|
297
|
+
signs_flat = _unpack_signs(sign_packed, original_numel)
|
|
298
|
+
signs_padded, _ = pad_to_blocks(signs_flat, block_size)
|
|
299
|
+
signs_blocks = signs_padded.reshape(num_blocks, block_size)
|
|
300
|
+
|
|
301
|
+
# Compute per-block scale for each costate (vectorized — no Python loop):
|
|
302
|
+
# Null (0): scale = 0
|
|
303
|
+
# Phase (1): scale = block_norm / sqrt(block_size)
|
|
304
|
+
# Amplitude (2): scale = fp16 stored scale
|
|
305
|
+
# Build a (num_blocks,) scale tensor, then broadcast over block_size.
|
|
306
|
+
phase_scales = block_norms / math.sqrt(block_size) # (num_blocks,)
|
|
307
|
+
|
|
308
|
+
# Build per-block scale: null→0, phase→phase_scale, amplitude→stored scale
|
|
309
|
+
# Use label as index: [0_scale, phase_scale, amp_scale] per block
|
|
310
|
+
block_scales = torch.zeros_like(phase_scales)
|
|
311
|
+
mask_phase = labels == 1
|
|
312
|
+
mask_amp = labels == 2
|
|
313
|
+
block_scales[mask_phase] = phase_scales[mask_phase]
|
|
314
|
+
block_scales[mask_amp] = scales[mask_amp]
|
|
315
|
+
|
|
316
|
+
# Broadcast: (num_blocks, 1) * (num_blocks, block_size) → (num_blocks, block_size)
|
|
317
|
+
delta_hat_blocks = (
|
|
318
|
+
block_scales.unsqueeze(1) * signs_blocks
|
|
319
|
+
) # (num_blocks, block_size)
|
|
320
|
+
|
|
321
|
+
# Trim to original numel and reconstruct m
|
|
322
|
+
delta_hat = delta_hat_blocks.reshape(-1)[:original_numel]
|
|
323
|
+
result = alpha * g_flat + delta_hat
|
|
324
|
+
return result.reshape(g.shape)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
# ---------------------------------------------------------------------------
|
|
328
|
+
# CoStateManager — stateful per-step update loop (spec section 4.5)
|
|
329
|
+
# ---------------------------------------------------------------------------
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class CoStateManager:
|
|
333
|
+
"""Stateful manager for CoState first moment compression.
|
|
334
|
+
|
|
335
|
+
Implements the per-step update procedure from spec section 4.5:
|
|
336
|
+
1. Load compressed δ̂ and costate bitmap from memory (if prior state exists)
|
|
337
|
+
2. Reconstruct m̃ = α · g + decompress(δ̂) [skip on first call, use m̃ = 0]
|
|
338
|
+
3. Compute EMA update: m_new = β₁ · m̃ + (1 - β₁) · g
|
|
339
|
+
4. Compute new projection: α_new = (m_new · g) / (g · g)
|
|
340
|
+
5. Compute new residual: δ_new = m_new - α_new · g
|
|
341
|
+
6. Classify blocks into costates using adaptive thresholds
|
|
342
|
+
7. Compress and store δ̂_new according to costate assignments
|
|
343
|
+
8. Store updated costate bitmap and α_new
|
|
344
|
+
|
|
345
|
+
Usage:
|
|
346
|
+
mgr = CoStateManager(block_size=128)
|
|
347
|
+
m = mgr.update(g, beta1=0.9) # call each optimizer step
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def __init__(
|
|
351
|
+
self,
|
|
352
|
+
block_size: int = BLOCK_SIZE,
|
|
353
|
+
error_feedback: bool = False,
|
|
354
|
+
null_pct: float = 0.10,
|
|
355
|
+
amp_pct: float = 0.90,
|
|
356
|
+
) -> None:
|
|
357
|
+
self.block_size = block_size
|
|
358
|
+
self._null_pct = null_pct
|
|
359
|
+
self._amp_pct = amp_pct
|
|
360
|
+
self._has_state: bool = False
|
|
361
|
+
self._alpha = 0.0 # becomes a scalar tensor after first update
|
|
362
|
+
self._encoded: dict | None = None
|
|
363
|
+
self._original_numel: int = 0
|
|
364
|
+
self._error_feedback = error_feedback
|
|
365
|
+
self._ef_residual: torch.Tensor | None = None
|
|
366
|
+
|
|
367
|
+
def update(self, g: torch.Tensor, beta1: float) -> torch.Tensor:
|
|
368
|
+
"""Run one step of the CoState update procedure.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
g: Current gradient tensor (any shape).
|
|
372
|
+
beta1: EMA decay for first moment (e.g. 0.9).
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
m_new: Updated first moment tensor, same shape as g.
|
|
376
|
+
"""
|
|
377
|
+
# Cast gradient to fp32 — CoState accumulators are fp32, and fp16/bf16
|
|
378
|
+
# gradients would cause dtype mismatches in dot products.
|
|
379
|
+
g = g.float()
|
|
380
|
+
|
|
381
|
+
# Use Triton kernels if available and on CUDA
|
|
382
|
+
try:
|
|
383
|
+
from turboadam.triton_kernels import (
|
|
384
|
+
triton_costate_decode,
|
|
385
|
+
triton_costate_encode,
|
|
386
|
+
triton_decompose_ratios,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
_use_triton = g.is_cuda
|
|
390
|
+
except ImportError:
|
|
391
|
+
_use_triton = False
|
|
392
|
+
|
|
393
|
+
_decode = triton_costate_decode if _use_triton else decode_blocks
|
|
394
|
+
_encode = triton_costate_encode if _use_triton else encode_blocks
|
|
395
|
+
_decompose_ratios = triton_decompose_ratios if _use_triton else None
|
|
396
|
+
|
|
397
|
+
# Step 1-2: Reconstruct m̃ from compressed prior state (or zeros on first call)
|
|
398
|
+
if self._has_state:
|
|
399
|
+
m_hat = _decode(
|
|
400
|
+
self._encoded,
|
|
401
|
+
self._alpha,
|
|
402
|
+
g,
|
|
403
|
+
self.block_size,
|
|
404
|
+
self._original_numel,
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
m_hat = torch.zeros_like(g, dtype=torch.float32)
|
|
408
|
+
|
|
409
|
+
# Error feedback: compensate for previous step's encoding loss
|
|
410
|
+
if self._error_feedback and self._ef_residual is not None:
|
|
411
|
+
g_corrected = g + self._ef_residual
|
|
412
|
+
else:
|
|
413
|
+
g_corrected = g
|
|
414
|
+
|
|
415
|
+
# Step 3: EMA update
|
|
416
|
+
m_new = beta1 * m_hat + (1.0 - beta1) * g_corrected
|
|
417
|
+
|
|
418
|
+
# Steps 4-6: Decompose + block ratios + classify
|
|
419
|
+
alpha_new, delta_new = decompose(m_new, g)
|
|
420
|
+
ratios = compute_block_ratios(delta_new, m_new, self.block_size)
|
|
421
|
+
tau0, tau1 = compute_thresholds(ratios, self._null_pct, self._amp_pct)
|
|
422
|
+
labels = classify_blocks(ratios, tau0, tau1)
|
|
423
|
+
|
|
424
|
+
# Steps 7-8: Compress and store
|
|
425
|
+
encoded_new = _encode(delta_new, labels, self.block_size)
|
|
426
|
+
|
|
427
|
+
# Error feedback: measure what the encoding lost, accumulate for next step
|
|
428
|
+
if self._error_feedback:
|
|
429
|
+
zero_alpha = g.new_zeros(1)
|
|
430
|
+
delta_hat = _decode(
|
|
431
|
+
encoded_new, zero_alpha, g, self.block_size, m_new.numel()
|
|
432
|
+
)
|
|
433
|
+
ef_error = (delta_new - delta_hat).detach()
|
|
434
|
+
if self._ef_residual is None:
|
|
435
|
+
self._ef_residual = ef_error
|
|
436
|
+
else:
|
|
437
|
+
self._ef_residual = beta1 * self._ef_residual + (1.0 - beta1) * ef_error
|
|
438
|
+
|
|
439
|
+
# Graph-stable buffer management: on first call allocate by cloning,
|
|
440
|
+
# on subsequent calls copy data in-place to keep tensor addresses stable.
|
|
441
|
+
if self._encoded is None:
|
|
442
|
+
# First step: allocate graph-stable buffers by cloning
|
|
443
|
+
self._alpha = (
|
|
444
|
+
alpha_new.clone() if isinstance(alpha_new, torch.Tensor) else alpha_new
|
|
445
|
+
)
|
|
446
|
+
self._encoded = {
|
|
447
|
+
"labels": encoded_new["labels"].clone(),
|
|
448
|
+
"sign_packed": encoded_new["sign_packed"].clone(),
|
|
449
|
+
"block_norms": encoded_new["block_norms"].clone(),
|
|
450
|
+
"scales": encoded_new["scales"].clone(),
|
|
451
|
+
}
|
|
452
|
+
else:
|
|
453
|
+
if isinstance(self._alpha, torch.Tensor):
|
|
454
|
+
self._alpha.copy_(alpha_new)
|
|
455
|
+
else:
|
|
456
|
+
self._alpha = alpha_new
|
|
457
|
+
self._encoded["labels"].copy_(encoded_new["labels"])
|
|
458
|
+
self._encoded["sign_packed"].copy_(encoded_new["sign_packed"])
|
|
459
|
+
self._encoded["block_norms"].copy_(encoded_new["block_norms"])
|
|
460
|
+
self._encoded["scales"].copy_(encoded_new["scales"])
|
|
461
|
+
self._original_numel = m_new.numel()
|
|
462
|
+
self._has_state = True
|
|
463
|
+
|
|
464
|
+
return m_new
|
turboadam/oneq.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""1Q — second moment (v) compression.
|
|
2
|
+
|
|
3
|
+
N-bit log-scale quantization for all parameters.
|
|
4
|
+
Compress-every-step architecture.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from turboadam.utils import pad_to_blocks, BLOCK_SIZE
|
|
10
|
+
from turboadam.quantize import quantize_logscale_nbits, dequantize_logscale_nbits
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compress_v_logscale(
|
|
14
|
+
v: torch.Tensor,
|
|
15
|
+
n_bits: int = 3,
|
|
16
|
+
block_size: int = BLOCK_SIZE,
|
|
17
|
+
stochastic_round: bool = False,
|
|
18
|
+
) -> dict:
|
|
19
|
+
"""Compress a second-moment tensor with n-bit log-scale quantization.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
v:
|
|
24
|
+
Second-moment tensor of any shape. Values are converted to fp32 and
|
|
25
|
+
padded to complete quantization blocks before encoding.
|
|
26
|
+
n_bits:
|
|
27
|
+
Number of bits per element, which determines the number of log-scale
|
|
28
|
+
buckets in each block.
|
|
29
|
+
block_size:
|
|
30
|
+
Number of elements per independent quantization block.
|
|
31
|
+
stochastic_round:
|
|
32
|
+
Whether to use stochastic rounding when assigning bucket indices.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
dict
|
|
37
|
+
Compressed representation containing quantized indices, per-block
|
|
38
|
+
scales, bit width, original shape, original length, and block size.
|
|
39
|
+
"""
|
|
40
|
+
original_shape = v.shape
|
|
41
|
+
v_flat = v.reshape(-1).float()
|
|
42
|
+
v_min = v_flat.min().item()
|
|
43
|
+
pad_value = max(v_min, 1e-38)
|
|
44
|
+
v_padded, original_length = pad_to_blocks(v_flat, block_size, pad_value=pad_value)
|
|
45
|
+
indices, scales, nb = quantize_logscale_nbits(
|
|
46
|
+
v_padded,
|
|
47
|
+
n_bits=n_bits,
|
|
48
|
+
block_size=block_size,
|
|
49
|
+
stochastic_round=stochastic_round,
|
|
50
|
+
)
|
|
51
|
+
return {
|
|
52
|
+
"indices": indices,
|
|
53
|
+
"scales": scales,
|
|
54
|
+
"n_bits": n_bits,
|
|
55
|
+
"original_shape": original_shape,
|
|
56
|
+
"original_length": original_length,
|
|
57
|
+
"block_size": block_size,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def decompress_v(compressed: dict) -> torch.Tensor:
|
|
62
|
+
"""Reconstruct fp32 v from a compressed representation.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
compressed: Dict produced by compress_v_logscale.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
fp32 tensor with the same shape as the original v.
|
|
69
|
+
"""
|
|
70
|
+
v_flat = dequantize_logscale_nbits(
|
|
71
|
+
compressed["indices"],
|
|
72
|
+
compressed["scales"],
|
|
73
|
+
n_bits=compressed["n_bits"],
|
|
74
|
+
block_size=compressed["block_size"],
|
|
75
|
+
original_numel=compressed["original_length"],
|
|
76
|
+
)
|
|
77
|
+
return v_flat.reshape(compressed["original_shape"])
|