invarlock 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. invarlock/__init__.py +2 -2
  2. invarlock/adapters/__init__.py +10 -14
  3. invarlock/adapters/auto.py +35 -40
  4. invarlock/adapters/capabilities.py +2 -2
  5. invarlock/adapters/hf_causal.py +418 -0
  6. invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
  7. invarlock/adapters/hf_mixin.py +25 -4
  8. invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
  9. invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
  10. invarlock/cli/adapter_auto.py +31 -21
  11. invarlock/cli/app.py +73 -2
  12. invarlock/cli/commands/certify.py +600 -59
  13. invarlock/cli/commands/doctor.py +8 -10
  14. invarlock/cli/commands/plugins.py +13 -9
  15. invarlock/cli/commands/report.py +233 -69
  16. invarlock/cli/commands/run.py +907 -183
  17. invarlock/cli/commands/verify.py +76 -11
  18. invarlock/cli/config.py +1 -1
  19. invarlock/cli/doctor_helpers.py +4 -5
  20. invarlock/cli/output.py +193 -0
  21. invarlock/cli/provenance.py +1 -1
  22. invarlock/core/bootstrap.py +1 -1
  23. invarlock/core/registry.py +9 -11
  24. invarlock/core/runner.py +111 -25
  25. invarlock/edits/quant_rtn.py +65 -37
  26. invarlock/eval/bench.py +3 -3
  27. invarlock/eval/data.py +68 -23
  28. invarlock/eval/metrics.py +59 -1
  29. invarlock/eval/tasks/__init__.py +12 -0
  30. invarlock/eval/tasks/classification.py +48 -0
  31. invarlock/eval/tasks/qa.py +36 -0
  32. invarlock/eval/tasks/text_generation.py +102 -0
  33. invarlock/guards/invariants.py +19 -10
  34. invarlock/guards/rmt.py +2 -2
  35. invarlock/guards/variance.py +2 -2
  36. invarlock/model_profile.py +48 -27
  37. invarlock/observability/health.py +6 -6
  38. invarlock/observability/metrics.py +108 -0
  39. invarlock/reporting/certificate.py +159 -9
  40. invarlock/reporting/certificate_schema.py +1 -1
  41. invarlock/reporting/guards_analysis.py +154 -4
  42. invarlock/reporting/html.py +55 -5
  43. invarlock/reporting/normalizer.py +7 -0
  44. invarlock/reporting/render.py +791 -431
  45. invarlock/reporting/report.py +39 -3
  46. invarlock/reporting/report_types.py +6 -1
  47. invarlock/reporting/telemetry.py +86 -0
  48. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/METADATA +23 -9
  49. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/RECORD +53 -48
  50. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
  51. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
  52. invarlock/adapters/hf_gpt2.py +0 -404
  53. invarlock/adapters/hf_llama.py +0 -487
  54. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
  55. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,102 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import Counter
4
+ from collections.abc import Iterable
5
+ from typing import Any
6
+
7
+
8
+ def _tokenize(text: str) -> list[str]:
9
+ return [tok for tok in str(text).strip().lower().split() if tok]
10
+
11
+
12
+ def _bleu1(pred: str, ref: str) -> float:
13
+ pred_tokens = _tokenize(pred)
14
+ ref_tokens = _tokenize(ref)
15
+ if not pred_tokens or not ref_tokens:
16
+ return 0.0
17
+ pred_counts = Counter(pred_tokens)
18
+ ref_counts = Counter(ref_tokens)
19
+ overlap = sum(min(pred_counts[tok], ref_counts.get(tok, 0)) for tok in pred_counts)
20
+ precision = overlap / float(len(pred_tokens))
21
+ bp = 1.0
22
+ if len(pred_tokens) < len(ref_tokens):
23
+ bp = pow(2.718281828, 1.0 - (len(ref_tokens) / float(len(pred_tokens))))
24
+ return float(precision * bp)
25
+
26
+
27
+ def bleu1_from_records(records: Iterable[dict[str, Any]]) -> float:
28
+ """Compute BLEU-1 from records with predictions and references."""
29
+ scores: list[float] = []
30
+ for record in records:
31
+ if not isinstance(record, dict):
32
+ continue
33
+ pred = record.get("prediction")
34
+ refs = record.get("references")
35
+ if pred is None:
36
+ continue
37
+ if refs is None and "reference" in record:
38
+ refs = [record.get("reference")]
39
+ if refs is None:
40
+ continue
41
+ ref_list = refs if isinstance(refs, list) else [refs]
42
+ best = 0.0
43
+ for ref in ref_list:
44
+ if ref is None:
45
+ continue
46
+ best = max(best, _bleu1(str(pred), str(ref)))
47
+ scores.append(best)
48
+ if not scores:
49
+ return float("nan")
50
+ return float(sum(scores) / float(len(scores)))
51
+
52
+
53
+ def _lcs_len(a: list[str], b: list[str]) -> int:
54
+ if not a or not b:
55
+ return 0
56
+ dp = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
57
+ for i, tok_a in enumerate(a, start=1):
58
+ for j, tok_b in enumerate(b, start=1):
59
+ if tok_a == tok_b:
60
+ dp[i][j] = dp[i - 1][j - 1] + 1
61
+ else:
62
+ dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
63
+ return dp[-1][-1]
64
+
65
+
66
+ def _rouge_l(pred: str, ref: str) -> float:
67
+ pred_tokens = _tokenize(pred)
68
+ ref_tokens = _tokenize(ref)
69
+ if not pred_tokens or not ref_tokens:
70
+ return 0.0
71
+ lcs = _lcs_len(pred_tokens, ref_tokens)
72
+ prec = lcs / float(len(pred_tokens))
73
+ rec = lcs / float(len(ref_tokens))
74
+ if prec + rec == 0:
75
+ return 0.0
76
+ return float(2 * prec * rec / (prec + rec))
77
+
78
+
79
+ def rouge_l_from_records(records: Iterable[dict[str, Any]]) -> float:
80
+ """Compute ROUGE-L (F1) from records with predictions and references."""
81
+ scores: list[float] = []
82
+ for record in records:
83
+ if not isinstance(record, dict):
84
+ continue
85
+ pred = record.get("prediction")
86
+ refs = record.get("references")
87
+ if pred is None:
88
+ continue
89
+ if refs is None and "reference" in record:
90
+ refs = [record.get("reference")]
91
+ if refs is None:
92
+ continue
93
+ ref_list = refs if isinstance(refs, list) else [refs]
94
+ best = 0.0
95
+ for ref in ref_list:
96
+ if ref is None:
97
+ continue
98
+ best = max(best, _rouge_l(str(pred), str(ref)))
99
+ scores.append(best)
100
+ if not scores:
101
+ return float("nan")
102
+ return float(sum(scores) / float(len(scores)))
@@ -5,6 +5,7 @@ InvarLock Guards - Invariants
5
5
  Invariant checking for model edits to ensure structural integrity.
6
6
  """
7
7
 
8
+ import hashlib
8
9
  from typing import Any
9
10
 
10
11
  import torch
@@ -33,6 +34,7 @@ class InvariantsGuard(Guard):
33
34
  self.on_fail = on_fail
34
35
  self.prepared = False
35
36
  self.baseline_checks: dict[str, Any] = {}
37
+ self.last_current_checks: dict[str, Any] = {}
36
38
  self.profile_checks: tuple[str, ...] = ()
37
39
 
38
40
  def prepare(
@@ -102,6 +104,10 @@ class InvariantsGuard(Guard):
102
104
  "action": outcome.action,
103
105
  "violations": outcome.violations,
104
106
  "metrics": outcome.metrics,
107
+ "details": {
108
+ "baseline_checks": self.baseline_checks,
109
+ "current_checks": self.last_current_checks,
110
+ },
105
111
  }
106
112
 
107
113
  def finalize(self, model: Any) -> GuardOutcome:
@@ -125,6 +131,7 @@ class InvariantsGuard(Guard):
125
131
 
126
132
  # Check current invariants
127
133
  current_checks = self._capture_invariants(model, None)
134
+ self.last_current_checks = current_checks
128
135
  violations: list[dict[str, Any]] = []
129
136
  tokenizer_mismatches: list[dict[str, Any]] = []
130
137
 
@@ -354,14 +361,14 @@ class InvariantsGuard(Guard):
354
361
  except Exception:
355
362
  pass
356
363
 
357
- # LLaMA style (model.embed_tokens <-> lm_head)
364
+ # Decoder embed_tokens style (model.embed_tokens <-> lm_head)
358
365
  try:
359
- llama_model = getattr(model, "model", None)
360
- embed_tokens = getattr(llama_model, "embed_tokens", None)
366
+ decoder_model = getattr(model, "model", None)
367
+ embed_tokens = getattr(decoder_model, "embed_tokens", None)
361
368
  embed_weight = getattr(embed_tokens, "weight", None)
362
- llama_head_weight = getattr(getattr(model, "lm_head", None), "weight", None)
363
- if embed_weight is not None and llama_head_weight is not None:
364
- weight_tying_flags["llama"] = _is_tied(embed_weight, llama_head_weight)
369
+ head_weight = getattr(getattr(model, "lm_head", None), "weight", None)
370
+ if embed_weight is not None and head_weight is not None:
371
+ weight_tying_flags["embed_tokens"] = _is_tied(embed_weight, head_weight)
365
372
  except Exception:
366
373
  pass
367
374
 
@@ -376,8 +383,10 @@ class InvariantsGuard(Guard):
376
383
  structure_items = []
377
384
  for name, module in model.named_modules():
378
385
  structure_items.append(f"{name}:{type(module).__name__}")
379
- structure_hash = hash(tuple(structure_items))
380
- checks["structure_hash"] = structure_hash
386
+ canonical = "\n".join(sorted(structure_items))
387
+ checks["structure_hash"] = hashlib.sha256(
388
+ canonical.encode("utf-8")
389
+ ).hexdigest()[:16]
381
390
  except Exception:
382
391
  checks["structure_hash"] = 0
383
392
 
@@ -424,7 +433,7 @@ class InvariantsGuard(Guard):
424
433
  return "bert" in model_type or has_cls_decoder
425
434
 
426
435
  if name in {"rope_rotary_embedding", "rotary_embedding"}:
427
- # Detect rotary embeddings used by LLaMA-style models
436
+ # Detect rotary embeddings used by RoPE-style models
428
437
  if hasattr(model, "model") and hasattr(model.model, "layers"):
429
438
  first_layer = model.model.layers[0] if model.model.layers else None
430
439
  else:
@@ -443,7 +452,7 @@ class InvariantsGuard(Guard):
443
452
  model_type = getattr(config, "model_type", "") if config else ""
444
453
  return any(
445
454
  keyword in model_type
446
- for keyword in ("gpt", "llama", "mistral", "opt", "phi")
455
+ for keyword in ("gpt", "mistral", "mixtral", "qwen", "opt", "phi")
447
456
  )
448
457
 
449
458
  return True
invarlock/guards/rmt.py CHANGED
@@ -387,7 +387,7 @@ def _iter_transformer_layers(model: nn.Module):
387
387
  except (TypeError, AttributeError):
388
388
  pass
389
389
  elif hasattr(model, "model") and hasattr(model.model, "layers"):
390
- # LLaMA style
390
+ # RoPE decoder style
391
391
  layers = model.model.layers
392
392
  if hasattr(layers, "__iter__") and hasattr(layers, "__len__"):
393
393
  try:
@@ -746,7 +746,7 @@ def rmt_detect_with_names(
746
746
  for idx, layer in enumerate(h_layers):
747
747
  layer_modules.append((f"transformer.h.{idx}", layer))
748
748
  elif hasattr(model, "model") and hasattr(model.model, "layers"):
749
- # LLaMA style
749
+ # RoPE decoder style
750
750
  layers = model.model.layers
751
751
  if hasattr(layers, "__iter__"):
752
752
  for idx, layer in enumerate(layers):
@@ -121,7 +121,7 @@ def _iter_transformer_layers(model: nn.Module):
121
121
  # GPT-2 style
122
122
  yield from model.transformer.h
123
123
  elif hasattr(model, "model") and hasattr(model.model, "layers"):
124
- # LLaMA style
124
+ # RoPE decoder style
125
125
  yield from model.model.layers
126
126
  elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
127
127
  # BERT style
@@ -214,7 +214,7 @@ def equalise_residual_variance(
214
214
  hooks[name] = attn_proj.register_forward_hook(_branch_hook(name))
215
215
 
216
216
  if hasattr(blk, "mlp"):
217
- # Check for c_proj (GPT-2) or down_proj (LLaMA) or fc2 (generic)
217
+ # Check for c_proj (GPT-2) or down_proj (RoPE decoder) or fc2 (generic)
218
218
  mlp_proj = (
219
219
  getattr(blk.mlp, "c_proj", None)
220
220
  or getattr(blk.mlp, "down_proj", None)
@@ -106,7 +106,7 @@ def _gpt2_selectors() -> dict[str, list[str]]:
106
106
  }
107
107
 
108
108
 
109
- def _llama_selectors() -> dict[str, list[str]]:
109
+ def _rope_decoder_selectors() -> dict[str, list[str]]:
110
110
  return {
111
111
  "attention": [
112
112
  "self_attn.q_proj",
@@ -191,11 +191,11 @@ def _make_gpt2_tokenizer(model_id: str):
191
191
  return factory
192
192
 
193
193
 
194
- def _make_llama_tokenizer(model_id: str):
194
+ def _make_causal_auto_tokenizer(model_id: str):
195
195
  def factory() -> tuple[PreTrainedTokenizerBase, str]:
196
196
  if AutoTokenizer is None and GPT2Tokenizer is None:
197
197
  raise RuntimeError(
198
- "LLaMA-style tokenizers require the 'transformers' extra. "
198
+ "Causal tokenizers require the 'transformers' extra. "
199
199
  "Install it with: pip install 'invarlock[adapters]'."
200
200
  )
201
201
  # Try offline-first to respect InvarLock network guard; fall back to a
@@ -227,7 +227,7 @@ def _make_llama_tokenizer(model_id: str):
227
227
  eos_token = getattr(tokenizer, "eos_token", None)
228
228
  if eos_token is not None:
229
229
  tokenizer.pad_token = eos_token
230
- # Some LLaMA tokenizers default to not adding a BOS token on encode;
230
+ # Some causal tokenizers default to not adding a BOS token on encode;
231
231
  # enable it to guarantee at least one non-pad, non-zero token id.
232
232
  if hasattr(tokenizer, "add_bos_token"):
233
233
  try:
@@ -289,7 +289,7 @@ def detect_model_profile(model_id: str, adapter: str | None = None) -> ModelProf
289
289
  model_lower = (model_id or "").lower()
290
290
 
291
291
  if any(
292
- keyword in adapter_lower for keyword in ("bert", "roberta", "deberta")
292
+ keyword in adapter_lower for keyword in ("hf_mlm", "bert", "roberta", "deberta")
293
293
  ) or any(keyword in model_lower for keyword in ("bert", "roberta", "deberta")):
294
294
  return ModelProfile(
295
295
  family="bert",
@@ -302,57 +302,78 @@ def detect_model_profile(model_id: str, adapter: str | None = None) -> ModelProf
302
302
  cert_lints=(
303
303
  {
304
304
  "type": "equals",
305
- "path": "metrics.loss_type",
306
- "value": "mlm",
307
- "message": "BERT cert must record MLM loss type.",
305
+ "path": "primary_metric.kind",
306
+ "value": "ppl_mlm",
307
+ "message": "BERT cert must use MLM metric.",
308
308
  },
309
309
  {
310
310
  "type": "gte",
311
- "path": "metrics.masked_tokens_total",
311
+ "path": "telemetry.masked_tokens_total",
312
312
  "value": "1",
313
313
  "message": "BERT cert must report masked tokens.",
314
314
  },
315
315
  ),
316
316
  )
317
317
 
318
- if any(keyword in adapter_lower for keyword in ("llama", "mistral", "qwen")) or any(
319
- keyword in model_lower for keyword in ("llama", "mistral", "qwen")
318
+ if any(keyword in adapter_lower for keyword in ("hf_seq2seq", "t5", "bart")) or any(
319
+ keyword in model_lower for keyword in ("t5", "bart")
320
320
  ):
321
321
  return ModelProfile(
322
- family="llama",
322
+ family="seq2seq",
323
+ default_loss="seq2seq",
324
+ make_tokenizer=_make_unknown_tokenizer(model_id),
325
+ default_metric="ppl_seq2seq",
326
+ default_provider="wikitext2",
327
+ module_selectors=_unknown_selectors(),
328
+ invariants=(),
329
+ cert_lints=(),
330
+ )
331
+
332
+ if any(
333
+ keyword in adapter_lower for keyword in ("gpt", "neox", "opt", "phi")
334
+ ) or any(keyword in model_lower for keyword in ("gpt", "neox", "opt", "phi")):
335
+ return ModelProfile(
336
+ family="gpt2",
323
337
  default_loss="causal",
324
- make_tokenizer=_make_llama_tokenizer(model_id),
338
+ make_tokenizer=_make_gpt2_tokenizer(model_id),
325
339
  default_metric="ppl_causal",
326
340
  default_provider="wikitext2",
327
- module_selectors=_llama_selectors(),
328
- invariants=("rope_rotary_embedding",),
341
+ module_selectors=_gpt2_selectors(),
342
+ invariants=("causal_masking",),
329
343
  cert_lints=(
330
344
  {
331
345
  "type": "equals",
332
- "path": "metrics.loss_type",
333
- "value": "causal",
334
- "message": "LLaMA cert should report causal loss.",
346
+ "path": "primary_metric.kind",
347
+ "value": "ppl_causal",
348
+ "message": "GPT-style cert must use causal ppl metric.",
335
349
  },
336
350
  ),
337
351
  )
338
352
 
339
353
  if any(
340
- keyword in adapter_lower for keyword in ("gpt", "neox", "opt", "phi")
341
- ) or any(keyword in model_lower for keyword in ("gpt", "neox", "opt", "phi")):
354
+ keyword in adapter_lower for keyword in ("mistral", "mixtral", "qwen", "yi")
355
+ ) or any(
356
+ keyword in model_lower for keyword in ("mistral", "mixtral", "qwen", "yi")
357
+ ):
358
+ family = "causal"
359
+ for keyword in ("mixtral", "mistral", "qwen", "yi"):
360
+ if keyword in adapter_lower or keyword in model_lower:
361
+ family = keyword
362
+ break
342
363
  return ModelProfile(
343
- family="gpt2",
364
+ family=family,
344
365
  default_loss="causal",
345
- make_tokenizer=_make_gpt2_tokenizer(model_id),
366
+ make_tokenizer=_make_causal_auto_tokenizer(model_id),
346
367
  default_metric="ppl_causal",
347
368
  default_provider="wikitext2",
348
- module_selectors=_gpt2_selectors(),
349
- invariants=("causal_masking",),
369
+ module_selectors=_rope_decoder_selectors(),
370
+ invariants=("rope_rotary_embedding",),
350
371
  cert_lints=(
351
372
  {
352
373
  "type": "equals",
353
- "path": "metrics.loss_type",
354
- "value": "causal",
355
- "message": "GPT-style cert should record causal loss.",
374
+ "path": "primary_metric.kind",
375
+ "value": "ppl_causal",
376
+ "message": "Causal cert must use causal ppl metric.",
356
377
  },
357
378
  ),
358
379
  )
@@ -374,15 +374,15 @@ class InvarLockHealthChecker(HealthChecker):
374
374
  """Check adapter availability."""
375
375
  try:
376
376
  from invarlock.adapters import (
377
- HF_BERT_Adapter,
378
- HF_GPT2_Adapter,
379
- HF_LLaMA_Adapter,
377
+ HF_Causal_Adapter,
378
+ HF_MLM_Adapter,
379
+ HF_Seq2Seq_Adapter,
380
380
  )
381
381
 
382
382
  adapters = {
383
- "hf_gpt2": HF_GPT2_Adapter,
384
- "hf_llama": HF_LLaMA_Adapter,
385
- "hf_bert": HF_BERT_Adapter,
383
+ "hf_causal": HF_Causal_Adapter,
384
+ "hf_mlm": HF_MLM_Adapter,
385
+ "hf_seq2seq": HF_Seq2Seq_Adapter,
386
386
  }
387
387
 
388
388
  available_adapters = []
@@ -455,3 +455,111 @@ def create_resource_metrics(registry: MetricsRegistry) -> dict[str, Any]:
455
455
  "gpu_memory": registry.register_gauge("invarlock.resource.gpu_memory_percent"),
456
456
  "disk_usage": registry.register_gauge("invarlock.resource.disk_percent"),
457
457
  }
458
+
459
+
460
+ def reset_peak_memory_stats() -> None:
461
+ """Reset GPU peak memory stats when available."""
462
+ try:
463
+ import torch
464
+
465
+ if torch.cuda.is_available():
466
+ torch.cuda.reset_peak_memory_stats()
467
+ mps = getattr(torch, "mps", None)
468
+ if mps is not None and hasattr(mps, "reset_peak_memory_stats"):
469
+ mps.reset_peak_memory_stats()
470
+ except Exception:
471
+ pass
472
+
473
+
474
+ def capture_memory_snapshot(
475
+ phase: str, *, timestamp: float | None = None
476
+ ) -> dict[str, Any]:
477
+ """Capture a point-in-time memory snapshot for the current process."""
478
+ snapshot: dict[str, Any] = {"phase": str(phase)}
479
+ if timestamp is None:
480
+ timestamp = time.time()
481
+ snapshot["ts"] = float(timestamp)
482
+
483
+ try:
484
+ import os
485
+
486
+ import psutil
487
+
488
+ process = psutil.Process(os.getpid())
489
+ rss_mb = process.memory_info().rss / 1024 / 1024
490
+ snapshot["rss_mb"] = float(rss_mb)
491
+ except Exception:
492
+ pass
493
+
494
+ try:
495
+ import torch
496
+
497
+ if torch.cuda.is_available():
498
+ device_index = torch.cuda.current_device()
499
+ snapshot["gpu_device"] = f"cuda:{device_index}"
500
+ snapshot["gpu_mb"] = float(
501
+ torch.cuda.memory_allocated(device_index) / 1024 / 1024
502
+ )
503
+ snapshot["gpu_reserved_mb"] = float(
504
+ torch.cuda.memory_reserved(device_index) / 1024 / 1024
505
+ )
506
+ snapshot["gpu_peak_mb"] = float(
507
+ torch.cuda.max_memory_allocated(device_index) / 1024 / 1024
508
+ )
509
+ snapshot["gpu_peak_reserved_mb"] = float(
510
+ torch.cuda.max_memory_reserved(device_index) / 1024 / 1024
511
+ )
512
+ else:
513
+ mps = getattr(torch, "mps", None)
514
+ if mps is not None and hasattr(torch.backends, "mps"):
515
+ if torch.backends.mps.is_available():
516
+ snapshot["gpu_device"] = "mps"
517
+ if hasattr(mps, "current_allocated_memory"):
518
+ snapshot["gpu_mb"] = float(
519
+ mps.current_allocated_memory() / 1024 / 1024
520
+ )
521
+ if hasattr(mps, "driver_allocated_memory"):
522
+ snapshot["gpu_reserved_mb"] = float(
523
+ mps.driver_allocated_memory() / 1024 / 1024
524
+ )
525
+ except Exception:
526
+ pass
527
+
528
+ if len(snapshot) <= 2:
529
+ return {}
530
+ return snapshot
531
+
532
+
533
+ def summarize_memory_snapshots(
534
+ snapshots: list[dict[str, Any]],
535
+ ) -> dict[str, float]:
536
+ """Summarize memory snapshots into peak metrics."""
537
+
538
+ def _peak(key: str) -> float | None:
539
+ values: list[float] = []
540
+ for entry in snapshots:
541
+ if not isinstance(entry, dict):
542
+ continue
543
+ value = entry.get(key)
544
+ if isinstance(value, int | float):
545
+ values.append(float(value))
546
+ return max(values) if values else None
547
+
548
+ summary: dict[str, float] = {}
549
+ rss_peak = _peak("rss_mb")
550
+ if rss_peak is not None:
551
+ summary["memory_mb_peak"] = rss_peak
552
+
553
+ gpu_peak = _peak("gpu_peak_mb")
554
+ if gpu_peak is None:
555
+ gpu_peak = _peak("gpu_mb")
556
+ if gpu_peak is not None:
557
+ summary["gpu_memory_mb_peak"] = gpu_peak
558
+
559
+ gpu_reserved_peak = _peak("gpu_peak_reserved_mb")
560
+ if gpu_reserved_peak is None:
561
+ gpu_reserved_peak = _peak("gpu_reserved_mb")
562
+ if gpu_reserved_peak is not None:
563
+ summary["gpu_memory_reserved_mb_peak"] = gpu_reserved_peak
564
+
565
+ return summary