VeloxQuant-MLX 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_kv_quant/__init__.py +41 -0
- mlx_kv_quant/__main__.py +28 -0
- mlx_kv_quant/artifacts/__init__.py +6 -0
- mlx_kv_quant/artifacts/base.py +5 -0
- mlx_kv_quant/artifacts/memory_store.py +93 -0
- mlx_kv_quant/artifacts/npy_store.py +114 -0
- mlx_kv_quant/benchmarks/__init__.py +0 -0
- mlx_kv_quant/benchmarks/attend_benchmark.py +184 -0
- mlx_kv_quant/cache/__init__.py +17 -0
- mlx_kv_quant/cache/base.py +273 -0
- mlx_kv_quant/cache/polar_cache.py +134 -0
- mlx_kv_quant/cache/qjl_cache.py +118 -0
- mlx_kv_quant/cache/sliding_window_cache.py +113 -0
- mlx_kv_quant/cache/turboquant_cache.py +312 -0
- mlx_kv_quant/cli/__init__.py +1 -0
- mlx_kv_quant/cli/benchmark.py +89 -0
- mlx_kv_quant/cli/precompute.py +31 -0
- mlx_kv_quant/codebooks/__init__.py +19 -0
- mlx_kv_quant/codebooks/base.py +79 -0
- mlx_kv_quant/codebooks/precompute.py +120 -0
- mlx_kv_quant/codebooks/scalar_codebook.py +107 -0
- mlx_kv_quant/codebooks/strategies.py +145 -0
- mlx_kv_quant/core/__init__.py +71 -0
- mlx_kv_quant/core/abstractions.py +369 -0
- mlx_kv_quant/core/constants.py +42 -0
- mlx_kv_quant/core/context.py +164 -0
- mlx_kv_quant/core/exceptions.py +17 -0
- mlx_kv_quant/core/registry.py +89 -0
- mlx_kv_quant/dsa/__init__.py +18 -0
- mlx_kv_quant/dsa/avl_tree.py +290 -0
- mlx_kv_quant/dsa/bit_pack.py +210 -0
- mlx_kv_quant/dsa/dag.py +169 -0
- mlx_kv_quant/dsa/heap.py +194 -0
- mlx_kv_quant/dsa/ring_buffer.py +136 -0
- mlx_kv_quant/handlers/__init__.py +21 -0
- mlx_kv_quant/handlers/base.py +6 -0
- mlx_kv_quant/handlers/bit_pack_handler.py +62 -0
- mlx_kv_quant/handlers/normalization.py +52 -0
- mlx_kv_quant/handlers/outlier_split.py +71 -0
- mlx_kv_quant/handlers/polar_handler.py +62 -0
- mlx_kv_quant/handlers/qjl_residual_handler.py +63 -0
- mlx_kv_quant/handlers/rotation_handler.py +43 -0
- mlx_kv_quant/handlers/scalar_quant_handler.py +47 -0
- mlx_kv_quant/handlers/value_quant_handler.py +54 -0
- mlx_kv_quant/integration/__init__.py +5 -0
- mlx_kv_quant/integration/mlx_lm_patch.py +83 -0
- mlx_kv_quant/math/__init__.py +14 -0
- mlx_kv_quant/math/distributions.py +111 -0
- mlx_kv_quant/math/lloyd_max.py +122 -0
- mlx_kv_quant/math/rotation.py +103 -0
- mlx_kv_quant/observers/__init__.py +14 -0
- mlx_kv_quant/observers/base.py +32 -0
- mlx_kv_quant/observers/distortion.py +179 -0
- mlx_kv_quant/observers/latency.py +55 -0
- mlx_kv_quant/observers/memory.py +46 -0
- mlx_kv_quant/outlier/__init__.py +5 -0
- mlx_kv_quant/outlier/detector.py +89 -0
- mlx_kv_quant/preconditioners/__init__.py +12 -0
- mlx_kv_quant/preconditioners/base.py +67 -0
- mlx_kv_quant/preconditioners/jl_sketch.py +131 -0
- mlx_kv_quant/preconditioners/rotation.py +117 -0
- mlx_kv_quant/quantizers/__init__.py +17 -0
- mlx_kv_quant/quantizers/base.py +66 -0
- mlx_kv_quant/quantizers/composite.py +111 -0
- mlx_kv_quant/quantizers/polarquant.py +163 -0
- mlx_kv_quant/quantizers/qjl.py +111 -0
- mlx_kv_quant/quantizers/turboquant_mse.py +139 -0
- mlx_kv_quant/quantizers/turboquant_prod.py +214 -0
- mlx_kv_quant/tests/__init__.py +1 -0
- mlx_kv_quant/tests/cache/__init__.py +1 -0
- mlx_kv_quant/tests/cache/test_sliding_window.py +63 -0
- mlx_kv_quant/tests/cache/test_turboquant_cache.py +268 -0
- mlx_kv_quant/tests/conftest.py +61 -0
- mlx_kv_quant/tests/dsa/__init__.py +1 -0
- mlx_kv_quant/tests/dsa/test_avl_tree.py +116 -0
- mlx_kv_quant/tests/dsa/test_bit_pack.py +74 -0
- mlx_kv_quant/tests/dsa/test_dag.py +106 -0
- mlx_kv_quant/tests/dsa/test_heap.py +90 -0
- mlx_kv_quant/tests/dsa/test_ring_buffer.py +95 -0
- mlx_kv_quant/tests/handlers/__init__.py +1 -0
- mlx_kv_quant/tests/handlers/test_pipeline.py +119 -0
- mlx_kv_quant/tests/integration/__init__.py +1 -0
- mlx_kv_quant/tests/integration/test_distortion_bounds.py +111 -0
- mlx_kv_quant/tests/math/__init__.py +1 -0
- mlx_kv_quant/tests/math/test_distributions.py +81 -0
- mlx_kv_quant/tests/math/test_lloyd_max.py +97 -0
- mlx_kv_quant/tests/quantizers/__init__.py +1 -0
- mlx_kv_quant/tests/quantizers/test_polar.py +76 -0
- mlx_kv_quant/tests/quantizers/test_qjl.py +69 -0
- mlx_kv_quant/tests/quantizers/test_turboquant_mse.py +63 -0
- mlx_kv_quant/tests/quantizers/test_turboquant_prod.py +88 -0
- mlx_kv_quant/transforms/__init__.py +5 -0
- mlx_kv_quant/transforms/base.py +5 -0
- mlx_kv_quant/transforms/polar.py +144 -0
- mlx_kv_quant/weight/__init__.py +4 -0
- mlx_kv_quant/weight/model_quantizer.py +170 -0
- mlx_kv_quant/weight/quantized_linear.py +152 -0
- veloxquant_mlx-0.2.0.dist-info/METADATA +448 -0
- veloxquant_mlx-0.2.0.dist-info/RECORD +103 -0
- veloxquant_mlx-0.2.0.dist-info/WHEEL +5 -0
- veloxquant_mlx-0.2.0.dist-info/entry_points.txt +3 -0
- veloxquant_mlx-0.2.0.dist-info/licenses/LICENSE +21 -0
- veloxquant_mlx-0.2.0.dist-info/top_level.txt +1 -0
mlx_kv_quant/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""mlx_kv_quant — KV cache quantization for Apple Silicon MLX.
|
|
2
|
+
|
|
3
|
+
Implements TurboQuant, PolarQuant, and QJL for production LLM inference.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from mlx_kv_quant.cache.base import KVCacheBuilder, KVCacheConfig, KVCacheFactory
|
|
8
|
+
from mlx_kv_quant.core.abstractions import (
|
|
9
|
+
ArtifactStore,
|
|
10
|
+
KVCache,
|
|
11
|
+
Quantizer,
|
|
12
|
+
QuantizationObserver,
|
|
13
|
+
)
|
|
14
|
+
from mlx_kv_quant.core.context import EncodedVector, QuantizationContext, TransformResult
|
|
15
|
+
from mlx_kv_quant.core.exceptions import (
|
|
16
|
+
ArtifactNotFoundError,
|
|
17
|
+
CodebookDimensionMismatch,
|
|
18
|
+
CyclicPipelineError,
|
|
19
|
+
QuantizerConfigError,
|
|
20
|
+
)
|
|
21
|
+
from mlx_kv_quant.quantizers.base import QuantizerFactory
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"KVCacheBuilder",
|
|
25
|
+
"KVCacheConfig",
|
|
26
|
+
"KVCacheFactory",
|
|
27
|
+
"ArtifactStore",
|
|
28
|
+
"KVCache",
|
|
29
|
+
"Quantizer",
|
|
30
|
+
"QuantizationObserver",
|
|
31
|
+
"EncodedVector",
|
|
32
|
+
"QuantizationContext",
|
|
33
|
+
"TransformResult",
|
|
34
|
+
"ArtifactNotFoundError",
|
|
35
|
+
"CodebookDimensionMismatch",
|
|
36
|
+
"CyclicPipelineError",
|
|
37
|
+
"QuantizerConfigError",
|
|
38
|
+
"QuantizerFactory",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
__version__ = "0.2.0"
|
mlx_kv_quant/__main__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Entry point for `python -m mlx_kv_quant <command>`."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def main() -> None:
|
|
8
|
+
if len(sys.argv) < 2:
|
|
9
|
+
print("Usage: veloxquant {precompute|benchmark}")
|
|
10
|
+
sys.exit(1)
|
|
11
|
+
|
|
12
|
+
command = sys.argv[1]
|
|
13
|
+
# Remove the subcommand so sub-parsers see argv correctly
|
|
14
|
+
sys.argv = [f"mlx_kv_quant {command}"] + sys.argv[2:]
|
|
15
|
+
|
|
16
|
+
if command == "precompute":
|
|
17
|
+
from mlx_kv_quant.cli.precompute import main as _main
|
|
18
|
+
_main()
|
|
19
|
+
elif command == "benchmark":
|
|
20
|
+
from mlx_kv_quant.cli.benchmark import main as _main
|
|
21
|
+
_main()
|
|
22
|
+
else:
|
|
23
|
+
print(f"Unknown command: {command!r}. Choices: precompute, benchmark")
|
|
24
|
+
sys.exit(1)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if __name__ == "__main__":
|
|
28
|
+
main()
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from mlx_kv_quant.core.abstractions import ArtifactStore
|
|
8
|
+
from mlx_kv_quant.core.exceptions import ArtifactNotFoundError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InMemoryArtifactStore(ArtifactStore):
|
|
12
|
+
"""In-memory artifact store for testing — performs no disk I/O.
|
|
13
|
+
|
|
14
|
+
All artifacts are stored in plain Python dicts keyed by descriptor tuples.
|
|
15
|
+
Arrays are stored as float16 numpy arrays and wrapped in MLX on load.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
None.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
self._rotations: Dict[Tuple, np.ndarray] = {}
|
|
23
|
+
self._codebooks: Dict[Tuple, np.ndarray] = {}
|
|
24
|
+
self._jls: Dict[Tuple, np.ndarray] = {}
|
|
25
|
+
|
|
26
|
+
# ------------------------------------------------------------------
|
|
27
|
+
# Rotation matrix
|
|
28
|
+
# ------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
def load_rotation_matrix(self, d: int, seed: int) -> Any:
|
|
31
|
+
key = (d, seed)
|
|
32
|
+
if key not in self._rotations:
|
|
33
|
+
raise ArtifactNotFoundError(
|
|
34
|
+
f"InMemoryArtifactStore: rotation d={d} seed={seed} not found."
|
|
35
|
+
)
|
|
36
|
+
import mlx.core as mx
|
|
37
|
+
return mx.array(self._rotations[key])
|
|
38
|
+
|
|
39
|
+
def save_rotation_matrix(self, Pi: Any, d: int, seed: int) -> None:
|
|
40
|
+
self._rotations[(d, seed)] = np.array(Pi, dtype=np.float16)
|
|
41
|
+
|
|
42
|
+
# ------------------------------------------------------------------
|
|
43
|
+
# Codebook
|
|
44
|
+
# ------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
def load_codebook(self, distribution: str, b: int, d: int) -> Any:
|
|
47
|
+
key = (distribution, b, d)
|
|
48
|
+
if key not in self._codebooks:
|
|
49
|
+
raise ArtifactNotFoundError(
|
|
50
|
+
f"InMemoryArtifactStore: codebook dist={distribution} b={b} d={d} not found."
|
|
51
|
+
)
|
|
52
|
+
import mlx.core as mx
|
|
53
|
+
return mx.array(self._codebooks[key])
|
|
54
|
+
|
|
55
|
+
def save_codebook(self, cb: Any, distribution: str, b: int, d: int) -> None:
|
|
56
|
+
self._codebooks[(distribution, b, d)] = np.array(cb, dtype=np.float16)
|
|
57
|
+
|
|
58
|
+
# ------------------------------------------------------------------
|
|
59
|
+
# JL matrix
|
|
60
|
+
# ------------------------------------------------------------------
|
|
61
|
+
|
|
62
|
+
def load_jl_matrix(self, d: int, m: int, seed: int) -> Any:
|
|
63
|
+
key = (d, m, seed)
|
|
64
|
+
if key not in self._jls:
|
|
65
|
+
raise ArtifactNotFoundError(
|
|
66
|
+
f"InMemoryArtifactStore: JL d={d} m={m} seed={seed} not found."
|
|
67
|
+
)
|
|
68
|
+
import mlx.core as mx
|
|
69
|
+
return mx.array(self._jls[key])
|
|
70
|
+
|
|
71
|
+
def save_jl_matrix(self, S: Any, d: int, m: int, seed: int) -> None:
|
|
72
|
+
self._jls[(d, m, seed)] = np.array(S, dtype=np.float16)
|
|
73
|
+
|
|
74
|
+
# ------------------------------------------------------------------
|
|
75
|
+
# Existence check
|
|
76
|
+
# ------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
def exists(self, artifact_type: str, **kwargs: Any) -> bool:
|
|
79
|
+
if artifact_type == "rotation":
|
|
80
|
+
return (kwargs["d"], kwargs["seed"]) in self._rotations
|
|
81
|
+
if artifact_type == "codebook":
|
|
82
|
+
return (kwargs["distribution"], kwargs["b"], kwargs["d"]) in self._codebooks
|
|
83
|
+
if artifact_type == "jl":
|
|
84
|
+
return (kwargs["d"], kwargs["m"], kwargs["seed"]) in self._jls
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
def __repr__(self) -> str:
|
|
88
|
+
return (
|
|
89
|
+
f"InMemoryArtifactStore("
|
|
90
|
+
f"rotations={len(self._rotations)}, "
|
|
91
|
+
f"codebooks={len(self._codebooks)}, "
|
|
92
|
+
f"jls={len(self._jls)})"
|
|
93
|
+
)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from mlx_kv_quant.core.abstractions import ArtifactStore
|
|
10
|
+
from mlx_kv_quant.core.exceptions import ArtifactNotFoundError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NpyArtifactStore(ArtifactStore):
|
|
14
|
+
"""Artifact store that reads and writes ``.npy`` files from a local directory.
|
|
15
|
+
|
|
16
|
+
File naming conventions:
|
|
17
|
+
rotation_d{d}_seed{seed}.npy
|
|
18
|
+
codebook_{distribution}_b{b}_d{d}.npy
|
|
19
|
+
jl_d{d}_m{m}_seed{seed}.npy
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
root_dir: Path to the directory where artifacts are stored.
|
|
23
|
+
Created automatically on first save if absent.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, root_dir: str | Path) -> None:
|
|
27
|
+
self._root = Path(root_dir)
|
|
28
|
+
self._root.mkdir(parents=True, exist_ok=True)
|
|
29
|
+
|
|
30
|
+
# ------------------------------------------------------------------
|
|
31
|
+
# Rotation matrix
|
|
32
|
+
# ------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
def _rotation_path(self, d: int, seed: int) -> Path:
|
|
35
|
+
return self._root / f"rotation_d{d}_seed{seed}.npy"
|
|
36
|
+
|
|
37
|
+
def load_rotation_matrix(self, d: int, seed: int) -> Any:
|
|
38
|
+
path = self._rotation_path(d, seed)
|
|
39
|
+
if not path.exists():
|
|
40
|
+
raise ArtifactNotFoundError(
|
|
41
|
+
f"Rotation matrix not found at {path}. "
|
|
42
|
+
f"Run `python -m mlx_kv_quant precompute --head_dim {d}` first."
|
|
43
|
+
)
|
|
44
|
+
import mlx.core as mx
|
|
45
|
+
return mx.array(np.load(path).astype(np.float16))
|
|
46
|
+
|
|
47
|
+
def save_rotation_matrix(self, Pi: Any, d: int, seed: int) -> None:
|
|
48
|
+
path = self._rotation_path(d, seed)
|
|
49
|
+
arr = np.array(Pi, dtype=np.float16)
|
|
50
|
+
np.save(path, arr)
|
|
51
|
+
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
# Codebook
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
def _codebook_path(self, distribution: str, b: int, d: int) -> Path:
|
|
57
|
+
return self._root / f"codebook_{distribution}_b{b}_d{d}.npy"
|
|
58
|
+
|
|
59
|
+
def load_codebook(self, distribution: str, b: int, d: int) -> Any:
|
|
60
|
+
path = self._codebook_path(distribution, b, d)
|
|
61
|
+
if not path.exists():
|
|
62
|
+
raise ArtifactNotFoundError(
|
|
63
|
+
f"Codebook not found at {path}. "
|
|
64
|
+
f"Run `python -m mlx_kv_quant precompute --head_dim {d} --bits {b}` first."
|
|
65
|
+
)
|
|
66
|
+
import mlx.core as mx
|
|
67
|
+
return mx.array(np.load(path).astype(np.float16))
|
|
68
|
+
|
|
69
|
+
def save_codebook(self, cb: Any, distribution: str, b: int, d: int) -> None:
|
|
70
|
+
path = self._codebook_path(distribution, b, d)
|
|
71
|
+
arr = np.array(cb, dtype=np.float16)
|
|
72
|
+
np.save(path, arr)
|
|
73
|
+
|
|
74
|
+
# ------------------------------------------------------------------
|
|
75
|
+
# JL matrix
|
|
76
|
+
# ------------------------------------------------------------------
|
|
77
|
+
|
|
78
|
+
def _jl_path(self, d: int, m: int, seed: int) -> Path:
|
|
79
|
+
return self._root / f"jl_d{d}_m{m}_seed{seed}.npy"
|
|
80
|
+
|
|
81
|
+
def load_jl_matrix(self, d: int, m: int, seed: int) -> Any:
|
|
82
|
+
path = self._jl_path(d, m, seed)
|
|
83
|
+
if not path.exists():
|
|
84
|
+
raise ArtifactNotFoundError(
|
|
85
|
+
f"JL matrix not found at {path}. "
|
|
86
|
+
f"Run `python -m mlx_kv_quant precompute --head_dim {d} --jl_dim {m}` first."
|
|
87
|
+
)
|
|
88
|
+
import mlx.core as mx
|
|
89
|
+
return mx.array(np.load(path).astype(np.float16))
|
|
90
|
+
|
|
91
|
+
def save_jl_matrix(self, S: Any, d: int, m: int, seed: int) -> None:
|
|
92
|
+
path = self._jl_path(d, m, seed)
|
|
93
|
+
arr = np.array(S, dtype=np.float16)
|
|
94
|
+
np.save(path, arr)
|
|
95
|
+
|
|
96
|
+
# ------------------------------------------------------------------
|
|
97
|
+
# Existence check
|
|
98
|
+
# ------------------------------------------------------------------
|
|
99
|
+
|
|
100
|
+
def exists(self, artifact_type: str, **kwargs: Any) -> bool:
|
|
101
|
+
if artifact_type == "rotation":
|
|
102
|
+
return self._rotation_path(kwargs["d"], kwargs["seed"]).exists()
|
|
103
|
+
if artifact_type == "codebook":
|
|
104
|
+
return self._codebook_path(
|
|
105
|
+
kwargs["distribution"], kwargs["b"], kwargs["d"]
|
|
106
|
+
).exists()
|
|
107
|
+
if artifact_type == "jl":
|
|
108
|
+
return self._jl_path(
|
|
109
|
+
kwargs["d"], kwargs["m"], kwargs["seed"]
|
|
110
|
+
).exists()
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
def __repr__(self) -> str:
|
|
114
|
+
return f"NpyArtifactStore(root={self._root!r})"
|
|
File without changes
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Attend latency and memory benchmark across sequence lengths.
|
|
2
|
+
|
|
3
|
+
Compares four configurations:
|
|
4
|
+
baseline — no optimizations
|
|
5
|
+
vectorized — enable_vectorized_attend=True
|
|
6
|
+
fused — enable_vectorized_attend=True + enable_fused_query_dot=True
|
|
7
|
+
all — fused + enable_outlier_two_stream=True
|
|
8
|
+
|
|
9
|
+
Usage::
|
|
10
|
+
|
|
11
|
+
python -m mlx_kv_quant.benchmarks.attend_benchmark
|
|
12
|
+
python -m mlx_kv_quant.benchmarks.attend_benchmark --method turboquant_mse --bits 2
|
|
13
|
+
python -m mlx_kv_quant.benchmarks.attend_benchmark --seq_lens 64 256 1024 4096
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import time
|
|
19
|
+
from typing import List
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_SEQ_LENS: List[int] = [128, 512, 1000, 2048]
|
|
25
|
+
_N_ATTEND_CALLS = 20 # per measurement
|
|
26
|
+
_N_OUTLIER_CHANNELS = 4
|
|
27
|
+
_N_CALIB_TOKENS = 50 # short so calibration completes inside every seq_len
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _build(
|
|
31
|
+
method: str,
|
|
32
|
+
d: int,
|
|
33
|
+
bits: int,
|
|
34
|
+
jl_dim: int,
|
|
35
|
+
seed: int,
|
|
36
|
+
*,
|
|
37
|
+
vectorized: bool = False,
|
|
38
|
+
fused: bool = False,
|
|
39
|
+
outlier: bool = False,
|
|
40
|
+
):
|
|
41
|
+
from mlx_kv_quant.cache.base import KVCacheBuilder
|
|
42
|
+
|
|
43
|
+
return (
|
|
44
|
+
KVCacheBuilder()
|
|
45
|
+
.with_method(method)
|
|
46
|
+
.with_head_dim(d)
|
|
47
|
+
.with_bit_width(inlier=bits)
|
|
48
|
+
.with_jl_dim(jl_dim)
|
|
49
|
+
.with_seed(seed)
|
|
50
|
+
.with_vectorized_attend(vectorized)
|
|
51
|
+
.with_fused_query_dot(fused)
|
|
52
|
+
.with_outlier_two_stream(outlier)
|
|
53
|
+
.with_n_outlier_channels(_N_OUTLIER_CHANNELS)
|
|
54
|
+
.with_n_calib_tokens(_N_CALIB_TOKENS)
|
|
55
|
+
.build()
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _fill(cache, keys, vals) -> None:
|
|
60
|
+
for i in range(len(keys)):
|
|
61
|
+
cache.append(keys[i], vals[i])
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _measure_attend_ms(cache, q, n_calls: int) -> float:
|
|
65
|
+
import mlx.core as mx
|
|
66
|
+
|
|
67
|
+
# Warm-up
|
|
68
|
+
mx.eval(cache.attend(q))
|
|
69
|
+
|
|
70
|
+
t0 = time.perf_counter()
|
|
71
|
+
for _ in range(n_calls):
|
|
72
|
+
mx.eval(cache.attend(q))
|
|
73
|
+
return (time.perf_counter() - t0) * 1_000.0 / n_calls
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _correctness_check(cache_a, cache_b, q, label_a: str, label_b: str) -> None:
|
|
77
|
+
"""Assert two caches produce numerically close attend outputs."""
|
|
78
|
+
import mlx.core as mx
|
|
79
|
+
|
|
80
|
+
out_a = np.array(cache_a.attend(q))
|
|
81
|
+
out_b = np.array(cache_b.attend(q))
|
|
82
|
+
mx.eval()
|
|
83
|
+
try:
|
|
84
|
+
np.testing.assert_allclose(out_a, out_b, rtol=5e-3, atol=5e-3)
|
|
85
|
+
print(f" Correctness {label_a} vs {label_b}: OK (max_diff={np.max(np.abs(out_a-out_b)):.5f})")
|
|
86
|
+
except AssertionError as e:
|
|
87
|
+
print(f" Correctness {label_a} vs {label_b}: FAIL — {e}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def run(
|
|
91
|
+
method: str,
|
|
92
|
+
d: int,
|
|
93
|
+
bits: int,
|
|
94
|
+
jl_dim: int,
|
|
95
|
+
seed: int,
|
|
96
|
+
seq_lens: List[int],
|
|
97
|
+
n_calls: int,
|
|
98
|
+
correctness: bool,
|
|
99
|
+
) -> None:
|
|
100
|
+
import mlx.core as mx
|
|
101
|
+
|
|
102
|
+
rng = np.random.default_rng(seed)
|
|
103
|
+
|
|
104
|
+
configs = {
|
|
105
|
+
"baseline": dict(vectorized=False, fused=False, outlier=False),
|
|
106
|
+
"vectorized": dict(vectorized=True, fused=False, outlier=False),
|
|
107
|
+
"fused": dict(vectorized=True, fused=True, outlier=False),
|
|
108
|
+
"all_opts": dict(vectorized=True, fused=True, outlier=True),
|
|
109
|
+
}
|
|
110
|
+
# turboquant_mse doesn't have a fused path; skip 'fused'/'all_opts' for it.
|
|
111
|
+
if method == "turboquant_mse":
|
|
112
|
+
configs = {
|
|
113
|
+
"baseline": dict(vectorized=False, fused=False, outlier=False),
|
|
114
|
+
"vectorized": dict(vectorized=True, fused=False, outlier=False),
|
|
115
|
+
"all_opts": dict(vectorized=True, fused=False, outlier=True),
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
col_w = 14
|
|
119
|
+
header = f"{'seq_len':>8} " + " ".join(f"{k:>{col_w}}" for k in configs)
|
|
120
|
+
print(f"\n=== attend latency (ms/call) — method={method}, d={d}, bits={bits} ===")
|
|
121
|
+
print(header)
|
|
122
|
+
print("-" * len(header))
|
|
123
|
+
|
|
124
|
+
for seq_len in seq_lens:
|
|
125
|
+
keys = [mx.array(rng.standard_normal(d).astype(np.float16)) for _ in range(seq_len)]
|
|
126
|
+
vals = [mx.array(rng.standard_normal(d).astype(np.float16)) for _ in range(seq_len)]
|
|
127
|
+
q = mx.array(rng.standard_normal(d).astype(np.float16))
|
|
128
|
+
|
|
129
|
+
caches = {}
|
|
130
|
+
latencies = {}
|
|
131
|
+
for name, flags in configs.items():
|
|
132
|
+
c = _build(method, d, bits, jl_dim, seed, **flags)
|
|
133
|
+
_fill(c, keys, vals)
|
|
134
|
+
caches[name] = c
|
|
135
|
+
latencies[name] = _measure_attend_ms(c, q, n_calls)
|
|
136
|
+
|
|
137
|
+
row = f"{seq_len:>8} " + " ".join(f"{latencies[k]:>{col_w}.3f}" for k in configs)
|
|
138
|
+
print(row)
|
|
139
|
+
|
|
140
|
+
# Speedup summary
|
|
141
|
+
base_ms = latencies["baseline"]
|
|
142
|
+
for name, ms in latencies.items():
|
|
143
|
+
if name != "baseline":
|
|
144
|
+
print(f" {name:>12}: {base_ms/max(ms, 1e-9):.2f}× speedup vs baseline")
|
|
145
|
+
|
|
146
|
+
# Optional correctness check
|
|
147
|
+
if correctness and len(caches) >= 2:
|
|
148
|
+
names = list(caches.keys())
|
|
149
|
+
_correctness_check(caches[names[0]], caches[names[1]], q, names[0], names[1])
|
|
150
|
+
|
|
151
|
+
# Memory footprint
|
|
152
|
+
print(f" memory (bytes): "
|
|
153
|
+
+ ", ".join(f"{n}={c.memory_bytes()}" for n, c in caches.items()))
|
|
154
|
+
print()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def main() -> None:
|
|
158
|
+
parser = argparse.ArgumentParser(description="TurboQuant attend latency sweep")
|
|
159
|
+
parser.add_argument("--method", default="turboquant_prod",
|
|
160
|
+
choices=["turboquant_prod", "turboquant_mse"])
|
|
161
|
+
parser.add_argument("--head_dim", type=int, default=128)
|
|
162
|
+
parser.add_argument("--bits", type=int, default=3)
|
|
163
|
+
parser.add_argument("--jl_dim", type=int, default=128)
|
|
164
|
+
parser.add_argument("--seed", type=int, default=42)
|
|
165
|
+
parser.add_argument("--seq_lens", type=int, nargs="*", default=_SEQ_LENS)
|
|
166
|
+
parser.add_argument("--n_calls", type=int, default=_N_ATTEND_CALLS)
|
|
167
|
+
parser.add_argument("--correctness", action="store_true",
|
|
168
|
+
help="Run cross-config correctness checks at each seq_len")
|
|
169
|
+
args = parser.parse_args()
|
|
170
|
+
|
|
171
|
+
run(
|
|
172
|
+
method=args.method,
|
|
173
|
+
d=args.head_dim,
|
|
174
|
+
bits=args.bits,
|
|
175
|
+
jl_dim=args.jl_dim,
|
|
176
|
+
seed=args.seed,
|
|
177
|
+
seq_lens=args.seq_lens,
|
|
178
|
+
n_calls=args.n_calls,
|
|
179
|
+
correctness=args.correctness,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == "__main__":
|
|
184
|
+
main()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from mlx_kv_quant.cache.base import KVCacheBuilder, KVCacheConfig, KVCacheFactory
|
|
4
|
+
from mlx_kv_quant.cache.polar_cache import PolarQuantKVCache
|
|
5
|
+
from mlx_kv_quant.cache.qjl_cache import QJLKVCache
|
|
6
|
+
from mlx_kv_quant.cache.sliding_window_cache import SlidingWindowKVCache
|
|
7
|
+
from mlx_kv_quant.cache.turboquant_cache import TurboQuantKVCache
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"KVCacheBuilder",
|
|
11
|
+
"KVCacheConfig",
|
|
12
|
+
"KVCacheFactory",
|
|
13
|
+
"PolarQuantKVCache",
|
|
14
|
+
"QJLKVCache",
|
|
15
|
+
"SlidingWindowKVCache",
|
|
16
|
+
"TurboQuantKVCache",
|
|
17
|
+
]
|