haoline 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. haoline/.streamlit/config.toml +10 -0
  2. haoline/__init__.py +248 -0
  3. haoline/analyzer.py +935 -0
  4. haoline/cli.py +2712 -0
  5. haoline/compare.py +811 -0
  6. haoline/compare_visualizations.py +1564 -0
  7. haoline/edge_analysis.py +525 -0
  8. haoline/eval/__init__.py +131 -0
  9. haoline/eval/adapters.py +844 -0
  10. haoline/eval/cli.py +390 -0
  11. haoline/eval/comparison.py +542 -0
  12. haoline/eval/deployment.py +633 -0
  13. haoline/eval/schemas.py +833 -0
  14. haoline/examples/__init__.py +15 -0
  15. haoline/examples/basic_inspection.py +74 -0
  16. haoline/examples/compare_models.py +117 -0
  17. haoline/examples/hardware_estimation.py +78 -0
  18. haoline/format_adapters.py +1001 -0
  19. haoline/formats/__init__.py +123 -0
  20. haoline/formats/coreml.py +250 -0
  21. haoline/formats/gguf.py +483 -0
  22. haoline/formats/openvino.py +255 -0
  23. haoline/formats/safetensors.py +273 -0
  24. haoline/formats/tflite.py +369 -0
  25. haoline/hardware.py +2307 -0
  26. haoline/hierarchical_graph.py +462 -0
  27. haoline/html_export.py +1573 -0
  28. haoline/layer_summary.py +769 -0
  29. haoline/llm_summarizer.py +465 -0
  30. haoline/op_icons.py +618 -0
  31. haoline/operational_profiling.py +1492 -0
  32. haoline/patterns.py +1116 -0
  33. haoline/pdf_generator.py +265 -0
  34. haoline/privacy.py +250 -0
  35. haoline/pydantic_models.py +241 -0
  36. haoline/report.py +1923 -0
  37. haoline/report_sections.py +539 -0
  38. haoline/risks.py +521 -0
  39. haoline/schema.py +523 -0
  40. haoline/streamlit_app.py +2024 -0
  41. haoline/tests/__init__.py +4 -0
  42. haoline/tests/conftest.py +123 -0
  43. haoline/tests/test_analyzer.py +868 -0
  44. haoline/tests/test_compare_visualizations.py +293 -0
  45. haoline/tests/test_edge_analysis.py +243 -0
  46. haoline/tests/test_eval.py +604 -0
  47. haoline/tests/test_format_adapters.py +460 -0
  48. haoline/tests/test_hardware.py +237 -0
  49. haoline/tests/test_hardware_recommender.py +90 -0
  50. haoline/tests/test_hierarchical_graph.py +326 -0
  51. haoline/tests/test_html_export.py +180 -0
  52. haoline/tests/test_layer_summary.py +428 -0
  53. haoline/tests/test_llm_patterns.py +540 -0
  54. haoline/tests/test_llm_summarizer.py +339 -0
  55. haoline/tests/test_patterns.py +774 -0
  56. haoline/tests/test_pytorch.py +327 -0
  57. haoline/tests/test_report.py +383 -0
  58. haoline/tests/test_risks.py +398 -0
  59. haoline/tests/test_schema.py +417 -0
  60. haoline/tests/test_tensorflow.py +380 -0
  61. haoline/tests/test_visualizations.py +316 -0
  62. haoline/universal_ir.py +856 -0
  63. haoline/visualizations.py +1086 -0
  64. haoline/visualize_yolo.py +44 -0
  65. haoline/web.py +110 -0
  66. haoline-0.3.0.dist-info/METADATA +471 -0
  67. haoline-0.3.0.dist-info/RECORD +70 -0
  68. haoline-0.3.0.dist-info/WHEEL +4 -0
  69. haoline-0.3.0.dist-info/entry_points.txt +5 -0
  70. haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
haoline/compare.py ADDED
@@ -0,0 +1,811 @@
1
+ #!/usr/bin/env python
2
+ # Copyright (c) 2025 HaoLine Contributors
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """
6
+ HaoLine Compare CLI - Multi-model comparison.
7
+
8
+ This CLI takes multiple model variants plus corresponding eval/perf
9
+ metrics JSON files and produces a comparison JSON + Markdown
10
+ summary focused on quantization impact / multi-variant trade-offs.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import logging
18
+ from collections.abc import Sequence
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ from . import ModelInspector
24
+ from .compare_visualizations import (
25
+ build_enhanced_markdown,
26
+ compute_tradeoff_points,
27
+ generate_compare_html,
28
+ generate_compare_pdf,
29
+ generate_memory_savings_chart,
30
+ generate_tradeoff_chart,
31
+ )
32
+ from .compare_visualizations import (
33
+ is_available as viz_available,
34
+ )
35
+
36
+ LOGGER = logging.getLogger("haoline.compare")
37
+
38
+
39
+ @dataclass
40
+ class VariantInputs:
41
+ """Inputs required to build a single variant entry in the compare report."""
42
+
43
+ model_path: Path
44
+ eval_metrics_path: Path
45
+ precision: str | None = None
46
+
47
+
48
+ @dataclass
49
+ class VariantReport:
50
+ """Bundle of inspection + eval metrics for a variant."""
51
+
52
+ model_path: Path
53
+ precision: str
54
+ quantization_scheme: str
55
+ size_bytes: int
56
+ report: Any # InspectionReport
57
+ metrics: dict[str, Any]
58
+
59
+
60
+ @dataclass
61
+ class ArchCompatResult:
62
+ """Result of architecture compatibility check between model variants."""
63
+
64
+ is_compatible: bool
65
+ warnings: list[str]
66
+ details: dict[str, Any]
67
+
68
+
69
+ def _parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace:
70
+ parser = argparse.ArgumentParser(
71
+ prog="model_inspect_compare",
72
+ description=(
73
+ "Compare multiple model variants (e.g., fp32/fp16/int8) and "
74
+ "summarize quantization / architecture trade-offs."
75
+ ),
76
+ formatter_class=argparse.RawDescriptionHelpFormatter,
77
+ epilog="""\
78
+ Examples:
79
+ # Basic quantization impact comparison
80
+ python -m onnxruntime.tools.model_inspect_compare \\
81
+ --models resnet_fp32.onnx resnet_fp16.onnx resnet_int8.onnx \\
82
+ --eval-metrics eval_fp32.json eval_fp16.json eval_int8.json \\
83
+ --baseline-precision fp32 \\
84
+ --out-json quant_impact.json \\
85
+ --out-md quant_impact.md
86
+ """,
87
+ )
88
+
89
+ parser.add_argument(
90
+ "--models",
91
+ type=Path,
92
+ nargs="+",
93
+ required=True,
94
+ help="List of ONNX model paths to compare (e.g., fp32/fp16/int8 variants).",
95
+ )
96
+ parser.add_argument(
97
+ "--eval-metrics",
98
+ type=Path,
99
+ nargs="+",
100
+ required=True,
101
+ help=(
102
+ "List of eval/perf metrics JSON files, one per model, "
103
+ "produced by a batch-eval or perf script."
104
+ ),
105
+ )
106
+ parser.add_argument(
107
+ "--precisions",
108
+ type=str,
109
+ nargs="+",
110
+ default=None,
111
+ help=(
112
+ "Optional list of precision labels for each model "
113
+ "(e.g., fp32 fp16 int8). If omitted, precisions are inferred "
114
+ "from filenames where possible."
115
+ ),
116
+ )
117
+ parser.add_argument(
118
+ "--baseline-precision",
119
+ type=str,
120
+ default=None,
121
+ help=(
122
+ "Precision label to use as baseline for delta computation "
123
+ "(e.g., fp32). If omitted, the first model is the baseline."
124
+ ),
125
+ )
126
+ parser.add_argument(
127
+ "--out-json",
128
+ type=Path,
129
+ default=None,
130
+ help="Output path for comparison JSON report.",
131
+ )
132
+ parser.add_argument(
133
+ "--out-md",
134
+ type=Path,
135
+ default=None,
136
+ help="Output path for human-readable Markdown summary.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--quiet",
141
+ action="store_true",
142
+ help="Reduce logging noise; only errors are printed.",
143
+ )
144
+ parser.add_argument(
145
+ "--with-charts",
146
+ action="store_true",
147
+ help="Generate accuracy vs speedup and memory savings charts (requires matplotlib).",
148
+ )
149
+ parser.add_argument(
150
+ "--assets-dir",
151
+ type=Path,
152
+ default=None,
153
+ help="Directory for chart assets. Defaults to same directory as --out-md.",
154
+ )
155
+ parser.add_argument(
156
+ "--out-html",
157
+ type=Path,
158
+ default=None,
159
+ help="Output path for HTML comparison report with engine summary panel.",
160
+ )
161
+ parser.add_argument(
162
+ "--out-pdf",
163
+ type=Path,
164
+ default=None,
165
+ help="Output path for PDF comparison report (requires Playwright).",
166
+ )
167
+
168
+ args = parser.parse_args(argv)
169
+
170
+ if len(args.models) != len(args.eval_metrics):
171
+ parser.error(
172
+ f"--models ({len(args.models)}) and --eval-metrics "
173
+ f"({len(args.eval_metrics)}) must have the same length."
174
+ )
175
+
176
+ if args.precisions is not None and len(args.precisions) != len(args.models):
177
+ parser.error(
178
+ f"--precisions length ({len(args.precisions)}) must match "
179
+ f"--models length ({len(args.models)})."
180
+ )
181
+
182
+ return args
183
+
184
+
185
+ def _infer_precision_from_name(path: Path) -> str | None:
186
+ """Heuristic to infer precision from filename."""
187
+ name = path.stem.lower()
188
+ if "int8" in name or "qdq" in name or "int_8" in name:
189
+ return "int8"
190
+ if "int4" in name or "int_4" in name:
191
+ return "int4"
192
+ if "fp16" in name or "half" in name:
193
+ return "fp16"
194
+ if "bf16" in name:
195
+ return "bf16"
196
+ if "fp32" in name or "float32" in name:
197
+ return "fp32"
198
+ return None
199
+
200
+
201
+ def _load_eval_metrics(path: Path) -> dict[str, Any]:
202
+ """
203
+ Load eval/perf metrics JSON.
204
+
205
+ We accept two common layouts:
206
+ 1) Root-level metrics dict:
207
+ {"f1_macro": 0.93, "latency_ms_p50": 14.5, ...}
208
+ 2) Wrapped metrics:
209
+ {"metrics": {...}}
210
+ """
211
+ with path.open("r", encoding="utf-8") as f:
212
+ data = json.load(f)
213
+
214
+ if isinstance(data, dict) and "metrics" in data and isinstance(data["metrics"], dict):
215
+ return data["metrics"]
216
+
217
+ if isinstance(data, dict):
218
+ return data
219
+
220
+ raise ValueError(f"Unsupported metrics JSON structure in {path}")
221
+
222
+
223
+ def _build_variant_inputs(args: argparse.Namespace) -> list[VariantInputs]:
224
+ variants: list[VariantInputs] = []
225
+ for idx, (model_path, metrics_path) in enumerate(
226
+ zip(args.models, args.eval_metrics, strict=True)
227
+ ):
228
+ if not model_path.is_file():
229
+ raise FileNotFoundError(f"Model not found: {model_path}")
230
+ if not metrics_path.is_file():
231
+ raise FileNotFoundError(f"Eval metrics JSON not found: {metrics_path}")
232
+
233
+ precision: str | None = None
234
+ if args.precisions is not None:
235
+ precision = args.precisions[idx]
236
+ else:
237
+ precision = _infer_precision_from_name(model_path)
238
+
239
+ variants.append(
240
+ VariantInputs(
241
+ model_path=model_path,
242
+ eval_metrics_path=metrics_path,
243
+ precision=precision,
244
+ )
245
+ )
246
+ return variants
247
+
248
+
249
+ def _run_inspection(model_path: Path, logger: logging.Logger) -> Any:
250
+ """
251
+ Run ModelInspector for a single model and return an InspectionReport.
252
+
253
+ We deliberately reuse the same inspector class used by model_inspect.py.
254
+ """
255
+ inspector = ModelInspector(logger=logger)
256
+ report = inspector.inspect(str(model_path))
257
+ return report
258
+
259
+
260
+ def _determine_baseline_index(
261
+ variants: Sequence[VariantReport], baseline_precision: str | None
262
+ ) -> int:
263
+ if baseline_precision is None:
264
+ return 0
265
+
266
+ for idx, v in enumerate(variants):
267
+ if v.precision.lower() == baseline_precision.lower():
268
+ return idx
269
+
270
+ # Fallback to first if requested precision not found
271
+ LOGGER.warning(
272
+ "Requested baseline precision '%s' not found among variants; "
273
+ "using first variant as baseline.",
274
+ baseline_precision,
275
+ )
276
+ return 0
277
+
278
+
279
+ def _check_architecture_compatibility(
280
+ variants: Sequence[VariantReport],
281
+ ) -> ArchCompatResult:
282
+ """
283
+ Check if all model variants share the same architecture.
284
+
285
+ Task 6.3.3: Verify architecture compatibility
286
+ - Same architecture_type (cnn, transformer, mlp, etc.)
287
+ - Compatible input/output shapes
288
+ - Similar detected block patterns
289
+ """
290
+ warnings: list[str] = []
291
+ details: dict[str, Any] = {}
292
+
293
+ if len(variants) < 2:
294
+ return ArchCompatResult(is_compatible=True, warnings=[], details={})
295
+
296
+ baseline = variants[0]
297
+ baseline_report = baseline.report
298
+
299
+ # Extract baseline characteristics
300
+ baseline_arch_type = getattr(baseline_report, "architecture_type", "unknown")
301
+ baseline_graph = getattr(baseline_report, "graph_summary", None)
302
+ baseline_blocks = getattr(baseline_report, "detected_blocks", [])
303
+
304
+ # Get baseline I/O shapes
305
+ baseline_inputs: dict[str, Any] = {}
306
+ baseline_outputs: dict[str, Any] = {}
307
+ if baseline_graph:
308
+ baseline_inputs = getattr(baseline_graph, "input_shapes", {}) or {}
309
+ baseline_outputs = getattr(baseline_graph, "output_shapes", {}) or {}
310
+
311
+ details["baseline_architecture"] = baseline_arch_type
312
+ details["baseline_inputs"] = baseline_inputs
313
+ details["baseline_outputs"] = baseline_outputs
314
+
315
+ # Count block types in baseline
316
+ baseline_block_counts: dict[str, int] = {}
317
+ for block in baseline_blocks:
318
+ block_type = getattr(block, "block_type", "unknown")
319
+ baseline_block_counts[block_type] = baseline_block_counts.get(block_type, 0) + 1
320
+
321
+ details["baseline_block_counts"] = baseline_block_counts
322
+
323
+ # Check each variant against baseline
324
+ is_compatible = True
325
+ for idx, variant in enumerate(variants[1:], start=1):
326
+ v_report = variant.report
327
+ v_arch_type = getattr(v_report, "architecture_type", "unknown")
328
+ v_graph = getattr(v_report, "graph_summary", None)
329
+ v_blocks = getattr(v_report, "detected_blocks", [])
330
+
331
+ # Check architecture type
332
+ if v_arch_type != baseline_arch_type:
333
+ warnings.append(
334
+ f"Variant {idx} ({variant.precision}): architecture_type mismatch "
335
+ f"({v_arch_type} vs baseline {baseline_arch_type})"
336
+ )
337
+ is_compatible = False
338
+
339
+ # Check input shapes (flexible: same keys, same ranks)
340
+ if v_graph:
341
+ v_inputs = getattr(v_graph, "input_shapes", {}) or {}
342
+ v_outputs = getattr(v_graph, "output_shapes", {}) or {}
343
+
344
+ # Compare input count and ranks
345
+ if len(v_inputs) != len(baseline_inputs):
346
+ warnings.append(
347
+ f"Variant {idx} ({variant.precision}): different number of inputs "
348
+ f"({len(v_inputs)} vs baseline {len(baseline_inputs)})"
349
+ )
350
+
351
+ # Compare output count
352
+ if len(v_outputs) != len(baseline_outputs):
353
+ warnings.append(
354
+ f"Variant {idx} ({variant.precision}): different number of outputs "
355
+ f"({len(v_outputs)} vs baseline {len(baseline_outputs)})"
356
+ )
357
+
358
+ # Check block pattern counts (some variation is OK for quantization)
359
+ v_block_counts: dict[str, int] = {}
360
+ for block in v_blocks:
361
+ block_type = getattr(block, "block_type", "unknown")
362
+ v_block_counts[block_type] = v_block_counts.get(block_type, 0) + 1
363
+
364
+ # Major block type differences are a warning
365
+ for block_type, count in baseline_block_counts.items():
366
+ v_count = v_block_counts.get(block_type, 0)
367
+ # Allow some variation but flag major differences
368
+ if abs(v_count - count) > max(count * 0.2, 2):
369
+ warnings.append(
370
+ f"Variant {idx} ({variant.precision}): block count mismatch for "
371
+ f"{block_type} ({v_count} vs baseline {count})"
372
+ )
373
+
374
+ return ArchCompatResult(
375
+ is_compatible=is_compatible,
376
+ warnings=warnings,
377
+ details=details,
378
+ )
379
+
380
+
381
+ def _get_numeric_metric(report: Any, *path: str) -> float | None:
382
+ """Safely extract a numeric metric from a nested report structure."""
383
+ obj = report
384
+ for key in path:
385
+ if obj is None:
386
+ return None
387
+ if hasattr(obj, key):
388
+ obj = getattr(obj, key)
389
+ elif isinstance(obj, dict) and key in obj:
390
+ obj = obj[key]
391
+ else:
392
+ return None
393
+ if isinstance(obj, (int, float)):
394
+ return float(obj)
395
+ return None
396
+
397
+
398
+ def _compute_deltas(baseline: VariantReport, other: VariantReport) -> dict[str, Any]:
399
+ """
400
+ Compute deltas between baseline and another variant.
401
+
402
+ Task 6.3.4: Compute comprehensive deltas including:
403
+ - File size
404
+ - Parameter counts
405
+ - FLOPs
406
+ - Memory estimates
407
+ - Hardware estimates (if available)
408
+ - Eval/perf metrics
409
+ """
410
+ deltas: dict[str, Any] = {}
411
+
412
+ # Model file size
413
+ deltas["size_bytes"] = other.size_bytes - baseline.size_bytes
414
+
415
+ # Parameter counts
416
+ base_params = _get_numeric_metric(baseline.report, "param_counts", "total")
417
+ other_params = _get_numeric_metric(other.report, "param_counts", "total")
418
+ if base_params is not None and other_params is not None:
419
+ deltas["total_params"] = int(other_params - base_params)
420
+
421
+ # FLOPs
422
+ base_flops = _get_numeric_metric(baseline.report, "flop_counts", "total")
423
+ other_flops = _get_numeric_metric(other.report, "flop_counts", "total")
424
+ if base_flops is not None and other_flops is not None:
425
+ deltas["total_flops"] = int(other_flops - base_flops)
426
+
427
+ # Memory estimates (model size)
428
+ base_mem = _get_numeric_metric(baseline.report, "memory_estimates", "model_size_bytes")
429
+ other_mem = _get_numeric_metric(other.report, "memory_estimates", "model_size_bytes")
430
+ if base_mem is not None and other_mem is not None:
431
+ deltas["memory_bytes"] = int(other_mem - base_mem)
432
+
433
+ # Peak activation memory
434
+ base_peak = _get_numeric_metric(baseline.report, "memory_estimates", "peak_activation_bytes")
435
+ other_peak = _get_numeric_metric(other.report, "memory_estimates", "peak_activation_bytes")
436
+ if base_peak is not None and other_peak is not None:
437
+ deltas["peak_activation_bytes"] = int(other_peak - base_peak)
438
+
439
+ # Hardware estimates (if available)
440
+ base_hw = getattr(baseline.report, "hardware_estimates", None)
441
+ other_hw = getattr(other.report, "hardware_estimates", None)
442
+ if base_hw is not None and other_hw is not None:
443
+ # Latency
444
+ base_lat = _get_numeric_metric(base_hw, "estimated_latency_ms")
445
+ other_lat = _get_numeric_metric(other_hw, "estimated_latency_ms")
446
+ if base_lat is not None and other_lat is not None:
447
+ deltas["latency_ms"] = other_lat - base_lat
448
+
449
+ # VRAM
450
+ base_vram = _get_numeric_metric(base_hw, "vram_required_bytes")
451
+ other_vram = _get_numeric_metric(other_hw, "vram_required_bytes")
452
+ if base_vram is not None and other_vram is not None:
453
+ deltas["vram_required_bytes"] = int(other_vram - base_vram)
454
+
455
+ # Compute utilization
456
+ base_util = _get_numeric_metric(base_hw, "compute_utilization")
457
+ other_util = _get_numeric_metric(other_hw, "compute_utilization")
458
+ if base_util is not None and other_util is not None:
459
+ deltas["compute_utilization"] = other_util - base_util
460
+
461
+ # Metric-wise deltas from eval/perf JSON (only for overlapping numeric fields)
462
+ for key, base_val in baseline.metrics.items():
463
+ other_val = other.metrics.get(key)
464
+ if isinstance(base_val, (int, float)) and isinstance(other_val, (int, float)):
465
+ # Don't overwrite structural deltas we already computed
466
+ if key not in deltas:
467
+ deltas[key] = other_val - base_val
468
+
469
+ return deltas
470
+
471
+
472
+ def _build_compare_json(
473
+ variants: Sequence[VariantReport],
474
+ baseline_index: int,
475
+ arch_compat: ArchCompatResult,
476
+ ) -> dict[str, Any]:
477
+ baseline = variants[baseline_index]
478
+
479
+ # Derive a simple model_family_id from baseline metadata
480
+ report = baseline.report
481
+ metadata = getattr(report, "metadata", None)
482
+ if metadata and getattr(metadata, "name", None):
483
+ model_family_id = metadata.name
484
+ else:
485
+ model_family_id = baseline.model_path.stem
486
+
487
+ out: dict[str, Any] = {
488
+ "model_family_id": model_family_id,
489
+ "baseline_precision": baseline.precision,
490
+ "architecture_compatible": arch_compat.is_compatible,
491
+ "compatibility_warnings": arch_compat.warnings,
492
+ "variants": [],
493
+ }
494
+
495
+ for idx, v in enumerate(variants):
496
+ deltas_vs_baseline: dict[str, Any] | None
497
+ if idx == baseline_index:
498
+ deltas_vs_baseline = None
499
+ else:
500
+ deltas_vs_baseline = _compute_deltas(baseline, v)
501
+
502
+ hw_estimates = getattr(v.report, "hardware_estimates", None)
503
+ hw_profile = getattr(v.report, "hardware_profile", None)
504
+
505
+ # Extract key metrics from inspection report for the variant summary
506
+ param_counts = getattr(v.report, "param_counts", None)
507
+ flop_counts = getattr(v.report, "flop_counts", None)
508
+ memory_estimates = getattr(v.report, "memory_estimates", None)
509
+
510
+ out_variant: dict[str, Any] = {
511
+ "precision": v.precision,
512
+ "quantization_scheme": v.quantization_scheme,
513
+ "model_path": str(v.model_path),
514
+ "size_bytes": int(v.size_bytes),
515
+ # Structural metrics from inspection
516
+ "total_params": (param_counts.total if param_counts is not None else None),
517
+ "total_flops": (flop_counts.total if flop_counts is not None else None),
518
+ "memory_bytes": (
519
+ memory_estimates.model_size_bytes if memory_estimates is not None else None
520
+ ),
521
+ # Eval/perf metrics from JSON
522
+ "metrics": v.metrics,
523
+ "hardware_estimates": (
524
+ hw_estimates.to_dict()
525
+ if hw_estimates is not None and hasattr(hw_estimates, "to_dict")
526
+ else None
527
+ ),
528
+ "hardware_profile": (
529
+ hw_profile.to_dict()
530
+ if hw_profile is not None and hasattr(hw_profile, "to_dict")
531
+ else None
532
+ ),
533
+ "deltas_vs_baseline": deltas_vs_baseline,
534
+ }
535
+ out["variants"].append(out_variant)
536
+
537
+ return out
538
+
539
+
540
+ def _format_number(n: float | None, suffix: str = "") -> str:
541
+ """Format a number with K/M/G suffixes for readability."""
542
+ if n is None:
543
+ return "-"
544
+ if abs(n) >= 1e9:
545
+ return f"{n / 1e9:.2f}G{suffix}"
546
+ if abs(n) >= 1e6:
547
+ return f"{n / 1e6:.2f}M{suffix}"
548
+ if abs(n) >= 1e3:
549
+ return f"{n / 1e3:.2f}K{suffix}"
550
+ if isinstance(n, float) and not n.is_integer():
551
+ return f"{n:.2f}{suffix}"
552
+ return f"{int(n)}{suffix}"
553
+
554
+
555
+ def _format_delta(val: float | None, suffix: str = "") -> str:
556
+ """Format a delta with +/- prefix and K/M/G suffixes."""
557
+ if val is None:
558
+ return "-"
559
+ sign = "+" if val >= 0 else ""
560
+ if abs(val) >= 1e9:
561
+ return f"{sign}{val / 1e9:.2f}G{suffix}"
562
+ if abs(val) >= 1e6:
563
+ return f"{sign}{val / 1e6:.2f}M{suffix}"
564
+ if abs(val) >= 1e3:
565
+ return f"{sign}{val / 1e3:.2f}K{suffix}"
566
+ if isinstance(val, float) and not val.is_integer():
567
+ return f"{sign}{val:.2f}{suffix}"
568
+ return f"{sign}{int(val)}{suffix}"
569
+
570
+
571
+ def _build_markdown_summary(
572
+ compare_json: dict[str, Any],
573
+ ) -> str:
574
+ """Generate a Markdown summary for compare mode with rich metrics."""
575
+ lines: list[str] = []
576
+
577
+ model_family_id = compare_json.get("model_family_id", "unknown_model")
578
+ baseline_precision = compare_json.get("baseline_precision", "unknown")
579
+ arch_compatible = compare_json.get("architecture_compatible", True)
580
+ warnings = compare_json.get("compatibility_warnings", [])
581
+
582
+ lines.append(f"# Quantization Impact: {model_family_id}")
583
+ lines.append("")
584
+ lines.append(
585
+ f"Baseline precision: **{baseline_precision}** (deltas are relative to this variant)."
586
+ )
587
+ lines.append("")
588
+
589
+ # Architecture compatibility notice
590
+ if not arch_compatible:
591
+ lines.append("## Compatibility Warnings")
592
+ lines.append("")
593
+ lines.append(
594
+ "> **Warning**: Model variants may not be directly comparable due to "
595
+ "architecture differences."
596
+ )
597
+ lines.append("")
598
+ for warn in warnings:
599
+ lines.append(f"- {warn}")
600
+ lines.append("")
601
+
602
+ # Summary table with comprehensive metrics
603
+ lines.append("## Variant Comparison")
604
+ lines.append("")
605
+ lines.append("| Precision | Size | Params | FLOPs | Δ Size | Δ Params | Δ FLOPs |")
606
+ lines.append("|-----------|------|--------|-------|--------|----------|---------|")
607
+
608
+ for v in compare_json.get("variants", []):
609
+ precision = v.get("precision", "unknown")
610
+ size_bytes = v.get("size_bytes", 0)
611
+ total_params = v.get("total_params")
612
+ total_flops = v.get("total_flops")
613
+ deltas = v.get("deltas_vs_baseline")
614
+
615
+ # Format absolute values
616
+ size_str = _format_number(size_bytes, "B")
617
+ params_str = _format_number(total_params)
618
+ flops_str = _format_number(total_flops)
619
+
620
+ # Format deltas
621
+ if deltas is None:
622
+ delta_size = "-"
623
+ delta_params = "-"
624
+ delta_flops = "-"
625
+ else:
626
+ delta_size = _format_delta(deltas.get("size_bytes"), "B")
627
+ delta_params = _format_delta(deltas.get("total_params"))
628
+ delta_flops = _format_delta(deltas.get("total_flops"))
629
+
630
+ lines.append(
631
+ f"| {precision} | {size_str} | {params_str} | {flops_str} | "
632
+ f"{delta_size} | {delta_params} | {delta_flops} |"
633
+ )
634
+
635
+ lines.append("")
636
+
637
+ # Performance metrics table (if available in eval metrics)
638
+ has_perf_metrics = False
639
+ for v in compare_json.get("variants", []):
640
+ metrics = v.get("metrics", {})
641
+ if any(k in metrics for k in ["latency_ms_p50", "throughput_qps", "f1_macro", "accuracy"]):
642
+ has_perf_metrics = True
643
+ break
644
+
645
+ if has_perf_metrics:
646
+ lines.append("## Performance Metrics")
647
+ lines.append("")
648
+ lines.append(
649
+ "| Precision | Latency (ms) | Throughput | Accuracy | Δ Latency | Δ Accuracy |"
650
+ )
651
+ lines.append(
652
+ "|-----------|--------------|------------|----------|-----------|------------|"
653
+ )
654
+
655
+ for v in compare_json.get("variants", []):
656
+ precision = v.get("precision", "unknown")
657
+ metrics = v.get("metrics", {})
658
+ deltas = v.get("deltas_vs_baseline")
659
+
660
+ latency = metrics.get("latency_ms_p50") or metrics.get("latency_ms")
661
+ throughput = metrics.get("throughput_qps") or metrics.get("throughput")
662
+ accuracy = metrics.get("f1_macro") or metrics.get("accuracy")
663
+
664
+ lat_str = f"{latency:.2f}" if latency is not None else "-"
665
+ tput_str = _format_number(throughput, " qps") if throughput is not None else "-"
666
+ acc_str = f"{accuracy:.4f}" if accuracy is not None else "-"
667
+
668
+ if deltas is None:
669
+ delta_lat = "-"
670
+ delta_acc = "-"
671
+ else:
672
+ d_lat = deltas.get("latency_ms_p50") or deltas.get("latency_ms")
673
+ d_acc = deltas.get("f1_macro") or deltas.get("accuracy")
674
+ delta_lat = f"{d_lat:+.2f}ms" if d_lat is not None else "-"
675
+ delta_acc = f"{d_acc:+.4f}" if d_acc is not None else "-"
676
+
677
+ lines.append(
678
+ f"| {precision} | {lat_str} | {tput_str} | {acc_str} | {delta_lat} | {delta_acc} |"
679
+ )
680
+
681
+ lines.append("")
682
+
683
+ lines.append("> Full details including hardware estimates are in the JSON report.")
684
+ lines.append("")
685
+
686
+ return "\n".join(lines)
687
+
688
+
689
+ def _build_variants(
690
+ variant_inputs: Sequence[VariantInputs],
691
+ logger: logging.Logger,
692
+ ) -> list[VariantReport]:
693
+ variants: list[VariantReport] = []
694
+
695
+ for v in variant_inputs:
696
+ metrics = _load_eval_metrics(v.eval_metrics_path)
697
+ report = _run_inspection(v.model_path, logger=logger)
698
+ size_bytes = v.model_path.stat().st_size
699
+
700
+ precision = v.precision or "unknown"
701
+ quant_scheme = precision if precision != "fp32" else "none"
702
+
703
+ variants.append(
704
+ VariantReport(
705
+ model_path=v.model_path,
706
+ precision=precision,
707
+ quantization_scheme=quant_scheme,
708
+ size_bytes=size_bytes,
709
+ report=report,
710
+ metrics=metrics,
711
+ )
712
+ )
713
+
714
+ return variants
715
+
716
+
717
+ def main(argv: Sequence[str] | None = None) -> int:
718
+ args = _parse_args(argv)
719
+
720
+ logging.basicConfig(
721
+ level=logging.ERROR if args.quiet else logging.INFO,
722
+ format="%(levelname)s: %(message)s",
723
+ )
724
+
725
+ logger = LOGGER
726
+
727
+ try:
728
+ variant_inputs = _build_variant_inputs(args)
729
+ variants = _build_variants(variant_inputs, logger=logger)
730
+
731
+ baseline_index = _determine_baseline_index(variants, args.baseline_precision)
732
+
733
+ # Task 6.3.3: Check architecture compatibility
734
+ arch_compat = _check_architecture_compatibility(variants)
735
+ if not arch_compat.is_compatible:
736
+ logger.warning(
737
+ "Model variants have architecture differences; comparison may not be meaningful."
738
+ )
739
+ for warn in arch_compat.warnings:
740
+ logger.warning(" - %s", warn)
741
+
742
+ compare_json = _build_compare_json(variants, baseline_index, arch_compat)
743
+
744
+ # Write JSON if requested
745
+ if args.out_json:
746
+ args.out_json.parent.mkdir(parents=True, exist_ok=True)
747
+ with args.out_json.open("w", encoding="utf-8") as f:
748
+ json.dump(compare_json, f, indent=2)
749
+ logger.info("Comparison JSON written to %s", args.out_json)
750
+
751
+ # Write Markdown if requested
752
+ if args.out_md:
753
+ assets_dir = args.assets_dir or args.out_md.parent
754
+ if args.with_charts:
755
+ md = build_enhanced_markdown(
756
+ compare_json,
757
+ include_charts=True,
758
+ assets_dir=assets_dir,
759
+ )
760
+ else:
761
+ md = _build_markdown_summary(compare_json)
762
+ args.out_md.parent.mkdir(parents=True, exist_ok=True)
763
+ args.out_md.write_text(md, encoding="utf-8")
764
+ logger.info("Comparison Markdown written to %s", args.out_md)
765
+
766
+ # Generate standalone charts if requested
767
+ if args.with_charts and viz_available():
768
+ points = compute_tradeoff_points(compare_json)
769
+ if points:
770
+ chart_path = assets_dir / "tradeoff_chart.png"
771
+ generate_tradeoff_chart(points, chart_path)
772
+ logger.info("Tradeoff chart written to %s", chart_path)
773
+
774
+ mem_path = assets_dir / "memory_savings.png"
775
+ generate_memory_savings_chart(compare_json, mem_path)
776
+ logger.info("Memory savings chart written to %s", mem_path)
777
+
778
+ # Write HTML if requested
779
+ if args.out_html:
780
+ generate_compare_html(
781
+ compare_json,
782
+ output_path=args.out_html,
783
+ include_charts=True,
784
+ )
785
+ logger.info("Comparison HTML written to %s", args.out_html)
786
+
787
+ # Write PDF if requested (Task 6.10.9)
788
+ if args.out_pdf:
789
+ pdf_path = generate_compare_pdf(
790
+ compare_json,
791
+ output_path=args.out_pdf,
792
+ include_charts=True,
793
+ )
794
+ if pdf_path:
795
+ logger.info("Comparison PDF written to %s", pdf_path)
796
+ else:
797
+ logger.warning("PDF generation failed (Playwright may not be installed)")
798
+
799
+ if not args.out_json and not args.out_md and not args.out_html and not args.out_pdf:
800
+ # Default to printing JSON to stdout if no outputs specified
801
+ print(json.dumps(compare_json, indent=2))
802
+
803
+ return 0
804
+
805
+ except Exception as exc: # pragma: no cover - top-level safety
806
+ logger.error("Compare-mode failed: %s", exc)
807
+ return 1
808
+
809
+
810
+ if __name__ == "__main__": # pragma: no cover
811
+ raise SystemExit(main())