kvcache-simulator 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ """Local KV cache hit-rate simulator aligned with KVCache.AI web tools."""
2
+
3
+ from .calculator import BYTES_PER_GIB, calculate_cache_size, load_models_data
4
+ from .simulator import run_sweep
5
+
6
+ __version__ = "0.1.0"
7
+
8
+ __all__ = [
9
+ "BYTES_PER_GIB",
10
+ "__version__",
11
+ "calculate_cache_size",
12
+ "load_models_data",
13
+ "run_sweep",
14
+ ]
@@ -0,0 +1,5 @@
1
+ from .cli import main
2
+
3
+
4
+ if __name__ == "__main__":
5
+ raise SystemExit(main())
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from importlib import resources
4
+ from pathlib import Path
5
+ import hashlib
6
+ import re
7
+ import tempfile
8
+
9
+ RESOURCE_PACKAGE = "kvcache_sim.resources"
10
+
11
+
12
+ def package_resource_path(name: str) -> Path:
13
+ resource = resources.files(RESOURCE_PACKAGE).joinpath(name)
14
+ try:
15
+ path = Path(resource)
16
+ if path.exists():
17
+ return path
18
+ except TypeError:
19
+ pass
20
+
21
+ payload = resource.read_bytes()
22
+ digest = hashlib.sha256(payload).hexdigest()[:16]
23
+ safe_name = re.sub(r"[^A-Za-z0-9_.-]+", "-", name)
24
+ target = Path(tempfile.gettempdir()) / f"kvcache-simulator-{digest}-{safe_name}"
25
+ if not target.exists() or target.read_bytes() != payload:
26
+ target.write_bytes(payload)
27
+ return target
@@ -0,0 +1,442 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ import math
6
+ import re
7
+ import shlex
8
+ from typing import Any
9
+
10
+ from ._resources import package_resource_path
11
+
12
+ BYTES_PER_GB = 1_000_000_000
13
+ BYTES_PER_GIB = 1024 ** 3
14
+ QWEN_LINEAR_CONV_BYTES_PER_ELEMENT = 2
15
+ QWEN_LINEAR_RECURRENT_BYTES_PER_ELEMENT = 4
16
+
17
+ DEFAULT_PRECISIONS = {
18
+ "bf16_fp16": {"label": "BF16 / FP16", "bytes_per_element": 2.0},
19
+ "fp8_int8": {"label": "FP8 / INT8", "bytes_per_element": 1.0},
20
+ "fp4_int4": {"label": "FP4 / INT4", "bytes_per_element": 0.5},
21
+ }
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class CacheSizeResult:
26
+ model_id: str
27
+ model_label: str
28
+ precision: str
29
+ precision_label: str
30
+ indexer_precision: str | None
31
+ indexer_precision_label: str | None
32
+ bytes_per_token: float
33
+ bytes_per_block: float | None
34
+ kv_bytes: float
35
+ indexer_bytes: float
36
+ total_bytes: float
37
+ total_gib: float
38
+
39
+
40
+ def default_models_path() -> Path:
41
+ return package_resource_path("models.yaml")
42
+
43
+
44
+ def _parse_scalar(value: str) -> Any:
45
+ value = value.strip()
46
+ if value == "":
47
+ return {}
48
+ if value in {"true", "True"}:
49
+ return True
50
+ if value in {"false", "False"}:
51
+ return False
52
+ if value.startswith("[") and value.endswith("]"):
53
+ inner = value[1:-1].strip()
54
+ if not inner:
55
+ return []
56
+ return [_parse_scalar(part) for part in shlex.split(inner.replace(",", " "))]
57
+ if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
58
+ return value[1:-1]
59
+ if re.fullmatch(r"-?\d+", value):
60
+ return int(value)
61
+ if re.fullmatch(r"-?(\d+\.\d*|\d*\.\d+)([eE][+-]?\d+)?", value) or re.fullmatch(r"-?\d+[eE][+-]?\d+", value):
62
+ return float(value)
63
+ return value
64
+
65
+
66
+ def _split_key_value(line: str) -> tuple[str, Any]:
67
+ key, value = line.split(":", 1)
68
+ return key.strip(), _parse_scalar(value)
69
+
70
+
71
+ def _load_models_data_minimal_yaml(path: Path) -> dict[str, Any]:
72
+ """Parse the small YAML subset used by data/kv_cache_calculator/models.yaml.
73
+
74
+ This fallback keeps the CLI usable without PyYAML. It is intentionally
75
+ narrow: top-level maps, lists of maps, and a nested scalar `fields` map.
76
+ """
77
+
78
+ lines = path.read_text(encoding="utf-8").splitlines()
79
+ data: dict[str, Any] = {}
80
+ current_section: str | None = None
81
+ current_item: dict[str, Any] | None = None
82
+ in_fields = False
83
+ in_serving_references = False
84
+
85
+ for raw in lines:
86
+ if not raw.strip() or raw.lstrip().startswith("#"):
87
+ continue
88
+ indent = len(raw) - len(raw.lstrip(" "))
89
+ text = raw.strip()
90
+
91
+ if indent == 0:
92
+ key, value = _split_key_value(text)
93
+ current_section = key
94
+ current_item = None
95
+ in_fields = False
96
+ in_serving_references = False
97
+ if value == {}:
98
+ data[key] = [] if key in {"precision_options", "indexer_precision_options", "models"} else {}
99
+ else:
100
+ data[key] = value
101
+ continue
102
+
103
+ if current_section is None:
104
+ continue
105
+
106
+ if current_section == "metadata":
107
+ if indent == 2:
108
+ key, value = _split_key_value(text)
109
+ data["metadata"][key] = value
110
+ in_serving_references = key == "serving_references" and value == {}
111
+ elif indent == 4 and in_serving_references:
112
+ key, value = _split_key_value(text)
113
+ data["metadata"].setdefault("serving_references", {})[key] = value
114
+ continue
115
+
116
+ if current_section in {"precision_options", "indexer_precision_options", "models"}:
117
+ if indent == 2 and text.startswith("- "):
118
+ current_item = {}
119
+ data[current_section].append(current_item)
120
+ key, value = _split_key_value(text[2:])
121
+ current_item[key] = value
122
+ in_fields = False
123
+ continue
124
+ if current_item is None:
125
+ continue
126
+ if indent == 4:
127
+ key, value = _split_key_value(text)
128
+ current_item[key] = value
129
+ in_fields = key == "fields" and value == {}
130
+ if in_fields:
131
+ current_item["fields"] = {}
132
+ continue
133
+ if indent == 6 and in_fields:
134
+ key, value = _split_key_value(text)
135
+ current_item["fields"][key] = value
136
+
137
+ return data
138
+
139
+
140
+ def load_models_data(path: str | Path | None = None) -> dict[str, Any]:
141
+ model_path = Path(path) if path else default_models_path()
142
+ try:
143
+ import yaml # type: ignore
144
+
145
+ with model_path.open("r", encoding="utf-8") as handle:
146
+ return yaml.safe_load(handle)
147
+ except ModuleNotFoundError:
148
+ return _load_models_data_minimal_yaml(model_path)
149
+
150
+
151
+ def precision_options(data: dict[str, Any]) -> dict[str, dict[str, Any]]:
152
+ return {
153
+ item["id"]: {
154
+ "label": item.get("label", item["id"]),
155
+ "bytes_per_element": float(item.get("bytes_per_element", item.get("bytesPerElement", 0))),
156
+ }
157
+ for item in data.get("precision_options", [])
158
+ } or DEFAULT_PRECISIONS
159
+
160
+
161
+ def indexer_precision_options(data: dict[str, Any]) -> dict[str, dict[str, Any]]:
162
+ items = data.get("indexer_precision_options") or data.get("precision_options") or []
163
+ return {
164
+ item["id"]: {
165
+ "label": item.get("label", item["id"]),
166
+ "bytes_per_element": float(item.get("bytes_per_element", item.get("bytesPerElement", 0))),
167
+ }
168
+ for item in items
169
+ } or precision_options(data)
170
+
171
+
172
+ def models_by_id(data: dict[str, Any]) -> dict[str, dict[str, Any]]:
173
+ return {model["id"]: model for model in data.get("models", [])}
174
+
175
+
176
+ def _safe_number(value: Any, fallback: float = 0) -> float:
177
+ try:
178
+ parsed = float(value)
179
+ except (TypeError, ValueError):
180
+ return fallback
181
+ return parsed if math.isfinite(parsed) else fallback
182
+
183
+
184
+ def _positive_int(value: Any, fallback: int) -> int:
185
+ parsed = math.floor(_safe_number(value, fallback))
186
+ return max(1, parsed) if parsed > 0 else max(1, fallback)
187
+
188
+
189
+ def _field(model: dict[str, Any], name: str) -> float:
190
+ fields = model.get("fields") or {}
191
+ if name not in fields:
192
+ raise ValueError(f"Model {model.get('id', '')} is missing numeric field {name}")
193
+ parsed = _safe_number(fields[name], math.nan)
194
+ if not math.isfinite(parsed):
195
+ raise ValueError(f"Model {model.get('id', '')} is missing numeric field {name}")
196
+ return parsed
197
+
198
+
199
+ def _optional_field(model: dict[str, Any], name: str, fallback: float) -> float:
200
+ fields = model.get("fields") or {}
201
+ return _safe_number(fields.get(name), fallback) if name in fields else fallback
202
+
203
+
204
+ def _is_deepseek_v4(model: dict[str, Any]) -> bool:
205
+ return model.get("formula") == "deepseek_v4_hybrid"
206
+
207
+
208
+ def _has_indexer_cache(model: dict[str, Any]) -> bool:
209
+ fields = model.get("fields") or {}
210
+ return math.isfinite(_safe_number(fields.get("index_head_dim"), math.nan))
211
+
212
+
213
+ def _draft_layer_count(model: dict[str, Any]) -> int:
214
+ fields = model.get("fields") or {}
215
+ if fields.get("disable_draft_kv_cache") is True:
216
+ return 0
217
+ nextn_layers = int(_safe_number(fields.get("num_nextn_predict_layers"), 0))
218
+ if nextn_layers > 0:
219
+ return nextn_layers
220
+ if fields.get("use_mtp") is True:
221
+ return int(_safe_number(fields.get("num_mtp_modules"), 0) * _safe_number(fields.get("mtp_transformer_layers"), 0))
222
+ return 0
223
+
224
+
225
+ def has_draft_kv_cache(model: dict[str, Any]) -> bool:
226
+ fields = model.get("fields") or {}
227
+ if _is_deepseek_v4(model):
228
+ layers = int(_safe_number(fields.get("num_hidden_layers"), 0))
229
+ ratios = fields.get("compress_ratios")
230
+ return isinstance(ratios, list) and len(ratios) > layers
231
+ return _draft_layer_count(model) > 0
232
+
233
+
234
+ def _fixed_indexer_precision_id(model: dict[str, Any]) -> str | None:
235
+ value = (model.get("fields") or {}).get("indexer_fixed_precision_id")
236
+ return value if isinstance(value, str) else None
237
+
238
+
239
+ def default_precision_id(model: dict[str, Any], options: dict[str, dict[str, Any]]) -> str:
240
+ if _is_deepseek_v4(model) and "fp8_int8" in options:
241
+ return "fp8_int8"
242
+ if "bf16_fp16" in options:
243
+ return "bf16_fp16"
244
+ return next(iter(options))
245
+
246
+
247
+ def default_indexer_precision_id(model: dict[str, Any], options: dict[str, dict[str, Any]], fallback_precision: str | None) -> str:
248
+ fixed = _fixed_indexer_precision_id(model)
249
+ if fixed and fixed in options:
250
+ return fixed
251
+ if _is_deepseek_v4(model) and "fp4_int4" in options:
252
+ return "fp4_int4"
253
+ if fallback_precision and fallback_precision in options:
254
+ return fallback_precision
255
+ if "bf16_fp16" in options:
256
+ return "bf16_fp16"
257
+ return "fp4_int4" if "fp4_int4" in options else next(iter(options))
258
+
259
+
260
+ def _indexer_layer_plan(model: dict[str, Any], layers: int, draft_layers: int) -> tuple[int, int, int, int]:
261
+ main = int(_optional_field(model, "indexer_full_layers", layers))
262
+ shared = int(_optional_field(model, "indexer_shared_layers", max(0, layers - main)))
263
+ draft = int(_optional_field(model, "draft_indexer_layers", draft_layers)) if draft_layers > 0 else 0
264
+ return main, shared, draft, main + draft
265
+
266
+
267
+ def _calculate_byte_groups(model: dict[str, Any], tokens: int, include_draft_kv_cache: bool, include_linear_attention_state: bool) -> tuple[float, list[dict[str, Any]]]:
268
+ formula = model.get("formula")
269
+ fields = model.get("fields") or {}
270
+ draft_layers = _draft_layer_count(model) if include_draft_kv_cache else 0
271
+
272
+ if formula == "standard_gqa":
273
+ layers = int(_field(model, "num_hidden_layers")) + draft_layers
274
+ elements_per_token = layers * 2 * _field(model, "num_key_value_heads") * _field(model, "head_dim")
275
+ return elements_per_token, [{"role": "kv", "label": "KV cache", "elements": elements_per_token * tokens}]
276
+
277
+ if formula == "mla":
278
+ layers = int(_field(model, "num_hidden_layers")) + draft_layers
279
+ elements_per_token = layers * (_field(model, "kv_lora_rank") + _field(model, "qk_rope_head_dim"))
280
+ return elements_per_token, [{"role": "kv", "label": "KV cache", "elements": elements_per_token * tokens}]
281
+
282
+ if formula == "dsa_mla":
283
+ layers = int(_field(model, "num_hidden_layers"))
284
+ active_layers = layers + draft_layers
285
+ _, _, _, active_indexer_layers = _indexer_layer_plan(model, layers, draft_layers)
286
+ kv_elements_per_token = active_layers * (_field(model, "kv_lora_rank") + _field(model, "qk_rope_head_dim"))
287
+ indexer_elements_per_token = active_indexer_layers * _field(model, "index_head_dim")
288
+ return kv_elements_per_token + indexer_elements_per_token, [
289
+ {"role": "kv", "label": "KV cache", "elements": kv_elements_per_token * tokens},
290
+ {"role": "indexer", "label": "Indexer cache", "elements": indexer_elements_per_token * tokens},
291
+ ]
292
+
293
+ if formula == "qwen_linear_full_hybrid":
294
+ full_layers = _field(model, "full_attention_layers")
295
+ kv_heads = _field(model, "num_key_value_heads")
296
+ head_dim = _field(model, "head_dim")
297
+ elements_per_token = full_layers * 2 * kv_heads * head_dim
298
+ groups = [{"role": "kv", "label": "Full-attention KV cache", "elements": elements_per_token * tokens}]
299
+ if include_linear_attention_state:
300
+ linear_layers = _field(model, "linear_attention_layers")
301
+ conv_kernel = _field(model, "linear_conv_kernel_dim")
302
+ key_heads = _field(model, "linear_num_key_heads")
303
+ key_dim = _field(model, "linear_key_head_dim")
304
+ value_heads = _field(model, "linear_num_value_heads")
305
+ value_dim = _field(model, "linear_value_head_dim")
306
+ conv_elements = linear_layers * conv_kernel * (2 * key_heads * key_dim + value_heads * value_dim)
307
+ recurrent_elements = linear_layers * value_heads * key_dim * value_dim
308
+ groups.append({
309
+ "role": "linear_state",
310
+ "label": "Linear-attention state",
311
+ "bytes_per_sequence": conv_elements * QWEN_LINEAR_CONV_BYTES_PER_ELEMENT + recurrent_elements * QWEN_LINEAR_RECURRENT_BYTES_PER_ELEMENT,
312
+ })
313
+ return elements_per_token, groups
314
+
315
+ if formula == "mixed_full_sliding_gqa":
316
+ full_layers = _field(model, "full_attention_layers")
317
+ sliding_layers = _field(model, "sliding_attention_layers")
318
+ kv_heads = _field(model, "num_key_value_heads")
319
+ head_dim = _field(model, "head_dim")
320
+ full_kv_heads = _optional_field(model, "num_global_key_value_heads", kv_heads)
321
+ full_head_dim = _optional_field(model, "global_head_dim", head_dim)
322
+ full_v_dim = _optional_field(model, "global_v_head_dim", _optional_field(model, "v_head_dim", full_head_dim))
323
+ sliding_kv_heads = _optional_field(model, "swa_num_key_value_heads", _optional_field(model, "sliding_num_key_value_heads", kv_heads))
324
+ sliding_head_dim = _optional_field(model, "swa_head_dim", _optional_field(model, "sliding_head_dim", head_dim))
325
+ sliding_v_dim = _optional_field(model, "swa_v_head_dim", _optional_field(model, "sliding_v_head_dim", _optional_field(model, "v_head_dim", sliding_head_dim)))
326
+ retained_sliding_tokens = min(tokens, int(_field(model, "sliding_window")))
327
+ full_elements = tokens * full_layers * full_kv_heads * (full_head_dim + full_v_dim)
328
+ sliding_elements = retained_sliding_tokens * sliding_layers * sliding_kv_heads * (sliding_head_dim + sliding_v_dim)
329
+ return (full_elements + sliding_elements) / tokens, [
330
+ {"role": "kv", "label": "Full-attention KV cache", "elements": full_elements},
331
+ {"role": "kv", "label": "Sliding-window KV cache", "elements": sliding_elements},
332
+ ]
333
+
334
+ if formula == "minimax_msa":
335
+ layers = _field(model, "num_hidden_layers")
336
+ sparse_layers = _field(model, "sparse_attention_layers")
337
+ kv_elements_per_token = layers * 2 * _field(model, "num_key_value_heads") * _field(model, "head_dim")
338
+ indexer_elements_per_token = sparse_layers * _field(model, "index_head_dim")
339
+ return kv_elements_per_token + indexer_elements_per_token, [
340
+ {"role": "kv", "label": "KV cache", "elements": kv_elements_per_token * tokens},
341
+ {"role": "indexer", "label": "Indexer cache", "elements": indexer_elements_per_token * tokens},
342
+ ]
343
+
344
+ if formula == "deepseek_v4_hybrid":
345
+ head_dim = _field(model, "head_dim")
346
+ index_dim = _field(model, "index_head_dim")
347
+ sliding_window = _field(model, "sliding_window")
348
+ layers = int(_field(model, "num_hidden_layers"))
349
+ ratios = [float(ratio) for ratio in fields.get("compress_ratios", [])]
350
+ active_ratios = ratios[:layers] + (ratios[layers:] if include_draft_kv_cache else [])
351
+ if not active_ratios:
352
+ raise ValueError(f"Model {model.get('id', '')} is missing compress_ratios")
353
+ window_elements = 0.0
354
+ compressed_elements = 0.0
355
+ indexer_elements = 0.0
356
+ for ratio in active_ratios:
357
+ window_elements += sliding_window * head_dim
358
+ if ratio > 0:
359
+ compressed_elements += math.floor(tokens / ratio) * head_dim
360
+ if ratio == 4:
361
+ indexer_elements += math.floor(tokens / 4) * index_dim
362
+ attention_elements = window_elements + compressed_elements
363
+ return (attention_elements + indexer_elements) / tokens, [
364
+ {"role": "kv", "label": "KV cache", "elements": attention_elements},
365
+ {"role": "indexer", "label": "Indexer cache", "elements": indexer_elements},
366
+ ]
367
+
368
+ raise ValueError(f"Unsupported formula: {formula}")
369
+
370
+
371
+ def _bytes_for_group(group: dict[str, Any], kv_precision_bytes: float, indexer_precision_bytes: float | None) -> float:
372
+ if "bytes_per_sequence" in group:
373
+ return float(group["bytes_per_sequence"])
374
+ role = group["role"]
375
+ if role == "indexer" and indexer_precision_bytes is not None:
376
+ bytes_per_element = indexer_precision_bytes
377
+ else:
378
+ bytes_per_element = kv_precision_bytes
379
+ return float(group["elements"]) * bytes_per_element
380
+
381
+
382
+ def calculate_cache_size(
383
+ model: dict[str, Any],
384
+ *,
385
+ tokens: int,
386
+ precision: str | None = None,
387
+ indexer_precision: str | None = None,
388
+ block_size: int | None = None,
389
+ include_draft_kv_cache: bool = False,
390
+ include_linear_attention_state: bool = False,
391
+ models_data: dict[str, Any] | None = None,
392
+ ) -> CacheSizeResult:
393
+ data = models_data or load_models_data()
394
+ precision_by_id = precision_options(data)
395
+ indexer_precision_by_id = indexer_precision_options(data)
396
+ precision_id = precision or default_precision_id(model, precision_by_id)
397
+ if precision_id not in precision_by_id:
398
+ raise ValueError(f"Unknown KV precision: {precision_id}")
399
+ precision_profile = precision_by_id[precision_id]
400
+ kv_precision_bytes = float(precision_profile["bytes_per_element"])
401
+
402
+ indexer_precision_id: str | None = None
403
+ indexer_precision_label: str | None = None
404
+ indexer_precision_bytes: float | None = None
405
+ if _has_indexer_cache(model):
406
+ fixed_indexer_precision = _fixed_indexer_precision_id(model)
407
+ indexer_precision_id = fixed_indexer_precision or indexer_precision or default_indexer_precision_id(model, indexer_precision_by_id, precision_id)
408
+ if indexer_precision_id not in indexer_precision_by_id:
409
+ raise ValueError(f"Unknown indexer precision: {indexer_precision_id}")
410
+ indexer_profile = indexer_precision_by_id[indexer_precision_id]
411
+ indexer_precision_label = str(indexer_profile["label"])
412
+ indexer_precision_bytes = float(indexer_profile["bytes_per_element"])
413
+
414
+ active_draft = include_draft_kv_cache and has_draft_kv_cache(model)
415
+ tokens = _positive_int(tokens, int(model.get("default_tokens") or 4096))
416
+ _, groups = _calculate_byte_groups(model, tokens, active_draft, include_linear_attention_state)
417
+ kv_bytes = 0.0
418
+ indexer_bytes = 0.0
419
+ total_bytes = 0.0
420
+ for group in groups:
421
+ group_bytes = _bytes_for_group(group, kv_precision_bytes, indexer_precision_bytes)
422
+ total_bytes += group_bytes
423
+ if group["role"] == "indexer":
424
+ indexer_bytes += group_bytes
425
+ else:
426
+ kv_bytes += group_bytes
427
+
428
+ bytes_per_token = total_bytes / tokens
429
+ return CacheSizeResult(
430
+ model_id=str(model["id"]),
431
+ model_label=str(model.get("label") or model["id"]),
432
+ precision=precision_id,
433
+ precision_label=str(precision_profile["label"]),
434
+ indexer_precision=indexer_precision_id,
435
+ indexer_precision_label=indexer_precision_label,
436
+ bytes_per_token=bytes_per_token,
437
+ bytes_per_block=bytes_per_token * block_size if block_size else None,
438
+ kv_bytes=kv_bytes,
439
+ indexer_bytes=indexer_bytes,
440
+ total_bytes=total_bytes,
441
+ total_gib=total_bytes / BYTES_PER_GIB,
442
+ )
kvcache_sim/cli.py ADDED
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ import sys
6
+
7
+ from .calculator import load_models_data, models_by_id
8
+ from .formatting import render_json, render_table
9
+ from .progress import ProgressBar
10
+ from .simulator import DEFAULT_BUDGETS_GIB, DEFAULT_POLICIES, run_sweep
11
+ from .trace import parse_trace_file
12
+
13
+
14
+ def _parse_csv_numbers(value: str) -> list[float]:
15
+ if not value:
16
+ return []
17
+ return [float(part.strip()) for part in value.split(",") if part.strip()]
18
+
19
+
20
+ def _parse_csv_strings(value: str) -> list[str]:
21
+ if not value:
22
+ return []
23
+ return [part.strip().lower() for part in value.split(",") if part.strip()]
24
+
25
+
26
+ def build_parser() -> argparse.ArgumentParser:
27
+ parser = argparse.ArgumentParser(prog="kvcache-simulator", description="Analyze KV cache hit rate for JSONL traces.")
28
+ subparsers = parser.add_subparsers(dest="command")
29
+
30
+ def add_run_arguments(command: argparse.ArgumentParser) -> None:
31
+ command.add_argument("--trace", required=True, help="Trace path (.jsonl or .jsonl.gz), or - for stdin")
32
+ command.add_argument("--model", required=True, help="Model id from the bundled KV Cache Size Calculator model catalog")
33
+ command.add_argument("--kv-precision", dest="kv_precision", default=None, help="KV precision id, e.g. bf16_fp16, fp8_int8, fp4_int4")
34
+ command.add_argument("--precision", dest="kv_precision", help=argparse.SUPPRESS)
35
+ command.add_argument("--indexer-precision", default=None, help="Indexer precision id for models with an indexer cache")
36
+ command.add_argument("--include-draft-kv-cache", action="store_true", help="Include draft/MTP KV cache where the model config supports it")
37
+ command.add_argument("--block-size", type=int, default=None, help="Fallback block size when trace records omit block_size; record block_size overrides this value")
38
+ command.add_argument("--estimate-tokens", type=int, default=None, help="Override calculator token count used for bytes/token")
39
+ command.add_argument("--budgets-gib", default=",".join(str(v) for v in DEFAULT_BUDGETS_GIB), help="Comma-separated GiB budgets")
40
+ command.add_argument("--policies", default=",".join(DEFAULT_POLICIES), help="Comma-separated policies: fifo,lru,optimal")
41
+ command.add_argument("--backend", choices=["cpp", "python"], default="cpp", help="Simulation backend (default: cpp)")
42
+ command.add_argument("--jobs", type=int, default=1, help="Worker processes for the Python backend; ignored by the C++ backend")
43
+ command.add_argument("--no-progress", action="store_true", help="Disable terminal progress output")
44
+ command.add_argument("--format", choices=["table", "json"], default="table", help="Output format (default: table)")
45
+ command.add_argument("--output", "-o", default="-", help="Output path, or - for stdout")
46
+ command.add_argument("--models-yaml", default=None, help="Override models.yaml path")
47
+ command.add_argument("--max-records", type=int, default=0, help="Stop after this many valid requests (debug/testing)")
48
+ command.add_argument("--max-events", type=int, default=0, help="Stop after this many trace blocks (debug/testing)")
49
+
50
+ run = subparsers.add_parser("run", help="Run hit-rate analysis over configured KV cache budgets")
51
+ add_run_arguments(run)
52
+
53
+ sweep = subparsers.add_parser("sweep", help="Alias for run; scans a set of KV cache memory budgets")
54
+ add_run_arguments(sweep)
55
+
56
+ list_models = subparsers.add_parser("list-models", help="List supported model ids")
57
+ list_models.add_argument("--models-yaml", default=None, help="Override models.yaml path")
58
+
59
+ return parser
60
+
61
+
62
+ def _write_output(text: str, output: str) -> None:
63
+ if output == "-":
64
+ print(text)
65
+ return
66
+ Path(output).write_text(text + "\n", encoding="utf-8")
67
+
68
+
69
+ def run_sweep_command(args: argparse.Namespace) -> int:
70
+ progress = ProgressBar(enabled=(not args.no_progress and sys.stderr.isatty()))
71
+ try:
72
+ data = load_models_data(args.models_yaml)
73
+ progress.update(0, 4, "reading trace")
74
+ trace = parse_trace_file(args.trace, block_size=args.block_size, max_records=args.max_records, max_events=args.max_events)
75
+ progress.update(1, 4, "trace loaded")
76
+ result = run_sweep(
77
+ trace,
78
+ model_id=args.model,
79
+ precision=args.kv_precision,
80
+ indexer_precision=args.indexer_precision,
81
+ budgets_gib=_parse_csv_numbers(args.budgets_gib),
82
+ policies=_parse_csv_strings(args.policies),
83
+ jobs=args.jobs,
84
+ backend=args.backend,
85
+ progress=progress.update,
86
+ estimate_tokens=args.estimate_tokens,
87
+ include_draft_kv_cache=args.include_draft_kv_cache,
88
+ models_data=data,
89
+ )
90
+ rendered = render_json(result) if args.format == "json" else render_table(result)
91
+ _write_output(rendered, args.output)
92
+ progress.finish()
93
+ return 0
94
+ except Exception:
95
+ progress.close()
96
+ raise
97
+
98
+
99
+ def run_list_models(args: argparse.Namespace) -> int:
100
+ data = load_models_data(args.models_yaml)
101
+ models = models_by_id(data)
102
+ for model in sorted(models.values(), key=lambda item: (item.get("family", ""), item.get("label", ""))):
103
+ print(f"{model['id']}\t{model.get('label', model['id'])}\t{model.get('family', '')}\t{model.get('formula', '')}")
104
+ return 0
105
+
106
+
107
+ def main(argv: list[str] | None = None) -> int:
108
+ parser = build_parser()
109
+ args = parser.parse_args(argv)
110
+ if args.command in {"run", "sweep"}:
111
+ return run_sweep_command(args)
112
+ if args.command == "list-models":
113
+ return run_list_models(args)
114
+ parser.print_help(sys.stderr)
115
+ return 2