mcp-plesk-dev-docs 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,190 @@
1
+ """TurboQuant KV cache helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+
10
+ # ---------------------------------------------------------------------------
11
+ # Closed-form Gaussian integration helpers (replaces scipy.integrate.quad)
12
+ # ---------------------------------------------------------------------------
13
+
14
+
15
+ def _gauss_pdf(x: float, sigma: float) -> float:
16
+ return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
17
+
18
+
19
+ def _gauss_cdf(x: float, sigma: float) -> float:
20
+ return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
21
+
22
+
23
+ def _int_pdf(a: float, b: float, sigma: float) -> float:
24
+ return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
25
+
26
+
27
+ def _int_x_pdf(a: float, b: float, sigma: float) -> float:
28
+ return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
29
+
30
+
31
+ class TurboQuantCompressorV2:
32
+ """Compressed key store with direct inner-product scoring."""
33
+
34
+ def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
35
+ self.head_dim = head_dim
36
+ self.bits = bits
37
+ self.mse_bits = max(bits - 1, 1)
38
+ self.device = device
39
+
40
+ gen = torch.Generator(device="cpu")
41
+ gen.manual_seed(seed)
42
+ G = torch.randn(head_dim, head_dim, generator=gen)
43
+ Q, R = torch.linalg.qr(G)
44
+ diag_sign = torch.sign(torch.diag(R))
45
+ diag_sign[diag_sign == 0] = 1.0
46
+ self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
47
+
48
+ self.centroids = self._solve_codebook(head_dim, self.mse_bits).to(device)
49
+
50
+ gen2 = torch.Generator(device="cpu")
51
+ gen2.manual_seed(seed + 10000)
52
+ self.S = torch.randn(head_dim, head_dim, generator=gen2).to(device)
53
+
54
+ self.PiT = self.Pi.T.contiguous()
55
+
56
+ def _solve_codebook(self, d: int, bits: int) -> torch.Tensor:
57
+ n_levels = 2**bits
58
+ sigma = 1.0 / math.sqrt(d)
59
+
60
+ lo, hi = -3.5 * sigma, 3.5 * sigma
61
+ centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
62
+
63
+ for _ in range(200):
64
+ boundaries = [
65
+ (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
66
+ ]
67
+ edges = [lo * 3] + boundaries + [hi * 3]
68
+ new_centroids = []
69
+ for i in range(n_levels):
70
+ a, b = edges[i], edges[i + 1]
71
+ num = _int_x_pdf(a, b, sigma)
72
+ den = _int_pdf(a, b, sigma)
73
+ new_centroids.append(num / den if den > 1e-15 else centroids[i])
74
+ if (
75
+ max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
76
+ < 1e-10
77
+ ):
78
+ break
79
+ centroids = new_centroids
80
+
81
+ return torch.tensor(centroids, dtype=torch.float32)
82
+
83
+ @torch.no_grad()
84
+ def compress(self, states: torch.Tensor) -> dict:
85
+ B, H, S, D = states.shape
86
+ flat = states.reshape(-1, D).float()
87
+
88
+ vec_norms = torch.norm(flat, dim=-1, keepdim=True)
89
+ flat_norm = flat / (vec_norms + 1e-8)
90
+
91
+ rotated = flat_norm @ self.Pi.T
92
+ diffs = rotated.unsqueeze(-1) - self.centroids
93
+ indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
94
+
95
+ reconstructed_rotated = self.centroids[indices.long()]
96
+ k_mse = (reconstructed_rotated @ self.Pi) * vec_norms
97
+
98
+ residual = flat - k_mse
99
+ residual_norm = torch.norm(residual, dim=-1)
100
+
101
+ projected = residual @ self.S.T
102
+ signs = (projected >= 0).to(torch.int8) * 2 - 1
103
+
104
+ return {
105
+ "k_mse": k_mse.to(torch.float16).reshape(B, H, S, D),
106
+ "qjl_signs": signs.reshape(B, H, S, D),
107
+ "residual_norm": residual_norm.to(torch.float16).reshape(B, H, S),
108
+ "shape": (B, H, S, D),
109
+ }
110
+
111
+ @torch.no_grad()
112
+ def asymmetric_attention_scores(
113
+ self, queries: torch.Tensor, compressed: dict
114
+ ) -> torch.Tensor:
115
+ k_mse = compressed["k_mse"].float()
116
+ signs = compressed["qjl_signs"].float()
117
+ r_norm = compressed["residual_norm"].float()
118
+
119
+ term1 = torch.matmul(queries.float(), k_mse.transpose(-2, -1))
120
+ q_projected = torch.matmul(queries.float(), self.S.T)
121
+ qjl_ip = torch.matmul(q_projected, signs.transpose(-2, -1))
122
+
123
+ m = self.S.shape[0]
124
+ correction_scale = math.sqrt(math.pi / 2) / m
125
+ term2 = correction_scale * qjl_ip * r_norm.unsqueeze(-2)
126
+
127
+ return term1 + term2
128
+
129
+
130
+ class TurboQuantCompressorMSE:
131
+ """MSE-only compressor for values."""
132
+
133
+ def __init__(self, head_dim: int, bits: int, seed: int, device: str = "cpu"):
134
+ self.head_dim = head_dim
135
+ self.bits = bits
136
+ self.device = device
137
+
138
+ gen = torch.Generator(device="cpu")
139
+ gen.manual_seed(seed)
140
+ G = torch.randn(head_dim, head_dim, generator=gen)
141
+ Q, R = torch.linalg.qr(G)
142
+ diag_sign = torch.sign(torch.diag(R))
143
+ diag_sign[diag_sign == 0] = 1.0
144
+ self.Pi = (Q * diag_sign.unsqueeze(0)).to(device)
145
+ self.centroids = self._solve_codebook(head_dim, bits).to(device)
146
+
147
+ def _solve_codebook(self, d, bits):
148
+ n_levels = 2**bits
149
+ sigma = 1.0 / math.sqrt(d)
150
+
151
+ lo, hi = -3.5 * sigma, 3.5 * sigma
152
+ centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
153
+ for _ in range(200):
154
+ boundaries = [
155
+ (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
156
+ ]
157
+ edges = [lo * 3] + boundaries + [hi * 3]
158
+ new_c = []
159
+ for i in range(n_levels):
160
+ a, b = edges[i], edges[i + 1]
161
+ num = _int_x_pdf(a, b, sigma)
162
+ den = _int_pdf(a, b, sigma)
163
+ new_c.append(num / den if den > 1e-15 else centroids[i])
164
+ if max(abs(new_c[i] - centroids[i]) for i in range(n_levels)) < 1e-10:
165
+ break
166
+ centroids = new_c
167
+ return torch.tensor(centroids, dtype=torch.float32)
168
+
169
+ @torch.no_grad()
170
+ def compress(self, states: torch.Tensor) -> dict:
171
+ B, H, S, D = states.shape
172
+ flat = states.reshape(-1, D).float()
173
+ vec_norms = torch.norm(flat, dim=-1, keepdim=True)
174
+ flat_norm = flat / (vec_norms + 1e-8)
175
+ rotated = flat_norm @ self.Pi.T
176
+ diffs = rotated.unsqueeze(-1) - self.centroids
177
+ indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
178
+ return {
179
+ "indices": indices,
180
+ "vec_norms": vec_norms.squeeze(-1).to(torch.float16),
181
+ "shape": (B, H, S, D),
182
+ }
183
+
184
+ @torch.no_grad()
185
+ def decompress(self, compressed: dict) -> torch.Tensor:
186
+ B, H, S, D = compressed["shape"]
187
+ indices = compressed["indices"].long()
188
+ reconstructed = self.centroids[indices] @ self.Pi
189
+ vec_norms = compressed["vec_norms"].float().unsqueeze(-1)
190
+ return (reconstructed * vec_norms).reshape(B, H, S, D)
@@ -0,0 +1,190 @@
1
+ # ruff: noqa
2
+ """Lloyd-Max scalar quantizer for rotated unit vectors.
3
+
4
+ The coordinate distribution is approximately Beta-shaped on [-1, 1] after
5
+ random rotation. For d >= 64, a Gaussian N(0, 1/d) is a good approximation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+
12
+ import torch
13
+
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Pure-Python Gaussian integration helpers (replaces scipy.integrate.quad)
17
+ # ---------------------------------------------------------------------------
18
+
19
+
20
+ def _gauss_pdf(x: float, sigma: float) -> float:
21
+ """N(0, σ²) probability density at x."""
22
+ return math.exp(-0.5 * (x / sigma) ** 2) / (sigma * math.sqrt(2.0 * math.pi))
23
+
24
+
25
+ def _gauss_cdf(x: float, sigma: float) -> float:
26
+ """N(0, σ²) cumulative distribution at x."""
27
+ return 0.5 * (1.0 + math.erf(x / (sigma * math.sqrt(2.0))))
28
+
29
+
30
+ def _int_pdf(a: float, b: float, sigma: float) -> float:
31
+ """∫[a,b] N(0,σ²)(x) dx — closed form via erf."""
32
+ return _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
33
+
34
+
35
+ def _int_x_pdf(a: float, b: float, sigma: float) -> float:
36
+ """∫[a,b] x·N(0,σ²)(x) dx = σ²·[f(a) − f(b)]."""
37
+ return sigma * sigma * (_gauss_pdf(a, sigma) - _gauss_pdf(b, sigma))
38
+
39
+
40
+ def _int_sq_pdf(a: float, b: float, sigma: float, c: float) -> float:
41
+ """∫[a,b] (x−c)²·N(0,σ²)(x) dx — closed form."""
42
+ fa, fb = _gauss_pdf(a, sigma), _gauss_pdf(b, sigma)
43
+ cdf_diff = _gauss_cdf(b, sigma) - _gauss_cdf(a, sigma)
44
+ sig2 = sigma * sigma
45
+ return (
46
+ sig2 * (a * fa - b * fb)
47
+ - 2.0 * c * sig2 * (fa - fb)
48
+ + (sig2 + c * c) * cdf_diff
49
+ )
50
+
51
+
52
+ def _quad(f, a: float, b: float, n: int = 200) -> float:
53
+ """Composite Simpson's rule numerical integration over [a, b].
54
+
55
+ Used only for the ``use_exact=True`` (Beta-PDF) path; the Gaussian path
56
+ uses closed-form helpers above.
57
+ """
58
+ if n % 2 != 0:
59
+ n += 1
60
+ h = (b - a) / n
61
+ s = f(a) + f(b)
62
+ for i in range(1, n):
63
+ s += (4 if i % 2 else 2) * f(a + i * h)
64
+ return h / 3.0 * s
65
+
66
+
67
+ def beta_pdf(x: float, d: int) -> float:
68
+ """PDF of a single coordinate after random rotation of a d-dim unit vector."""
69
+ if abs(x) >= 1.0:
70
+ return 0.0
71
+ coeff = math.gamma(d / 2) / (math.sqrt(math.pi) * math.gamma((d - 1) / 2))
72
+ return coeff * (1 - x * x) ** ((d - 3) / 2)
73
+
74
+
75
+ def gaussian_approx_pdf(x: float, d: int) -> float:
76
+ """Gaussian approximation N(0, 1/d) -- accurate for d >= 64."""
77
+ sigma2 = 1.0 / d
78
+ return (1.0 / math.sqrt(2 * math.pi * sigma2)) * math.exp(-x * x / (2 * sigma2))
79
+
80
+
81
+ def solve_lloyd_max(
82
+ d: int, bits: int, use_exact: bool = False, max_iter: int = 200, tol: float = 1e-10
83
+ ):
84
+ """
85
+ Solve Lloyd-Max optimal quantizer for the coordinate distribution.
86
+
87
+ Args:
88
+ d: vector dimension
89
+ bits: number of quantization bits
90
+ use_exact: if True, use exact Beta PDF; if False, use Gaussian approx
91
+ max_iter: maximum Lloyd-Max iterations
92
+ tol: convergence tolerance
93
+
94
+ Returns:
95
+ centroids: sorted tensor of 2^bits optimal centroids
96
+ boundaries: sorted tensor of 2^bits - 1 boundaries between centroids
97
+ """
98
+ n_levels = 2**bits
99
+ sigma = 1.0 / math.sqrt(d)
100
+
101
+ lo, hi = -3.5 * sigma, 3.5 * sigma
102
+ centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
103
+
104
+ for _ in range(max_iter):
105
+ boundaries = [
106
+ (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
107
+ ]
108
+ edges = [lo * 3] + boundaries + [hi * 3]
109
+ new_centroids = []
110
+ for i in range(n_levels):
111
+ a, b = edges[i], edges[i + 1]
112
+
113
+ if use_exact:
114
+ numerator = _quad(lambda x: x * beta_pdf(x, d), a, b)
115
+ denominator = _quad(lambda x: beta_pdf(x, d), a, b)
116
+ else:
117
+ numerator = _int_x_pdf(a, b, sigma)
118
+ denominator = _int_pdf(a, b, sigma)
119
+
120
+ if denominator > 1e-15:
121
+ new_centroids.append(numerator / denominator)
122
+ else:
123
+ new_centroids.append(centroids[i])
124
+
125
+ max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels))
126
+ centroids = new_centroids
127
+
128
+ if max_shift < tol:
129
+ break
130
+
131
+ boundaries = [(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)]
132
+
133
+ return (
134
+ torch.tensor(centroids, dtype=torch.float32),
135
+ torch.tensor(boundaries, dtype=torch.float32),
136
+ )
137
+
138
+
139
+ def compute_expected_distortion(
140
+ d: int,
141
+ bits: int,
142
+ centroids: torch.Tensor,
143
+ boundaries: torch.Tensor,
144
+ use_exact: bool = False,
145
+ ) -> float:
146
+ """Compute the expected MSE distortion per coordinate for the given quantizer."""
147
+ sigma = 1.0 / math.sqrt(d)
148
+ n_levels = len(centroids)
149
+
150
+ edges = [-3.5 * sigma * 3] + boundaries.tolist() + [3.5 * sigma * 3]
151
+ total_distortion = 0.0
152
+
153
+ for i in range(n_levels):
154
+ a, b = edges[i], edges[i + 1]
155
+ c = centroids[i].item()
156
+ if use_exact:
157
+ dist = _quad(lambda x, _c=c: (x - _c) ** 2 * beta_pdf(x, d), a, b)
158
+ else:
159
+ dist = _int_sq_pdf(a, b, sigma, c)
160
+ total_distortion += dist
161
+
162
+ return total_distortion
163
+
164
+
165
+ class LloydMaxCodebook:
166
+ """Precomputed Lloyd-Max codebook for a given dimension and bit-width."""
167
+
168
+ def __init__(self, d: int, bits: int, use_exact: bool = False):
169
+ self.d = d
170
+ self.bits = bits
171
+ self.n_levels = 2**bits
172
+ self.centroids, self.boundaries = solve_lloyd_max(d, bits, use_exact)
173
+ self.distortion = compute_expected_distortion(
174
+ d, bits, self.centroids, self.boundaries, use_exact
175
+ )
176
+
177
+ def quantize(self, x: torch.Tensor) -> torch.Tensor:
178
+ """Quantize values to nearest centroid indices."""
179
+ diffs = x.unsqueeze(-1) - self.centroids.to(x.device)
180
+ return diffs.abs().argmin(dim=-1)
181
+
182
+ def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
183
+ """Map indices back to centroid values."""
184
+ return self.centroids.to(indices.device)[indices]
185
+
186
+ def __repr__(self):
187
+ return (
188
+ f"LloydMaxCodebook(d={self.d}, bits={self.bits}, "
189
+ f"levels={self.n_levels}, distortion_per_coord={self.distortion:.6f})"
190
+ )
@@ -0,0 +1,249 @@
1
+ """TurboQuant: two-stage vector quantization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Optional, Tuple, cast
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from .lloyd_max import LloydMaxCodebook
12
+
13
+
14
+ def generate_rotation_matrix(
15
+ d: int, seed: Optional[int] = None, device: str = "cpu"
16
+ ) -> torch.Tensor:
17
+ """Generate a random orthogonal rotation matrix via QR decomposition."""
18
+ gen = torch.Generator(device="cpu")
19
+ if seed is not None:
20
+ gen.manual_seed(seed)
21
+ G = torch.randn(d, d, generator=gen)
22
+ Q, R = torch.linalg.qr(G)
23
+ diag_sign = torch.sign(torch.diag(R))
24
+ diag_sign[diag_sign == 0] = 1.0
25
+ Q = Q * diag_sign.unsqueeze(0)
26
+ return Q.to(device)
27
+
28
+
29
+ def generate_qjl_matrix(
30
+ d: int, m: Optional[int] = None, seed: Optional[int] = None, device: str = "cpu"
31
+ ) -> torch.Tensor:
32
+ """
33
+ Generate the random projection matrix S for QJL.
34
+ S has i.i.d. N(0,1) entries, shape (m, d).
35
+ Default m = d (same dimensionality).
36
+ """
37
+ if m is None:
38
+ m = d
39
+ gen = torch.Generator(device="cpu")
40
+ if seed is not None:
41
+ gen.manual_seed(seed)
42
+ S = torch.randn(m, d, generator=gen)
43
+ return S.to(device)
44
+
45
+
46
+ class TurboQuantMSE(nn.Module):
47
+ """Stage 1: MSE-optimal quantizer."""
48
+
49
+ def __init__(self, d: int, bits: int, seed: int = 42, device: str = "cpu"):
50
+ super().__init__()
51
+ self.d = d
52
+ self.bits = bits
53
+ self.device = device
54
+
55
+ self.register_buffer(
56
+ "Pi", generate_rotation_matrix(d, seed=seed, device=device)
57
+ )
58
+ self.codebook = LloydMaxCodebook(d, bits)
59
+ self.register_buffer("centroids", self.codebook.centroids.to(device))
60
+ self.register_buffer("boundaries", self.codebook.boundaries.to(device))
61
+
62
+ def rotate(self, x: torch.Tensor) -> torch.Tensor:
63
+ Pi = cast("torch.Tensor", self.Pi)
64
+ return x @ Pi.T
65
+
66
+ def unrotate(self, y: torch.Tensor) -> torch.Tensor:
67
+ Pi = cast("torch.Tensor", self.Pi)
68
+ return y @ Pi
69
+
70
+ def quantize(self, x: torch.Tensor) -> torch.Tensor:
71
+ centroids = cast("torch.Tensor", self.centroids)
72
+ y = self.rotate(x)
73
+ diffs = y.unsqueeze(-1) - centroids
74
+ indices = diffs.abs().argmin(dim=-1)
75
+ return indices
76
+
77
+ def dequantize(self, indices: torch.Tensor) -> torch.Tensor:
78
+ centroids = cast("torch.Tensor", self.centroids)
79
+ y_hat = centroids[indices]
80
+ return self.unrotate(y_hat)
81
+
82
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ indices = self.quantize(x)
84
+ x_hat = self.dequantize(indices)
85
+ return x_hat, indices
86
+
87
+
88
+ class TurboQuantProd(nn.Module):
89
+ """Stage 1 + Stage 2: Unbiased inner product quantizer."""
90
+
91
+ def __init__(
92
+ self,
93
+ d: int,
94
+ bits: int,
95
+ qjl_dim: Optional[int] = None,
96
+ seed: int = 42,
97
+ device: str = "cpu",
98
+ ):
99
+ super().__init__()
100
+ self.d = d
101
+ self.bits = bits
102
+ self.mse_bits = max(bits - 1, 1)
103
+ self.qjl_dim = qjl_dim or d
104
+ self.device = device
105
+
106
+ self.mse = TurboQuantMSE(d, self.mse_bits, seed=seed, device=device)
107
+ self.register_buffer(
108
+ "S", generate_qjl_matrix(d, m=self.qjl_dim, seed=seed + 1, device=device)
109
+ )
110
+
111
+ def quantize(self, x: torch.Tensor) -> dict:
112
+ x_hat, mse_indices = self.mse(x)
113
+ residual = x - x_hat
114
+ residual_norm = torch.norm(residual, dim=-1, keepdim=True)
115
+ S = cast("torch.Tensor", self.S)
116
+ projected = residual @ S.T
117
+ qjl_signs = torch.sign(projected)
118
+ qjl_signs[qjl_signs == 0] = 1.0
119
+
120
+ return {
121
+ "mse_indices": mse_indices,
122
+ "qjl_signs": qjl_signs,
123
+ "residual_norm": residual_norm.squeeze(-1),
124
+ }
125
+
126
+ def dequantize(self, compressed: dict) -> torch.Tensor:
127
+ return self.mse.dequantize(compressed["mse_indices"])
128
+
129
+ def inner_product(self, y: torch.Tensor, compressed: dict) -> torch.Tensor:
130
+ x_mse = self.mse.dequantize(compressed["mse_indices"])
131
+ term1 = (y * x_mse).sum(dim=-1)
132
+
133
+ S = cast("torch.Tensor", self.S)
134
+ y_projected = y @ S.T
135
+ qjl_ip = (y_projected * compressed["qjl_signs"]).sum(dim=-1)
136
+
137
+ m = self.qjl_dim
138
+ correction_scale = math.sqrt(math.pi / 2) / m
139
+ term2 = compressed["residual_norm"] * correction_scale * qjl_ip
140
+
141
+ return term1 + term2
142
+
143
+ def forward(self, x: torch.Tensor) -> dict:
144
+ return self.quantize(x)
145
+
146
+
147
+ class TurboQuantKVCache:
148
+ """KV cache wrapper that uses TurboQuant to compress keys and values."""
149
+
150
+ def __init__(
151
+ self,
152
+ d_key: int,
153
+ d_value: int,
154
+ bits: int = 3,
155
+ seed: int = 42,
156
+ device: str = "cpu",
157
+ ):
158
+ self.d_key = d_key
159
+ self.d_value = d_value
160
+ self.bits = bits
161
+ self.device = device
162
+
163
+ self.key_quantizer = TurboQuantProd(d_key, bits, seed=seed, device=device)
164
+ self.value_quantizer = TurboQuantMSE(
165
+ d_value, bits, seed=seed + 100, device=device
166
+ )
167
+
168
+ self.key_cache = []
169
+ self.value_cache = []
170
+
171
+ def append(self, keys: torch.Tensor, values: torch.Tensor):
172
+ orig_shape = keys.shape
173
+ flat_keys = keys.reshape(-1, self.d_key)
174
+ flat_values = values.reshape(-1, self.d_value)
175
+
176
+ compressed_keys = self.key_quantizer.quantize(flat_keys)
177
+ value_indices = self.value_quantizer.quantize(flat_values)
178
+
179
+ self.key_cache.append(
180
+ {
181
+ "mse_indices": compressed_keys["mse_indices"],
182
+ "qjl_signs": compressed_keys["qjl_signs"],
183
+ "residual_norm": compressed_keys["residual_norm"],
184
+ "shape": orig_shape,
185
+ }
186
+ )
187
+ self.value_cache.append(
188
+ {
189
+ "indices": value_indices,
190
+ "shape": values.shape,
191
+ }
192
+ )
193
+
194
+ def attention_scores(self, queries: torch.Tensor) -> torch.Tensor:
195
+ scores = []
196
+ for cached in self.key_cache:
197
+ s = self.key_quantizer.inner_product(queries, cached)
198
+ scores.append(s)
199
+ return torch.cat(scores, dim=-1) if scores else torch.tensor([])
200
+
201
+ def get_values(self) -> torch.Tensor:
202
+ values = []
203
+ for cached in self.value_cache:
204
+ v = self.value_quantizer.dequantize(cached["indices"])
205
+ values.append(v)
206
+ return torch.cat(values, dim=0) if values else torch.tensor([])
207
+
208
+ def memory_usage_bits(self) -> dict:
209
+ n_keys = (
210
+ sum(c["mse_indices"].numel() for c in self.key_cache)
211
+ if self.key_cache
212
+ else 0
213
+ )
214
+ n_qjl = (
215
+ sum(c["qjl_signs"].numel() for c in self.key_cache) if self.key_cache else 0
216
+ )
217
+ n_norms = (
218
+ sum(c["residual_norm"].numel() for c in self.key_cache)
219
+ if self.key_cache
220
+ else 0
221
+ )
222
+ n_values = (
223
+ sum(c["indices"].numel() for c in self.value_cache)
224
+ if self.value_cache
225
+ else 0
226
+ )
227
+
228
+ key_bits = n_keys * self.key_quantizer.mse_bits + n_qjl * 1 + n_norms * 16
229
+ value_bits = n_values * self.bits
230
+ fp16_equivalent = (n_keys + n_values) * 16
231
+
232
+ return {
233
+ "key_bits": key_bits,
234
+ "value_bits": value_bits,
235
+ "total_bits": key_bits + value_bits,
236
+ "fp16_bits": fp16_equivalent,
237
+ "compression_ratio": (
238
+ fp16_equivalent / (key_bits + value_bits)
239
+ if (key_bits + value_bits) > 0
240
+ else 0
241
+ ),
242
+ }
243
+
244
+ def __len__(self):
245
+ return (
246
+ sum(c["mse_indices"].shape[0] for c in self.key_cache)
247
+ if self.key_cache
248
+ else 0
249
+ )
plesk_unified/types.py ADDED
@@ -0,0 +1,27 @@
1
+ from enum import Enum
2
+ from typing import Literal, Union
3
+
4
+
5
+ class CategoryEnum(str, Enum):
6
+ """Supported Plesk documentation categories."""
7
+
8
+ GUIDE = "guide"
9
+ CLI = "cli"
10
+ API = "api"
11
+ PHP_STUBS = "php-stubs"
12
+ JS_SDK = "js-sdk"
13
+
14
+
15
+ VALID_CATEGORIES: frozenset[str] = frozenset(c.value for c in CategoryEnum)
16
+
17
+
18
+ def validate_category(category: str, allow_all: bool = False) -> None:
19
+ """Validate that a category string is valid."""
20
+ if allow_all and category == "all":
21
+ return
22
+ if category not in VALID_CATEGORIES:
23
+ raise ValueError(f"Invalid category: '{category}'")
24
+
25
+
26
+ # Type alias for refresh_knowledge which accepts a specific category or "all"
27
+ CategoryOrAll = Union[CategoryEnum, Literal["all"]]