mcp-plesk-dev-docs 0.4.2__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mcp_plesk_dev_docs-0.4.2/mcp_plesk_dev_docs.egg-info → mcp_plesk_dev_docs-0.5.0}/PKG-INFO +4 -1
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/README.md +3 -1
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0/mcp_plesk_dev_docs.egg-info}/PKG-INFO +4 -1
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/SOURCES.txt +0 -4
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/requires.txt +1 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/tq_index.py +23 -11
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/pyproject.toml +8 -2
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_turboquant_regression.py +1 -1
- mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/__init__.py +0 -21
- mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/compressors.py +0 -190
- mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/lloyd_max.py +0 -190
- mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/turboquant.py +0 -249
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/LICENSE +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/NOTICE +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/dependency_links.txt +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/entry_points.txt +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/top_level.txt +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/__init__.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/ai_client.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_engines.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_gates.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_reporting.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_runner.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_suites.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/chunking.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/error_handling.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/html_utils.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/indexing.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/io_utils.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/log_handler.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/model_config.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/platform_utils.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/settings.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/summary_cache.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/types.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/setup.cfg +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_ai_client.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_async_tools.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_benchmark_engines.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_benchmark_gates.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_chunking.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_error_handling.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_html_utils.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_indexing.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_io_utils.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_log_handler.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_model_config.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_progress.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_prompts.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_resources.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_sampling.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_search_helpers.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_server.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_settings.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_sota_ph1.py +0 -0
- {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_startup_path.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mcp-plesk-dev-docs
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: A unified MCP server that indexes and retrieves Plesk documentation using vector embeddings and semantic search with reranking
|
|
5
5
|
Author-email: Gilson Siqueira <gilson@example.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -35,6 +35,7 @@ Requires-Dist: torch>=2.4.0
|
|
|
35
35
|
Requires-Dist: markdownify>=0.14.1
|
|
36
36
|
Requires-Dist: tantivy>=0.22.0
|
|
37
37
|
Requires-Dist: lance-namespace==0.6.1
|
|
38
|
+
Requires-Dist: tq-search
|
|
38
39
|
Provides-Extra: dev
|
|
39
40
|
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
|
40
41
|
Requires-Dist: requests>=2.32.0; extra == "dev"
|
|
@@ -49,6 +50,8 @@ Dynamic: license-file
|
|
|
49
50
|
# mcp-plesk-dev-docs
|
|
50
51
|
|
|
51
52
|
[](https://www.python.org/downloads/)
|
|
53
|
+
[](https://pypi.org/project/mcp-plesk-dev-docs/)
|
|
54
|
+
[](https://registry.modelcontextprotocol.io/v0.1/servers/io.github.barateza%2Fmcp-plesk-dev-docs/versions/0.4.3)
|
|
52
55
|
[](LICENSE)
|
|
53
56
|
[](https://modelcontextprotocol.io/)
|
|
54
57
|
[](https://github.com/psf/black)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# mcp-plesk-dev-docs
|
|
2
2
|
|
|
3
3
|
[](https://www.python.org/downloads/)
|
|
4
|
+
[](https://pypi.org/project/mcp-plesk-dev-docs/)
|
|
5
|
+
[](https://registry.modelcontextprotocol.io/v0.1/servers/io.github.barateza%2Fmcp-plesk-dev-docs/versions/0.4.3)
|
|
4
6
|
[](LICENSE)
|
|
5
7
|
[](https://modelcontextprotocol.io/)
|
|
6
8
|
[](https://github.com/psf/black)
|
|
@@ -170,4 +172,4 @@ Portions of this repository were developed under contract for Plesk Internationa
|
|
|
170
172
|
|
|
171
173
|
*Built to make Plesk extension development faster.*
|
|
172
174
|
|
|
173
|
-
<!-- mcp-name: io.github.barateza/mcp-plesk-dev-docs -->
|
|
175
|
+
<!-- mcp-name: io.github.barateza/mcp-plesk-dev-docs -->
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mcp-plesk-dev-docs
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: A unified MCP server that indexes and retrieves Plesk documentation using vector embeddings and semantic search with reranking
|
|
5
5
|
Author-email: Gilson Siqueira <gilson@example.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -35,6 +35,7 @@ Requires-Dist: torch>=2.4.0
|
|
|
35
35
|
Requires-Dist: markdownify>=0.14.1
|
|
36
36
|
Requires-Dist: tantivy>=0.22.0
|
|
37
37
|
Requires-Dist: lance-namespace==0.6.1
|
|
38
|
+
Requires-Dist: tq-search
|
|
38
39
|
Provides-Extra: dev
|
|
39
40
|
Requires-Dist: pytest>=8.0.0; extra == "dev"
|
|
40
41
|
Requires-Dist: requests>=2.32.0; extra == "dev"
|
|
@@ -49,6 +50,8 @@ Dynamic: license-file
|
|
|
49
50
|
# mcp-plesk-dev-docs
|
|
50
51
|
|
|
51
52
|
[](https://www.python.org/downloads/)
|
|
53
|
+
[](https://pypi.org/project/mcp-plesk-dev-docs/)
|
|
54
|
+
[](https://registry.modelcontextprotocol.io/v0.1/servers/io.github.barateza%2Fmcp-plesk-dev-docs/versions/0.4.3)
|
|
52
55
|
[](LICENSE)
|
|
53
56
|
[](https://modelcontextprotocol.io/)
|
|
54
57
|
[](https://github.com/psf/black)
|
{mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/SOURCES.txt
RENAMED
|
@@ -27,10 +27,6 @@ plesk_unified/settings.py
|
|
|
27
27
|
plesk_unified/summary_cache.py
|
|
28
28
|
plesk_unified/tq_index.py
|
|
29
29
|
plesk_unified/types.py
|
|
30
|
-
plesk_unified/turboquant/__init__.py
|
|
31
|
-
plesk_unified/turboquant/compressors.py
|
|
32
|
-
plesk_unified/turboquant/lloyd_max.py
|
|
33
|
-
plesk_unified/turboquant/turboquant.py
|
|
34
30
|
tests/test_ai_client.py
|
|
35
31
|
tests/test_async_tools.py
|
|
36
32
|
tests/test_benchmark_engines.py
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from tq_search import TurboQuantProd
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class TurboQuantIndex:
|
|
@@ -52,6 +52,13 @@ class TurboQuantIndex:
|
|
|
52
52
|
if self.compressed_db is None:
|
|
53
53
|
return []
|
|
54
54
|
|
|
55
|
+
# 1. Lazily move the compressed database to the target device once
|
|
56
|
+
first_val = next(iter(self.compressed_db.values()))
|
|
57
|
+
if str(first_val.device) != self.device:
|
|
58
|
+
self.compressed_db = {
|
|
59
|
+
k: v.to(self.device) for k, v in self.compressed_db.items()
|
|
60
|
+
}
|
|
61
|
+
|
|
55
62
|
selected_indices: list[int]
|
|
56
63
|
if category:
|
|
57
64
|
selected_indices = self._category_to_indices.get(category, [])
|
|
@@ -60,25 +67,30 @@ class TurboQuantIndex:
|
|
|
60
67
|
else:
|
|
61
68
|
selected_indices = list(range(len(self._meta)))
|
|
62
69
|
|
|
63
|
-
#
|
|
70
|
+
# 2. L2-Normalize the query
|
|
64
71
|
norm = np.linalg.norm(query_vec)
|
|
65
72
|
query_normalized = query_vec / max(norm, 1e-12)
|
|
66
73
|
|
|
67
|
-
#
|
|
74
|
+
# 3. Prepare query as a batched tensor (1, dim) directly on target device
|
|
68
75
|
q = torch.from_numpy(query_normalized).to(self.device).unsqueeze(0)
|
|
69
76
|
|
|
70
|
-
#
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
77
|
+
# 4. Slice candidates only if we are actually filtering a subset
|
|
78
|
+
if category and len(selected_indices) < len(self._meta):
|
|
79
|
+
selected_tensor = torch.as_tensor(
|
|
80
|
+
selected_indices, dtype=torch.long, device=self.device
|
|
81
|
+
)
|
|
82
|
+
db_on_device = {
|
|
83
|
+
k: v.index_select(0, selected_tensor)
|
|
84
|
+
for k, v in self.compressed_db.items()
|
|
85
|
+
}
|
|
86
|
+
else:
|
|
87
|
+
db_on_device = self.compressed_db
|
|
76
88
|
|
|
77
|
-
#
|
|
89
|
+
# 5. Perform a SINGLE batched inner product calculation
|
|
78
90
|
with torch.no_grad():
|
|
79
91
|
scores = self.quantizer.inner_product(q, db_on_device).squeeze(0)
|
|
80
92
|
|
|
81
|
-
#
|
|
93
|
+
# 6. Sort and return
|
|
82
94
|
scores_np = scores.cpu().numpy()
|
|
83
95
|
idx = np.argsort(-scores_np)[:top_k]
|
|
84
96
|
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "mcp-plesk-dev-docs"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.5.0"
|
|
8
8
|
description = "A unified MCP server that indexes and retrieves Plesk documentation using vector embeddings and semantic search with reranking"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = "MIT"
|
|
@@ -38,8 +38,11 @@ dependencies = [
|
|
|
38
38
|
"markdownify>=0.14.1",
|
|
39
39
|
"tantivy>=0.22.0",
|
|
40
40
|
"lance-namespace==0.6.1",
|
|
41
|
+
"tq-search",
|
|
41
42
|
]
|
|
42
43
|
|
|
44
|
+
|
|
45
|
+
|
|
43
46
|
[project.urls]
|
|
44
47
|
Homepage = "https://github.com/barateza/mcp-plesk-dev-docs"
|
|
45
48
|
Documentation = "https://github.com/barateza/mcp-plesk-dev-docs#readme"
|
|
@@ -47,7 +50,8 @@ Repository = "https://github.com/barateza/mcp-plesk-dev-docs.git"
|
|
|
47
50
|
"Bug Tracker" = "https://github.com/barateza/mcp-plesk-dev-docs/issues"
|
|
48
51
|
|
|
49
52
|
[tool.setuptools]
|
|
50
|
-
packages = ["plesk_unified"
|
|
53
|
+
packages = ["plesk_unified"]
|
|
54
|
+
|
|
51
55
|
|
|
52
56
|
[project.scripts]
|
|
53
57
|
# Console script to run the MCP server
|
|
@@ -78,7 +82,9 @@ url = "https://download.pytorch.org/whl/cu124"
|
|
|
78
82
|
explicit = true
|
|
79
83
|
|
|
80
84
|
[tool.uv.sources]
|
|
85
|
+
tq-search = { path = "/Users/gilsonsiqueira/tq-search", editable = true }
|
|
81
86
|
torch = [
|
|
87
|
+
|
|
82
88
|
{ index = "pytorch-cu124", marker = "sys_platform == 'win32'" },
|
|
83
89
|
]
|
|
84
90
|
torchvision = [
|
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from plesk_unified import tq_index
|
|
7
|
-
from
|
|
7
|
+
from tq_search import LloydMaxCodebook, TurboQuantMSE, TurboQuantProd
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def test_turboquant_package_exports():
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
"""TurboQuant helpers used by the unified retrieval path.
|
|
2
|
-
|
|
3
|
-
Base implementation: https://github.com/tonbistudio/turboquant-pytorch
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from __future__ import annotations
|
|
7
|
-
|
|
8
|
-
from .compressors import TurboQuantCompressorMSE, TurboQuantCompressorV2
|
|
9
|
-
from .lloyd_max import LloydMaxCodebook, compute_expected_distortion, solve_lloyd_max
|
|
10
|
-
from .turboquant import TurboQuantKVCache, TurboQuantMSE, TurboQuantProd
|
|
11
|
-
|
|
12
|
-
__all__ = [
|
|
13
|
-
"TurboQuantCompressorMSE",
|
|
14
|
-
"TurboQuantCompressorV2",
|
|
15
|
-
"TurboQuantKVCache",
|
|
16
|
-
"TurboQuantMSE",
|
|
17
|
-
"TurboQuantProd",
|
|
18
|
-
"LloydMaxCodebook",
|
|
19
|
-
"compute_expected_distortion",
|
|
20
|
-
"solve_lloyd_max",
|
|
21
|
-
]
|
|
@@ -1,190 +0,0 @@
|
|
|
1
|
-
"""TurboQuant KV cache helpers."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import math
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
# ---------------------------------------------------------------------------
|
|
11
|
-
# Closed-form Gaussian integration helpers (replaces scipy.integrate.quad)
|
|
12
|
-
# ---------------------------------------------------------------------------
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def _gauss_pdf(x: float, sigma: float) -> float:
|
|
16
|
-
return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def _gauss_cdf(x: float, sigma: float) -> float:
|
|
20
|
-
return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def _int_pdf(a: float, b: float, sigma: float) -> float:
|
|
24
|
-
return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def _int_x_pdf(a: float, b: float, sigma: float) -> float:
|
|
28
|
-
return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class TurboQuantCompressorV2:
|
|
32
|
-
"""Compressed key store with direct inner-product scoring."""
|
|
33
|
-
|
|
34
|
-
def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
|
|
35
|
-
self.head_dim = head_dim
|
|
36
|
-
self.bits = bits
|
|
37
|
-
self.mse_bits = max(bits - 1, 1)
|
|
38
|
-
self.device = device
|
|
39
|
-
|
|
40
|
-
gen = torch.Generator(device="cpu")
|
|
41
|
-
gen.manual_seed(seed)
|
|
42
|
-
G = torch.randn(head_dim, head_dim, generator=gen)
|
|
43
|
-
Q, R = torch.linalg.qr(G)
|
|
44
|
-
diag_sign = torch.sign(torch.diag(R))
|
|
45
|
-
diag_sign[diag_sign == 0] = 1.0
|
|
46
|
-
self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
|
|
47
|
-
|
|
48
|
-
self.centroids = self._solve_codebook(head_dim, self.mse_bits).to(device)
|
|
49
|
-
|
|
50
|
-
gen2 = torch.Generator(device="cpu")
|
|
51
|
-
gen2.manual_seed(seed + 10000)
|
|
52
|
-
self.S = torch.randn(head_dim, head_dim, generator=gen2).to(device)
|
|
53
|
-
|
|
54
|
-
self.PiT = self.Pi.T.contiguous()
|
|
55
|
-
|
|
56
|
-
def _solve_codebook(self, d: int, bits: int) -> torch.Tensor:
|
|
57
|
-
n_levels = 2**bits
|
|
58
|
-
sigma = 1.0 / math.sqrt(d)
|
|
59
|
-
|
|
60
|
-
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
61
|
-
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
62
|
-
|
|
63
|
-
for _ in range(200):
|
|
64
|
-
boundaries = [
|
|
65
|
-
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
66
|
-
]
|
|
67
|
-
edges = [lo * 3] + boundaries + [hi * 3]
|
|
68
|
-
new_centroids = []
|
|
69
|
-
for i in range(n_levels):
|
|
70
|
-
a, b = edges[i], edges[i + 1]
|
|
71
|
-
num = _int_x_pdf(a, b, sigma)
|
|
72
|
-
den = _int_pdf(a, b, sigma)
|
|
73
|
-
new_centroids.append(num / den if den > 1e-15 else centroids[i])
|
|
74
|
-
if (
|
|
75
|
-
max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
|
|
76
|
-
< 1e-10
|
|
77
|
-
):
|
|
78
|
-
break
|
|
79
|
-
centroids = new_centroids
|
|
80
|
-
|
|
81
|
-
return torch.tensor(centroids, dtype=torch.float32)
|
|
82
|
-
|
|
83
|
-
@torch.no_grad()
|
|
84
|
-
def compress(self, states: torch.Tensor) -> dict:
|
|
85
|
-
B, H, S, D = states.shape
|
|
86
|
-
flat = states.reshape(-1, D).float()
|
|
87
|
-
|
|
88
|
-
vec_norms = torch.norm(flat, dim=-1, keepdim=True)
|
|
89
|
-
flat_norm = flat / (vec_norms + 1e-8)
|
|
90
|
-
|
|
91
|
-
rotated = flat_norm @ self.Pi.T
|
|
92
|
-
diffs = rotated.unsqueeze(-1) - self.centroids
|
|
93
|
-
indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
|
|
94
|
-
|
|
95
|
-
reconstructed_rotated = self.centroids[indices.long()]
|
|
96
|
-
k_mse = (reconstructed_rotated @ self.Pi) * vec_norms
|
|
97
|
-
|
|
98
|
-
residual = flat - k_mse
|
|
99
|
-
residual_norm = torch.norm(residual, dim=-1)
|
|
100
|
-
|
|
101
|
-
projected = residual @ self.S.T
|
|
102
|
-
signs = (projected >= 0).to(torch.int8) * 2 - 1
|
|
103
|
-
|
|
104
|
-
return {
|
|
105
|
-
"k_mse": k_mse.to(torch.float16).reshape(B, H, S, D),
|
|
106
|
-
"qjl_signs": signs.reshape(B, H, S, D),
|
|
107
|
-
"residual_norm": residual_norm.to(torch.float16).reshape(B, H, S),
|
|
108
|
-
"shape": (B, H, S, D),
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
@torch.no_grad()
|
|
112
|
-
def asymmetric_attention_scores(
|
|
113
|
-
self, queries: torch.Tensor, compressed: dict
|
|
114
|
-
) -> torch.Tensor:
|
|
115
|
-
k_mse = compressed["k_mse"].float()
|
|
116
|
-
signs = compressed["qjl_signs"].float()
|
|
117
|
-
r_norm = compressed["residual_norm"].float()
|
|
118
|
-
|
|
119
|
-
term1 = torch.matmul(queries.float(), k_mse.transpose(-2, -1))
|
|
120
|
-
q_projected = torch.matmul(queries.float(), self.S.T)
|
|
121
|
-
qjl_ip = torch.matmul(q_projected, signs.transpose(-2, -1))
|
|
122
|
-
|
|
123
|
-
m = self.S.shape[0]
|
|
124
|
-
correction_scale = math.sqrt(math.pi / 2) / m
|
|
125
|
-
term2 = correction_scale * qjl_ip * r_norm.unsqueeze(-2)
|
|
126
|
-
|
|
127
|
-
return term1 + term2
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class TurboQuantCompressorMSE:
|
|
131
|
-
"""MSE-only compressor for values."""
|
|
132
|
-
|
|
133
|
-
def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
|
|
134
|
-
self.head_dim = head_dim
|
|
135
|
-
self.bits = bits
|
|
136
|
-
self.device = device
|
|
137
|
-
|
|
138
|
-
gen = torch.Generator(device="cpu")
|
|
139
|
-
gen.manual_seed(seed)
|
|
140
|
-
G = torch.randn(head_dim, head_dim, generator=gen)
|
|
141
|
-
Q, R = torch.linalg.qr(G)
|
|
142
|
-
diag_sign = torch.sign(torch.diag(R))
|
|
143
|
-
diag_sign[diag_sign == 0] = 1.0
|
|
144
|
-
self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
|
|
145
|
-
self.centroids = self._solve_codebook(head_dim, bits).to(device)
|
|
146
|
-
|
|
147
|
-
def _solve_codebook(self, d, bits):
|
|
148
|
-
n_levels = 2**bits
|
|
149
|
-
sigma = 1.0 / math.sqrt(d)
|
|
150
|
-
|
|
151
|
-
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
152
|
-
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
153
|
-
for _ in range(200):
|
|
154
|
-
boundaries = [
|
|
155
|
-
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
156
|
-
]
|
|
157
|
-
edges = [lo * 3] + boundaries + [hi * 3]
|
|
158
|
-
new_c = []
|
|
159
|
-
for i in range(n_levels):
|
|
160
|
-
a, b = edges[i], edges[i + 1]
|
|
161
|
-
num = _int_x_pdf(a, b, sigma)
|
|
162
|
-
den = _int_pdf(a, b, sigma)
|
|
163
|
-
new_c.append(num / den if den > 1e-15 else centroids[i])
|
|
164
|
-
if max(abs(new_c[i] - centroids[i]) for i in range(n_levels)) < 1e-10:
|
|
165
|
-
break
|
|
166
|
-
centroids = new_c
|
|
167
|
-
return torch.tensor(centroids, dtype=torch.float32)
|
|
168
|
-
|
|
169
|
-
@torch.no_grad()
|
|
170
|
-
def compress(self, states: torch.Tensor) -> dict:
|
|
171
|
-
B, H, S, D = states.shape
|
|
172
|
-
flat = states.reshape(-1, D).float()
|
|
173
|
-
vec_norms = torch.norm(flat, dim=-1, keepdim=True)
|
|
174
|
-
flat_norm = flat / (vec_norms + 1e-8)
|
|
175
|
-
rotated = flat_norm @ self.Pi.T
|
|
176
|
-
diffs = rotated.unsqueeze(-1) - self.centroids
|
|
177
|
-
indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
|
|
178
|
-
return {
|
|
179
|
-
"indices": indices,
|
|
180
|
-
"vec_norms": vec_norms.squeeze(-1).to(torch.float16),
|
|
181
|
-
"shape": (B, H, S, D),
|
|
182
|
-
}
|
|
183
|
-
|
|
184
|
-
@torch.no_grad()
|
|
185
|
-
def decompress(self, compressed: dict) -> torch.Tensor:
|
|
186
|
-
B, H, S, D = compressed["shape"]
|
|
187
|
-
indices = compressed["indices"].long()
|
|
188
|
-
reconstructed = self.centroids[indices] @ self.Pi
|
|
189
|
-
vec_norms = compressed["vec_norms"].float().unsqueeze(-1)
|
|
190
|
-
return (reconstructed * vec_norms).reshape(B, H, S, D)
|
|
@@ -1,190 +0,0 @@
|
|
|
1
|
-
# ruff: noqa
|
|
2
|
-
"""Lloyd-Max scalar quantizer for rotated unit vectors.
|
|
3
|
-
|
|
4
|
-
The coordinate distribution is approximately Beta-shaped on [-1, 1] after
|
|
5
|
-
random rotation. For d >= 64, a Gaussian N(0, 1/d) is a good approximation.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
from __future__ import annotations
|
|
9
|
-
|
|
10
|
-
import math
|
|
11
|
-
|
|
12
|
-
import torch
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
# ---------------------------------------------------------------------------
|
|
16
|
-
# Pure-Python Gaussian integration helpers (replaces scipy.integrate.quad)
|
|
17
|
-
# ---------------------------------------------------------------------------
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def _gauss_pdf(x: float, sigma: float) -> float:
|
|
21
|
-
"""N(0, σ²) probability density at x."""
|
|
22
|
-
return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def _gauss_cdf(x: float, sigma: float) -> float:
|
|
26
|
-
"""N(0, σ²) cumulative distribution at x."""
|
|
27
|
-
return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _int_pdf(a: float, b: float, sigma: float) -> float:
|
|
31
|
-
"""∫[a,b] N(0,σ²)(x) dx — closed form via erf."""
|
|
32
|
-
return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def _int_x_pdf(a: float, b: float, sigma: float) -> float:
|
|
36
|
-
"""∫[a,b] x·N(0,σ²)(x) dx = σ²·[f(a) − f(b)]."""
|
|
37
|
-
return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def _int_sq_pdf(a: float, b: float, sigma: float, c: float) -> float:
|
|
41
|
-
"""∫[a,b] (x−c)²·N(0,σ²)(x) dx — closed form."""
|
|
42
|
-
fa, fb = _gauss_pdf(a, sigma), _gauss_pdf(b, sigma)
|
|
43
|
-
cdf_diff = _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
|
|
44
|
-
sig2 = sigma * sigma
|
|
45
|
-
return (
|
|
46
|
-
sig2 * (a * fa - b * fb)
|
|
47
|
-
- 2.0 * c * sig2 * (fa - fb)
|
|
48
|
-
+ (sig2 + c * c) * cdf_diff
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def _quad(f, a: float, b: float, n: int = 200) -> float:
|
|
53
|
-
"""Composite Simpson's rule numerical integration over [a, b].
|
|
54
|
-
|
|
55
|
-
Used only for the ``use_exact=True`` (Beta-PDF) path; the Gaussian path
|
|
56
|
-
uses closed-form helpers above.
|
|
57
|
-
"""
|
|
58
|
-
if n % 2 != 0:
|
|
59
|
-
n += 1
|
|
60
|
-
h = (b - a) / n
|
|
61
|
-
s = f(a) + f(b)
|
|
62
|
-
for i in range(1, n):
|
|
63
|
-
s += (4 if i % 2 else 2) * f(a + i * h)
|
|
64
|
-
return h / 3.0 * s
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def beta_pdf(x: float, d: int) -> float:
|
|
68
|
-
"""PDF of a single coordinate after random rotation of a d-dim unit vector."""
|
|
69
|
-
if abs(x) >= 1.0:
|
|
70
|
-
return 0.0
|
|
71
|
-
coeff = math.gamma(d / 2) / (math.sqrt(math.pi) * math.gamma((d - 1) / 2))
|
|
72
|
-
return coeff * (1 - x * x) ** ((d - 3) / 2)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def gaussian_approx_pdf(x: float, d: int) -> float:
|
|
76
|
-
"""Gaussian approximation N(0, 1/d) -- accurate for d >= 64."""
|
|
77
|
-
sigma2 = 1.0 / d
|
|
78
|
-
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def solve_lloyd_max(
|
|
82
|
-
d: int, bits: int, use_exact: bool = False, max_iter: int = 200, tol: float = 1e-10
|
|
83
|
-
):
|
|
84
|
-
"""
|
|
85
|
-
Solve Lloyd-Max optimal quantizer for the coordinate distribution.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
d: vector dimension
|
|
89
|
-
bits: number of quantization bits
|
|
90
|
-
use_exact: if True, use exact Beta PDF; if False, use Gaussian approx
|
|
91
|
-
max_iter: maximum Lloyd-Max iterations
|
|
92
|
-
tol: convergence tolerance
|
|
93
|
-
|
|
94
|
-
Returns:
|
|
95
|
-
centroids: sorted tensor of 2^bits optimal centroids
|
|
96
|
-
boundaries: sorted tensor of 2^bits - 1 boundaries between centroids
|
|
97
|
-
"""
|
|
98
|
-
n_levels = 2**bits
|
|
99
|
-
sigma = 1.0 / math.sqrt(d)
|
|
100
|
-
|
|
101
|
-
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
102
|
-
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
103
|
-
|
|
104
|
-
for _ in range(max_iter):
|
|
105
|
-
boundaries = [
|
|
106
|
-
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
107
|
-
]
|
|
108
|
-
edges = [lo * 3] + boundaries + [hi * 3]
|
|
109
|
-
new_centroids = []
|
|
110
|
-
for i in range(n_levels):
|
|
111
|
-
a, b = edges[i], edges[i + 1]
|
|
112
|
-
|
|
113
|
-
if use_exact:
|
|
114
|
-
numerator = _quad(lambda x: x * beta_pdf(x, d), a, b)
|
|
115
|
-
denominator = _quad(lambda x: beta_pdf(x, d), a, b)
|
|
116
|
-
else:
|
|
117
|
-
numerator = _int_x_pdf(a, b, sigma)
|
|
118
|
-
denominator = _int_pdf(a, b, sigma)
|
|
119
|
-
|
|
120
|
-
if denominator > 1e-15:
|
|
121
|
-
new_centroids.append(numerator / denominator)
|
|
122
|
-
else:
|
|
123
|
-
new_centroids.append(centroids[i])
|
|
124
|
-
|
|
125
|
-
max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
|
|
126
|
-
centroids = new_centroids
|
|
127
|
-
|
|
128
|
-
if max_shift < tol:
|
|
129
|
-
break
|
|
130
|
-
|
|
131
|
-
boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
|
|
132
|
-
|
|
133
|
-
return (
|
|
134
|
-
torch.tensor(centroids, dtype=torch.float32),
|
|
135
|
-
torch.tensor(boundaries, dtype=torch.float32),
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def compute_expected_distortion(
|
|
140
|
-
d: int,
|
|
141
|
-
bits: int,
|
|
142
|
-
centroids: torch.Tensor,
|
|
143
|
-
boundaries: torch.Tensor,
|
|
144
|
-
use_exact: bool = False,
|
|
145
|
-
) -> float:
|
|
146
|
-
"""Compute the expected MSE distortion per coordinate for the given quantizer."""
|
|
147
|
-
sigma = 1.0 / math.sqrt(d)
|
|
148
|
-
n_levels = len(centroids)
|
|
149
|
-
|
|
150
|
-
edges = [-3.5 * sigma * 3] + boundaries.tolist() + [3.5 * sigma * 3]
|
|
151
|
-
total_distortion = 0.0
|
|
152
|
-
|
|
153
|
-
for i in range(n_levels):
|
|
154
|
-
a, b = edges[i], edges[i + 1]
|
|
155
|
-
c = centroids[i].item()
|
|
156
|
-
if use_exact:
|
|
157
|
-
dist = _quad(lambda x, _c=c: (x - _c) ** 2 * beta_pdf(x, d), a, b)
|
|
158
|
-
else:
|
|
159
|
-
dist = _int_sq_pdf(a, b, sigma, c)
|
|
160
|
-
total_distortion += dist
|
|
161
|
-
|
|
162
|
-
return total_distortion
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
class LloydMaxCodebook:
|
|
166
|
-
"""Precomputed Lloyd-Max codebook for a given dimension and bit-width."""
|
|
167
|
-
|
|
168
|
-
def __init__(self, d: int, bits: int, use_exact: bool = False):
|
|
169
|
-
self.d = d
|
|
170
|
-
self.bits = bits
|
|
171
|
-
self.n_levels = 2**bits
|
|
172
|
-
self.centroids, self.boundaries = solve_lloyd_max(d, bits, use_exact)
|
|
173
|
-
self.distortion = compute_expected_distortion(
|
|
174
|
-
d, bits, self.centroids, self.boundaries, use_exact
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
def quantize(self, x: torch.Tensor) -> torch.Tensor:
|
|
178
|
-
"""Quantize values to nearest centroid indices."""
|
|
179
|
-
diffs = x.unsqueeze(-1) - self.centroids.to(x.device)
|
|
180
|
-
return diffs.abs().argmin(dim=-1)
|
|
181
|
-
|
|
182
|
-
def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
|
|
183
|
-
"""Map indices back to centroid values."""
|
|
184
|
-
return self.centroids.to(indices.device)[indices]
|
|
185
|
-
|
|
186
|
-
def __repr__(self):
|
|
187
|
-
return (
|
|
188
|
-
f"LloydMaxCodebook(d={self.d}, bits={self.bits}, "
|
|
189
|
-
f"levels={self.n_levels}, distortion_per_coord={self.distortion:.6f})"
|
|
190
|
-
)
|
|
@@ -1,249 +0,0 @@
|
|
|
1
|
-
"""TurboQuant: two-stage vector quantization."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import math
|
|
6
|
-
from typing import Optional, Tuple, cast
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
from torch import nn
|
|
10
|
-
|
|
11
|
-
from .lloyd_max import LloydMaxCodebook
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def generate_rotation_matrix(
|
|
15
|
-
d: int, seed: Optional[int] = None, device: str = "cpu"
|
|
16
|
-
) -> torch.Tensor:
|
|
17
|
-
"""Generate a random orthogonal rotation matrix via QR decomposition."""
|
|
18
|
-
gen = torch.Generator(device="cpu")
|
|
19
|
-
if seed is not None:
|
|
20
|
-
gen.manual_seed(seed)
|
|
21
|
-
G = torch.randn(d, d, generator=gen)
|
|
22
|
-
Q, R = torch.linalg.qr(G)
|
|
23
|
-
diag_sign = torch.sign(torch.diag(R))
|
|
24
|
-
diag_sign[diag_sign == 0] = 1.0
|
|
25
|
-
Q = Q * diag_sign.unsqueeze(0)
|
|
26
|
-
return Q.to(device)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def generate_qjl_matrix(
|
|
30
|
-
d: int, m: Optional[int] = None, seed: Optional[int] = None, device: str = "cpu"
|
|
31
|
-
) -> torch.Tensor:
|
|
32
|
-
"""
|
|
33
|
-
Generate the random projection matrix S for QJL.
|
|
34
|
-
S has i.i.d. N(0,1) entries, shape (m, d).
|
|
35
|
-
Default m = d (same dimensionality).
|
|
36
|
-
"""
|
|
37
|
-
if m is None:
|
|
38
|
-
m = d
|
|
39
|
-
gen = torch.Generator(device="cpu")
|
|
40
|
-
if seed is not None:
|
|
41
|
-
gen.manual_seed(seed)
|
|
42
|
-
S = torch.randn(m, d, generator=gen)
|
|
43
|
-
return S.to(device)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class TurboQuantMSE(nn.Module):
|
|
47
|
-
"""Stage 1: MSE-optimal quantizer."""
|
|
48
|
-
|
|
49
|
-
def __init__(self, d: int, bits: int, seed: int = 42, device: str = "cpu"):
|
|
50
|
-
super().__init__()
|
|
51
|
-
self.d = d
|
|
52
|
-
self.bits = bits
|
|
53
|
-
self.device = device
|
|
54
|
-
|
|
55
|
-
self.register_buffer(
|
|
56
|
-
"Pi", generate_rotation_matrix(d, seed=seed, device=device)
|
|
57
|
-
)
|
|
58
|
-
self.codebook = LloydMaxCodebook(d, bits)
|
|
59
|
-
self.register_buffer("centroids", self.codebook.centroids.to(device))
|
|
60
|
-
self.register_buffer("boundaries", self.codebook.boundaries.to(device))
|
|
61
|
-
|
|
62
|
-
def rotate(self, x: torch.Tensor) -> torch.Tensor:
|
|
63
|
-
Pi = cast("torch.Tensor", self.Pi)
|
|
64
|
-
return x @ Pi.T
|
|
65
|
-
|
|
66
|
-
def unrotate(self, y: torch.Tensor) -> torch.Tensor:
|
|
67
|
-
Pi = cast("torch.Tensor", self.Pi)
|
|
68
|
-
return y @ Pi
|
|
69
|
-
|
|
70
|
-
def quantize(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
-
centroids = cast("torch.Tensor", self.centroids)
|
|
72
|
-
y = self.rotate(x)
|
|
73
|
-
diffs = y.unsqueeze(-1) - centroids
|
|
74
|
-
indices = diffs.abs().argmin(dim=-1)
|
|
75
|
-
return indices
|
|
76
|
-
|
|
77
|
-
def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
|
|
78
|
-
centroids = cast("torch.Tensor", self.centroids)
|
|
79
|
-
y_hat = centroids[indices]
|
|
80
|
-
return self.unrotate(y_hat)
|
|
81
|
-
|
|
82
|
-
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
-
indices = self.quantize(x)
|
|
84
|
-
x_hat = self.dequantize(indices)
|
|
85
|
-
return x_hat, indices
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
class TurboQuantProd(nn.Module):
|
|
89
|
-
"""Stage 1 + Stage 2: Unbiased inner product quantizer."""
|
|
90
|
-
|
|
91
|
-
def __init__(
|
|
92
|
-
self,
|
|
93
|
-
d: int,
|
|
94
|
-
bits: int,
|
|
95
|
-
qjl_dim: Optional[int] = None,
|
|
96
|
-
seed: int = 42,
|
|
97
|
-
device: str = "cpu",
|
|
98
|
-
):
|
|
99
|
-
super().__init__()
|
|
100
|
-
self.d = d
|
|
101
|
-
self.bits = bits
|
|
102
|
-
self.mse_bits = max(bits - 1, 1)
|
|
103
|
-
self.qjl_dim = qjl_dim or d
|
|
104
|
-
self.device = device
|
|
105
|
-
|
|
106
|
-
self.mse = TurboQuantMSE(d, self.mse_bits, seed=seed, device=device)
|
|
107
|
-
self.register_buffer(
|
|
108
|
-
"S", generate_qjl_matrix(d, m=self.qjl_dim, seed=seed + 1, device=device)
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
def quantize(self, x: torch.Tensor) -> dict:
|
|
112
|
-
x_hat, mse_indices = self.mse(x)
|
|
113
|
-
residual = x - x_hat
|
|
114
|
-
residual_norm = torch.norm(residual, dim=-1, keepdim=True)
|
|
115
|
-
S = cast("torch.Tensor", self.S)
|
|
116
|
-
projected = residual @ S.T
|
|
117
|
-
qjl_signs = torch.sign(projected)
|
|
118
|
-
qjl_signs[qjl_signs == 0] = 1.0
|
|
119
|
-
|
|
120
|
-
return {
|
|
121
|
-
"mse_indices": mse_indices,
|
|
122
|
-
"qjl_signs": qjl_signs,
|
|
123
|
-
"residual_norm": residual_norm.squeeze(-1),
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
def dequantize(self, compressed: dict) -> torch.Tensor:
|
|
127
|
-
return self.mse.dequantize(compressed["mse_indices"])
|
|
128
|
-
|
|
129
|
-
def inner_product(self, y: torch.Tensor, compressed: dict) -> torch.Tensor:
|
|
130
|
-
x_mse = self.mse.dequantize(compressed["mse_indices"])
|
|
131
|
-
term1 = (y * x_mse).sum(dim=-1)
|
|
132
|
-
|
|
133
|
-
S = cast("torch.Tensor", self.S)
|
|
134
|
-
y_projected = y @ S.T
|
|
135
|
-
qjl_ip = (y_projected * compressed["qjl_signs"]).sum(dim=-1)
|
|
136
|
-
|
|
137
|
-
m = self.qjl_dim
|
|
138
|
-
correction_scale = math.sqrt(math.pi / 2) / m
|
|
139
|
-
term2 = compressed["residual_norm"] * correction_scale * qjl_ip
|
|
140
|
-
|
|
141
|
-
return term1 + term2
|
|
142
|
-
|
|
143
|
-
def forward(self, x: torch.Tensor) -> dict:
|
|
144
|
-
return self.quantize(x)
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
class TurboQuantKVCache:
|
|
148
|
-
"""KV cache wrapper that uses TurboQuant to compress keys and values."""
|
|
149
|
-
|
|
150
|
-
def __init__(
|
|
151
|
-
self,
|
|
152
|
-
d_key: int,
|
|
153
|
-
d_value: int,
|
|
154
|
-
bits: int = 3,
|
|
155
|
-
seed: int = 42,
|
|
156
|
-
device: str = "cpu",
|
|
157
|
-
):
|
|
158
|
-
self.d_key = d_key
|
|
159
|
-
self.d_value = d_value
|
|
160
|
-
self.bits = bits
|
|
161
|
-
self.device = device
|
|
162
|
-
|
|
163
|
-
self.key_quantizer = TurboQuantProd(d_key, bits, seed=seed, device=device)
|
|
164
|
-
self.value_quantizer = TurboQuantMSE(
|
|
165
|
-
d_value, bits, seed=seed + 100, device=device
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
self.key_cache = []
|
|
169
|
-
self.value_cache = []
|
|
170
|
-
|
|
171
|
-
def append(self, keys: torch.Tensor, values: torch.Tensor):
|
|
172
|
-
orig_shape = keys.shape
|
|
173
|
-
flat_keys = keys.reshape(-1, self.d_key)
|
|
174
|
-
flat_values = values.reshape(-1, self.d_value)
|
|
175
|
-
|
|
176
|
-
compressed_keys = self.key_quantizer.quantize(flat_keys)
|
|
177
|
-
value_indices = self.value_quantizer.quantize(flat_values)
|
|
178
|
-
|
|
179
|
-
self.key_cache.append(
|
|
180
|
-
{
|
|
181
|
-
"mse_indices": compressed_keys["mse_indices"],
|
|
182
|
-
"qjl_signs": compressed_keys["qjl_signs"],
|
|
183
|
-
"residual_norm": compressed_keys["residual_norm"],
|
|
184
|
-
"shape": orig_shape,
|
|
185
|
-
}
|
|
186
|
-
)
|
|
187
|
-
self.value_cache.append(
|
|
188
|
-
{
|
|
189
|
-
"indices": value_indices,
|
|
190
|
-
"shape": values.shape,
|
|
191
|
-
}
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
def attention_scores(self, queries: torch.Tensor) -> torch.Tensor:
|
|
195
|
-
scores = []
|
|
196
|
-
for cached in self.key_cache:
|
|
197
|
-
s = self.key_quantizer.inner_product(queries, cached)
|
|
198
|
-
scores.append(s)
|
|
199
|
-
return torch.cat(scores, dim=-1) if scores else torch.tensor([])
|
|
200
|
-
|
|
201
|
-
def get_values(self) -> torch.Tensor:
|
|
202
|
-
values = []
|
|
203
|
-
for cached in self.value_cache:
|
|
204
|
-
v = self.value_quantizer.dequantize(cached["indices"])
|
|
205
|
-
values.append(v)
|
|
206
|
-
return torch.cat(values, dim=0) if values else torch.tensor([])
|
|
207
|
-
|
|
208
|
-
def memory_usage_bits(self) -> dict:
|
|
209
|
-
n_keys = (
|
|
210
|
-
sum(c["mse_indices"].numel() for c in self.key_cache)
|
|
211
|
-
if self.key_cache
|
|
212
|
-
else 0
|
|
213
|
-
)
|
|
214
|
-
n_qjl = (
|
|
215
|
-
sum(c["qjl_signs"].numel() for c in self.key_cache) if self.key_cache else 0
|
|
216
|
-
)
|
|
217
|
-
n_norms = (
|
|
218
|
-
sum(c["residual_norm"].numel() for c in self.key_cache)
|
|
219
|
-
if self.key_cache
|
|
220
|
-
else 0
|
|
221
|
-
)
|
|
222
|
-
n_values = (
|
|
223
|
-
sum(c["indices"].numel() for c in self.value_cache)
|
|
224
|
-
if self.value_cache
|
|
225
|
-
else 0
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
key_bits = n_keys * self.key_quantizer.mse_bits + n_qjl * 1 + n_norms * 16
|
|
229
|
-
value_bits = n_values * self.bits
|
|
230
|
-
fp16_equivalent = (n_keys + n_values) * 16
|
|
231
|
-
|
|
232
|
-
return {
|
|
233
|
-
"key_bits": key_bits,
|
|
234
|
-
"value_bits": value_bits,
|
|
235
|
-
"total_bits": key_bits + value_bits,
|
|
236
|
-
"fp16_bits": fp16_equivalent,
|
|
237
|
-
"compression_ratio": (
|
|
238
|
-
fp16_equivalent / (key_bits + value_bits)
|
|
239
|
-
if (key_bits + value_bits) > 0
|
|
240
|
-
else 0
|
|
241
|
-
),
|
|
242
|
-
}
|
|
243
|
-
|
|
244
|
-
def __len__(self):
|
|
245
|
-
return (
|
|
246
|
-
sum(c["mse_indices"].shape[0] for c in self.key_cache)
|
|
247
|
-
if self.key_cache
|
|
248
|
-
else 0
|
|
249
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/entry_points.txt
RENAMED
|
File without changes
|
{mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|