kvcache-simulator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kvcache_sim/__init__.py +14 -0
- kvcache_sim/__main__.py +5 -0
- kvcache_sim/_resources.py +27 -0
- kvcache_sim/calculator.py +442 -0
- kvcache_sim/cli.py +115 -0
- kvcache_sim/cpp_backend.py +203 -0
- kvcache_sim/formatting.py +63 -0
- kvcache_sim/plan.py +78 -0
- kvcache_sim/policies.py +286 -0
- kvcache_sim/progress.py +36 -0
- kvcache_sim/resources/__init__.py +1 -0
- kvcache_sim/resources/kv-cache-lab-native-sim.cc +710 -0
- kvcache_sim/resources/models.yaml +934 -0
- kvcache_sim/simulator.py +226 -0
- kvcache_sim/trace.py +226 -0
- kvcache_simulator-0.1.0.dist-info/METADATA +129 -0
- kvcache_simulator-0.1.0.dist-info/RECORD +21 -0
- kvcache_simulator-0.1.0.dist-info/WHEEL +5 -0
- kvcache_simulator-0.1.0.dist-info/entry_points.txt +2 -0
- kvcache_simulator-0.1.0.dist-info/licenses/LICENSE.md +21 -0
- kvcache_simulator-0.1.0.dist-info/top_level.txt +1 -0
kvcache_sim/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Local KV cache hit-rate simulator aligned with KVCache.AI web tools."""
|
|
2
|
+
|
|
3
|
+
from .calculator import BYTES_PER_GIB, calculate_cache_size, load_models_data
|
|
4
|
+
from .simulator import run_sweep
|
|
5
|
+
|
|
6
|
+
__version__ = "0.1.0"
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"BYTES_PER_GIB",
|
|
10
|
+
"__version__",
|
|
11
|
+
"calculate_cache_size",
|
|
12
|
+
"load_models_data",
|
|
13
|
+
"run_sweep",
|
|
14
|
+
]
|
kvcache_sim/__main__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib import resources
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import hashlib
|
|
6
|
+
import re
|
|
7
|
+
import tempfile
|
|
8
|
+
|
|
9
|
+
RESOURCE_PACKAGE = "kvcache_sim.resources"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def package_resource_path(name: str) -> Path:
|
|
13
|
+
resource = resources.files(RESOURCE_PACKAGE).joinpath(name)
|
|
14
|
+
try:
|
|
15
|
+
path = Path(resource)
|
|
16
|
+
if path.exists():
|
|
17
|
+
return path
|
|
18
|
+
except TypeError:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
payload = resource.read_bytes()
|
|
22
|
+
digest = hashlib.sha256(payload).hexdigest()[:16]
|
|
23
|
+
safe_name = re.sub(r"[^A-Za-z0-9_.-]+", "-", name)
|
|
24
|
+
target = Path(tempfile.gettempdir()) / f"kvcache-simulator-{digest}-{safe_name}"
|
|
25
|
+
if not target.exists() or target.read_bytes() != payload:
|
|
26
|
+
target.write_bytes(payload)
|
|
27
|
+
return target
|
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import math
|
|
6
|
+
import re
|
|
7
|
+
import shlex
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from ._resources import package_resource_path
|
|
11
|
+
|
|
12
|
+
BYTES_PER_GB = 1_000_000_000
|
|
13
|
+
BYTES_PER_GIB = 1024 ** 3
|
|
14
|
+
QWEN_LINEAR_CONV_BYTES_PER_ELEMENT = 2
|
|
15
|
+
QWEN_LINEAR_RECURRENT_BYTES_PER_ELEMENT = 4
|
|
16
|
+
|
|
17
|
+
DEFAULT_PRECISIONS = {
|
|
18
|
+
"bf16_fp16": {"label": "BF16 / FP16", "bytes_per_element": 2.0},
|
|
19
|
+
"fp8_int8": {"label": "FP8 / INT8", "bytes_per_element": 1.0},
|
|
20
|
+
"fp4_int4": {"label": "FP4 / INT4", "bytes_per_element": 0.5},
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True)
|
|
25
|
+
class CacheSizeResult:
|
|
26
|
+
model_id: str
|
|
27
|
+
model_label: str
|
|
28
|
+
precision: str
|
|
29
|
+
precision_label: str
|
|
30
|
+
indexer_precision: str | None
|
|
31
|
+
indexer_precision_label: str | None
|
|
32
|
+
bytes_per_token: float
|
|
33
|
+
bytes_per_block: float | None
|
|
34
|
+
kv_bytes: float
|
|
35
|
+
indexer_bytes: float
|
|
36
|
+
total_bytes: float
|
|
37
|
+
total_gib: float
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def default_models_path() -> Path:
|
|
41
|
+
return package_resource_path("models.yaml")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _parse_scalar(value: str) -> Any:
|
|
45
|
+
value = value.strip()
|
|
46
|
+
if value == "":
|
|
47
|
+
return {}
|
|
48
|
+
if value in {"true", "True"}:
|
|
49
|
+
return True
|
|
50
|
+
if value in {"false", "False"}:
|
|
51
|
+
return False
|
|
52
|
+
if value.startswith("[") and value.endswith("]"):
|
|
53
|
+
inner = value[1:-1].strip()
|
|
54
|
+
if not inner:
|
|
55
|
+
return []
|
|
56
|
+
return [_parse_scalar(part) for part in shlex.split(inner.replace(",", " "))]
|
|
57
|
+
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
|
58
|
+
return value[1:-1]
|
|
59
|
+
if re.fullmatch(r"-?\d+", value):
|
|
60
|
+
return int(value)
|
|
61
|
+
if re.fullmatch(r"-?(\d+\.\d*|\d*\.\d+)([eE][+-]?\d+)?", value) or re.fullmatch(r"-?\d+[eE][+-]?\d+", value):
|
|
62
|
+
return float(value)
|
|
63
|
+
return value
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _split_key_value(line: str) -> tuple[str, Any]:
|
|
67
|
+
key, value = line.split(":", 1)
|
|
68
|
+
return key.strip(), _parse_scalar(value)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _load_models_data_minimal_yaml(path: Path) -> dict[str, Any]:
|
|
72
|
+
"""Parse the small YAML subset used by data/kv_cache_calculator/models.yaml.
|
|
73
|
+
|
|
74
|
+
This fallback keeps the CLI usable without PyYAML. It is intentionally
|
|
75
|
+
narrow: top-level maps, lists of maps, and a nested scalar `fields` map.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
lines = path.read_text(encoding="utf-8").splitlines()
|
|
79
|
+
data: dict[str, Any] = {}
|
|
80
|
+
current_section: str | None = None
|
|
81
|
+
current_item: dict[str, Any] | None = None
|
|
82
|
+
in_fields = False
|
|
83
|
+
in_serving_references = False
|
|
84
|
+
|
|
85
|
+
for raw in lines:
|
|
86
|
+
if not raw.strip() or raw.lstrip().startswith("#"):
|
|
87
|
+
continue
|
|
88
|
+
indent = len(raw) - len(raw.lstrip(" "))
|
|
89
|
+
text = raw.strip()
|
|
90
|
+
|
|
91
|
+
if indent == 0:
|
|
92
|
+
key, value = _split_key_value(text)
|
|
93
|
+
current_section = key
|
|
94
|
+
current_item = None
|
|
95
|
+
in_fields = False
|
|
96
|
+
in_serving_references = False
|
|
97
|
+
if value == {}:
|
|
98
|
+
data[key] = [] if key in {"precision_options", "indexer_precision_options", "models"} else {}
|
|
99
|
+
else:
|
|
100
|
+
data[key] = value
|
|
101
|
+
continue
|
|
102
|
+
|
|
103
|
+
if current_section is None:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
if current_section == "metadata":
|
|
107
|
+
if indent == 2:
|
|
108
|
+
key, value = _split_key_value(text)
|
|
109
|
+
data["metadata"][key] = value
|
|
110
|
+
in_serving_references = key == "serving_references" and value == {}
|
|
111
|
+
elif indent == 4 and in_serving_references:
|
|
112
|
+
key, value = _split_key_value(text)
|
|
113
|
+
data["metadata"].setdefault("serving_references", {})[key] = value
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
if current_section in {"precision_options", "indexer_precision_options", "models"}:
|
|
117
|
+
if indent == 2 and text.startswith("- "):
|
|
118
|
+
current_item = {}
|
|
119
|
+
data[current_section].append(current_item)
|
|
120
|
+
key, value = _split_key_value(text[2:])
|
|
121
|
+
current_item[key] = value
|
|
122
|
+
in_fields = False
|
|
123
|
+
continue
|
|
124
|
+
if current_item is None:
|
|
125
|
+
continue
|
|
126
|
+
if indent == 4:
|
|
127
|
+
key, value = _split_key_value(text)
|
|
128
|
+
current_item[key] = value
|
|
129
|
+
in_fields = key == "fields" and value == {}
|
|
130
|
+
if in_fields:
|
|
131
|
+
current_item["fields"] = {}
|
|
132
|
+
continue
|
|
133
|
+
if indent == 6 and in_fields:
|
|
134
|
+
key, value = _split_key_value(text)
|
|
135
|
+
current_item["fields"][key] = value
|
|
136
|
+
|
|
137
|
+
return data
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def load_models_data(path: str | Path | None = None) -> dict[str, Any]:
|
|
141
|
+
model_path = Path(path) if path else default_models_path()
|
|
142
|
+
try:
|
|
143
|
+
import yaml # type: ignore
|
|
144
|
+
|
|
145
|
+
with model_path.open("r", encoding="utf-8") as handle:
|
|
146
|
+
return yaml.safe_load(handle)
|
|
147
|
+
except ModuleNotFoundError:
|
|
148
|
+
return _load_models_data_minimal_yaml(model_path)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def precision_options(data: dict[str, Any]) -> dict[str, dict[str, Any]]:
|
|
152
|
+
return {
|
|
153
|
+
item["id"]: {
|
|
154
|
+
"label": item.get("label", item["id"]),
|
|
155
|
+
"bytes_per_element": float(item.get("bytes_per_element", item.get("bytesPerElement", 0))),
|
|
156
|
+
}
|
|
157
|
+
for item in data.get("precision_options", [])
|
|
158
|
+
} or DEFAULT_PRECISIONS
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def indexer_precision_options(data: dict[str, Any]) -> dict[str, dict[str, Any]]:
|
|
162
|
+
items = data.get("indexer_precision_options") or data.get("precision_options") or []
|
|
163
|
+
return {
|
|
164
|
+
item["id"]: {
|
|
165
|
+
"label": item.get("label", item["id"]),
|
|
166
|
+
"bytes_per_element": float(item.get("bytes_per_element", item.get("bytesPerElement", 0))),
|
|
167
|
+
}
|
|
168
|
+
for item in items
|
|
169
|
+
} or precision_options(data)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def models_by_id(data: dict[str, Any]) -> dict[str, dict[str, Any]]:
|
|
173
|
+
return {model["id"]: model for model in data.get("models", [])}
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _safe_number(value: Any, fallback: float = 0) -> float:
|
|
177
|
+
try:
|
|
178
|
+
parsed = float(value)
|
|
179
|
+
except (TypeError, ValueError):
|
|
180
|
+
return fallback
|
|
181
|
+
return parsed if math.isfinite(parsed) else fallback
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _positive_int(value: Any, fallback: int) -> int:
|
|
185
|
+
parsed = math.floor(_safe_number(value, fallback))
|
|
186
|
+
return max(1, parsed) if parsed > 0 else max(1, fallback)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _field(model: dict[str, Any], name: str) -> float:
|
|
190
|
+
fields = model.get("fields") or {}
|
|
191
|
+
if name not in fields:
|
|
192
|
+
raise ValueError(f"Model {model.get('id', '')} is missing numeric field {name}")
|
|
193
|
+
parsed = _safe_number(fields[name], math.nan)
|
|
194
|
+
if not math.isfinite(parsed):
|
|
195
|
+
raise ValueError(f"Model {model.get('id', '')} is missing numeric field {name}")
|
|
196
|
+
return parsed
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _optional_field(model: dict[str, Any], name: str, fallback: float) -> float:
|
|
200
|
+
fields = model.get("fields") or {}
|
|
201
|
+
return _safe_number(fields.get(name), fallback) if name in fields else fallback
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _is_deepseek_v4(model: dict[str, Any]) -> bool:
|
|
205
|
+
return model.get("formula") == "deepseek_v4_hybrid"
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _has_indexer_cache(model: dict[str, Any]) -> bool:
|
|
209
|
+
fields = model.get("fields") or {}
|
|
210
|
+
return math.isfinite(_safe_number(fields.get("index_head_dim"), math.nan))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _draft_layer_count(model: dict[str, Any]) -> int:
|
|
214
|
+
fields = model.get("fields") or {}
|
|
215
|
+
if fields.get("disable_draft_kv_cache") is True:
|
|
216
|
+
return 0
|
|
217
|
+
nextn_layers = int(_safe_number(fields.get("num_nextn_predict_layers"), 0))
|
|
218
|
+
if nextn_layers > 0:
|
|
219
|
+
return nextn_layers
|
|
220
|
+
if fields.get("use_mtp") is True:
|
|
221
|
+
return int(_safe_number(fields.get("num_mtp_modules"), 0) * _safe_number(fields.get("mtp_transformer_layers"), 0))
|
|
222
|
+
return 0
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def has_draft_kv_cache(model: dict[str, Any]) -> bool:
|
|
226
|
+
fields = model.get("fields") or {}
|
|
227
|
+
if _is_deepseek_v4(model):
|
|
228
|
+
layers = int(_safe_number(fields.get("num_hidden_layers"), 0))
|
|
229
|
+
ratios = fields.get("compress_ratios")
|
|
230
|
+
return isinstance(ratios, list) and len(ratios) > layers
|
|
231
|
+
return _draft_layer_count(model) > 0
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _fixed_indexer_precision_id(model: dict[str, Any]) -> str | None:
|
|
235
|
+
value = (model.get("fields") or {}).get("indexer_fixed_precision_id")
|
|
236
|
+
return value if isinstance(value, str) else None
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def default_precision_id(model: dict[str, Any], options: dict[str, dict[str, Any]]) -> str:
|
|
240
|
+
if _is_deepseek_v4(model) and "fp8_int8" in options:
|
|
241
|
+
return "fp8_int8"
|
|
242
|
+
if "bf16_fp16" in options:
|
|
243
|
+
return "bf16_fp16"
|
|
244
|
+
return next(iter(options))
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def default_indexer_precision_id(model: dict[str, Any], options: dict[str, dict[str, Any]], fallback_precision: str | None) -> str:
|
|
248
|
+
fixed = _fixed_indexer_precision_id(model)
|
|
249
|
+
if fixed and fixed in options:
|
|
250
|
+
return fixed
|
|
251
|
+
if _is_deepseek_v4(model) and "fp4_int4" in options:
|
|
252
|
+
return "fp4_int4"
|
|
253
|
+
if fallback_precision and fallback_precision in options:
|
|
254
|
+
return fallback_precision
|
|
255
|
+
if "bf16_fp16" in options:
|
|
256
|
+
return "bf16_fp16"
|
|
257
|
+
return "fp4_int4" if "fp4_int4" in options else next(iter(options))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _indexer_layer_plan(model: dict[str, Any], layers: int, draft_layers: int) -> tuple[int, int, int, int]:
|
|
261
|
+
main = int(_optional_field(model, "indexer_full_layers", layers))
|
|
262
|
+
shared = int(_optional_field(model, "indexer_shared_layers", max(0, layers - main)))
|
|
263
|
+
draft = int(_optional_field(model, "draft_indexer_layers", draft_layers)) if draft_layers > 0 else 0
|
|
264
|
+
return main, shared, draft, main + draft
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _calculate_byte_groups(model: dict[str, Any], tokens: int, include_draft_kv_cache: bool, include_linear_attention_state: bool) -> tuple[float, list[dict[str, Any]]]:
|
|
268
|
+
formula = model.get("formula")
|
|
269
|
+
fields = model.get("fields") or {}
|
|
270
|
+
draft_layers = _draft_layer_count(model) if include_draft_kv_cache else 0
|
|
271
|
+
|
|
272
|
+
if formula == "standard_gqa":
|
|
273
|
+
layers = int(_field(model, "num_hidden_layers")) + draft_layers
|
|
274
|
+
elements_per_token = layers * 2 * _field(model, "num_key_value_heads") * _field(model, "head_dim")
|
|
275
|
+
return elements_per_token, [{"role": "kv", "label": "KV cache", "elements": elements_per_token * tokens}]
|
|
276
|
+
|
|
277
|
+
if formula == "mla":
|
|
278
|
+
layers = int(_field(model, "num_hidden_layers")) + draft_layers
|
|
279
|
+
elements_per_token = layers * (_field(model, "kv_lora_rank") + _field(model, "qk_rope_head_dim"))
|
|
280
|
+
return elements_per_token, [{"role": "kv", "label": "KV cache", "elements": elements_per_token * tokens}]
|
|
281
|
+
|
|
282
|
+
if formula == "dsa_mla":
|
|
283
|
+
layers = int(_field(model, "num_hidden_layers"))
|
|
284
|
+
active_layers = layers + draft_layers
|
|
285
|
+
_, _, _, active_indexer_layers = _indexer_layer_plan(model, layers, draft_layers)
|
|
286
|
+
kv_elements_per_token = active_layers * (_field(model, "kv_lora_rank") + _field(model, "qk_rope_head_dim"))
|
|
287
|
+
indexer_elements_per_token = active_indexer_layers * _field(model, "index_head_dim")
|
|
288
|
+
return kv_elements_per_token + indexer_elements_per_token, [
|
|
289
|
+
{"role": "kv", "label": "KV cache", "elements": kv_elements_per_token * tokens},
|
|
290
|
+
{"role": "indexer", "label": "Indexer cache", "elements": indexer_elements_per_token * tokens},
|
|
291
|
+
]
|
|
292
|
+
|
|
293
|
+
if formula == "qwen_linear_full_hybrid":
|
|
294
|
+
full_layers = _field(model, "full_attention_layers")
|
|
295
|
+
kv_heads = _field(model, "num_key_value_heads")
|
|
296
|
+
head_dim = _field(model, "head_dim")
|
|
297
|
+
elements_per_token = full_layers * 2 * kv_heads * head_dim
|
|
298
|
+
groups = [{"role": "kv", "label": "Full-attention KV cache", "elements": elements_per_token * tokens}]
|
|
299
|
+
if include_linear_attention_state:
|
|
300
|
+
linear_layers = _field(model, "linear_attention_layers")
|
|
301
|
+
conv_kernel = _field(model, "linear_conv_kernel_dim")
|
|
302
|
+
key_heads = _field(model, "linear_num_key_heads")
|
|
303
|
+
key_dim = _field(model, "linear_key_head_dim")
|
|
304
|
+
value_heads = _field(model, "linear_num_value_heads")
|
|
305
|
+
value_dim = _field(model, "linear_value_head_dim")
|
|
306
|
+
conv_elements = linear_layers * conv_kernel * (2 * key_heads * key_dim + value_heads * value_dim)
|
|
307
|
+
recurrent_elements = linear_layers * value_heads * key_dim * value_dim
|
|
308
|
+
groups.append({
|
|
309
|
+
"role": "linear_state",
|
|
310
|
+
"label": "Linear-attention state",
|
|
311
|
+
"bytes_per_sequence": conv_elements * QWEN_LINEAR_CONV_BYTES_PER_ELEMENT + recurrent_elements * QWEN_LINEAR_RECURRENT_BYTES_PER_ELEMENT,
|
|
312
|
+
})
|
|
313
|
+
return elements_per_token, groups
|
|
314
|
+
|
|
315
|
+
if formula == "mixed_full_sliding_gqa":
|
|
316
|
+
full_layers = _field(model, "full_attention_layers")
|
|
317
|
+
sliding_layers = _field(model, "sliding_attention_layers")
|
|
318
|
+
kv_heads = _field(model, "num_key_value_heads")
|
|
319
|
+
head_dim = _field(model, "head_dim")
|
|
320
|
+
full_kv_heads = _optional_field(model, "num_global_key_value_heads", kv_heads)
|
|
321
|
+
full_head_dim = _optional_field(model, "global_head_dim", head_dim)
|
|
322
|
+
full_v_dim = _optional_field(model, "global_v_head_dim", _optional_field(model, "v_head_dim", full_head_dim))
|
|
323
|
+
sliding_kv_heads = _optional_field(model, "swa_num_key_value_heads", _optional_field(model, "sliding_num_key_value_heads", kv_heads))
|
|
324
|
+
sliding_head_dim = _optional_field(model, "swa_head_dim", _optional_field(model, "sliding_head_dim", head_dim))
|
|
325
|
+
sliding_v_dim = _optional_field(model, "swa_v_head_dim", _optional_field(model, "sliding_v_head_dim", _optional_field(model, "v_head_dim", sliding_head_dim)))
|
|
326
|
+
retained_sliding_tokens = min(tokens, int(_field(model, "sliding_window")))
|
|
327
|
+
full_elements = tokens * full_layers * full_kv_heads * (full_head_dim + full_v_dim)
|
|
328
|
+
sliding_elements = retained_sliding_tokens * sliding_layers * sliding_kv_heads * (sliding_head_dim + sliding_v_dim)
|
|
329
|
+
return (full_elements + sliding_elements) / tokens, [
|
|
330
|
+
{"role": "kv", "label": "Full-attention KV cache", "elements": full_elements},
|
|
331
|
+
{"role": "kv", "label": "Sliding-window KV cache", "elements": sliding_elements},
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
if formula == "minimax_msa":
|
|
335
|
+
layers = _field(model, "num_hidden_layers")
|
|
336
|
+
sparse_layers = _field(model, "sparse_attention_layers")
|
|
337
|
+
kv_elements_per_token = layers * 2 * _field(model, "num_key_value_heads") * _field(model, "head_dim")
|
|
338
|
+
indexer_elements_per_token = sparse_layers * _field(model, "index_head_dim")
|
|
339
|
+
return kv_elements_per_token + indexer_elements_per_token, [
|
|
340
|
+
{"role": "kv", "label": "KV cache", "elements": kv_elements_per_token * tokens},
|
|
341
|
+
{"role": "indexer", "label": "Indexer cache", "elements": indexer_elements_per_token * tokens},
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
if formula == "deepseek_v4_hybrid":
|
|
345
|
+
head_dim = _field(model, "head_dim")
|
|
346
|
+
index_dim = _field(model, "index_head_dim")
|
|
347
|
+
sliding_window = _field(model, "sliding_window")
|
|
348
|
+
layers = int(_field(model, "num_hidden_layers"))
|
|
349
|
+
ratios = [float(ratio) for ratio in fields.get("compress_ratios", [])]
|
|
350
|
+
active_ratios = ratios[:layers] + (ratios[layers:] if include_draft_kv_cache else [])
|
|
351
|
+
if not active_ratios:
|
|
352
|
+
raise ValueError(f"Model {model.get('id', '')} is missing compress_ratios")
|
|
353
|
+
window_elements = 0.0
|
|
354
|
+
compressed_elements = 0.0
|
|
355
|
+
indexer_elements = 0.0
|
|
356
|
+
for ratio in active_ratios:
|
|
357
|
+
window_elements += sliding_window * head_dim
|
|
358
|
+
if ratio > 0:
|
|
359
|
+
compressed_elements += math.floor(tokens / ratio) * head_dim
|
|
360
|
+
if ratio == 4:
|
|
361
|
+
indexer_elements += math.floor(tokens / 4) * index_dim
|
|
362
|
+
attention_elements = window_elements + compressed_elements
|
|
363
|
+
return (attention_elements + indexer_elements) / tokens, [
|
|
364
|
+
{"role": "kv", "label": "KV cache", "elements": attention_elements},
|
|
365
|
+
{"role": "indexer", "label": "Indexer cache", "elements": indexer_elements},
|
|
366
|
+
]
|
|
367
|
+
|
|
368
|
+
raise ValueError(f"Unsupported formula: {formula}")
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _bytes_for_group(group: dict[str, Any], kv_precision_bytes: float, indexer_precision_bytes: float | None) -> float:
|
|
372
|
+
if "bytes_per_sequence" in group:
|
|
373
|
+
return float(group["bytes_per_sequence"])
|
|
374
|
+
role = group["role"]
|
|
375
|
+
if role == "indexer" and indexer_precision_bytes is not None:
|
|
376
|
+
bytes_per_element = indexer_precision_bytes
|
|
377
|
+
else:
|
|
378
|
+
bytes_per_element = kv_precision_bytes
|
|
379
|
+
return float(group["elements"]) * bytes_per_element
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def calculate_cache_size(
|
|
383
|
+
model: dict[str, Any],
|
|
384
|
+
*,
|
|
385
|
+
tokens: int,
|
|
386
|
+
precision: str | None = None,
|
|
387
|
+
indexer_precision: str | None = None,
|
|
388
|
+
block_size: int | None = None,
|
|
389
|
+
include_draft_kv_cache: bool = False,
|
|
390
|
+
include_linear_attention_state: bool = False,
|
|
391
|
+
models_data: dict[str, Any] | None = None,
|
|
392
|
+
) -> CacheSizeResult:
|
|
393
|
+
data = models_data or load_models_data()
|
|
394
|
+
precision_by_id = precision_options(data)
|
|
395
|
+
indexer_precision_by_id = indexer_precision_options(data)
|
|
396
|
+
precision_id = precision or default_precision_id(model, precision_by_id)
|
|
397
|
+
if precision_id not in precision_by_id:
|
|
398
|
+
raise ValueError(f"Unknown KV precision: {precision_id}")
|
|
399
|
+
precision_profile = precision_by_id[precision_id]
|
|
400
|
+
kv_precision_bytes = float(precision_profile["bytes_per_element"])
|
|
401
|
+
|
|
402
|
+
indexer_precision_id: str | None = None
|
|
403
|
+
indexer_precision_label: str | None = None
|
|
404
|
+
indexer_precision_bytes: float | None = None
|
|
405
|
+
if _has_indexer_cache(model):
|
|
406
|
+
fixed_indexer_precision = _fixed_indexer_precision_id(model)
|
|
407
|
+
indexer_precision_id = fixed_indexer_precision or indexer_precision or default_indexer_precision_id(model, indexer_precision_by_id, precision_id)
|
|
408
|
+
if indexer_precision_id not in indexer_precision_by_id:
|
|
409
|
+
raise ValueError(f"Unknown indexer precision: {indexer_precision_id}")
|
|
410
|
+
indexer_profile = indexer_precision_by_id[indexer_precision_id]
|
|
411
|
+
indexer_precision_label = str(indexer_profile["label"])
|
|
412
|
+
indexer_precision_bytes = float(indexer_profile["bytes_per_element"])
|
|
413
|
+
|
|
414
|
+
active_draft = include_draft_kv_cache and has_draft_kv_cache(model)
|
|
415
|
+
tokens = _positive_int(tokens, int(model.get("default_tokens") or 4096))
|
|
416
|
+
_, groups = _calculate_byte_groups(model, tokens, active_draft, include_linear_attention_state)
|
|
417
|
+
kv_bytes = 0.0
|
|
418
|
+
indexer_bytes = 0.0
|
|
419
|
+
total_bytes = 0.0
|
|
420
|
+
for group in groups:
|
|
421
|
+
group_bytes = _bytes_for_group(group, kv_precision_bytes, indexer_precision_bytes)
|
|
422
|
+
total_bytes += group_bytes
|
|
423
|
+
if group["role"] == "indexer":
|
|
424
|
+
indexer_bytes += group_bytes
|
|
425
|
+
else:
|
|
426
|
+
kv_bytes += group_bytes
|
|
427
|
+
|
|
428
|
+
bytes_per_token = total_bytes / tokens
|
|
429
|
+
return CacheSizeResult(
|
|
430
|
+
model_id=str(model["id"]),
|
|
431
|
+
model_label=str(model.get("label") or model["id"]),
|
|
432
|
+
precision=precision_id,
|
|
433
|
+
precision_label=str(precision_profile["label"]),
|
|
434
|
+
indexer_precision=indexer_precision_id,
|
|
435
|
+
indexer_precision_label=indexer_precision_label,
|
|
436
|
+
bytes_per_token=bytes_per_token,
|
|
437
|
+
bytes_per_block=bytes_per_token * block_size if block_size else None,
|
|
438
|
+
kv_bytes=kv_bytes,
|
|
439
|
+
indexer_bytes=indexer_bytes,
|
|
440
|
+
total_bytes=total_bytes,
|
|
441
|
+
total_gib=total_bytes / BYTES_PER_GIB,
|
|
442
|
+
)
|
kvcache_sim/cli.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from .calculator import load_models_data, models_by_id
|
|
8
|
+
from .formatting import render_json, render_table
|
|
9
|
+
from .progress import ProgressBar
|
|
10
|
+
from .simulator import DEFAULT_BUDGETS_GIB, DEFAULT_POLICIES, run_sweep
|
|
11
|
+
from .trace import parse_trace_file
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _parse_csv_numbers(value: str) -> list[float]:
|
|
15
|
+
if not value:
|
|
16
|
+
return []
|
|
17
|
+
return [float(part.strip()) for part in value.split(",") if part.strip()]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _parse_csv_strings(value: str) -> list[str]:
|
|
21
|
+
if not value:
|
|
22
|
+
return []
|
|
23
|
+
return [part.strip().lower() for part in value.split(",") if part.strip()]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
27
|
+
parser = argparse.ArgumentParser(prog="kvcache-simulator", description="Analyze KV cache hit rate for JSONL traces.")
|
|
28
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
29
|
+
|
|
30
|
+
def add_run_arguments(command: argparse.ArgumentParser) -> None:
|
|
31
|
+
command.add_argument("--trace", required=True, help="Trace path (.jsonl or .jsonl.gz), or - for stdin")
|
|
32
|
+
command.add_argument("--model", required=True, help="Model id from the bundled KV Cache Size Calculator model catalog")
|
|
33
|
+
command.add_argument("--kv-precision", dest="kv_precision", default=None, help="KV precision id, e.g. bf16_fp16, fp8_int8, fp4_int4")
|
|
34
|
+
command.add_argument("--precision", dest="kv_precision", help=argparse.SUPPRESS)
|
|
35
|
+
command.add_argument("--indexer-precision", default=None, help="Indexer precision id for models with an indexer cache")
|
|
36
|
+
command.add_argument("--include-draft-kv-cache", action="store_true", help="Include draft/MTP KV cache where the model config supports it")
|
|
37
|
+
command.add_argument("--block-size", type=int, default=None, help="Fallback block size when trace records omit block_size; record block_size overrides this value")
|
|
38
|
+
command.add_argument("--estimate-tokens", type=int, default=None, help="Override calculator token count used for bytes/token")
|
|
39
|
+
command.add_argument("--budgets-gib", default=",".join(str(v) for v in DEFAULT_BUDGETS_GIB), help="Comma-separated GiB budgets")
|
|
40
|
+
command.add_argument("--policies", default=",".join(DEFAULT_POLICIES), help="Comma-separated policies: fifo,lru,optimal")
|
|
41
|
+
command.add_argument("--backend", choices=["cpp", "python"], default="cpp", help="Simulation backend (default: cpp)")
|
|
42
|
+
command.add_argument("--jobs", type=int, default=1, help="Worker processes for the Python backend; ignored by the C++ backend")
|
|
43
|
+
command.add_argument("--no-progress", action="store_true", help="Disable terminal progress output")
|
|
44
|
+
command.add_argument("--format", choices=["table", "json"], default="table", help="Output format (default: table)")
|
|
45
|
+
command.add_argument("--output", "-o", default="-", help="Output path, or - for stdout")
|
|
46
|
+
command.add_argument("--models-yaml", default=None, help="Override models.yaml path")
|
|
47
|
+
command.add_argument("--max-records", type=int, default=0, help="Stop after this many valid requests (debug/testing)")
|
|
48
|
+
command.add_argument("--max-events", type=int, default=0, help="Stop after this many trace blocks (debug/testing)")
|
|
49
|
+
|
|
50
|
+
run = subparsers.add_parser("run", help="Run hit-rate analysis over configured KV cache budgets")
|
|
51
|
+
add_run_arguments(run)
|
|
52
|
+
|
|
53
|
+
sweep = subparsers.add_parser("sweep", help="Alias for run; scans a set of KV cache memory budgets")
|
|
54
|
+
add_run_arguments(sweep)
|
|
55
|
+
|
|
56
|
+
list_models = subparsers.add_parser("list-models", help="List supported model ids")
|
|
57
|
+
list_models.add_argument("--models-yaml", default=None, help="Override models.yaml path")
|
|
58
|
+
|
|
59
|
+
return parser
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _write_output(text: str, output: str) -> None:
|
|
63
|
+
if output == "-":
|
|
64
|
+
print(text)
|
|
65
|
+
return
|
|
66
|
+
Path(output).write_text(text + "\n", encoding="utf-8")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def run_sweep_command(args: argparse.Namespace) -> int:
|
|
70
|
+
progress = ProgressBar(enabled=(not args.no_progress and sys.stderr.isatty()))
|
|
71
|
+
try:
|
|
72
|
+
data = load_models_data(args.models_yaml)
|
|
73
|
+
progress.update(0, 4, "reading trace")
|
|
74
|
+
trace = parse_trace_file(args.trace, block_size=args.block_size, max_records=args.max_records, max_events=args.max_events)
|
|
75
|
+
progress.update(1, 4, "trace loaded")
|
|
76
|
+
result = run_sweep(
|
|
77
|
+
trace,
|
|
78
|
+
model_id=args.model,
|
|
79
|
+
precision=args.kv_precision,
|
|
80
|
+
indexer_precision=args.indexer_precision,
|
|
81
|
+
budgets_gib=_parse_csv_numbers(args.budgets_gib),
|
|
82
|
+
policies=_parse_csv_strings(args.policies),
|
|
83
|
+
jobs=args.jobs,
|
|
84
|
+
backend=args.backend,
|
|
85
|
+
progress=progress.update,
|
|
86
|
+
estimate_tokens=args.estimate_tokens,
|
|
87
|
+
include_draft_kv_cache=args.include_draft_kv_cache,
|
|
88
|
+
models_data=data,
|
|
89
|
+
)
|
|
90
|
+
rendered = render_json(result) if args.format == "json" else render_table(result)
|
|
91
|
+
_write_output(rendered, args.output)
|
|
92
|
+
progress.finish()
|
|
93
|
+
return 0
|
|
94
|
+
except Exception:
|
|
95
|
+
progress.close()
|
|
96
|
+
raise
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def run_list_models(args: argparse.Namespace) -> int:
|
|
100
|
+
data = load_models_data(args.models_yaml)
|
|
101
|
+
models = models_by_id(data)
|
|
102
|
+
for model in sorted(models.values(), key=lambda item: (item.get("family", ""), item.get("label", ""))):
|
|
103
|
+
print(f"{model['id']}\t{model.get('label', model['id'])}\t{model.get('family', '')}\t{model.get('formula', '')}")
|
|
104
|
+
return 0
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def main(argv: list[str] | None = None) -> int:
|
|
108
|
+
parser = build_parser()
|
|
109
|
+
args = parser.parse_args(argv)
|
|
110
|
+
if args.command in {"run", "sweep"}:
|
|
111
|
+
return run_sweep_command(args)
|
|
112
|
+
if args.command == "list-models":
|
|
113
|
+
return run_list_models(args)
|
|
114
|
+
parser.print_help(sys.stderr)
|
|
115
|
+
return 2
|