invarlock 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. invarlock/__init__.py +2 -2
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +11 -15
  4. invarlock/adapters/auto.py +35 -40
  5. invarlock/adapters/capabilities.py +2 -2
  6. invarlock/adapters/hf_causal.py +418 -0
  7. invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
  8. invarlock/adapters/hf_mixin.py +25 -4
  9. invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
  10. invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
  11. invarlock/calibration/spectral_null.py +15 -10
  12. invarlock/calibration/variance_ve.py +0 -2
  13. invarlock/cli/adapter_auto.py +31 -21
  14. invarlock/cli/app.py +73 -2
  15. invarlock/cli/commands/calibrate.py +6 -2
  16. invarlock/cli/commands/certify.py +651 -91
  17. invarlock/cli/commands/doctor.py +11 -11
  18. invarlock/cli/commands/explain_gates.py +57 -8
  19. invarlock/cli/commands/plugins.py +13 -9
  20. invarlock/cli/commands/report.py +233 -69
  21. invarlock/cli/commands/run.py +1066 -244
  22. invarlock/cli/commands/verify.py +154 -15
  23. invarlock/cli/config.py +22 -6
  24. invarlock/cli/doctor_helpers.py +4 -5
  25. invarlock/cli/output.py +193 -0
  26. invarlock/cli/provenance.py +1 -1
  27. invarlock/core/api.py +45 -5
  28. invarlock/core/auto_tuning.py +65 -20
  29. invarlock/core/bootstrap.py +1 -1
  30. invarlock/core/contracts.py +7 -1
  31. invarlock/core/registry.py +11 -13
  32. invarlock/core/runner.py +425 -75
  33. invarlock/edits/quant_rtn.py +65 -37
  34. invarlock/eval/bench.py +3 -16
  35. invarlock/eval/data.py +82 -51
  36. invarlock/eval/metrics.py +63 -2
  37. invarlock/eval/primary_metric.py +23 -0
  38. invarlock/eval/tail_stats.py +230 -0
  39. invarlock/eval/tasks/__init__.py +12 -0
  40. invarlock/eval/tasks/classification.py +48 -0
  41. invarlock/eval/tasks/qa.py +36 -0
  42. invarlock/eval/tasks/text_generation.py +102 -0
  43. invarlock/guards/_estimators.py +154 -0
  44. invarlock/guards/invariants.py +19 -10
  45. invarlock/guards/policies.py +16 -6
  46. invarlock/guards/rmt.py +627 -546
  47. invarlock/guards/spectral.py +348 -110
  48. invarlock/guards/tier_config.py +32 -30
  49. invarlock/guards/variance.py +7 -31
  50. invarlock/guards_ref/rmt_ref.py +23 -23
  51. invarlock/model_profile.py +90 -42
  52. invarlock/observability/health.py +6 -6
  53. invarlock/observability/metrics.py +108 -0
  54. invarlock/reporting/certificate.py +384 -55
  55. invarlock/reporting/certificate_schema.py +3 -2
  56. invarlock/reporting/dataset_hashing.py +15 -2
  57. invarlock/reporting/guards_analysis.py +350 -277
  58. invarlock/reporting/html.py +55 -5
  59. invarlock/reporting/normalizer.py +13 -0
  60. invarlock/reporting/policy_utils.py +38 -36
  61. invarlock/reporting/primary_metric_utils.py +71 -17
  62. invarlock/reporting/render.py +852 -431
  63. invarlock/reporting/report.py +40 -4
  64. invarlock/reporting/report_types.py +11 -3
  65. invarlock/reporting/telemetry.py +86 -0
  66. invarlock/reporting/validate.py +1 -18
  67. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/METADATA +27 -13
  68. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/RECORD +72 -65
  69. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
  70. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
  71. invarlock/adapters/hf_gpt2.py +0 -404
  72. invarlock/adapters/hf_llama.py +0 -487
  73. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
  74. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
  invarlock verify command
3
3
  ====================
4
4
 
5
- Validates generated safety certificates for internal consistency. The command
5
+ Validates generated evaluation certificates for internal consistency. The command
6
6
  ensures schema compliance, checks that the primary metric ratio agrees with the
7
7
  baseline reference, and enforces paired-window guarantees (match=1.0,
8
8
  overlap=0.0).
@@ -10,6 +10,7 @@ overlap=0.0).
10
10
 
11
11
  from __future__ import annotations
12
12
 
13
+ import hashlib
13
14
  import json
14
15
  import math
15
16
  from pathlib import Path
@@ -65,9 +66,25 @@ def _validate_primary_metric(certificate: dict[str, Any]) -> list[str]:
65
66
  errors.append("Certificate missing primary_metric block.")
66
67
  return errors
67
68
 
69
+ def _is_finite_number(value: Any) -> bool:
70
+ return isinstance(value, (int, float)) and math.isfinite(float(value))
71
+
72
+ def _declares_invalid_primary_metric(metric: dict[str, Any]) -> bool:
73
+ if bool(metric.get("invalid")):
74
+ return True
75
+ reason = metric.get("degraded_reason")
76
+ if isinstance(reason, str):
77
+ r = reason.strip().lower()
78
+ return r.startswith("non_finite") or r in {
79
+ "primary_metric_invalid",
80
+ "evaluation_error",
81
+ }
82
+ return False
83
+
68
84
  kind = str(pm.get("kind", "")).lower()
69
85
  ratio_vs_baseline = pm.get("ratio_vs_baseline")
70
86
  final = pm.get("final")
87
+ pm_invalid = _declares_invalid_primary_metric(pm)
71
88
 
72
89
  if kind.startswith("ppl"):
73
90
  baseline_ref = certificate.get("baseline_ref", {}) or {}
@@ -81,16 +98,14 @@ def _validate_primary_metric(certificate: dict[str, Any]) -> list[str]:
81
98
  bv = baseline_pm.get("final")
82
99
  if isinstance(bv, (int | float)):
83
100
  baseline_final = float(bv)
84
- if isinstance(final, int | float) and isinstance(baseline_final, int | float):
85
- if baseline_final <= 0.0:
101
+ if _is_finite_number(final) and _is_finite_number(baseline_final):
102
+ if float(baseline_final) <= 0.0:
86
103
  errors.append(
87
104
  f"Baseline final must be > 0.0 to compute ratio (found {baseline_final})."
88
105
  )
89
106
  else:
90
107
  expected_ratio = float(final) / float(baseline_final)
91
- if not isinstance(ratio_vs_baseline, int | float) or not math.isfinite(
92
- float(ratio_vs_baseline)
93
- ):
108
+ if not _is_finite_number(ratio_vs_baseline):
94
109
  errors.append(
95
110
  "Certificate is missing a finite primary_metric.ratio_vs_baseline value."
96
111
  )
@@ -101,7 +116,18 @@ def _validate_primary_metric(certificate: dict[str, Any]) -> list[str]:
101
116
  "Primary metric ratio mismatch: "
102
117
  f"recorded={float(ratio_vs_baseline):.12f}, expected={expected_ratio:.12f}"
103
118
  )
119
+ else:
120
+ # If the primary metric is non-finite, it must be explicitly marked invalid.
121
+ # This is expected for structural error-injection runs (NaN/Inf weights).
122
+ if (isinstance(final, (int | float)) and not _is_finite_number(final)) and (
123
+ not pm_invalid
124
+ ):
125
+ errors.append(
126
+ "Primary metric final is non-finite but primary_metric.invalid is not set."
127
+ )
104
128
  else:
129
+ if pm_invalid:
130
+ return errors
105
131
  if ratio_vs_baseline is None or not isinstance(ratio_vs_baseline, int | float):
106
132
  errors.append(
107
133
  "Certificate missing primary_metric.ratio_vs_baseline for non-ppl metric."
@@ -193,14 +219,29 @@ def _validate_counts(certificate: dict[str, Any]) -> list[str]:
193
219
 
194
220
 
195
221
  def _validate_drift_band(certificate: dict[str, Any]) -> list[str]:
196
- """Validate preview→final drift stays within the configured band (0.95–1.05)."""
222
+ """Validate preview→final drift stays within the configured band.
223
+
224
+ Defaults to 0.95–1.05 unless the certificate provides `primary_metric.drift_band`.
225
+ """
197
226
  errors: list[str] = []
198
227
  pm = certificate.get("primary_metric", {}) or {}
228
+ if not isinstance(pm, dict) or not pm:
229
+ errors.append("Certificate missing primary_metric block.")
230
+ return errors
231
+ if bool(pm.get("invalid")):
232
+ # Drift is undefined when the primary metric is invalid (e.g., NaN/Inf weights).
233
+ return errors
199
234
  drift_ratio = None
200
235
  try:
201
236
  prev = pm.get("preview")
202
237
  fin = pm.get("final")
203
- if isinstance(prev, int | float) and isinstance(fin, int | float) and prev > 0:
238
+ if (
239
+ isinstance(prev, int | float)
240
+ and isinstance(fin, int | float)
241
+ and math.isfinite(float(prev))
242
+ and math.isfinite(float(fin))
243
+ and prev > 0
244
+ ):
204
245
  drift_ratio = float(fin) / float(prev)
205
246
  except Exception:
206
247
  drift_ratio = None
@@ -209,9 +250,33 @@ def _validate_drift_band(certificate: dict[str, Any]) -> list[str]:
209
250
  errors.append("Certificate missing preview/final to compute drift ratio.")
210
251
  return errors
211
252
 
212
- if not 0.95 <= float(drift_ratio) <= 1.05:
253
+ drift_min = 0.95
254
+ drift_max = 1.05
255
+ band = pm.get("drift_band")
256
+ try:
257
+ if isinstance(band, dict):
258
+ lo = band.get("min")
259
+ hi = band.get("max")
260
+ if isinstance(lo, int | float) and isinstance(hi, int | float):
261
+ lo_f = float(lo)
262
+ hi_f = float(hi)
263
+ if math.isfinite(lo_f) and math.isfinite(hi_f) and 0 < lo_f < hi_f:
264
+ drift_min = lo_f
265
+ drift_max = hi_f
266
+ elif isinstance(band, list | tuple) and len(band) == 2:
267
+ lo_raw, hi_raw = band[0], band[1]
268
+ if isinstance(lo_raw, int | float) and isinstance(hi_raw, int | float):
269
+ lo_f = float(lo_raw)
270
+ hi_f = float(hi_raw)
271
+ if math.isfinite(lo_f) and math.isfinite(hi_f) and 0 < lo_f < hi_f:
272
+ drift_min = lo_f
273
+ drift_max = hi_f
274
+ except Exception:
275
+ pass
276
+
277
+ if not drift_min <= float(drift_ratio) <= drift_max:
213
278
  errors.append(
214
- f"Preview→final drift ratio out of band (0.951.05): observed {drift_ratio:.6f}."
279
+ f"Preview→final drift ratio out of band ({drift_min:.2f}{drift_max:.2f}): observed {drift_ratio:.6f}."
215
280
  )
216
281
 
217
282
  return errors
@@ -220,9 +285,8 @@ def _validate_drift_band(certificate: dict[str, Any]) -> list[str]:
220
285
  def _validate_tokenizer_hash(certificate: dict[str, Any]) -> list[str]:
221
286
  """Validate tokenizer hash consistency between baseline and edited runs.
222
287
 
223
- The check is enforced only when both hashes are present to preserve
224
- compatibility with legacy certificates. When present and different,
225
- the verification fails.
288
+ The check is enforced only when both hashes are present. When present and
289
+ different, the verification fails.
226
290
  """
227
291
  errors: list[str] = []
228
292
  meta = certificate.get("meta", {}) or {}
@@ -244,7 +308,7 @@ def _validate_tokenizer_hash(certificate: dict[str, Any]) -> list[str]:
244
308
  if isinstance(edited_hash, str) and isinstance(baseline_hash, str):
245
309
  if edited_hash and baseline_hash and edited_hash != baseline_hash:
246
310
  errors.append("Tokenizer hash mismatch between baseline and edited runs.")
247
- # If either hash is missing, skip the check for backward compatibility
311
+ # If either hash is missing, skip the check
248
312
  return errors
249
313
 
250
314
 
@@ -259,6 +323,74 @@ def _resolve_path(payload: Any, path: str) -> Any:
259
323
  return current
260
324
 
261
325
 
326
+ def _measurement_contract_digest(contract: Any) -> str | None:
327
+ if not isinstance(contract, dict) or not contract:
328
+ return None
329
+ try:
330
+ canonical = json.dumps(contract, sort_keys=True, default=str)
331
+ except Exception:
332
+ return None
333
+ return hashlib.sha256(canonical.encode()).hexdigest()[:16]
334
+
335
+
336
+ def _validate_measurement_contracts(
337
+ certificate: dict[str, Any], *, profile: str
338
+ ) -> list[str]:
339
+ """Enforce measurement-contract presence and baseline pairing for guards."""
340
+ errors: list[str] = []
341
+ prof = (profile or "").strip().lower()
342
+ resolved_policy = certificate.get("resolved_policy") or {}
343
+
344
+ for guard_key in ("spectral", "rmt"):
345
+ block = certificate.get(guard_key) or {}
346
+ if not isinstance(block, dict):
347
+ continue
348
+ evaluated = bool(block.get("evaluated", True))
349
+ if not evaluated:
350
+ continue
351
+
352
+ mc = block.get("measurement_contract")
353
+ mc_hash = _measurement_contract_digest(mc)
354
+ expected_hash = block.get("measurement_contract_hash")
355
+ if not isinstance(mc, dict) or not mc:
356
+ errors.append(f"Certificate missing {guard_key}.measurement_contract.")
357
+ elif isinstance(expected_hash, str) and expected_hash:
358
+ if mc_hash and mc_hash != expected_hash:
359
+ errors.append(
360
+ f"{guard_key}.measurement_contract_hash mismatch: expected={expected_hash}, computed={mc_hash}."
361
+ )
362
+ else:
363
+ errors.append(f"Certificate missing {guard_key}.measurement_contract_hash.")
364
+
365
+ rp_guard = (
366
+ resolved_policy.get(guard_key)
367
+ if isinstance(resolved_policy, dict)
368
+ else None
369
+ )
370
+ rp_mc = (
371
+ rp_guard.get("measurement_contract") if isinstance(rp_guard, dict) else None
372
+ )
373
+ rp_hash = _measurement_contract_digest(rp_mc)
374
+ if not isinstance(rp_mc, dict) or not rp_mc:
375
+ errors.append(
376
+ f"Certificate missing resolved_policy.{guard_key}.measurement_contract."
377
+ )
378
+ elif mc_hash and rp_hash and mc_hash != rp_hash:
379
+ errors.append(
380
+ f"{guard_key} measurement_contract differs between analysis and resolved_policy "
381
+ f"(analysis={mc_hash}, resolved_policy={rp_hash})."
382
+ )
383
+
384
+ if prof in {"ci", "release"}:
385
+ match = block.get("measurement_contract_match")
386
+ if match is not True:
387
+ errors.append(
388
+ f"{guard_key} measurement contract must match baseline for {prof} profile."
389
+ )
390
+
391
+ return errors
392
+
393
+
262
394
  def _apply_profile_lints(certificate: dict[str, Any]) -> list[str]:
263
395
  """Apply model-profile specific lint rules embedded in the certificate."""
264
396
  errors: list[str] = []
@@ -338,11 +470,18 @@ def _validate_certificate_payload(
338
470
  )
339
471
  except Exception:
340
472
  prof = "dev"
341
- # Enforce drift band only for CI/Release; skip in dev profile
473
+ # Drift band is a CI/Release enforcement check; dev profile should not
474
+ # fail verification due to preview→final drift.
342
475
  if prof in {"ci", "release"}:
343
476
  errors.extend(_validate_drift_band(certificate))
344
477
  errors.extend(_apply_profile_lints(certificate))
345
478
  errors.extend(_validate_tokenizer_hash(certificate))
479
+ if prof in {"ci", "release"}:
480
+ errors.extend(_validate_measurement_contracts(certificate, profile=prof))
481
+
482
+ # strict/fast assurance mode checks were removed; verification gates rely on
483
+ # structural schema + guard metric contracts instead.
484
+
346
485
  # Release-only enforcement: guard overhead must be measured or explicitly skipped.
347
486
  if prof == "release":
348
487
  go = certificate.get("guard_overhead")
invarlock/cli/config.py CHANGED
@@ -131,14 +131,9 @@ class EvalBootstrapConfig:
131
131
  @dataclass
132
132
  class SpectralGuardConfig:
133
133
  sigma_quantile: float | None = None
134
- contraction: float | None = None
135
134
  family_caps: dict[str, Any] = field(default_factory=dict)
136
135
 
137
136
  def __post_init__(self) -> None:
138
- # contraction is an alias for sigma_quantile
139
- if self.contraction is not None and self.sigma_quantile is None:
140
- self.sigma_quantile = float(self.contraction)
141
- self.contraction = None
142
137
  # normalize family_caps: scalar → {"kappa": value}
143
138
  caps = {}
144
139
  for k, v in (self.family_caps or {}).items():
@@ -244,6 +239,27 @@ def load_config(path: str | Path) -> InvarLockConfig:
244
239
  raise ValueError("defaults must be a mapping when present")
245
240
  if isinstance(defaults, dict):
246
241
  raw = _deep_merge(defaults, raw)
242
+
243
+ # "assurance" (strict/fast) was removed in the GPU/MPS-first measurement-contract
244
+ # world. Fail closed so outdated configs are updated explicitly.
245
+ if raw.get("assurance") is not None:
246
+ raise ValueError(
247
+ "assurance.* is deprecated; configure measurement contracts under guards.* "
248
+ "(e.g., guards.spectral.estimator, guards.rmt.activation.sampling)."
249
+ )
250
+
251
+ # Per-guard strict/fast mode overrides were also removed. Fail closed to avoid
252
+ # silently accepting configs that no longer apply.
253
+ guards_block = raw.get("guards")
254
+ if isinstance(guards_block, dict):
255
+ for guard_name in ("spectral", "rmt"):
256
+ node = guards_block.get(guard_name)
257
+ if isinstance(node, dict) and "mode" in node:
258
+ raise ValueError(
259
+ f"guards.{guard_name}.mode is deprecated; remove it and configure "
260
+ "measurement-contract knobs under guard policy fields instead."
261
+ )
262
+
247
263
  # Coerce known guard configs for friendlier attribute access
248
264
  guards = raw.get("guards")
249
265
  if isinstance(guards, dict):
@@ -399,7 +415,7 @@ def _deep_merge_dicts(a: dict, b: dict) -> dict: # pragma: no cover - trivial a
399
415
 
400
416
  def create_example_config() -> InvarLockConfig: # pragma: no cover - test helper
401
417
  return InvarLockConfig(
402
- model={"id": "gpt2", "adapter": "hf_gpt2", "device": "auto"},
418
+ model={"id": "gpt2", "adapter": "hf_causal", "device": "auto"},
403
419
  edit={"name": "quant_rtn", "plan": {}},
404
420
  dataset={"provider": "wikitext2", "seq_len": 512, "stride": 512},
405
421
  output={"dir": "runs"},
@@ -8,7 +8,7 @@ from typing import Any
8
8
  def get_adapter_rows() -> list[dict[str, Any]]:
9
9
  """Build adapter rows similar to doctor output for testing.
10
10
 
11
- Applies optional-extra detection for hf_onnx (optimum/onnxruntime) even if
11
+ Applies optional-extra detection for hf_causal_onnx (optimum/onnxruntime) even if
12
12
  registered as a core adapter, so missing extras are surfaced.
13
13
  """
14
14
  from invarlock.core.registry import get_registry
@@ -29,13 +29,12 @@ def get_adapter_rows() -> list[dict[str, Any]]:
29
29
  module = str(info.get("module") or "")
30
30
  support = (
31
31
  "auto"
32
- if module.startswith("invarlock.adapters")
33
- and name in {"hf_causal_auto", "hf_mlm_auto"}
32
+ if module.startswith("invarlock.adapters") and name in {"hf_auto"}
34
33
  else ("core" if module.startswith("invarlock.adapters") else "optional")
35
34
  )
36
35
  backend, status, enable = None, "ready", ""
37
36
 
38
- if name in {"hf_gpt2", "hf_bert", "hf_llama", "hf_causal_auto", "hf_mlm_auto"}:
37
+ if name in {"hf_causal", "hf_mlm", "hf_seq2seq", "hf_auto"}:
39
38
  backend = "transformers"
40
39
  elif name == "hf_gptq":
41
40
  backend = "auto-gptq"
@@ -49,7 +48,7 @@ def get_adapter_rows() -> list[dict[str, Any]]:
49
48
  backend = "bitsandbytes"
50
49
  if not has_cuda:
51
50
  status, enable = "unsupported", "Requires CUDA"
52
- elif name == "hf_onnx":
51
+ elif name == "hf_causal_onnx":
53
52
  backend = "onnxruntime"
54
53
  present = (
55
54
  importlib.util.find_spec("optimum.onnxruntime") is not None
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from collections.abc import Iterator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass
8
+ from typing import TextIO
9
+
10
+ from rich.console import Console
11
+
12
+ _STYLE_AUDIT = "audit"
13
+ _STYLE_FRIENDLY = "friendly"
14
+ _VALID_STYLES = {_STYLE_AUDIT, _STYLE_FRIENDLY}
15
+
16
+
17
+ def _safe_console_print(console: Console, *args: object, **kwargs: object) -> None:
18
+ try:
19
+ console.print(*args, **kwargs)
20
+ except TypeError:
21
+ console.print(*args)
22
+
23
+
24
+ def env_no_color() -> bool:
25
+ """Return True when NO_COLOR is set (value-agnostic)."""
26
+ return bool(str(os.environ.get("NO_COLOR", "")).strip())
27
+
28
+
29
+ def perf_counter() -> float:
30
+ return time.perf_counter()
31
+
32
+
33
+ @dataclass(frozen=True, slots=True)
34
+ class OutputStyle:
35
+ name: str
36
+ progress: bool = False
37
+ timing: bool = False
38
+ color: bool = True
39
+
40
+ @property
41
+ def emojis(self) -> bool:
42
+ return self.name != _STYLE_AUDIT
43
+
44
+ @property
45
+ def audit(self) -> bool:
46
+ return self.name == _STYLE_AUDIT
47
+
48
+
49
+ def normalize_style(style: str | None) -> str | None:
50
+ if style is None:
51
+ return None
52
+ value = str(style).strip().lower()
53
+ if not value:
54
+ return None
55
+ return value if value in _VALID_STYLES else None
56
+
57
+
58
+ def resolve_style_name(style: str | None, profile: str | None) -> str:
59
+ normalized = normalize_style(style)
60
+ if normalized is not None:
61
+ return normalized
62
+ profile_norm = str(profile or "").strip().lower()
63
+ if profile_norm in {"ci", "ci_cpu", "release"}:
64
+ return _STYLE_AUDIT
65
+ return _STYLE_FRIENDLY
66
+
67
+
68
+ def resolve_output_style(
69
+ *,
70
+ style: str | None,
71
+ profile: str | None,
72
+ progress: bool = False,
73
+ timing: bool = False,
74
+ no_color: bool = False,
75
+ ) -> OutputStyle:
76
+ name = resolve_style_name(style, profile)
77
+ color_enabled = not (bool(no_color) or env_no_color())
78
+ return OutputStyle(
79
+ name=name,
80
+ progress=bool(progress),
81
+ timing=bool(timing),
82
+ color=color_enabled,
83
+ )
84
+
85
+
86
+ def make_console(
87
+ *,
88
+ file: TextIO | None = None,
89
+ force_terminal: bool | None = None,
90
+ no_color: bool | None = None,
91
+ ) -> Console:
92
+ if no_color is None:
93
+ no_color = env_no_color()
94
+ if no_color:
95
+ color_system = None
96
+ else:
97
+ color_system = "standard" if force_terminal else "auto"
98
+ return Console(
99
+ file=file,
100
+ force_terminal=force_terminal,
101
+ no_color=bool(no_color),
102
+ color_system=color_system,
103
+ )
104
+
105
+
106
+ def format_event_line(
107
+ tag: str,
108
+ message: str,
109
+ *,
110
+ style: OutputStyle,
111
+ emoji: str | None = None,
112
+ ) -> str:
113
+ tag_norm = str(tag or "").strip().upper() or "INFO"
114
+ if style.emojis and emoji:
115
+ prefix = emoji
116
+ else:
117
+ prefix = f"[{tag_norm}]"
118
+ msg = str(message or "").rstrip()
119
+ return f"{prefix} {msg}".rstrip()
120
+
121
+
122
+ def print_event(
123
+ console: Console,
124
+ tag: str,
125
+ message: str,
126
+ *,
127
+ style: OutputStyle,
128
+ emoji: str | None = None,
129
+ console_style: str | None = None,
130
+ ) -> None:
131
+ line = format_event_line(tag, message, style=style, emoji=emoji)
132
+ if console_style is None and style.color:
133
+ tag_norm = str(tag or "").strip().upper()
134
+ if tag_norm in {"PASS"}:
135
+ console_style = "green"
136
+ elif tag_norm in {"FAIL", "ERROR"}:
137
+ console_style = "red"
138
+ elif tag_norm in {"WARN", "WARNING"}:
139
+ console_style = "yellow"
140
+ elif tag_norm in {"METRIC"}:
141
+ console_style = "cyan"
142
+ _safe_console_print(console, line, style=console_style, markup=False)
143
+
144
+
145
+ @contextmanager
146
+ def timed_step(
147
+ *,
148
+ console: Console,
149
+ style: OutputStyle,
150
+ timings: dict[str, float] | None,
151
+ key: str,
152
+ tag: str,
153
+ message: str,
154
+ emoji: str | None = None,
155
+ ) -> Iterator[None]:
156
+ start = perf_counter()
157
+ try:
158
+ yield
159
+ finally:
160
+ elapsed = max(0.0, float(perf_counter() - start))
161
+ if timings is not None:
162
+ timings[key] = elapsed
163
+ if style.progress:
164
+ print_event(
165
+ console,
166
+ tag,
167
+ f"{message} done ({elapsed:.2f}s)",
168
+ style=style,
169
+ emoji=emoji,
170
+ )
171
+
172
+
173
+ def print_timing_summary(
174
+ console: Console,
175
+ timings: dict[str, float],
176
+ *,
177
+ style: OutputStyle,
178
+ order: list[tuple[str, str]],
179
+ extra_lines: list[str] | None = None,
180
+ ) -> None:
181
+ if not style.timing:
182
+ return
183
+ _safe_console_print(console, "", markup=False)
184
+ _safe_console_print(console, "TIMING SUMMARY", markup=False)
185
+ for label, key in order:
186
+ if key not in timings:
187
+ continue
188
+ _safe_console_print(
189
+ console, f" {label:<11}: {timings[key]:.2f}s", markup=False
190
+ )
191
+ if extra_lines:
192
+ for line in extra_lines:
193
+ _safe_console_print(console, line, markup=False)
@@ -31,7 +31,7 @@ _FAMILY_MAP: dict[str, tuple[str, str, list[str]]] = {
31
31
  "hf_awq": ("awq", "autoawq", []),
32
32
  "hf_bnb": ("bnb", "bitsandbytes", []),
33
33
  # ONNX stack (requires extras: invarlock[onnx])
34
- "hf_onnx": ("onnx", "onnxruntime", []),
34
+ "hf_causal_onnx": ("onnx", "onnxruntime", []),
35
35
  }
36
36
 
37
37
 
invarlock/core/api.py CHANGED
@@ -17,7 +17,7 @@ from __future__ import annotations
17
17
  from abc import ABC, abstractmethod
18
18
  from dataclasses import dataclass, field
19
19
  from pathlib import Path
20
- from typing import Any
20
+ from typing import Any, Protocol, runtime_checkable
21
21
 
22
22
 
23
23
  class ModelAdapter(ABC):
@@ -88,6 +88,15 @@ class ModelEdit(ABC):
88
88
  pass
89
89
 
90
90
 
91
+ @runtime_checkable
92
+ class EditLike(Protocol):
93
+ name: str
94
+
95
+ def can_edit(self, model_desc: dict[str, Any]) -> bool: ...
96
+
97
+ def apply(self, model: Any, adapter: ModelAdapter, **kwargs) -> dict[str, Any]: ...
98
+
99
+
91
100
  class Guard(ABC):
92
101
  """
93
102
  Abstract interface for safety guards.
@@ -116,6 +125,37 @@ class Guard(ABC):
116
125
  pass
117
126
 
118
127
 
128
+ @runtime_checkable
129
+ class GuardWithContext(Protocol):
130
+ def set_run_context(self, report: Any) -> None: ...
131
+
132
+
133
+ @runtime_checkable
134
+ class GuardWithPrepare(Protocol):
135
+ def prepare(
136
+ self,
137
+ model: Any,
138
+ adapter: ModelAdapter,
139
+ calib: Any,
140
+ policy_config: dict[str, Any],
141
+ ) -> dict[str, Any]: ...
142
+
143
+
144
+ @runtime_checkable
145
+ class GuardWithBeforeEdit(Protocol):
146
+ def before_edit(self, model: Any) -> Any: ...
147
+
148
+
149
+ @runtime_checkable
150
+ class GuardWithAfterEdit(Protocol):
151
+ def after_edit(self, model: Any) -> Any: ...
152
+
153
+
154
+ @runtime_checkable
155
+ class GuardWithFinalize(Protocol):
156
+ def finalize(self, model: Any) -> Any: ...
157
+
158
+
119
159
  class GuardChain:
120
160
  """
121
161
  Manages a chain of guards with policy-based execution.
@@ -145,7 +185,7 @@ class GuardChain:
145
185
  """Prepare all guards."""
146
186
  results = {}
147
187
  for guard in self.guards:
148
- if hasattr(guard, "prepare"):
188
+ if isinstance(guard, GuardWithPrepare):
149
189
  results[guard.name] = guard.prepare(
150
190
  model, adapter, calib, policy_config
151
191
  )
@@ -157,7 +197,7 @@ class GuardChain:
157
197
  """Execute before_edit on all guards."""
158
198
  results = []
159
199
  for guard in self.guards:
160
- if hasattr(guard, "before_edit"):
200
+ if isinstance(guard, GuardWithBeforeEdit):
161
201
  result = guard.before_edit(model)
162
202
  if result is not None:
163
203
  results.append(result)
@@ -167,7 +207,7 @@ class GuardChain:
167
207
  """Execute after_edit on all guards."""
168
208
  results = []
169
209
  for guard in self.guards:
170
- if hasattr(guard, "after_edit"):
210
+ if isinstance(guard, GuardWithAfterEdit):
171
211
  result = guard.after_edit(model)
172
212
  if result is not None:
173
213
  results.append(result)
@@ -177,7 +217,7 @@ class GuardChain:
177
217
  """Finalize all guards and return outcomes."""
178
218
  results = []
179
219
  for guard in self.guards:
180
- if hasattr(guard, "finalize"):
220
+ if isinstance(guard, GuardWithFinalize):
181
221
  result = guard.finalize(model)
182
222
  results.append(result)
183
223
  return results