mcp-plesk-dev-docs 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_plesk_dev_docs-0.4.2.dist-info/METADATA +221 -0
- mcp_plesk_dev_docs-0.4.2.dist-info/RECORD +30 -0
- mcp_plesk_dev_docs-0.4.2.dist-info/WHEEL +5 -0
- mcp_plesk_dev_docs-0.4.2.dist-info/entry_points.txt +2 -0
- mcp_plesk_dev_docs-0.4.2.dist-info/licenses/LICENSE +21 -0
- mcp_plesk_dev_docs-0.4.2.dist-info/licenses/NOTICE +0 -0
- mcp_plesk_dev_docs-0.4.2.dist-info/top_level.txt +1 -0
- plesk_unified/__init__.py +3 -0
- plesk_unified/ai_client.py +257 -0
- plesk_unified/benchmark_engines.py +330 -0
- plesk_unified/benchmark_gates.py +254 -0
- plesk_unified/benchmark_reporting.py +107 -0
- plesk_unified/benchmark_runner.py +433 -0
- plesk_unified/benchmark_suites.py +30 -0
- plesk_unified/chunking.py +360 -0
- plesk_unified/error_handling.py +112 -0
- plesk_unified/html_utils.py +217 -0
- plesk_unified/indexing.py +53 -0
- plesk_unified/io_utils.py +287 -0
- plesk_unified/log_handler.py +209 -0
- plesk_unified/model_config.py +218 -0
- plesk_unified/platform_utils.py +214 -0
- plesk_unified/settings.py +93 -0
- plesk_unified/summary_cache.py +55 -0
- plesk_unified/tq_index.py +85 -0
- plesk_unified/turboquant/__init__.py +21 -0
- plesk_unified/turboquant/compressors.py +190 -0
- plesk_unified/turboquant/lloyd_max.py +190 -0
- plesk_unified/turboquant/turboquant.py +249 -0
- plesk_unified/types.py +27 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""TurboQuant KV cache helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# ---------------------------------------------------------------------------
|
|
11
|
+
# Closed-form Gaussian integration helpers (replaces scipy.integrate.quad)
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _gauss_pdf(x: float, sigma: float) -> float:
|
|
16
|
+
return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _gauss_cdf(x: float, sigma: float) -> float:
|
|
20
|
+
return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _int_pdf(a: float, b: float, sigma: float) -> float:
|
|
24
|
+
return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _int_x_pdf(a: float, b: float, sigma: float) -> float:
|
|
28
|
+
return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TurboQuantCompressorV2:
|
|
32
|
+
"""Compressed key store with direct inner-product scoring."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
|
|
35
|
+
self.head_dim = head_dim
|
|
36
|
+
self.bits = bits
|
|
37
|
+
self.mse_bits = max(bits - 1, 1)
|
|
38
|
+
self.device = device
|
|
39
|
+
|
|
40
|
+
gen = torch.Generator(device="cpu")
|
|
41
|
+
gen.manual_seed(seed)
|
|
42
|
+
G = torch.randn(head_dim, head_dim, generator=gen)
|
|
43
|
+
Q, R = torch.linalg.qr(G)
|
|
44
|
+
diag_sign = torch.sign(torch.diag(R))
|
|
45
|
+
diag_sign[diag_sign == 0] = 1.0
|
|
46
|
+
self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
|
|
47
|
+
|
|
48
|
+
self.centroids = self._solve_codebook(head_dim, self.mse_bits).to(device)
|
|
49
|
+
|
|
50
|
+
gen2 = torch.Generator(device="cpu")
|
|
51
|
+
gen2.manual_seed(seed + 10000)
|
|
52
|
+
self.S = torch.randn(head_dim, head_dim, generator=gen2).to(device)
|
|
53
|
+
|
|
54
|
+
self.PiT = self.Pi.T.contiguous()
|
|
55
|
+
|
|
56
|
+
def _solve_codebook(self, d: int, bits: int) -> torch.Tensor:
|
|
57
|
+
n_levels = 2**bits
|
|
58
|
+
sigma = 1.0 / math.sqrt(d)
|
|
59
|
+
|
|
60
|
+
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
61
|
+
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
62
|
+
|
|
63
|
+
for _ in range(200):
|
|
64
|
+
boundaries = [
|
|
65
|
+
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
66
|
+
]
|
|
67
|
+
edges = [lo * 3] + boundaries + [hi * 3]
|
|
68
|
+
new_centroids = []
|
|
69
|
+
for i in range(n_levels):
|
|
70
|
+
a, b = edges[i], edges[i + 1]
|
|
71
|
+
num = _int_x_pdf(a, b, sigma)
|
|
72
|
+
den = _int_pdf(a, b, sigma)
|
|
73
|
+
new_centroids.append(num / den if den > 1e-15 else centroids[i])
|
|
74
|
+
if (
|
|
75
|
+
max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
|
|
76
|
+
< 1e-10
|
|
77
|
+
):
|
|
78
|
+
break
|
|
79
|
+
centroids = new_centroids
|
|
80
|
+
|
|
81
|
+
return torch.tensor(centroids, dtype=torch.float32)
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def compress(self, states: torch.Tensor) -> dict:
|
|
85
|
+
B, H, S, D = states.shape
|
|
86
|
+
flat = states.reshape(-1, D).float()
|
|
87
|
+
|
|
88
|
+
vec_norms = torch.norm(flat, dim=-1, keepdim=True)
|
|
89
|
+
flat_norm = flat / (vec_norms + 1e-8)
|
|
90
|
+
|
|
91
|
+
rotated = flat_norm @ self.Pi.T
|
|
92
|
+
diffs = rotated.unsqueeze(-1) - self.centroids
|
|
93
|
+
indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
|
|
94
|
+
|
|
95
|
+
reconstructed_rotated = self.centroids[indices.long()]
|
|
96
|
+
k_mse = (reconstructed_rotated @ self.Pi) * vec_norms
|
|
97
|
+
|
|
98
|
+
residual = flat - k_mse
|
|
99
|
+
residual_norm = torch.norm(residual, dim=-1)
|
|
100
|
+
|
|
101
|
+
projected = residual @ self.S.T
|
|
102
|
+
signs = (projected >= 0).to(torch.int8) * 2 - 1
|
|
103
|
+
|
|
104
|
+
return {
|
|
105
|
+
"k_mse": k_mse.to(torch.float16).reshape(B, H, S, D),
|
|
106
|
+
"qjl_signs": signs.reshape(B, H, S, D),
|
|
107
|
+
"residual_norm": residual_norm.to(torch.float16).reshape(B, H, S),
|
|
108
|
+
"shape": (B, H, S, D),
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
@torch.no_grad()
|
|
112
|
+
def asymmetric_attention_scores(
|
|
113
|
+
self, queries: torch.Tensor, compressed: dict
|
|
114
|
+
) -> torch.Tensor:
|
|
115
|
+
k_mse = compressed["k_mse"].float()
|
|
116
|
+
signs = compressed["qjl_signs"].float()
|
|
117
|
+
r_norm = compressed["residual_norm"].float()
|
|
118
|
+
|
|
119
|
+
term1 = torch.matmul(queries.float(), k_mse.transpose(-2, -1))
|
|
120
|
+
q_projected = torch.matmul(queries.float(), self.S.T)
|
|
121
|
+
qjl_ip = torch.matmul(q_projected, signs.transpose(-2, -1))
|
|
122
|
+
|
|
123
|
+
m = self.S.shape[0]
|
|
124
|
+
correction_scale = math.sqrt(math.pi / 2) / m
|
|
125
|
+
term2 = correction_scale * qjl_ip * r_norm.unsqueeze(-2)
|
|
126
|
+
|
|
127
|
+
return term1 + term2
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class TurboQuantCompressorMSE:
|
|
131
|
+
"""MSE-only compressor for values."""
|
|
132
|
+
|
|
133
|
+
def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
|
|
134
|
+
self.head_dim = head_dim
|
|
135
|
+
self.bits = bits
|
|
136
|
+
self.device = device
|
|
137
|
+
|
|
138
|
+
gen = torch.Generator(device="cpu")
|
|
139
|
+
gen.manual_seed(seed)
|
|
140
|
+
G = torch.randn(head_dim, head_dim, generator=gen)
|
|
141
|
+
Q, R = torch.linalg.qr(G)
|
|
142
|
+
diag_sign = torch.sign(torch.diag(R))
|
|
143
|
+
diag_sign[diag_sign == 0] = 1.0
|
|
144
|
+
self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
|
|
145
|
+
self.centroids = self._solve_codebook(head_dim, bits).to(device)
|
|
146
|
+
|
|
147
|
+
def _solve_codebook(self, d, bits):
|
|
148
|
+
n_levels = 2**bits
|
|
149
|
+
sigma = 1.0 / math.sqrt(d)
|
|
150
|
+
|
|
151
|
+
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
152
|
+
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
153
|
+
for _ in range(200):
|
|
154
|
+
boundaries = [
|
|
155
|
+
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
156
|
+
]
|
|
157
|
+
edges = [lo * 3] + boundaries + [hi * 3]
|
|
158
|
+
new_c = []
|
|
159
|
+
for i in range(n_levels):
|
|
160
|
+
a, b = edges[i], edges[i + 1]
|
|
161
|
+
num = _int_x_pdf(a, b, sigma)
|
|
162
|
+
den = _int_pdf(a, b, sigma)
|
|
163
|
+
new_c.append(num / den if den > 1e-15 else centroids[i])
|
|
164
|
+
if max(abs(new_c[i] - centroids[i]) for i in range(n_levels)) < 1e-10:
|
|
165
|
+
break
|
|
166
|
+
centroids = new_c
|
|
167
|
+
return torch.tensor(centroids, dtype=torch.float32)
|
|
168
|
+
|
|
169
|
+
@torch.no_grad()
|
|
170
|
+
def compress(self, states: torch.Tensor) -> dict:
|
|
171
|
+
B, H, S, D = states.shape
|
|
172
|
+
flat = states.reshape(-1, D).float()
|
|
173
|
+
vec_norms = torch.norm(flat, dim=-1, keepdim=True)
|
|
174
|
+
flat_norm = flat / (vec_norms + 1e-8)
|
|
175
|
+
rotated = flat_norm @ self.Pi.T
|
|
176
|
+
diffs = rotated.unsqueeze(-1) - self.centroids
|
|
177
|
+
indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
|
|
178
|
+
return {
|
|
179
|
+
"indices": indices,
|
|
180
|
+
"vec_norms": vec_norms.squeeze(-1).to(torch.float16),
|
|
181
|
+
"shape": (B, H, S, D),
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
@torch.no_grad()
|
|
185
|
+
def decompress(self, compressed: dict) -> torch.Tensor:
|
|
186
|
+
B, H, S, D = compressed["shape"]
|
|
187
|
+
indices = compressed["indices"].long()
|
|
188
|
+
reconstructed = self.centroids[indices] @ self.Pi
|
|
189
|
+
vec_norms = compressed["vec_norms"].float().unsqueeze(-1)
|
|
190
|
+
return (reconstructed * vec_norms).reshape(B, H, S, D)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# ruff: noqa
|
|
2
|
+
"""Lloyd-Max scalar quantizer for rotated unit vectors.
|
|
3
|
+
|
|
4
|
+
The coordinate distribution is approximately Beta-shaped on [-1, 1] after
|
|
5
|
+
random rotation. For d >= 64, a Gaussian N(0, 1/d) is a good approximation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
# Pure-Python Gaussian integration helpers (replaces scipy.integrate.quad)
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _gauss_pdf(x: float, sigma: float) -> float:
|
|
21
|
+
"""N(0, σ²) probability density at x."""
|
|
22
|
+
return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _gauss_cdf(x: float, sigma: float) -> float:
|
|
26
|
+
"""N(0, σ²) cumulative distribution at x."""
|
|
27
|
+
return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _int_pdf(a: float, b: float, sigma: float) -> float:
|
|
31
|
+
"""∫[a,b] N(0,σ²)(x) dx — closed form via erf."""
|
|
32
|
+
return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _int_x_pdf(a: float, b: float, sigma: float) -> float:
|
|
36
|
+
"""∫[a,b] x·N(0,σ²)(x) dx = σ²·[f(a) − f(b)]."""
|
|
37
|
+
return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _int_sq_pdf(a: float, b: float, sigma: float, c: float) -> float:
|
|
41
|
+
"""∫[a,b] (x−c)²·N(0,σ²)(x) dx — closed form."""
|
|
42
|
+
fa, fb = _gauss_pdf(a, sigma), _gauss_pdf(b, sigma)
|
|
43
|
+
cdf_diff = _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
|
|
44
|
+
sig2 = sigma * sigma
|
|
45
|
+
return (
|
|
46
|
+
sig2 * (a * fa - b * fb)
|
|
47
|
+
- 2.0 * c * sig2 * (fa - fb)
|
|
48
|
+
+ (sig2 + c * c) * cdf_diff
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _quad(f, a: float, b: float, n: int = 200) -> float:
|
|
53
|
+
"""Composite Simpson's rule numerical integration over [a, b].
|
|
54
|
+
|
|
55
|
+
Used only for the ``use_exact=True`` (Beta-PDF) path; the Gaussian path
|
|
56
|
+
uses closed-form helpers above.
|
|
57
|
+
"""
|
|
58
|
+
if n % 2 != 0:
|
|
59
|
+
n += 1
|
|
60
|
+
h = (b - a) / n
|
|
61
|
+
s = f(a) + f(b)
|
|
62
|
+
for i in range(1, n):
|
|
63
|
+
s += (4 if i % 2 else 2) * f(a + i * h)
|
|
64
|
+
return h / 3.0 * s
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def beta_pdf(x: float, d: int) -> float:
|
|
68
|
+
"""PDF of a single coordinate after random rotation of a d-dim unit vector."""
|
|
69
|
+
if abs(x) >= 1.0:
|
|
70
|
+
return 0.0
|
|
71
|
+
coeff = math.gamma(d / 2) / (math.sqrt(math.pi) * math.gamma((d - 1) / 2))
|
|
72
|
+
return coeff * (1 - x * x) ** ((d - 3) / 2)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def gaussian_approx_pdf(x: float, d: int) -> float:
|
|
76
|
+
"""Gaussian approximation N(0, 1/d) -- accurate for d >= 64."""
|
|
77
|
+
sigma2 = 1.0 / d
|
|
78
|
+
return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def solve_lloyd_max(
|
|
82
|
+
d: int, bits: int, use_exact: bool = False, max_iter: int = 200, tol: float = 1e-10
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Solve Lloyd-Max optimal quantizer for the coordinate distribution.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
d: vector dimension
|
|
89
|
+
bits: number of quantization bits
|
|
90
|
+
use_exact: if True, use exact Beta PDF; if False, use Gaussian approx
|
|
91
|
+
max_iter: maximum Lloyd-Max iterations
|
|
92
|
+
tol: convergence tolerance
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
centroids: sorted tensor of 2^bits optimal centroids
|
|
96
|
+
boundaries: sorted tensor of 2^bits - 1 boundaries between centroids
|
|
97
|
+
"""
|
|
98
|
+
n_levels = 2**bits
|
|
99
|
+
sigma = 1.0 / math.sqrt(d)
|
|
100
|
+
|
|
101
|
+
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
102
|
+
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
103
|
+
|
|
104
|
+
for _ in range(max_iter):
|
|
105
|
+
boundaries = [
|
|
106
|
+
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
107
|
+
]
|
|
108
|
+
edges = [lo * 3] + boundaries + [hi * 3]
|
|
109
|
+
new_centroids = []
|
|
110
|
+
for i in range(n_levels):
|
|
111
|
+
a, b = edges[i], edges[i + 1]
|
|
112
|
+
|
|
113
|
+
if use_exact:
|
|
114
|
+
numerator = _quad(lambda x: x * beta_pdf(x, d), a, b)
|
|
115
|
+
denominator = _quad(lambda x: beta_pdf(x, d), a, b)
|
|
116
|
+
else:
|
|
117
|
+
numerator = _int_x_pdf(a, b, sigma)
|
|
118
|
+
denominator = _int_pdf(a, b, sigma)
|
|
119
|
+
|
|
120
|
+
if denominator > 1e-15:
|
|
121
|
+
new_centroids.append(numerator / denominator)
|
|
122
|
+
else:
|
|
123
|
+
new_centroids.append(centroids[i])
|
|
124
|
+
|
|
125
|
+
max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
|
|
126
|
+
centroids = new_centroids
|
|
127
|
+
|
|
128
|
+
if max_shift < tol:
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
|
|
132
|
+
|
|
133
|
+
return (
|
|
134
|
+
torch.tensor(centroids, dtype=torch.float32),
|
|
135
|
+
torch.tensor(boundaries, dtype=torch.float32),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def compute_expected_distortion(
|
|
140
|
+
d: int,
|
|
141
|
+
bits: int,
|
|
142
|
+
centroids: torch.Tensor,
|
|
143
|
+
boundaries: torch.Tensor,
|
|
144
|
+
use_exact: bool = False,
|
|
145
|
+
) -> float:
|
|
146
|
+
"""Compute the expected MSE distortion per coordinate for the given quantizer."""
|
|
147
|
+
sigma = 1.0 / math.sqrt(d)
|
|
148
|
+
n_levels = len(centroids)
|
|
149
|
+
|
|
150
|
+
edges = [-3.5 * sigma * 3] + boundaries.tolist() + [3.5 * sigma * 3]
|
|
151
|
+
total_distortion = 0.0
|
|
152
|
+
|
|
153
|
+
for i in range(n_levels):
|
|
154
|
+
a, b = edges[i], edges[i + 1]
|
|
155
|
+
c = centroids[i].item()
|
|
156
|
+
if use_exact:
|
|
157
|
+
dist = _quad(lambda x, _c=c: (x - _c) ** 2 * beta_pdf(x, d), a, b)
|
|
158
|
+
else:
|
|
159
|
+
dist = _int_sq_pdf(a, b, sigma, c)
|
|
160
|
+
total_distortion += dist
|
|
161
|
+
|
|
162
|
+
return total_distortion
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class LloydMaxCodebook:
|
|
166
|
+
"""Precomputed Lloyd-Max codebook for a given dimension and bit-width."""
|
|
167
|
+
|
|
168
|
+
def __init__(self, d: int, bits: int, use_exact: bool = False):
|
|
169
|
+
self.d = d
|
|
170
|
+
self.bits = bits
|
|
171
|
+
self.n_levels = 2**bits
|
|
172
|
+
self.centroids, self.boundaries = solve_lloyd_max(d, bits, use_exact)
|
|
173
|
+
self.distortion = compute_expected_distortion(
|
|
174
|
+
d, bits, self.centroids, self.boundaries, use_exact
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def quantize(self, x: torch.Tensor) -> torch.Tensor:
|
|
178
|
+
"""Quantize values to nearest centroid indices."""
|
|
179
|
+
diffs = x.unsqueeze(-1) - self.centroids.to(x.device)
|
|
180
|
+
return diffs.abs().argmin(dim=-1)
|
|
181
|
+
|
|
182
|
+
def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
|
|
183
|
+
"""Map indices back to centroid values."""
|
|
184
|
+
return self.centroids.to(indices.device)[indices]
|
|
185
|
+
|
|
186
|
+
def __repr__(self):
|
|
187
|
+
return (
|
|
188
|
+
f"LloydMaxCodebook(d={self.d}, bits={self.bits}, "
|
|
189
|
+
f"levels={self.n_levels}, distortion_per_coord={self.distortion:.6f})"
|
|
190
|
+
)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""TurboQuant: two-stage vector quantization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import Optional, Tuple, cast
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from .lloyd_max import LloydMaxCodebook
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def generate_rotation_matrix(
|
|
15
|
+
d: int, seed: Optional[int] = None, device: str = "cpu"
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
"""Generate a random orthogonal rotation matrix via QR decomposition."""
|
|
18
|
+
gen = torch.Generator(device="cpu")
|
|
19
|
+
if seed is not None:
|
|
20
|
+
gen.manual_seed(seed)
|
|
21
|
+
G = torch.randn(d, d, generator=gen)
|
|
22
|
+
Q, R = torch.linalg.qr(G)
|
|
23
|
+
diag_sign = torch.sign(torch.diag(R))
|
|
24
|
+
diag_sign[diag_sign == 0] = 1.0
|
|
25
|
+
Q = Q * diag_sign.unsqueeze(0)
|
|
26
|
+
return Q.to(device)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def generate_qjl_matrix(
|
|
30
|
+
d: int, m: Optional[int] = None, seed: Optional[int] = None, device: str = "cpu"
|
|
31
|
+
) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Generate the random projection matrix S for QJL.
|
|
34
|
+
S has i.i.d. N(0,1) entries, shape (m, d).
|
|
35
|
+
Default m = d (same dimensionality).
|
|
36
|
+
"""
|
|
37
|
+
if m is None:
|
|
38
|
+
m = d
|
|
39
|
+
gen = torch.Generator(device="cpu")
|
|
40
|
+
if seed is not None:
|
|
41
|
+
gen.manual_seed(seed)
|
|
42
|
+
S = torch.randn(m, d, generator=gen)
|
|
43
|
+
return S.to(device)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class TurboQuantMSE(nn.Module):
|
|
47
|
+
"""Stage 1: MSE-optimal quantizer."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, d: int, bits: int, seed: int = 42, device: str = "cpu"):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.d = d
|
|
52
|
+
self.bits = bits
|
|
53
|
+
self.device = device
|
|
54
|
+
|
|
55
|
+
self.register_buffer(
|
|
56
|
+
"Pi", generate_rotation_matrix(d, seed=seed, device=device)
|
|
57
|
+
)
|
|
58
|
+
self.codebook = LloydMaxCodebook(d, bits)
|
|
59
|
+
self.register_buffer("centroids", self.codebook.centroids.to(device))
|
|
60
|
+
self.register_buffer("boundaries", self.codebook.boundaries.to(device))
|
|
61
|
+
|
|
62
|
+
def rotate(self, x: torch.Tensor) -> torch.Tensor:
|
|
63
|
+
Pi = cast("torch.Tensor", self.Pi)
|
|
64
|
+
return x @ Pi.T
|
|
65
|
+
|
|
66
|
+
def unrotate(self, y: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
Pi = cast("torch.Tensor", self.Pi)
|
|
68
|
+
return y @ Pi
|
|
69
|
+
|
|
70
|
+
def quantize(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
centroids = cast("torch.Tensor", self.centroids)
|
|
72
|
+
y = self.rotate(x)
|
|
73
|
+
diffs = y.unsqueeze(-1) - centroids
|
|
74
|
+
indices = diffs.abs().argmin(dim=-1)
|
|
75
|
+
return indices
|
|
76
|
+
|
|
77
|
+
def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
centroids = cast("torch.Tensor", self.centroids)
|
|
79
|
+
y_hat = centroids[indices]
|
|
80
|
+
return self.unrotate(y_hat)
|
|
81
|
+
|
|
82
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
83
|
+
indices = self.quantize(x)
|
|
84
|
+
x_hat = self.dequantize(indices)
|
|
85
|
+
return x_hat, indices
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class TurboQuantProd(nn.Module):
|
|
89
|
+
"""Stage 1 + Stage 2: Unbiased inner product quantizer."""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
d: int,
|
|
94
|
+
bits: int,
|
|
95
|
+
qjl_dim: Optional[int] = None,
|
|
96
|
+
seed: int = 42,
|
|
97
|
+
device: str = "cpu",
|
|
98
|
+
):
|
|
99
|
+
super().__init__()
|
|
100
|
+
self.d = d
|
|
101
|
+
self.bits = bits
|
|
102
|
+
self.mse_bits = max(bits - 1, 1)
|
|
103
|
+
self.qjl_dim = qjl_dim or d
|
|
104
|
+
self.device = device
|
|
105
|
+
|
|
106
|
+
self.mse = TurboQuantMSE(d, self.mse_bits, seed=seed, device=device)
|
|
107
|
+
self.register_buffer(
|
|
108
|
+
"S", generate_qjl_matrix(d, m=self.qjl_dim, seed=seed + 1, device=device)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def quantize(self, x: torch.Tensor) -> dict:
|
|
112
|
+
x_hat, mse_indices = self.mse(x)
|
|
113
|
+
residual = x - x_hat
|
|
114
|
+
residual_norm = torch.norm(residual, dim=-1, keepdim=True)
|
|
115
|
+
S = cast("torch.Tensor", self.S)
|
|
116
|
+
projected = residual @ S.T
|
|
117
|
+
qjl_signs = torch.sign(projected)
|
|
118
|
+
qjl_signs[qjl_signs == 0] = 1.0
|
|
119
|
+
|
|
120
|
+
return {
|
|
121
|
+
"mse_indices": mse_indices,
|
|
122
|
+
"qjl_signs": qjl_signs,
|
|
123
|
+
"residual_norm": residual_norm.squeeze(-1),
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def dequantize(self, compressed: dict) -> torch.Tensor:
|
|
127
|
+
return self.mse.dequantize(compressed["mse_indices"])
|
|
128
|
+
|
|
129
|
+
def inner_product(self, y: torch.Tensor, compressed: dict) -> torch.Tensor:
|
|
130
|
+
x_mse = self.mse.dequantize(compressed["mse_indices"])
|
|
131
|
+
term1 = (y * x_mse).sum(dim=-1)
|
|
132
|
+
|
|
133
|
+
S = cast("torch.Tensor", self.S)
|
|
134
|
+
y_projected = y @ S.T
|
|
135
|
+
qjl_ip = (y_projected * compressed["qjl_signs"]).sum(dim=-1)
|
|
136
|
+
|
|
137
|
+
m = self.qjl_dim
|
|
138
|
+
correction_scale = math.sqrt(math.pi / 2) / m
|
|
139
|
+
term2 = compressed["residual_norm"] * correction_scale * qjl_ip
|
|
140
|
+
|
|
141
|
+
return term1 + term2
|
|
142
|
+
|
|
143
|
+
def forward(self, x: torch.Tensor) -> dict:
|
|
144
|
+
return self.quantize(x)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class TurboQuantKVCache:
|
|
148
|
+
"""KV cache wrapper that uses TurboQuant to compress keys and values."""
|
|
149
|
+
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
d_key: int,
|
|
153
|
+
d_value: int,
|
|
154
|
+
bits: int = 3,
|
|
155
|
+
seed: int = 42,
|
|
156
|
+
device: str = "cpu",
|
|
157
|
+
):
|
|
158
|
+
self.d_key = d_key
|
|
159
|
+
self.d_value = d_value
|
|
160
|
+
self.bits = bits
|
|
161
|
+
self.device = device
|
|
162
|
+
|
|
163
|
+
self.key_quantizer = TurboQuantProd(d_key, bits, seed=seed, device=device)
|
|
164
|
+
self.value_quantizer = TurboQuantMSE(
|
|
165
|
+
d_value, bits, seed=seed + 100, device=device
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.key_cache = []
|
|
169
|
+
self.value_cache = []
|
|
170
|
+
|
|
171
|
+
def append(self, keys: torch.Tensor, values: torch.Tensor):
|
|
172
|
+
orig_shape = keys.shape
|
|
173
|
+
flat_keys = keys.reshape(-1, self.d_key)
|
|
174
|
+
flat_values = values.reshape(-1, self.d_value)
|
|
175
|
+
|
|
176
|
+
compressed_keys = self.key_quantizer.quantize(flat_keys)
|
|
177
|
+
value_indices = self.value_quantizer.quantize(flat_values)
|
|
178
|
+
|
|
179
|
+
self.key_cache.append(
|
|
180
|
+
{
|
|
181
|
+
"mse_indices": compressed_keys["mse_indices"],
|
|
182
|
+
"qjl_signs": compressed_keys["qjl_signs"],
|
|
183
|
+
"residual_norm": compressed_keys["residual_norm"],
|
|
184
|
+
"shape": orig_shape,
|
|
185
|
+
}
|
|
186
|
+
)
|
|
187
|
+
self.value_cache.append(
|
|
188
|
+
{
|
|
189
|
+
"indices": value_indices,
|
|
190
|
+
"shape": values.shape,
|
|
191
|
+
}
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def attention_scores(self, queries: torch.Tensor) -> torch.Tensor:
|
|
195
|
+
scores = []
|
|
196
|
+
for cached in self.key_cache:
|
|
197
|
+
s = self.key_quantizer.inner_product(queries, cached)
|
|
198
|
+
scores.append(s)
|
|
199
|
+
return torch.cat(scores, dim=-1) if scores else torch.tensor([])
|
|
200
|
+
|
|
201
|
+
def get_values(self) -> torch.Tensor:
|
|
202
|
+
values = []
|
|
203
|
+
for cached in self.value_cache:
|
|
204
|
+
v = self.value_quantizer.dequantize(cached["indices"])
|
|
205
|
+
values.append(v)
|
|
206
|
+
return torch.cat(values, dim=0) if values else torch.tensor([])
|
|
207
|
+
|
|
208
|
+
def memory_usage_bits(self) -> dict:
|
|
209
|
+
n_keys = (
|
|
210
|
+
sum(c["mse_indices"].numel() for c in self.key_cache)
|
|
211
|
+
if self.key_cache
|
|
212
|
+
else 0
|
|
213
|
+
)
|
|
214
|
+
n_qjl = (
|
|
215
|
+
sum(c["qjl_signs"].numel() for c in self.key_cache) if self.key_cache else 0
|
|
216
|
+
)
|
|
217
|
+
n_norms = (
|
|
218
|
+
sum(c["residual_norm"].numel() for c in self.key_cache)
|
|
219
|
+
if self.key_cache
|
|
220
|
+
else 0
|
|
221
|
+
)
|
|
222
|
+
n_values = (
|
|
223
|
+
sum(c["indices"].numel() for c in self.value_cache)
|
|
224
|
+
if self.value_cache
|
|
225
|
+
else 0
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
key_bits = n_keys * self.key_quantizer.mse_bits + n_qjl * 1 + n_norms * 16
|
|
229
|
+
value_bits = n_values * self.bits
|
|
230
|
+
fp16_equivalent = (n_keys + n_values) * 16
|
|
231
|
+
|
|
232
|
+
return {
|
|
233
|
+
"key_bits": key_bits,
|
|
234
|
+
"value_bits": value_bits,
|
|
235
|
+
"total_bits": key_bits + value_bits,
|
|
236
|
+
"fp16_bits": fp16_equivalent,
|
|
237
|
+
"compression_ratio": (
|
|
238
|
+
fp16_equivalent / (key_bits + value_bits)
|
|
239
|
+
if (key_bits + value_bits) > 0
|
|
240
|
+
else 0
|
|
241
|
+
),
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
def __len__(self):
|
|
245
|
+
return (
|
|
246
|
+
sum(c["mse_indices"].shape[0] for c in self.key_cache)
|
|
247
|
+
if self.key_cache
|
|
248
|
+
else 0
|
|
249
|
+
)
|
plesk_unified/types.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Literal, Union
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CategoryEnum(str, Enum):
|
|
6
|
+
"""Supported Plesk documentation categories."""
|
|
7
|
+
|
|
8
|
+
GUIDE = "guide"
|
|
9
|
+
CLI = "cli"
|
|
10
|
+
API = "api"
|
|
11
|
+
PHP_STUBS = "php-stubs"
|
|
12
|
+
JS_SDK = "js-sdk"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
VALID_CATEGORIES: frozenset[str] = frozenset(c.value for c in CategoryEnum)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def validate_category(category: str, allow_all: bool = False) -> None:
|
|
19
|
+
"""Validate that a category string is valid."""
|
|
20
|
+
if allow_all and category == "all":
|
|
21
|
+
return
|
|
22
|
+
if category not in VALID_CATEGORIES:
|
|
23
|
+
raise ValueError(f"Invalid category: '{category}'")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Type alias for refresh_knowledge which accepts a specific category or "all"
|
|
27
|
+
CategoryOrAll = Union[CategoryEnum, Literal["all"]]
|