turboadam 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
turboadam/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ from importlib.metadata import version, PackageNotFoundError
2
+
3
+ try:
4
+ __version__ = version("turboadam")
5
+ except PackageNotFoundError:
6
+ __version__ = "unknown"
7
+
8
+ from turboadam.optimizer import TurboAdam
9
+
10
+ __all__ = ["TurboAdam"]
turboadam/costate.py ADDED
@@ -0,0 +1,464 @@
1
+ """CoState — first moment (m) compression.
2
+
3
+ Gradient-residual decomposition: m = α·g + δ
4
+ - α = (m·g) / (g·g) — scalar per parameter tensor
5
+ - δ = m - α·g — residual orthogonal to current gradient
6
+
7
+ Residual δ is partitioned into 128-element blocks and classified:
8
+ - Null costate (r < τ₀): store 1 bit in bitmap
9
+ - Phase costate (τ₀ ≤ r < τ₁): store 1-bit sign per element
10
+ - Amplitude costate (r ≥ τ₁): store 1-bit sign + fp16 block scale
11
+
12
+ Adaptive thresholds: τ₀ = P_10(r), τ₁ = P_90(r) per parameter tensor per step.
13
+ No warmup required — EMA error-washing handles cold-start.
14
+ """
15
+
16
+ import math
17
+ import torch
18
+
19
+ from turboadam.utils import pad_to_blocks, BLOCK_SIZE
20
+
21
+
22
+ def decompose(m: torch.Tensor, g: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
23
+ """Decompose momentum into gradient-aligned component and residual.
24
+
25
+ m = α·g + δ where α = (m·g) / (g·g)
26
+
27
+ Args:
28
+ m: First moment tensor (any shape, will be treated as flat).
29
+ g: Gradient tensor (same shape as m).
30
+
31
+ Returns:
32
+ (alpha, delta) where alpha is a scalar tensor (0-dim) and delta has
33
+ the same shape as m.
34
+ """
35
+ m_flat = m.reshape(-1)
36
+ g_flat = g.reshape(-1)
37
+ g_dot_g = g_flat.dot(g_flat)
38
+ # Keep alpha as a GPU scalar tensor — no .item() sync
39
+ alpha = torch.where(
40
+ g_dot_g > 0, m_flat.dot(g_flat) / g_dot_g, g_dot_g.new_zeros(())
41
+ )
42
+ delta = m - alpha * g
43
+ return alpha, delta
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Block ratio computation
48
+ # ---------------------------------------------------------------------------
49
+
50
+
51
+ def compute_block_ratios(
52
+ delta: torch.Tensor,
53
+ m: torch.Tensor,
54
+ block_size: int = BLOCK_SIZE,
55
+ ) -> torch.Tensor:
56
+ """Compute per-block ratio r = norm(delta_block) / norm(m_block).
57
+
58
+ Args:
59
+ delta: Residual tensor (any shape).
60
+ m: First moment tensor (same shape as delta).
61
+ block_size: Elements per block (default: BLOCK_SIZE = 128).
62
+
63
+ Returns:
64
+ 1-D float32 tensor of length num_blocks, each entry in [0, ∞).
65
+ Blocks where norm(m_block) == 0 get ratio 0.
66
+ """
67
+ delta_flat = delta.reshape(-1).float()
68
+ m_flat = m.reshape(-1).float()
69
+
70
+ delta_padded, orig_len = pad_to_blocks(delta_flat, block_size)
71
+ m_padded, _ = pad_to_blocks(m_flat, block_size)
72
+
73
+ num_blocks = delta_padded.shape[0] // block_size
74
+ delta_blocks = delta_padded.reshape(num_blocks, block_size)
75
+ m_blocks = m_padded.reshape(num_blocks, block_size)
76
+
77
+ delta_norms = delta_blocks.norm(dim=1) # (num_blocks,)
78
+ m_norms = m_blocks.norm(dim=1) # (num_blocks,)
79
+
80
+ # Guard: where m_norm is zero, ratio is 0
81
+ safe_m_norms = m_norms.clone()
82
+ safe_m_norms[safe_m_norms == 0.0] = 1.0 # avoid division by zero
83
+ ratios = delta_norms / safe_m_norms
84
+ ratios[m_norms == 0.0] = 0.0
85
+
86
+ return ratios
87
+
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # Threshold computation
91
+ # ---------------------------------------------------------------------------
92
+
93
+
94
+ def compute_thresholds(
95
+ ratios: torch.Tensor,
96
+ null_pct: float = 0.10,
97
+ amp_pct: float = 0.90,
98
+ ) -> tuple[torch.Tensor, torch.Tensor]:
99
+ """Compute adaptive thresholds as percentiles of ratios.
100
+
101
+ Default P10/P90 gives 10% null, 80% phase, 10% amplitude.
102
+
103
+ Uses sort+index instead of torch.quantile for ~4x speedup on GPU.
104
+ Returns scalar tensors (no .item() sync).
105
+
106
+ Args:
107
+ ratios: 1-D float tensor of per-block ratios.
108
+ null_pct: Percentile for null/phase boundary. Default: 0.10 (P10).
109
+ amp_pct: Percentile for phase/amplitude boundary. Default: 0.90 (P90).
110
+
111
+ Returns:
112
+ (tau0, tau1) as scalar tensors on the same device as ratios.
113
+ """
114
+ sorted_r = ratios.sort().values
115
+ n = sorted_r.shape[0]
116
+ idx_lo = max(0, int(null_pct * n) - 1)
117
+ idx_hi = min(n - 1, int(amp_pct * n))
118
+ return sorted_r[idx_lo], sorted_r[idx_hi]
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Block classification
123
+ # ---------------------------------------------------------------------------
124
+
125
+
126
+ def classify_blocks(
127
+ ratios: torch.Tensor,
128
+ tau0: float,
129
+ tau1: float,
130
+ ) -> torch.Tensor:
131
+ """Assign costate label to each block based on its ratio.
132
+
133
+ Labels:
134
+ 0 (null) : r < tau0
135
+ 1 (phase) : tau0 <= r < tau1
136
+ 2 (amplitude) : r >= tau1
137
+
138
+ Args:
139
+ ratios: 1-D float tensor of per-block ratios.
140
+ tau0: Lower threshold (10th percentile).
141
+ tau1: Upper threshold (90th percentile).
142
+
143
+ Returns:
144
+ uint8 tensor of labels, same length as ratios.
145
+ """
146
+ labels = torch.zeros(ratios.shape[0], dtype=torch.uint8, device=ratios.device)
147
+ labels[ratios >= tau0] = 1 # phase (will be overwritten for amplitude)
148
+ labels[ratios >= tau1] = 2 # amplitude
149
+ return labels
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Encoding and decoding
154
+ # ---------------------------------------------------------------------------
155
+
156
+
157
+ def _pack_signs(values: torch.Tensor) -> torch.Tensor:
158
+ """Pack sign bits (1 if negative, 0 if non-negative) into uint8, 8 bits/byte.
159
+
160
+ Args:
161
+ values: 1-D float tensor of arbitrary length.
162
+
163
+ Returns:
164
+ uint8 tensor of length ceil(len(values) / 8).
165
+ """
166
+ n = values.shape[0]
167
+ pad = (8 - n % 8) % 8
168
+ if pad > 0:
169
+ values = torch.cat([values, values.new_zeros(pad)])
170
+ sign_bits = (values < 0).to(torch.uint8).reshape(-1, 8)
171
+ # Pack via bitwise shifts (avoids creating multiplier tensor each call)
172
+ packed = (
173
+ (sign_bits[:, 0] << 7)
174
+ | (sign_bits[:, 1] << 6)
175
+ | (sign_bits[:, 2] << 5)
176
+ | (sign_bits[:, 3] << 4)
177
+ | (sign_bits[:, 4] << 3)
178
+ | (sign_bits[:, 5] << 2)
179
+ | (sign_bits[:, 6] << 1)
180
+ | sign_bits[:, 7]
181
+ )
182
+ return packed
183
+
184
+
185
+ def _unpack_signs(packed: torch.Tensor, n: int) -> torch.Tensor:
186
+ """Unpack sign bits from uint8 bytes back to +1/-1 float values.
187
+
188
+ Args:
189
+ packed: uint8 tensor of packed sign bytes.
190
+ n: Number of elements to unpack (may be less than len(packed)*8).
191
+
192
+ Returns:
193
+ float32 tensor of length n with values +1 or -1.
194
+ """
195
+ # Ensure uint8 (PyTorch load_state_dict _cast may change dtype).
196
+ # Use vectorized bit extraction — no Python loop, stays on the original device.
197
+ # MPS doesn't support integer bitwise ops, so fall back to CPU for MPS only.
198
+ orig_device = packed.device
199
+ is_mps = orig_device.type == "mps"
200
+ work_device = torch.device("cpu") if is_mps else orig_device
201
+
202
+ packed_int = packed.to(dtype=torch.int32, device=work_device)
203
+ shifts = torch.tensor(
204
+ [7, 6, 5, 4, 3, 2, 1, 0], dtype=torch.int32, device=work_device
205
+ )
206
+ # (num_bytes, 8): extract all 8 bits per byte in one shot
207
+ bits = ((packed_int.unsqueeze(1) >> shifts.unsqueeze(0)) & 1).float()
208
+ bits = bits.reshape(-1)[:n]
209
+ signs = 1.0 - 2.0 * bits
210
+ return signs.to(orig_device)
211
+
212
+
213
+ def encode_blocks(
214
+ delta: torch.Tensor,
215
+ labels: torch.Tensor,
216
+ block_size: int = BLOCK_SIZE,
217
+ ) -> dict:
218
+ """Encode delta into per-block compressed representation.
219
+
220
+ Args:
221
+ delta: Residual tensor (any shape, flattened internally).
222
+ labels: uint8 costate labels, one per block.
223
+ block_size: Elements per block.
224
+
225
+ Returns:
226
+ dict with:
227
+ labels (uint8) : costate label per block
228
+ sign_packed (uint8) : packed sign bits, ceil(numel/8) bytes
229
+ block_norms (float32): L2 norm of each delta block
230
+ scales (float16): per-block amplitude scale = block_norm / sqrt(block_size),
231
+ stored as fp16. This is the uniform per-element magnitude
232
+ that sign-only reconstruction (amplitude costate) uses.
233
+ """
234
+ delta_flat = delta.reshape(-1).float()
235
+
236
+ delta_padded, _ = pad_to_blocks(delta_flat, block_size)
237
+ num_blocks = delta_padded.shape[0] // block_size
238
+ delta_blocks = delta_padded.reshape(num_blocks, block_size)
239
+
240
+ # Per-block L2 norms
241
+ block_norms = delta_blocks.norm(dim=1).float() # (num_blocks,)
242
+
243
+ # Per-block fp16 amplitude scales: block_norm / sqrt(block_size).
244
+ # Storing the per-element uniform magnitude (rather than the block L2 norm) means
245
+ # amplitude decode is: scale * sign(delta_block), yielding the correct element magnitudes.
246
+ # Phase decode computes the same value on the fly from block_norms; amplitude stores it
247
+ # explicitly in fp16 so the scale is preserved with full fp16 precision.
248
+ scales = (block_norms / math.sqrt(block_size)).to(torch.float16) # (num_blocks,)
249
+
250
+ # Pack sign bits for the original (un-padded) elements
251
+ sign_packed = _pack_signs(delta_flat)
252
+
253
+ return {
254
+ "labels": labels,
255
+ "sign_packed": sign_packed,
256
+ "block_norms": block_norms,
257
+ "scales": scales,
258
+ }
259
+
260
+
261
+ def decode_blocks(
262
+ encoded: dict,
263
+ alpha,
264
+ g: torch.Tensor,
265
+ block_size: int = BLOCK_SIZE,
266
+ original_numel: int = None,
267
+ ) -> torch.Tensor:
268
+ """Reconstruct approximated m = alpha*g + delta_hat from encoded representation.
269
+
270
+ Per-costate delta_hat reconstruction:
271
+ Null (0) : delta_hat_block = 0
272
+ Phase (1) : delta_hat_block = (norm(delta_block)/sqrt(block_size)) * sign(delta_block)
273
+ Amplitude (2) : delta_hat_block = fp16_scale * sign(delta_block)
274
+
275
+ Args:
276
+ encoded: dict returned by encode_blocks.
277
+ alpha: Scalar float from decompose().
278
+ g: Gradient tensor (same original shape as delta).
279
+ block_size: Elements per block.
280
+ original_numel: Number of elements in the original delta (before padding).
281
+
282
+ Returns:
283
+ Reconstructed m tensor with the same shape as g.
284
+ """
285
+ g_flat = g.reshape(-1).float()
286
+ if original_numel is None:
287
+ original_numel = g_flat.shape[0]
288
+
289
+ device = g_flat.device
290
+ labels = encoded["labels"].to(dtype=torch.uint8, device=device)
291
+ sign_packed = encoded["sign_packed"]
292
+ block_norms = encoded["block_norms"].to(dtype=torch.float32, device=device)
293
+ scales = encoded["scales"].to(dtype=torch.float32, device=device)
294
+ num_blocks = labels.shape[0]
295
+
296
+ # Unpack sign bits and reshape to block layout
297
+ signs_flat = _unpack_signs(sign_packed, original_numel)
298
+ signs_padded, _ = pad_to_blocks(signs_flat, block_size)
299
+ signs_blocks = signs_padded.reshape(num_blocks, block_size)
300
+
301
+ # Compute per-block scale for each costate (vectorized — no Python loop):
302
+ # Null (0): scale = 0
303
+ # Phase (1): scale = block_norm / sqrt(block_size)
304
+ # Amplitude (2): scale = fp16 stored scale
305
+ # Build a (num_blocks,) scale tensor, then broadcast over block_size.
306
+ phase_scales = block_norms / math.sqrt(block_size) # (num_blocks,)
307
+
308
+ # Build per-block scale: null→0, phase→phase_scale, amplitude→stored scale
309
+ # Use label as index: [0_scale, phase_scale, amp_scale] per block
310
+ block_scales = torch.zeros_like(phase_scales)
311
+ mask_phase = labels == 1
312
+ mask_amp = labels == 2
313
+ block_scales[mask_phase] = phase_scales[mask_phase]
314
+ block_scales[mask_amp] = scales[mask_amp]
315
+
316
+ # Broadcast: (num_blocks, 1) * (num_blocks, block_size) → (num_blocks, block_size)
317
+ delta_hat_blocks = (
318
+ block_scales.unsqueeze(1) * signs_blocks
319
+ ) # (num_blocks, block_size)
320
+
321
+ # Trim to original numel and reconstruct m
322
+ delta_hat = delta_hat_blocks.reshape(-1)[:original_numel]
323
+ result = alpha * g_flat + delta_hat
324
+ return result.reshape(g.shape)
325
+
326
+
327
+ # ---------------------------------------------------------------------------
328
+ # CoStateManager — stateful per-step update loop (spec section 4.5)
329
+ # ---------------------------------------------------------------------------
330
+
331
+
332
+ class CoStateManager:
333
+ """Stateful manager for CoState first moment compression.
334
+
335
+ Implements the per-step update procedure from spec section 4.5:
336
+ 1. Load compressed δ̂ and costate bitmap from memory (if prior state exists)
337
+ 2. Reconstruct m̃ = α · g + decompress(δ̂) [skip on first call, use m̃ = 0]
338
+ 3. Compute EMA update: m_new = β₁ · m̃ + (1 - β₁) · g
339
+ 4. Compute new projection: α_new = (m_new · g) / (g · g)
340
+ 5. Compute new residual: δ_new = m_new - α_new · g
341
+ 6. Classify blocks into costates using adaptive thresholds
342
+ 7. Compress and store δ̂_new according to costate assignments
343
+ 8. Store updated costate bitmap and α_new
344
+
345
+ Usage:
346
+ mgr = CoStateManager(block_size=128)
347
+ m = mgr.update(g, beta1=0.9) # call each optimizer step
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ block_size: int = BLOCK_SIZE,
353
+ error_feedback: bool = False,
354
+ null_pct: float = 0.10,
355
+ amp_pct: float = 0.90,
356
+ ) -> None:
357
+ self.block_size = block_size
358
+ self._null_pct = null_pct
359
+ self._amp_pct = amp_pct
360
+ self._has_state: bool = False
361
+ self._alpha = 0.0 # becomes a scalar tensor after first update
362
+ self._encoded: dict | None = None
363
+ self._original_numel: int = 0
364
+ self._error_feedback = error_feedback
365
+ self._ef_residual: torch.Tensor | None = None
366
+
367
+ def update(self, g: torch.Tensor, beta1: float) -> torch.Tensor:
368
+ """Run one step of the CoState update procedure.
369
+
370
+ Args:
371
+ g: Current gradient tensor (any shape).
372
+ beta1: EMA decay for first moment (e.g. 0.9).
373
+
374
+ Returns:
375
+ m_new: Updated first moment tensor, same shape as g.
376
+ """
377
+ # Cast gradient to fp32 — CoState accumulators are fp32, and fp16/bf16
378
+ # gradients would cause dtype mismatches in dot products.
379
+ g = g.float()
380
+
381
+ # Use Triton kernels if available and on CUDA
382
+ try:
383
+ from turboadam.triton_kernels import (
384
+ triton_costate_decode,
385
+ triton_costate_encode,
386
+ triton_decompose_ratios,
387
+ )
388
+
389
+ _use_triton = g.is_cuda
390
+ except ImportError:
391
+ _use_triton = False
392
+
393
+ _decode = triton_costate_decode if _use_triton else decode_blocks
394
+ _encode = triton_costate_encode if _use_triton else encode_blocks
395
+ _decompose_ratios = triton_decompose_ratios if _use_triton else None
396
+
397
+ # Step 1-2: Reconstruct m̃ from compressed prior state (or zeros on first call)
398
+ if self._has_state:
399
+ m_hat = _decode(
400
+ self._encoded,
401
+ self._alpha,
402
+ g,
403
+ self.block_size,
404
+ self._original_numel,
405
+ )
406
+ else:
407
+ m_hat = torch.zeros_like(g, dtype=torch.float32)
408
+
409
+ # Error feedback: compensate for previous step's encoding loss
410
+ if self._error_feedback and self._ef_residual is not None:
411
+ g_corrected = g + self._ef_residual
412
+ else:
413
+ g_corrected = g
414
+
415
+ # Step 3: EMA update
416
+ m_new = beta1 * m_hat + (1.0 - beta1) * g_corrected
417
+
418
+ # Steps 4-6: Decompose + block ratios + classify
419
+ alpha_new, delta_new = decompose(m_new, g)
420
+ ratios = compute_block_ratios(delta_new, m_new, self.block_size)
421
+ tau0, tau1 = compute_thresholds(ratios, self._null_pct, self._amp_pct)
422
+ labels = classify_blocks(ratios, tau0, tau1)
423
+
424
+ # Steps 7-8: Compress and store
425
+ encoded_new = _encode(delta_new, labels, self.block_size)
426
+
427
+ # Error feedback: measure what the encoding lost, accumulate for next step
428
+ if self._error_feedback:
429
+ zero_alpha = g.new_zeros(1)
430
+ delta_hat = _decode(
431
+ encoded_new, zero_alpha, g, self.block_size, m_new.numel()
432
+ )
433
+ ef_error = (delta_new - delta_hat).detach()
434
+ if self._ef_residual is None:
435
+ self._ef_residual = ef_error
436
+ else:
437
+ self._ef_residual = beta1 * self._ef_residual + (1.0 - beta1) * ef_error
438
+
439
+ # Graph-stable buffer management: on first call allocate by cloning,
440
+ # on subsequent calls copy data in-place to keep tensor addresses stable.
441
+ if self._encoded is None:
442
+ # First step: allocate graph-stable buffers by cloning
443
+ self._alpha = (
444
+ alpha_new.clone() if isinstance(alpha_new, torch.Tensor) else alpha_new
445
+ )
446
+ self._encoded = {
447
+ "labels": encoded_new["labels"].clone(),
448
+ "sign_packed": encoded_new["sign_packed"].clone(),
449
+ "block_norms": encoded_new["block_norms"].clone(),
450
+ "scales": encoded_new["scales"].clone(),
451
+ }
452
+ else:
453
+ if isinstance(self._alpha, torch.Tensor):
454
+ self._alpha.copy_(alpha_new)
455
+ else:
456
+ self._alpha = alpha_new
457
+ self._encoded["labels"].copy_(encoded_new["labels"])
458
+ self._encoded["sign_packed"].copy_(encoded_new["sign_packed"])
459
+ self._encoded["block_norms"].copy_(encoded_new["block_norms"])
460
+ self._encoded["scales"].copy_(encoded_new["scales"])
461
+ self._original_numel = m_new.numel()
462
+ self._has_state = True
463
+
464
+ return m_new
turboadam/oneq.py ADDED
@@ -0,0 +1,77 @@
1
+ """1Q — second moment (v) compression.
2
+
3
+ N-bit log-scale quantization for all parameters.
4
+ Compress-every-step architecture.
5
+ """
6
+
7
+ import torch
8
+
9
+ from turboadam.utils import pad_to_blocks, BLOCK_SIZE
10
+ from turboadam.quantize import quantize_logscale_nbits, dequantize_logscale_nbits
11
+
12
+
13
+ def compress_v_logscale(
14
+ v: torch.Tensor,
15
+ n_bits: int = 3,
16
+ block_size: int = BLOCK_SIZE,
17
+ stochastic_round: bool = False,
18
+ ) -> dict:
19
+ """Compress a second-moment tensor with n-bit log-scale quantization.
20
+
21
+ Parameters
22
+ ----------
23
+ v:
24
+ Second-moment tensor of any shape. Values are converted to fp32 and
25
+ padded to complete quantization blocks before encoding.
26
+ n_bits:
27
+ Number of bits per element, which determines the number of log-scale
28
+ buckets in each block.
29
+ block_size:
30
+ Number of elements per independent quantization block.
31
+ stochastic_round:
32
+ Whether to use stochastic rounding when assigning bucket indices.
33
+
34
+ Returns
35
+ -------
36
+ dict
37
+ Compressed representation containing quantized indices, per-block
38
+ scales, bit width, original shape, original length, and block size.
39
+ """
40
+ original_shape = v.shape
41
+ v_flat = v.reshape(-1).float()
42
+ v_min = v_flat.min().item()
43
+ pad_value = max(v_min, 1e-38)
44
+ v_padded, original_length = pad_to_blocks(v_flat, block_size, pad_value=pad_value)
45
+ indices, scales, nb = quantize_logscale_nbits(
46
+ v_padded,
47
+ n_bits=n_bits,
48
+ block_size=block_size,
49
+ stochastic_round=stochastic_round,
50
+ )
51
+ return {
52
+ "indices": indices,
53
+ "scales": scales,
54
+ "n_bits": n_bits,
55
+ "original_shape": original_shape,
56
+ "original_length": original_length,
57
+ "block_size": block_size,
58
+ }
59
+
60
+
61
+ def decompress_v(compressed: dict) -> torch.Tensor:
62
+ """Reconstruct fp32 v from a compressed representation.
63
+
64
+ Args:
65
+ compressed: Dict produced by compress_v_logscale.
66
+
67
+ Returns:
68
+ fp32 tensor with the same shape as the original v.
69
+ """
70
+ v_flat = dequantize_logscale_nbits(
71
+ compressed["indices"],
72
+ compressed["scales"],
73
+ n_bits=compressed["n_bits"],
74
+ block_size=compressed["block_size"],
75
+ original_numel=compressed["original_length"],
76
+ )
77
+ return v_flat.reshape(compressed["original_shape"])