VeloxQuant-MLX 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. mlx_kv_quant/__init__.py +41 -0
  2. mlx_kv_quant/__main__.py +28 -0
  3. mlx_kv_quant/artifacts/__init__.py +6 -0
  4. mlx_kv_quant/artifacts/base.py +5 -0
  5. mlx_kv_quant/artifacts/memory_store.py +93 -0
  6. mlx_kv_quant/artifacts/npy_store.py +114 -0
  7. mlx_kv_quant/benchmarks/__init__.py +0 -0
  8. mlx_kv_quant/benchmarks/attend_benchmark.py +184 -0
  9. mlx_kv_quant/cache/__init__.py +17 -0
  10. mlx_kv_quant/cache/base.py +273 -0
  11. mlx_kv_quant/cache/polar_cache.py +134 -0
  12. mlx_kv_quant/cache/qjl_cache.py +118 -0
  13. mlx_kv_quant/cache/sliding_window_cache.py +113 -0
  14. mlx_kv_quant/cache/turboquant_cache.py +312 -0
  15. mlx_kv_quant/cli/__init__.py +1 -0
  16. mlx_kv_quant/cli/benchmark.py +89 -0
  17. mlx_kv_quant/cli/precompute.py +31 -0
  18. mlx_kv_quant/codebooks/__init__.py +19 -0
  19. mlx_kv_quant/codebooks/base.py +79 -0
  20. mlx_kv_quant/codebooks/precompute.py +120 -0
  21. mlx_kv_quant/codebooks/scalar_codebook.py +107 -0
  22. mlx_kv_quant/codebooks/strategies.py +145 -0
  23. mlx_kv_quant/core/__init__.py +71 -0
  24. mlx_kv_quant/core/abstractions.py +369 -0
  25. mlx_kv_quant/core/constants.py +42 -0
  26. mlx_kv_quant/core/context.py +164 -0
  27. mlx_kv_quant/core/exceptions.py +17 -0
  28. mlx_kv_quant/core/registry.py +89 -0
  29. mlx_kv_quant/dsa/__init__.py +18 -0
  30. mlx_kv_quant/dsa/avl_tree.py +290 -0
  31. mlx_kv_quant/dsa/bit_pack.py +210 -0
  32. mlx_kv_quant/dsa/dag.py +169 -0
  33. mlx_kv_quant/dsa/heap.py +194 -0
  34. mlx_kv_quant/dsa/ring_buffer.py +136 -0
  35. mlx_kv_quant/handlers/__init__.py +21 -0
  36. mlx_kv_quant/handlers/base.py +6 -0
  37. mlx_kv_quant/handlers/bit_pack_handler.py +62 -0
  38. mlx_kv_quant/handlers/normalization.py +52 -0
  39. mlx_kv_quant/handlers/outlier_split.py +71 -0
  40. mlx_kv_quant/handlers/polar_handler.py +62 -0
  41. mlx_kv_quant/handlers/qjl_residual_handler.py +63 -0
  42. mlx_kv_quant/handlers/rotation_handler.py +43 -0
  43. mlx_kv_quant/handlers/scalar_quant_handler.py +47 -0
  44. mlx_kv_quant/handlers/value_quant_handler.py +54 -0
  45. mlx_kv_quant/integration/__init__.py +5 -0
  46. mlx_kv_quant/integration/mlx_lm_patch.py +83 -0
  47. mlx_kv_quant/math/__init__.py +14 -0
  48. mlx_kv_quant/math/distributions.py +111 -0
  49. mlx_kv_quant/math/lloyd_max.py +122 -0
  50. mlx_kv_quant/math/rotation.py +103 -0
  51. mlx_kv_quant/observers/__init__.py +14 -0
  52. mlx_kv_quant/observers/base.py +32 -0
  53. mlx_kv_quant/observers/distortion.py +179 -0
  54. mlx_kv_quant/observers/latency.py +55 -0
  55. mlx_kv_quant/observers/memory.py +46 -0
  56. mlx_kv_quant/outlier/__init__.py +5 -0
  57. mlx_kv_quant/outlier/detector.py +89 -0
  58. mlx_kv_quant/preconditioners/__init__.py +12 -0
  59. mlx_kv_quant/preconditioners/base.py +67 -0
  60. mlx_kv_quant/preconditioners/jl_sketch.py +131 -0
  61. mlx_kv_quant/preconditioners/rotation.py +117 -0
  62. mlx_kv_quant/quantizers/__init__.py +17 -0
  63. mlx_kv_quant/quantizers/base.py +66 -0
  64. mlx_kv_quant/quantizers/composite.py +111 -0
  65. mlx_kv_quant/quantizers/polarquant.py +163 -0
  66. mlx_kv_quant/quantizers/qjl.py +111 -0
  67. mlx_kv_quant/quantizers/turboquant_mse.py +139 -0
  68. mlx_kv_quant/quantizers/turboquant_prod.py +214 -0
  69. mlx_kv_quant/tests/__init__.py +1 -0
  70. mlx_kv_quant/tests/cache/__init__.py +1 -0
  71. mlx_kv_quant/tests/cache/test_sliding_window.py +63 -0
  72. mlx_kv_quant/tests/cache/test_turboquant_cache.py +268 -0
  73. mlx_kv_quant/tests/conftest.py +61 -0
  74. mlx_kv_quant/tests/dsa/__init__.py +1 -0
  75. mlx_kv_quant/tests/dsa/test_avl_tree.py +116 -0
  76. mlx_kv_quant/tests/dsa/test_bit_pack.py +74 -0
  77. mlx_kv_quant/tests/dsa/test_dag.py +106 -0
  78. mlx_kv_quant/tests/dsa/test_heap.py +90 -0
  79. mlx_kv_quant/tests/dsa/test_ring_buffer.py +95 -0
  80. mlx_kv_quant/tests/handlers/__init__.py +1 -0
  81. mlx_kv_quant/tests/handlers/test_pipeline.py +119 -0
  82. mlx_kv_quant/tests/integration/__init__.py +1 -0
  83. mlx_kv_quant/tests/integration/test_distortion_bounds.py +111 -0
  84. mlx_kv_quant/tests/math/__init__.py +1 -0
  85. mlx_kv_quant/tests/math/test_distributions.py +81 -0
  86. mlx_kv_quant/tests/math/test_lloyd_max.py +97 -0
  87. mlx_kv_quant/tests/quantizers/__init__.py +1 -0
  88. mlx_kv_quant/tests/quantizers/test_polar.py +76 -0
  89. mlx_kv_quant/tests/quantizers/test_qjl.py +69 -0
  90. mlx_kv_quant/tests/quantizers/test_turboquant_mse.py +63 -0
  91. mlx_kv_quant/tests/quantizers/test_turboquant_prod.py +88 -0
  92. mlx_kv_quant/transforms/__init__.py +5 -0
  93. mlx_kv_quant/transforms/base.py +5 -0
  94. mlx_kv_quant/transforms/polar.py +144 -0
  95. mlx_kv_quant/weight/__init__.py +4 -0
  96. mlx_kv_quant/weight/model_quantizer.py +170 -0
  97. mlx_kv_quant/weight/quantized_linear.py +152 -0
  98. veloxquant_mlx-0.2.0.dist-info/METADATA +448 -0
  99. veloxquant_mlx-0.2.0.dist-info/RECORD +103 -0
  100. veloxquant_mlx-0.2.0.dist-info/WHEEL +5 -0
  101. veloxquant_mlx-0.2.0.dist-info/entry_points.txt +3 -0
  102. veloxquant_mlx-0.2.0.dist-info/licenses/LICENSE +21 -0
  103. veloxquant_mlx-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,41 @@
1
+ """mlx_kv_quant — KV cache quantization for Apple Silicon MLX.
2
+
3
+ Implements TurboQuant, PolarQuant, and QJL for production LLM inference.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ from mlx_kv_quant.cache.base import KVCacheBuilder, KVCacheConfig, KVCacheFactory
8
+ from mlx_kv_quant.core.abstractions import (
9
+ ArtifactStore,
10
+ KVCache,
11
+ Quantizer,
12
+ QuantizationObserver,
13
+ )
14
+ from mlx_kv_quant.core.context import EncodedVector, QuantizationContext, TransformResult
15
+ from mlx_kv_quant.core.exceptions import (
16
+ ArtifactNotFoundError,
17
+ CodebookDimensionMismatch,
18
+ CyclicPipelineError,
19
+ QuantizerConfigError,
20
+ )
21
+ from mlx_kv_quant.quantizers.base import QuantizerFactory
22
+
23
+ __all__ = [
24
+ "KVCacheBuilder",
25
+ "KVCacheConfig",
26
+ "KVCacheFactory",
27
+ "ArtifactStore",
28
+ "KVCache",
29
+ "Quantizer",
30
+ "QuantizationObserver",
31
+ "EncodedVector",
32
+ "QuantizationContext",
33
+ "TransformResult",
34
+ "ArtifactNotFoundError",
35
+ "CodebookDimensionMismatch",
36
+ "CyclicPipelineError",
37
+ "QuantizerConfigError",
38
+ "QuantizerFactory",
39
+ ]
40
+
41
+ __version__ = "0.2.0"
@@ -0,0 +1,28 @@
1
+ """Entry point for `python -m mlx_kv_quant <command>`."""
2
+ from __future__ import annotations
3
+
4
+ import sys
5
+
6
+
7
+ def main() -> None:
8
+ if len(sys.argv) < 2:
9
+ print("Usage: veloxquant {precompute|benchmark}")
10
+ sys.exit(1)
11
+
12
+ command = sys.argv[1]
13
+ # Remove the subcommand so sub-parsers see argv correctly
14
+ sys.argv = [f"mlx_kv_quant {command}"] + sys.argv[2:]
15
+
16
+ if command == "precompute":
17
+ from mlx_kv_quant.cli.precompute import main as _main
18
+ _main()
19
+ elif command == "benchmark":
20
+ from mlx_kv_quant.cli.benchmark import main as _main
21
+ _main()
22
+ else:
23
+ print(f"Unknown command: {command!r}. Choices: precompute, benchmark")
24
+ sys.exit(1)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
@@ -0,0 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+ from mlx_kv_quant.artifacts.memory_store import InMemoryArtifactStore
4
+ from mlx_kv_quant.artifacts.npy_store import NpyArtifactStore
5
+
6
+ __all__ = ["InMemoryArtifactStore", "NpyArtifactStore"]
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from mlx_kv_quant.core.abstractions import ArtifactStore
4
+
5
+ __all__ = ["ArtifactStore"]
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from mlx_kv_quant.core.abstractions import ArtifactStore
8
+ from mlx_kv_quant.core.exceptions import ArtifactNotFoundError
9
+
10
+
11
+ class InMemoryArtifactStore(ArtifactStore):
12
+ """In-memory artifact store for testing — performs no disk I/O.
13
+
14
+ All artifacts are stored in plain Python dicts keyed by descriptor tuples.
15
+ Arrays are stored as float16 numpy arrays and wrapped in MLX on load.
16
+
17
+ Args:
18
+ None.
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ self._rotations: Dict[Tuple, np.ndarray] = {}
23
+ self._codebooks: Dict[Tuple, np.ndarray] = {}
24
+ self._jls: Dict[Tuple, np.ndarray] = {}
25
+
26
+ # ------------------------------------------------------------------
27
+ # Rotation matrix
28
+ # ------------------------------------------------------------------
29
+
30
+ def load_rotation_matrix(self, d: int, seed: int) -> Any:
31
+ key = (d, seed)
32
+ if key not in self._rotations:
33
+ raise ArtifactNotFoundError(
34
+ f"InMemoryArtifactStore: rotation d={d} seed={seed} not found."
35
+ )
36
+ import mlx.core as mx
37
+ return mx.array(self._rotations[key])
38
+
39
+ def save_rotation_matrix(self, Pi: Any, d: int, seed: int) -> None:
40
+ self._rotations[(d, seed)] = np.array(Pi, dtype=np.float16)
41
+
42
+ # ------------------------------------------------------------------
43
+ # Codebook
44
+ # ------------------------------------------------------------------
45
+
46
+ def load_codebook(self, distribution: str, b: int, d: int) -> Any:
47
+ key = (distribution, b, d)
48
+ if key not in self._codebooks:
49
+ raise ArtifactNotFoundError(
50
+ f"InMemoryArtifactStore: codebook dist={distribution} b={b} d={d} not found."
51
+ )
52
+ import mlx.core as mx
53
+ return mx.array(self._codebooks[key])
54
+
55
+ def save_codebook(self, cb: Any, distribution: str, b: int, d: int) -> None:
56
+ self._codebooks[(distribution, b, d)] = np.array(cb, dtype=np.float16)
57
+
58
+ # ------------------------------------------------------------------
59
+ # JL matrix
60
+ # ------------------------------------------------------------------
61
+
62
+ def load_jl_matrix(self, d: int, m: int, seed: int) -> Any:
63
+ key = (d, m, seed)
64
+ if key not in self._jls:
65
+ raise ArtifactNotFoundError(
66
+ f"InMemoryArtifactStore: JL d={d} m={m} seed={seed} not found."
67
+ )
68
+ import mlx.core as mx
69
+ return mx.array(self._jls[key])
70
+
71
+ def save_jl_matrix(self, S: Any, d: int, m: int, seed: int) -> None:
72
+ self._jls[(d, m, seed)] = np.array(S, dtype=np.float16)
73
+
74
+ # ------------------------------------------------------------------
75
+ # Existence check
76
+ # ------------------------------------------------------------------
77
+
78
+ def exists(self, artifact_type: str, **kwargs: Any) -> bool:
79
+ if artifact_type == "rotation":
80
+ return (kwargs["d"], kwargs["seed"]) in self._rotations
81
+ if artifact_type == "codebook":
82
+ return (kwargs["distribution"], kwargs["b"], kwargs["d"]) in self._codebooks
83
+ if artifact_type == "jl":
84
+ return (kwargs["d"], kwargs["m"], kwargs["seed"]) in self._jls
85
+ return False
86
+
87
+ def __repr__(self) -> str:
88
+ return (
89
+ f"InMemoryArtifactStore("
90
+ f"rotations={len(self._rotations)}, "
91
+ f"codebooks={len(self._codebooks)}, "
92
+ f"jls={len(self._jls)})"
93
+ )
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+ from mlx_kv_quant.core.abstractions import ArtifactStore
10
+ from mlx_kv_quant.core.exceptions import ArtifactNotFoundError
11
+
12
+
13
+ class NpyArtifactStore(ArtifactStore):
14
+ """Artifact store that reads and writes ``.npy`` files from a local directory.
15
+
16
+ File naming conventions:
17
+ rotation_d{d}_seed{seed}.npy
18
+ codebook_{distribution}_b{b}_d{d}.npy
19
+ jl_d{d}_m{m}_seed{seed}.npy
20
+
21
+ Args:
22
+ root_dir: Path to the directory where artifacts are stored.
23
+ Created automatically on first save if absent.
24
+ """
25
+
26
+ def __init__(self, root_dir: str | Path) -> None:
27
+ self._root = Path(root_dir)
28
+ self._root.mkdir(parents=True, exist_ok=True)
29
+
30
+ # ------------------------------------------------------------------
31
+ # Rotation matrix
32
+ # ------------------------------------------------------------------
33
+
34
+ def _rotation_path(self, d: int, seed: int) -> Path:
35
+ return self._root / f"rotation_d{d}_seed{seed}.npy"
36
+
37
+ def load_rotation_matrix(self, d: int, seed: int) -> Any:
38
+ path = self._rotation_path(d, seed)
39
+ if not path.exists():
40
+ raise ArtifactNotFoundError(
41
+ f"Rotation matrix not found at {path}. "
42
+ f"Run `python -m mlx_kv_quant precompute --head_dim {d}` first."
43
+ )
44
+ import mlx.core as mx
45
+ return mx.array(np.load(path).astype(np.float16))
46
+
47
+ def save_rotation_matrix(self, Pi: Any, d: int, seed: int) -> None:
48
+ path = self._rotation_path(d, seed)
49
+ arr = np.array(Pi, dtype=np.float16)
50
+ np.save(path, arr)
51
+
52
+ # ------------------------------------------------------------------
53
+ # Codebook
54
+ # ------------------------------------------------------------------
55
+
56
+ def _codebook_path(self, distribution: str, b: int, d: int) -> Path:
57
+ return self._root / f"codebook_{distribution}_b{b}_d{d}.npy"
58
+
59
+ def load_codebook(self, distribution: str, b: int, d: int) -> Any:
60
+ path = self._codebook_path(distribution, b, d)
61
+ if not path.exists():
62
+ raise ArtifactNotFoundError(
63
+ f"Codebook not found at {path}. "
64
+ f"Run `python -m mlx_kv_quant precompute --head_dim {d} --bits {b}` first."
65
+ )
66
+ import mlx.core as mx
67
+ return mx.array(np.load(path).astype(np.float16))
68
+
69
+ def save_codebook(self, cb: Any, distribution: str, b: int, d: int) -> None:
70
+ path = self._codebook_path(distribution, b, d)
71
+ arr = np.array(cb, dtype=np.float16)
72
+ np.save(path, arr)
73
+
74
+ # ------------------------------------------------------------------
75
+ # JL matrix
76
+ # ------------------------------------------------------------------
77
+
78
+ def _jl_path(self, d: int, m: int, seed: int) -> Path:
79
+ return self._root / f"jl_d{d}_m{m}_seed{seed}.npy"
80
+
81
+ def load_jl_matrix(self, d: int, m: int, seed: int) -> Any:
82
+ path = self._jl_path(d, m, seed)
83
+ if not path.exists():
84
+ raise ArtifactNotFoundError(
85
+ f"JL matrix not found at {path}. "
86
+ f"Run `python -m mlx_kv_quant precompute --head_dim {d} --jl_dim {m}` first."
87
+ )
88
+ import mlx.core as mx
89
+ return mx.array(np.load(path).astype(np.float16))
90
+
91
+ def save_jl_matrix(self, S: Any, d: int, m: int, seed: int) -> None:
92
+ path = self._jl_path(d, m, seed)
93
+ arr = np.array(S, dtype=np.float16)
94
+ np.save(path, arr)
95
+
96
+ # ------------------------------------------------------------------
97
+ # Existence check
98
+ # ------------------------------------------------------------------
99
+
100
+ def exists(self, artifact_type: str, **kwargs: Any) -> bool:
101
+ if artifact_type == "rotation":
102
+ return self._rotation_path(kwargs["d"], kwargs["seed"]).exists()
103
+ if artifact_type == "codebook":
104
+ return self._codebook_path(
105
+ kwargs["distribution"], kwargs["b"], kwargs["d"]
106
+ ).exists()
107
+ if artifact_type == "jl":
108
+ return self._jl_path(
109
+ kwargs["d"], kwargs["m"], kwargs["seed"]
110
+ ).exists()
111
+ return False
112
+
113
+ def __repr__(self) -> str:
114
+ return f"NpyArtifactStore(root={self._root!r})"
File without changes
@@ -0,0 +1,184 @@
1
+ """Attend latency and memory benchmark across sequence lengths.
2
+
3
+ Compares four configurations:
4
+ baseline — no optimizations
5
+ vectorized — enable_vectorized_attend=True
6
+ fused — enable_vectorized_attend=True + enable_fused_query_dot=True
7
+ all — fused + enable_outlier_two_stream=True
8
+
9
+ Usage::
10
+
11
+ python -m mlx_kv_quant.benchmarks.attend_benchmark
12
+ python -m mlx_kv_quant.benchmarks.attend_benchmark --method turboquant_mse --bits 2
13
+ python -m mlx_kv_quant.benchmarks.attend_benchmark --seq_lens 64 256 1024 4096
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import time
19
+ from typing import List
20
+
21
+ import numpy as np
22
+
23
+
24
+ _SEQ_LENS: List[int] = [128, 512, 1000, 2048]
25
+ _N_ATTEND_CALLS = 20 # per measurement
26
+ _N_OUTLIER_CHANNELS = 4
27
+ _N_CALIB_TOKENS = 50 # short so calibration completes inside every seq_len
28
+
29
+
30
+ def _build(
31
+ method: str,
32
+ d: int,
33
+ bits: int,
34
+ jl_dim: int,
35
+ seed: int,
36
+ *,
37
+ vectorized: bool = False,
38
+ fused: bool = False,
39
+ outlier: bool = False,
40
+ ):
41
+ from mlx_kv_quant.cache.base import KVCacheBuilder
42
+
43
+ return (
44
+ KVCacheBuilder()
45
+ .with_method(method)
46
+ .with_head_dim(d)
47
+ .with_bit_width(inlier=bits)
48
+ .with_jl_dim(jl_dim)
49
+ .with_seed(seed)
50
+ .with_vectorized_attend(vectorized)
51
+ .with_fused_query_dot(fused)
52
+ .with_outlier_two_stream(outlier)
53
+ .with_n_outlier_channels(_N_OUTLIER_CHANNELS)
54
+ .with_n_calib_tokens(_N_CALIB_TOKENS)
55
+ .build()
56
+ )
57
+
58
+
59
+ def _fill(cache, keys, vals) -> None:
60
+ for i in range(len(keys)):
61
+ cache.append(keys[i], vals[i])
62
+
63
+
64
+ def _measure_attend_ms(cache, q, n_calls: int) -> float:
65
+ import mlx.core as mx
66
+
67
+ # Warm-up
68
+ mx.eval(cache.attend(q))
69
+
70
+ t0 = time.perf_counter()
71
+ for _ in range(n_calls):
72
+ mx.eval(cache.attend(q))
73
+ return (time.perf_counter() - t0) * 1_000.0 / n_calls
74
+
75
+
76
+ def _correctness_check(cache_a, cache_b, q, label_a: str, label_b: str) -> None:
77
+ """Assert two caches produce numerically close attend outputs."""
78
+ import mlx.core as mx
79
+
80
+ out_a = np.array(cache_a.attend(q))
81
+ out_b = np.array(cache_b.attend(q))
82
+ mx.eval()
83
+ try:
84
+ np.testing.assert_allclose(out_a, out_b, rtol=5e-3, atol=5e-3)
85
+ print(f" Correctness {label_a} vs {label_b}: OK (max_diff={np.max(np.abs(out_a-out_b)):.5f})")
86
+ except AssertionError as e:
87
+ print(f" Correctness {label_a} vs {label_b}: FAIL — {e}")
88
+
89
+
90
+ def run(
91
+ method: str,
92
+ d: int,
93
+ bits: int,
94
+ jl_dim: int,
95
+ seed: int,
96
+ seq_lens: List[int],
97
+ n_calls: int,
98
+ correctness: bool,
99
+ ) -> None:
100
+ import mlx.core as mx
101
+
102
+ rng = np.random.default_rng(seed)
103
+
104
+ configs = {
105
+ "baseline": dict(vectorized=False, fused=False, outlier=False),
106
+ "vectorized": dict(vectorized=True, fused=False, outlier=False),
107
+ "fused": dict(vectorized=True, fused=True, outlier=False),
108
+ "all_opts": dict(vectorized=True, fused=True, outlier=True),
109
+ }
110
+ # turboquant_mse doesn't have a fused path; skip 'fused'/'all_opts' for it.
111
+ if method == "turboquant_mse":
112
+ configs = {
113
+ "baseline": dict(vectorized=False, fused=False, outlier=False),
114
+ "vectorized": dict(vectorized=True, fused=False, outlier=False),
115
+ "all_opts": dict(vectorized=True, fused=False, outlier=True),
116
+ }
117
+
118
+ col_w = 14
119
+ header = f"{'seq_len':>8} " + " ".join(f"{k:>{col_w}}" for k in configs)
120
+ print(f"\n=== attend latency (ms/call) — method={method}, d={d}, bits={bits} ===")
121
+ print(header)
122
+ print("-" * len(header))
123
+
124
+ for seq_len in seq_lens:
125
+ keys = [mx.array(rng.standard_normal(d).astype(np.float16)) for _ in range(seq_len)]
126
+ vals = [mx.array(rng.standard_normal(d).astype(np.float16)) for _ in range(seq_len)]
127
+ q = mx.array(rng.standard_normal(d).astype(np.float16))
128
+
129
+ caches = {}
130
+ latencies = {}
131
+ for name, flags in configs.items():
132
+ c = _build(method, d, bits, jl_dim, seed, **flags)
133
+ _fill(c, keys, vals)
134
+ caches[name] = c
135
+ latencies[name] = _measure_attend_ms(c, q, n_calls)
136
+
137
+ row = f"{seq_len:>8} " + " ".join(f"{latencies[k]:>{col_w}.3f}" for k in configs)
138
+ print(row)
139
+
140
+ # Speedup summary
141
+ base_ms = latencies["baseline"]
142
+ for name, ms in latencies.items():
143
+ if name != "baseline":
144
+ print(f" {name:>12}: {base_ms/max(ms, 1e-9):.2f}× speedup vs baseline")
145
+
146
+ # Optional correctness check
147
+ if correctness and len(caches) >= 2:
148
+ names = list(caches.keys())
149
+ _correctness_check(caches[names[0]], caches[names[1]], q, names[0], names[1])
150
+
151
+ # Memory footprint
152
+ print(f" memory (bytes): "
153
+ + ", ".join(f"{n}={c.memory_bytes()}" for n, c in caches.items()))
154
+ print()
155
+
156
+
157
+ def main() -> None:
158
+ parser = argparse.ArgumentParser(description="TurboQuant attend latency sweep")
159
+ parser.add_argument("--method", default="turboquant_prod",
160
+ choices=["turboquant_prod", "turboquant_mse"])
161
+ parser.add_argument("--head_dim", type=int, default=128)
162
+ parser.add_argument("--bits", type=int, default=3)
163
+ parser.add_argument("--jl_dim", type=int, default=128)
164
+ parser.add_argument("--seed", type=int, default=42)
165
+ parser.add_argument("--seq_lens", type=int, nargs="*", default=_SEQ_LENS)
166
+ parser.add_argument("--n_calls", type=int, default=_N_ATTEND_CALLS)
167
+ parser.add_argument("--correctness", action="store_true",
168
+ help="Run cross-config correctness checks at each seq_len")
169
+ args = parser.parse_args()
170
+
171
+ run(
172
+ method=args.method,
173
+ d=args.head_dim,
174
+ bits=args.bits,
175
+ jl_dim=args.jl_dim,
176
+ seed=args.seed,
177
+ seq_lens=args.seq_lens,
178
+ n_calls=args.n_calls,
179
+ correctness=args.correctness,
180
+ )
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from mlx_kv_quant.cache.base import KVCacheBuilder, KVCacheConfig, KVCacheFactory
4
+ from mlx_kv_quant.cache.polar_cache import PolarQuantKVCache
5
+ from mlx_kv_quant.cache.qjl_cache import QJLKVCache
6
+ from mlx_kv_quant.cache.sliding_window_cache import SlidingWindowKVCache
7
+ from mlx_kv_quant.cache.turboquant_cache import TurboQuantKVCache
8
+
9
+ __all__ = [
10
+ "KVCacheBuilder",
11
+ "KVCacheConfig",
12
+ "KVCacheFactory",
13
+ "PolarQuantKVCache",
14
+ "QJLKVCache",
15
+ "SlidingWindowKVCache",
16
+ "TurboQuantKVCache",
17
+ ]