mcp-plesk-dev-docs 0.4.2__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {mcp_plesk_dev_docs-0.4.2/mcp_plesk_dev_docs.egg-info → mcp_plesk_dev_docs-0.5.0}/PKG-INFO +4 -1
  2. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/README.md +3 -1
  3. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0/mcp_plesk_dev_docs.egg-info}/PKG-INFO +4 -1
  4. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/SOURCES.txt +0 -4
  5. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/requires.txt +1 -0
  6. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/tq_index.py +23 -11
  7. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/pyproject.toml +8 -2
  8. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_turboquant_regression.py +1 -1
  9. mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/__init__.py +0 -21
  10. mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/compressors.py +0 -190
  11. mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/lloyd_max.py +0 -190
  12. mcp_plesk_dev_docs-0.4.2/plesk_unified/turboquant/turboquant.py +0 -249
  13. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/LICENSE +0 -0
  14. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/NOTICE +0 -0
  15. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/dependency_links.txt +0 -0
  16. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/entry_points.txt +0 -0
  17. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/mcp_plesk_dev_docs.egg-info/top_level.txt +0 -0
  18. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/__init__.py +0 -0
  19. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/ai_client.py +0 -0
  20. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_engines.py +0 -0
  21. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_gates.py +0 -0
  22. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_reporting.py +0 -0
  23. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_runner.py +0 -0
  24. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/benchmark_suites.py +0 -0
  25. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/chunking.py +0 -0
  26. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/error_handling.py +0 -0
  27. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/html_utils.py +0 -0
  28. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/indexing.py +0 -0
  29. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/io_utils.py +0 -0
  30. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/log_handler.py +0 -0
  31. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/model_config.py +0 -0
  32. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/platform_utils.py +0 -0
  33. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/settings.py +0 -0
  34. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/summary_cache.py +0 -0
  35. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/plesk_unified/types.py +0 -0
  36. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/setup.cfg +0 -0
  37. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_ai_client.py +0 -0
  38. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_async_tools.py +0 -0
  39. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_benchmark_engines.py +0 -0
  40. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_benchmark_gates.py +0 -0
  41. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_chunking.py +0 -0
  42. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_error_handling.py +0 -0
  43. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_html_utils.py +0 -0
  44. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_indexing.py +0 -0
  45. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_io_utils.py +0 -0
  46. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_log_handler.py +0 -0
  47. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_model_config.py +0 -0
  48. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_progress.py +0 -0
  49. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_prompts.py +0 -0
  50. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_resources.py +0 -0
  51. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_sampling.py +0 -0
  52. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_search_helpers.py +0 -0
  53. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_server.py +0 -0
  54. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_settings.py +0 -0
  55. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_sota_ph1.py +0 -0
  56. {mcp_plesk_dev_docs-0.4.2 → mcp_plesk_dev_docs-0.5.0}/tests/test_startup_path.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mcp-plesk-dev-docs
3
- Version: 0.4.2
3
+ Version: 0.5.0
4
4
  Summary: A unified MCP server that indexes and retrieves Plesk documentation using vector embeddings and semantic search with reranking
5
5
  Author-email: Gilson Siqueira <gilson@example.com>
6
6
  License-Expression: MIT
@@ -35,6 +35,7 @@ Requires-Dist: torch>=2.4.0
35
35
  Requires-Dist: markdownify>=0.14.1
36
36
  Requires-Dist: tantivy>=0.22.0
37
37
  Requires-Dist: lance-namespace==0.6.1
38
+ Requires-Dist: tq-search
38
39
  Provides-Extra: dev
39
40
  Requires-Dist: pytest>=8.0.0; extra == "dev"
40
41
  Requires-Dist: requests>=2.32.0; extra == "dev"
@@ -49,6 +50,8 @@ Dynamic: license-file
49
50
  # mcp-plesk-dev-docs
50
51
 
51
52
  [![Python 3.12+](https://img.shields.io/badge/python-3.12%2B-blue?style=flat-square)](https://www.python.org/downloads/)
53
+ [![PyPI](https://img.shields.io/pypi/v/mcp-plesk-dev-docs?style=flat-square)](https://pypi.org/project/mcp-plesk-dev-docs/)
54
+ [![MCP Registry](https://img.shields.io/badge/MCP%20Registry-listed-green?style=flat-square)](https://registry.modelcontextprotocol.io/v0.1/servers/io.github.barateza%2Fmcp-plesk-dev-docs/versions/0.4.3)
52
55
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=flat-square)](LICENSE)
53
56
  [![MCP Compatible](https://img.shields.io/badge/MCP-Compatible-green?style=flat-square)](https://modelcontextprotocol.io/)
54
57
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/psf/black)
@@ -1,6 +1,8 @@
1
1
  # mcp-plesk-dev-docs
2
2
 
3
3
  [![Python 3.12+](https://img.shields.io/badge/python-3.12%2B-blue?style=flat-square)](https://www.python.org/downloads/)
4
+ [![PyPI](https://img.shields.io/pypi/v/mcp-plesk-dev-docs?style=flat-square)](https://pypi.org/project/mcp-plesk-dev-docs/)
5
+ [![MCP Registry](https://img.shields.io/badge/MCP%20Registry-listed-green?style=flat-square)](https://registry.modelcontextprotocol.io/v0.1/servers/io.github.barateza%2Fmcp-plesk-dev-docs/versions/0.4.3)
4
6
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=flat-square)](LICENSE)
5
7
  [![MCP Compatible](https://img.shields.io/badge/MCP-Compatible-green?style=flat-square)](https://modelcontextprotocol.io/)
6
8
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/psf/black)
@@ -170,4 +172,4 @@ Portions of this repository were developed under contract for Plesk Internationa
170
172
 
171
173
  *Built to make Plesk extension development faster.*
172
174
 
173
- <!-- mcp-name: io.github.barateza/mcp-plesk-dev-docs -->
175
+ <!-- mcp-name: io.github.barateza/mcp-plesk-dev-docs -->
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mcp-plesk-dev-docs
3
- Version: 0.4.2
3
+ Version: 0.5.0
4
4
  Summary: A unified MCP server that indexes and retrieves Plesk documentation using vector embeddings and semantic search with reranking
5
5
  Author-email: Gilson Siqueira <gilson@example.com>
6
6
  License-Expression: MIT
@@ -35,6 +35,7 @@ Requires-Dist: torch>=2.4.0
35
35
  Requires-Dist: markdownify>=0.14.1
36
36
  Requires-Dist: tantivy>=0.22.0
37
37
  Requires-Dist: lance-namespace==0.6.1
38
+ Requires-Dist: tq-search
38
39
  Provides-Extra: dev
39
40
  Requires-Dist: pytest>=8.0.0; extra == "dev"
40
41
  Requires-Dist: requests>=2.32.0; extra == "dev"
@@ -49,6 +50,8 @@ Dynamic: license-file
49
50
  # mcp-plesk-dev-docs
50
51
 
51
52
  [![Python 3.12+](https://img.shields.io/badge/python-3.12%2B-blue?style=flat-square)](https://www.python.org/downloads/)
53
+ [![PyPI](https://img.shields.io/pypi/v/mcp-plesk-dev-docs?style=flat-square)](https://pypi.org/project/mcp-plesk-dev-docs/)
54
+ [![MCP Registry](https://img.shields.io/badge/MCP%20Registry-listed-green?style=flat-square)](https://registry.modelcontextprotocol.io/v0.1/servers/io.github.barateza%2Fmcp-plesk-dev-docs/versions/0.4.3)
52
55
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg?style=flat-square)](LICENSE)
53
56
  [![MCP Compatible](https://img.shields.io/badge/MCP-Compatible-green?style=flat-square)](https://modelcontextprotocol.io/)
54
57
  [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/psf/black)
@@ -27,10 +27,6 @@ plesk_unified/settings.py
27
27
  plesk_unified/summary_cache.py
28
28
  plesk_unified/tq_index.py
29
29
  plesk_unified/types.py
30
- plesk_unified/turboquant/__init__.py
31
- plesk_unified/turboquant/compressors.py
32
- plesk_unified/turboquant/lloyd_max.py
33
- plesk_unified/turboquant/turboquant.py
34
30
  tests/test_ai_client.py
35
31
  tests/test_async_tools.py
36
32
  tests/test_benchmark_engines.py
@@ -11,6 +11,7 @@ torch>=2.4.0
11
11
  markdownify>=0.14.1
12
12
  tantivy>=0.22.0
13
13
  lance-namespace==0.6.1
14
+ tq-search
14
15
 
15
16
  [dev]
16
17
  pytest>=8.0.0
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import numpy as np
6
6
  import torch
7
7
 
8
- from plesk_unified.turboquant import TurboQuantProd
8
+ from tq_search import TurboQuantProd
9
9
 
10
10
 
11
11
  class TurboQuantIndex:
@@ -52,6 +52,13 @@ class TurboQuantIndex:
52
52
  if self.compressed_db is None:
53
53
  return []
54
54
 
55
+ # 1. Lazily move the compressed database to the target device once
56
+ first_val = next(iter(self.compressed_db.values()))
57
+ if str(first_val.device) != self.device:
58
+ self.compressed_db = {
59
+ k: v.to(self.device) for k, v in self.compressed_db.items()
60
+ }
61
+
55
62
  selected_indices: list[int]
56
63
  if category:
57
64
  selected_indices = self._category_to_indices.get(category, [])
@@ -60,25 +67,30 @@ class TurboQuantIndex:
60
67
  else:
61
68
  selected_indices = list(range(len(self._meta)))
62
69
 
63
- # 1. L2-Normalize the query
70
+ # 2. L2-Normalize the query
64
71
  norm = np.linalg.norm(query_vec)
65
72
  query_normalized = query_vec / max(norm, 1e-12)
66
73
 
67
- # 2. Prepare query as a batched tensor (1, dim)
74
+ # 3. Prepare query as a batched tensor (1, dim) directly on target device
68
75
  q = torch.from_numpy(query_normalized).to(self.device).unsqueeze(0)
69
76
 
70
- # 3. Slice candidates and move them to the target device.
71
- selected_tensor = torch.as_tensor(selected_indices, dtype=torch.long)
72
- db_on_device = {
73
- k: v.index_select(0, selected_tensor).to(self.device)
74
- for k, v in self.compressed_db.items()
75
- }
77
+ # 4. Slice candidates only if we are actually filtering a subset
78
+ if category and len(selected_indices) < len(self._meta):
79
+ selected_tensor = torch.as_tensor(
80
+ selected_indices, dtype=torch.long, device=self.device
81
+ )
82
+ db_on_device = {
83
+ k: v.index_select(0, selected_tensor)
84
+ for k, v in self.compressed_db.items()
85
+ }
86
+ else:
87
+ db_on_device = self.compressed_db
76
88
 
77
- # 4. Perform a SINGLE batched inner product calculation
89
+ # 5. Perform a SINGLE batched inner product calculation
78
90
  with torch.no_grad():
79
91
  scores = self.quantizer.inner_product(q, db_on_device).squeeze(0)
80
92
 
81
- # 5. Sort and return
93
+ # 6. Sort and return
82
94
  scores_np = scores.cpu().numpy()
83
95
  idx = np.argsort(-scores_np)[:top_k]
84
96
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mcp-plesk-dev-docs"
7
- version = "0.4.2"
7
+ version = "0.5.0"
8
8
  description = "A unified MCP server that indexes and retrieves Plesk documentation using vector embeddings and semantic search with reranking"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -38,8 +38,11 @@ dependencies = [
38
38
  "markdownify>=0.14.1",
39
39
  "tantivy>=0.22.0",
40
40
  "lance-namespace==0.6.1",
41
+ "tq-search",
41
42
  ]
42
43
 
44
+
45
+
43
46
  [project.urls]
44
47
  Homepage = "https://github.com/barateza/mcp-plesk-dev-docs"
45
48
  Documentation = "https://github.com/barateza/mcp-plesk-dev-docs#readme"
@@ -47,7 +50,8 @@ Repository = "https://github.com/barateza/mcp-plesk-dev-docs.git"
47
50
  "Bug Tracker" = "https://github.com/barateza/mcp-plesk-dev-docs/issues"
48
51
 
49
52
  [tool.setuptools]
50
- packages = ["plesk_unified", "plesk_unified.turboquant"]
53
+ packages = ["plesk_unified"]
54
+
51
55
 
52
56
  [project.scripts]
53
57
  # Console script to run the MCP server
@@ -78,7 +82,9 @@ url = "https://download.pytorch.org/whl/cu124"
78
82
  explicit = true
79
83
 
80
84
  [tool.uv.sources]
85
+ tq-search = { path = "/Users/gilsonsiqueira/tq-search", editable = true }
81
86
  torch = [
87
+
82
88
  { index = "pytorch-cu124", marker = "sys_platform == 'win32'" },
83
89
  ]
84
90
  torchvision = [
@@ -4,7 +4,7 @@ import numpy as np
4
4
  import torch
5
5
 
6
6
  from plesk_unified import tq_index
7
- from plesk_unified.turboquant import LloydMaxCodebook, TurboQuantMSE, TurboQuantProd
7
+ from tq_search import LloydMaxCodebook, TurboQuantMSE, TurboQuantProd
8
8
 
9
9
 
10
10
  def test_turboquant_package_exports():
@@ -1,21 +0,0 @@
1
- """TurboQuant helpers used by the unified retrieval path.
2
-
3
- Base implementation: https://github.com/tonbistudio/turboquant-pytorch
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- from .compressors import TurboQuantCompressorMSE, TurboQuantCompressorV2
9
- from .lloyd_max import LloydMaxCodebook, compute_expected_distortion, solve_lloyd_max
10
- from .turboquant import TurboQuantKVCache, TurboQuantMSE, TurboQuantProd
11
-
12
- __all__ = [
13
- "TurboQuantCompressorMSE",
14
- "TurboQuantCompressorV2",
15
- "TurboQuantKVCache",
16
- "TurboQuantMSE",
17
- "TurboQuantProd",
18
- "LloydMaxCodebook",
19
- "compute_expected_distortion",
20
- "solve_lloyd_max",
21
- ]
@@ -1,190 +0,0 @@
1
- """TurboQuant KV cache helpers."""
2
-
3
- from __future__ import annotations
4
-
5
- import math
6
-
7
- import torch
8
-
9
-
10
- # ---------------------------------------------------------------------------
11
- # Closed-form Gaussian integration helpers (replaces scipy.integrate.quad)
12
- # ---------------------------------------------------------------------------
13
-
14
-
15
- def _gauss_pdf(x: float, sigma: float) -> float:
16
- return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
17
-
18
-
19
- def _gauss_cdf(x: float, sigma: float) -> float:
20
- return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
21
-
22
-
23
- def _int_pdf(a: float, b: float, sigma: float) -> float:
24
- return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
25
-
26
-
27
- def _int_x_pdf(a: float, b: float, sigma: float) -> float:
28
- return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
29
-
30
-
31
- class TurboQuantCompressorV2:
32
- """Compressed key store with direct inner-product scoring."""
33
-
34
- def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
35
- self.head_dim = head_dim
36
- self.bits = bits
37
- self.mse_bits = max(bits - 1, 1)
38
- self.device = device
39
-
40
- gen = torch.Generator(device="cpu")
41
- gen.manual_seed(seed)
42
- G = torch.randn(head_dim, head_dim, generator=gen)
43
- Q, R = torch.linalg.qr(G)
44
- diag_sign = torch.sign(torch.diag(R))
45
- diag_sign[diag_sign == 0] = 1.0
46
- self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
47
-
48
- self.centroids = self._solve_codebook(head_dim, self.mse_bits).to(device)
49
-
50
- gen2 = torch.Generator(device="cpu")
51
- gen2.manual_seed(seed + 10000)
52
- self.S = torch.randn(head_dim, head_dim, generator=gen2).to(device)
53
-
54
- self.PiT = self.Pi.T.contiguous()
55
-
56
- def _solve_codebook(self, d: int, bits: int) -> torch.Tensor:
57
- n_levels = 2**bits
58
- sigma = 1.0 / math.sqrt(d)
59
-
60
- lo, hi = -3.5 * sigma, 3.5 * sigma
61
- centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
62
-
63
- for _ in range(200):
64
- boundaries = [
65
- (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
66
- ]
67
- edges = [lo * 3] + boundaries + [hi * 3]
68
- new_centroids = []
69
- for i in range(n_levels):
70
- a, b = edges[i], edges[i + 1]
71
- num = _int_x_pdf(a, b, sigma)
72
- den = _int_pdf(a, b, sigma)
73
- new_centroids.append(num / den if den > 1e-15 else centroids[i])
74
- if (
75
- max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
76
- < 1e-10
77
- ):
78
- break
79
- centroids = new_centroids
80
-
81
- return torch.tensor(centroids, dtype=torch.float32)
82
-
83
- @torch.no_grad()
84
- def compress(self, states: torch.Tensor) -> dict:
85
- B, H, S, D = states.shape
86
- flat = states.reshape(-1, D).float()
87
-
88
- vec_norms = torch.norm(flat, dim=-1, keepdim=True)
89
- flat_norm = flat / (vec_norms + 1e-8)
90
-
91
- rotated = flat_norm @ self.Pi.T
92
- diffs = rotated.unsqueeze(-1) - self.centroids
93
- indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
94
-
95
- reconstructed_rotated = self.centroids[indices.long()]
96
- k_mse = (reconstructed_rotated @ self.Pi) * vec_norms
97
-
98
- residual = flat - k_mse
99
- residual_norm = torch.norm(residual, dim=-1)
100
-
101
- projected = residual @ self.S.T
102
- signs = (projected >= 0).to(torch.int8) * 2 - 1
103
-
104
- return {
105
- "k_mse": k_mse.to(torch.float16).reshape(B, H, S, D),
106
- "qjl_signs": signs.reshape(B, H, S, D),
107
- "residual_norm": residual_norm.to(torch.float16).reshape(B, H, S),
108
- "shape": (B, H, S, D),
109
- }
110
-
111
- @torch.no_grad()
112
- def asymmetric_attention_scores(
113
- self, queries: torch.Tensor, compressed: dict
114
- ) -> torch.Tensor:
115
- k_mse = compressed["k_mse"].float()
116
- signs = compressed["qjl_signs"].float()
117
- r_norm = compressed["residual_norm"].float()
118
-
119
- term1 = torch.matmul(queries.float(), k_mse.transpose(-2, -1))
120
- q_projected = torch.matmul(queries.float(), self.S.T)
121
- qjl_ip = torch.matmul(q_projected, signs.transpose(-2, -1))
122
-
123
- m = self.S.shape[0]
124
- correction_scale = math.sqrt(math.pi / 2) / m
125
- term2 = correction_scale * qjl_ip * r_norm.unsqueeze(-2)
126
-
127
- return term1 + term2
128
-
129
-
130
- class TurboQuantCompressorMSE:
131
- """MSE-only compressor for values."""
132
-
133
- def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
134
- self.head_dim = head_dim
135
- self.bits = bits
136
- self.device = device
137
-
138
- gen = torch.Generator(device="cpu")
139
- gen.manual_seed(seed)
140
- G = torch.randn(head_dim, head_dim, generator=gen)
141
- Q, R = torch.linalg.qr(G)
142
- diag_sign = torch.sign(torch.diag(R))
143
- diag_sign[diag_sign == 0] = 1.0
144
- self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
145
- self.centroids = self._solve_codebook(head_dim, bits).to(device)
146
-
147
- def _solve_codebook(self, d, bits):
148
- n_levels = 2**bits
149
- sigma = 1.0 / math.sqrt(d)
150
-
151
- lo, hi = -3.5 * sigma, 3.5 * sigma
152
- centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
153
- for _ in range(200):
154
- boundaries = [
155
- (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
156
- ]
157
- edges = [lo * 3] + boundaries + [hi * 3]
158
- new_c = []
159
- for i in range(n_levels):
160
- a, b = edges[i], edges[i + 1]
161
- num = _int_x_pdf(a, b, sigma)
162
- den = _int_pdf(a, b, sigma)
163
- new_c.append(num / den if den > 1e-15 else centroids[i])
164
- if max(abs(new_c[i] - centroids[i]) for i in range(n_levels)) < 1e-10:
165
- break
166
- centroids = new_c
167
- return torch.tensor(centroids, dtype=torch.float32)
168
-
169
- @torch.no_grad()
170
- def compress(self, states: torch.Tensor) -> dict:
171
- B, H, S, D = states.shape
172
- flat = states.reshape(-1, D).float()
173
- vec_norms = torch.norm(flat, dim=-1, keepdim=True)
174
- flat_norm = flat / (vec_norms + 1e-8)
175
- rotated = flat_norm @ self.Pi.T
176
- diffs = rotated.unsqueeze(-1) - self.centroids
177
- indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
178
- return {
179
- "indices": indices,
180
- "vec_norms": vec_norms.squeeze(-1).to(torch.float16),
181
- "shape": (B, H, S, D),
182
- }
183
-
184
- @torch.no_grad()
185
- def decompress(self, compressed: dict) -> torch.Tensor:
186
- B, H, S, D = compressed["shape"]
187
- indices = compressed["indices"].long()
188
- reconstructed = self.centroids[indices] @ self.Pi
189
- vec_norms = compressed["vec_norms"].float().unsqueeze(-1)
190
- return (reconstructed * vec_norms).reshape(B, H, S, D)
@@ -1,190 +0,0 @@
1
- # ruff: noqa
2
- """Lloyd-Max scalar quantizer for rotated unit vectors.
3
-
4
- The coordinate distribution is approximately Beta-shaped on [-1, 1] after
5
- random rotation. For d >= 64, a Gaussian N(0, 1/d) is a good approximation.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import math
11
-
12
- import torch
13
-
14
-
15
- # ---------------------------------------------------------------------------
16
- # Pure-Python Gaussian integration helpers (replaces scipy.integrate.quad)
17
- # ---------------------------------------------------------------------------
18
-
19
-
20
- def _gauss_pdf(x: float, sigma: float) -> float:
21
- """N(0, σ²) probability density at x."""
22
- return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
23
-
24
-
25
- def _gauss_cdf(x: float, sigma: float) -> float:
26
- """N(0, σ²) cumulative distribution at x."""
27
- return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
28
-
29
-
30
- def _int_pdf(a: float, b: float, sigma: float) -> float:
31
- """∫[a,b] N(0,σ²)(x) dx — closed form via erf."""
32
- return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
33
-
34
-
35
- def _int_x_pdf(a: float, b: float, sigma: float) -> float:
36
- """∫[a,b] x·N(0,σ²)(x) dx = σ²·[f(a) − f(b)]."""
37
- return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
38
-
39
-
40
- def _int_sq_pdf(a: float, b: float, sigma: float, c: float) -> float:
41
- """∫[a,b] (x−c)²·N(0,σ²)(x) dx — closed form."""
42
- fa, fb = _gauss_pdf(a, sigma), _gauss_pdf(b, sigma)
43
- cdf_diff = _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
44
- sig2 = sigma * sigma
45
- return (
46
- sig2 * (a * fa - b * fb)
47
- - 2.0 * c * sig2 * (fa - fb)
48
- + (sig2 + c * c) * cdf_diff
49
- )
50
-
51
-
52
- def _quad(f, a: float, b: float, n: int = 200) -> float:
53
- """Composite Simpson's rule numerical integration over [a, b].
54
-
55
- Used only for the ``use_exact=True`` (Beta-PDF) path; the Gaussian path
56
- uses closed-form helpers above.
57
- """
58
- if n % 2 != 0:
59
- n += 1
60
- h = (b - a) / n
61
- s = f(a) + f(b)
62
- for i in range(1, n):
63
- s += (4 if i % 2 else 2) * f(a + i * h)
64
- return h / 3.0 * s
65
-
66
-
67
- def beta_pdf(x: float, d: int) -> float:
68
- """PDF of a single coordinate after random rotation of a d-dim unit vector."""
69
- if abs(x) >= 1.0:
70
- return 0.0
71
- coeff = math.gamma(d / 2) / (math.sqrt(math.pi) * math.gamma((d - 1) / 2))
72
- return coeff * (1 - x * x) ** ((d - 3) / 2)
73
-
74
-
75
- def gaussian_approx_pdf(x: float, d: int) -> float:
76
- """Gaussian approximation N(0, 1/d) -- accurate for d >= 64."""
77
- sigma2 = 1.0 / d
78
- return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
79
-
80
-
81
- def solve_lloyd_max(
82
- d: int, bits: int, use_exact: bool = False, max_iter: int = 200, tol: float = 1e-10
83
- ):
84
- """
85
- Solve Lloyd-Max optimal quantizer for the coordinate distribution.
86
-
87
- Args:
88
- d: vector dimension
89
- bits: number of quantization bits
90
- use_exact: if True, use exact Beta PDF; if False, use Gaussian approx
91
- max_iter: maximum Lloyd-Max iterations
92
- tol: convergence tolerance
93
-
94
- Returns:
95
- centroids: sorted tensor of 2^bits optimal centroids
96
- boundaries: sorted tensor of 2^bits - 1 boundaries between centroids
97
- """
98
- n_levels = 2**bits
99
- sigma = 1.0 / math.sqrt(d)
100
-
101
- lo, hi = -3.5 * sigma, 3.5 * sigma
102
- centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
103
-
104
- for _ in range(max_iter):
105
- boundaries = [
106
- (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
107
- ]
108
- edges = [lo * 3] + boundaries + [hi * 3]
109
- new_centroids = []
110
- for i in range(n_levels):
111
- a, b = edges[i], edges[i + 1]
112
-
113
- if use_exact:
114
- numerator = _quad(lambda x: x * beta_pdf(x, d), a, b)
115
- denominator = _quad(lambda x: beta_pdf(x, d), a, b)
116
- else:
117
- numerator = _int_x_pdf(a, b, sigma)
118
- denominator = _int_pdf(a, b, sigma)
119
-
120
- if denominator > 1e-15:
121
- new_centroids.append(numerator / denominator)
122
- else:
123
- new_centroids.append(centroids[i])
124
-
125
- max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
126
- centroids = new_centroids
127
-
128
- if max_shift < tol:
129
- break
130
-
131
- boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
132
-
133
- return (
134
- torch.tensor(centroids, dtype=torch.float32),
135
- torch.tensor(boundaries, dtype=torch.float32),
136
- )
137
-
138
-
139
- def compute_expected_distortion(
140
- d: int,
141
- bits: int,
142
- centroids: torch.Tensor,
143
- boundaries: torch.Tensor,
144
- use_exact: bool = False,
145
- ) -> float:
146
- """Compute the expected MSE distortion per coordinate for the given quantizer."""
147
- sigma = 1.0 / math.sqrt(d)
148
- n_levels = len(centroids)
149
-
150
- edges = [-3.5 * sigma * 3] + boundaries.tolist() + [3.5 * sigma * 3]
151
- total_distortion = 0.0
152
-
153
- for i in range(n_levels):
154
- a, b = edges[i], edges[i + 1]
155
- c = centroids[i].item()
156
- if use_exact:
157
- dist = _quad(lambda x, _c=c: (x - _c) ** 2 * beta_pdf(x, d), a, b)
158
- else:
159
- dist = _int_sq_pdf(a, b, sigma, c)
160
- total_distortion += dist
161
-
162
- return total_distortion
163
-
164
-
165
- class LloydMaxCodebook:
166
- """Precomputed Lloyd-Max codebook for a given dimension and bit-width."""
167
-
168
- def __init__(self, d: int, bits: int, use_exact: bool = False):
169
- self.d = d
170
- self.bits = bits
171
- self.n_levels = 2**bits
172
- self.centroids, self.boundaries = solve_lloyd_max(d, bits, use_exact)
173
- self.distortion = compute_expected_distortion(
174
- d, bits, self.centroids, self.boundaries, use_exact
175
- )
176
-
177
- def quantize(self, x: torch.Tensor) -> torch.Tensor:
178
- """Quantize values to nearest centroid indices."""
179
- diffs = x.unsqueeze(-1) - self.centroids.to(x.device)
180
- return diffs.abs().argmin(dim=-1)
181
-
182
- def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
183
- """Map indices back to centroid values."""
184
- return self.centroids.to(indices.device)[indices]
185
-
186
- def __repr__(self):
187
- return (
188
- f"LloydMaxCodebook(d={self.d}, bits={self.bits}, "
189
- f"levels={self.n_levels}, distortion_per_coord={self.distortion:.6f})"
190
- )
@@ -1,249 +0,0 @@
1
- """TurboQuant: two-stage vector quantization."""
2
-
3
- from __future__ import annotations
4
-
5
- import math
6
- from typing import Optional, Tuple, cast
7
-
8
- import torch
9
- from torch import nn
10
-
11
- from .lloyd_max import LloydMaxCodebook
12
-
13
-
14
- def generate_rotation_matrix(
15
- d: int, seed: Optional[int] = None, device: str = "cpu"
16
- ) -> torch.Tensor:
17
- """Generate a random orthogonal rotation matrix via QR decomposition."""
18
- gen = torch.Generator(device="cpu")
19
- if seed is not None:
20
- gen.manual_seed(seed)
21
- G = torch.randn(d, d, generator=gen)
22
- Q, R = torch.linalg.qr(G)
23
- diag_sign = torch.sign(torch.diag(R))
24
- diag_sign[diag_sign == 0] = 1.0
25
- Q = Q * diag_sign.unsqueeze(0)
26
- return Q.to(device)
27
-
28
-
29
- def generate_qjl_matrix(
30
- d: int, m: Optional[int] = None, seed: Optional[int] = None, device: str = "cpu"
31
- ) -> torch.Tensor:
32
- """
33
- Generate the random projection matrix S for QJL.
34
- S has i.i.d. N(0,1) entries, shape (m, d).
35
- Default m = d (same dimensionality).
36
- """
37
- if m is None:
38
- m = d
39
- gen = torch.Generator(device="cpu")
40
- if seed is not None:
41
- gen.manual_seed(seed)
42
- S = torch.randn(m, d, generator=gen)
43
- return S.to(device)
44
-
45
-
46
- class TurboQuantMSE(nn.Module):
47
- """Stage 1: MSE-optimal quantizer."""
48
-
49
- def __init__(self, d: int, bits: int, seed: int = 42, device: str = "cpu"):
50
- super().__init__()
51
- self.d = d
52
- self.bits = bits
53
- self.device = device
54
-
55
- self.register_buffer(
56
- "Pi", generate_rotation_matrix(d, seed=seed, device=device)
57
- )
58
- self.codebook = LloydMaxCodebook(d, bits)
59
- self.register_buffer("centroids", self.codebook.centroids.to(device))
60
- self.register_buffer("boundaries", self.codebook.boundaries.to(device))
61
-
62
- def rotate(self, x: torch.Tensor) -> torch.Tensor:
63
- Pi = cast("torch.Tensor", self.Pi)
64
- return x @ Pi.T
65
-
66
- def unrotate(self, y: torch.Tensor) -> torch.Tensor:
67
- Pi = cast("torch.Tensor", self.Pi)
68
- return y @ Pi
69
-
70
- def quantize(self, x: torch.Tensor) -> torch.Tensor:
71
- centroids = cast("torch.Tensor", self.centroids)
72
- y = self.rotate(x)
73
- diffs = y.unsqueeze(-1) - centroids
74
- indices = diffs.abs().argmin(dim=-1)
75
- return indices
76
-
77
- def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
78
- centroids = cast("torch.Tensor", self.centroids)
79
- y_hat = centroids[indices]
80
- return self.unrotate(y_hat)
81
-
82
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
83
- indices = self.quantize(x)
84
- x_hat = self.dequantize(indices)
85
- return x_hat, indices
86
-
87
-
88
- class TurboQuantProd(nn.Module):
89
- """Stage 1 + Stage 2: Unbiased inner product quantizer."""
90
-
91
- def __init__(
92
- self,
93
- d: int,
94
- bits: int,
95
- qjl_dim: Optional[int] = None,
96
- seed: int = 42,
97
- device: str = "cpu",
98
- ):
99
- super().__init__()
100
- self.d = d
101
- self.bits = bits
102
- self.mse_bits = max(bits - 1, 1)
103
- self.qjl_dim = qjl_dim or d
104
- self.device = device
105
-
106
- self.mse = TurboQuantMSE(d, self.mse_bits, seed=seed, device=device)
107
- self.register_buffer(
108
- "S", generate_qjl_matrix(d, m=self.qjl_dim, seed=seed + 1, device=device)
109
- )
110
-
111
- def quantize(self, x: torch.Tensor) -> dict:
112
- x_hat, mse_indices = self.mse(x)
113
- residual = x - x_hat
114
- residual_norm = torch.norm(residual, dim=-1, keepdim=True)
115
- S = cast("torch.Tensor", self.S)
116
- projected = residual @ S.T
117
- qjl_signs = torch.sign(projected)
118
- qjl_signs[qjl_signs == 0] = 1.0
119
-
120
- return {
121
- "mse_indices": mse_indices,
122
- "qjl_signs": qjl_signs,
123
- "residual_norm": residual_norm.squeeze(-1),
124
- }
125
-
126
- def dequantize(self, compressed: dict) -> torch.Tensor:
127
- return self.mse.dequantize(compressed["mse_indices"])
128
-
129
- def inner_product(self, y: torch.Tensor, compressed: dict) -> torch.Tensor:
130
- x_mse = self.mse.dequantize(compressed["mse_indices"])
131
- term1 = (y * x_mse).sum(dim=-1)
132
-
133
- S = cast("torch.Tensor", self.S)
134
- y_projected = y @ S.T
135
- qjl_ip = (y_projected * compressed["qjl_signs"]).sum(dim=-1)
136
-
137
- m = self.qjl_dim
138
- correction_scale = math.sqrt(math.pi / 2) / m
139
- term2 = compressed["residual_norm"] * correction_scale * qjl_ip
140
-
141
- return term1 + term2
142
-
143
- def forward(self, x: torch.Tensor) -> dict:
144
- return self.quantize(x)
145
-
146
-
147
- class TurboQuantKVCache:
148
- """KV cache wrapper that uses TurboQuant to compress keys and values."""
149
-
150
- def __init__(
151
- self,
152
- d_key: int,
153
- d_value: int,
154
- bits: int = 3,
155
- seed: int = 42,
156
- device: str = "cpu",
157
- ):
158
- self.d_key = d_key
159
- self.d_value = d_value
160
- self.bits = bits
161
- self.device = device
162
-
163
- self.key_quantizer = TurboQuantProd(d_key, bits, seed=seed, device=device)
164
- self.value_quantizer = TurboQuantMSE(
165
- d_value, bits, seed=seed + 100, device=device
166
- )
167
-
168
- self.key_cache = []
169
- self.value_cache = []
170
-
171
- def append(self, keys: torch.Tensor, values: torch.Tensor):
172
- orig_shape = keys.shape
173
- flat_keys = keys.reshape(-1, self.d_key)
174
- flat_values = values.reshape(-1, self.d_value)
175
-
176
- compressed_keys = self.key_quantizer.quantize(flat_keys)
177
- value_indices = self.value_quantizer.quantize(flat_values)
178
-
179
- self.key_cache.append(
180
- {
181
- "mse_indices": compressed_keys["mse_indices"],
182
- "qjl_signs": compressed_keys["qjl_signs"],
183
- "residual_norm": compressed_keys["residual_norm"],
184
- "shape": orig_shape,
185
- }
186
- )
187
- self.value_cache.append(
188
- {
189
- "indices": value_indices,
190
- "shape": values.shape,
191
- }
192
- )
193
-
194
- def attention_scores(self, queries: torch.Tensor) -> torch.Tensor:
195
- scores = []
196
- for cached in self.key_cache:
197
- s = self.key_quantizer.inner_product(queries, cached)
198
- scores.append(s)
199
- return torch.cat(scores, dim=-1) if scores else torch.tensor([])
200
-
201
- def get_values(self) -> torch.Tensor:
202
- values = []
203
- for cached in self.value_cache:
204
- v = self.value_quantizer.dequantize(cached["indices"])
205
- values.append(v)
206
- return torch.cat(values, dim=0) if values else torch.tensor([])
207
-
208
- def memory_usage_bits(self) -> dict:
209
- n_keys = (
210
- sum(c["mse_indices"].numel() for c in self.key_cache)
211
- if self.key_cache
212
- else 0
213
- )
214
- n_qjl = (
215
- sum(c["qjl_signs"].numel() for c in self.key_cache) if self.key_cache else 0
216
- )
217
- n_norms = (
218
- sum(c["residual_norm"].numel() for c in self.key_cache)
219
- if self.key_cache
220
- else 0
221
- )
222
- n_values = (
223
- sum(c["indices"].numel() for c in self.value_cache)
224
- if self.value_cache
225
- else 0
226
- )
227
-
228
- key_bits = n_keys * self.key_quantizer.mse_bits + n_qjl * 1 + n_norms * 16
229
- value_bits = n_values * self.bits
230
- fp16_equivalent = (n_keys + n_values) * 16
231
-
232
- return {
233
- "key_bits": key_bits,
234
- "value_bits": value_bits,
235
- "total_bits": key_bits + value_bits,
236
- "fp16_bits": fp16_equivalent,
237
- "compression_ratio": (
238
- fp16_equivalent / (key_bits + value_bits)
239
- if (key_bits + value_bits) > 0
240
- else 0
241
- ),
242
- }
243
-
244
- def __len__(self):
245
- return (
246
- sum(c["mse_indices"].shape[0] for c in self.key_cache)
247
- if self.key_cache
248
- else 0
249
- )