invarlock 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +1 -1
  4. invarlock/calibration/spectral_null.py +15 -10
  5. invarlock/calibration/variance_ve.py +0 -2
  6. invarlock/cli/commands/calibrate.py +6 -2
  7. invarlock/cli/commands/certify.py +58 -39
  8. invarlock/cli/commands/doctor.py +3 -1
  9. invarlock/cli/commands/explain_gates.py +57 -8
  10. invarlock/cli/commands/report.py +1 -1
  11. invarlock/cli/commands/run.py +159 -61
  12. invarlock/cli/commands/verify.py +78 -4
  13. invarlock/cli/config.py +21 -5
  14. invarlock/core/api.py +45 -5
  15. invarlock/core/auto_tuning.py +65 -20
  16. invarlock/core/contracts.py +7 -1
  17. invarlock/core/registry.py +2 -2
  18. invarlock/core/runner.py +314 -50
  19. invarlock/eval/bench.py +0 -13
  20. invarlock/eval/data.py +14 -28
  21. invarlock/eval/metrics.py +4 -1
  22. invarlock/eval/primary_metric.py +23 -0
  23. invarlock/eval/tail_stats.py +230 -0
  24. invarlock/guards/_estimators.py +154 -0
  25. invarlock/guards/policies.py +16 -6
  26. invarlock/guards/rmt.py +625 -544
  27. invarlock/guards/spectral.py +348 -110
  28. invarlock/guards/tier_config.py +32 -30
  29. invarlock/guards/variance.py +5 -29
  30. invarlock/guards_ref/rmt_ref.py +23 -23
  31. invarlock/model_profile.py +42 -15
  32. invarlock/reporting/certificate.py +225 -46
  33. invarlock/reporting/certificate_schema.py +2 -1
  34. invarlock/reporting/dataset_hashing.py +15 -2
  35. invarlock/reporting/guards_analysis.py +197 -274
  36. invarlock/reporting/normalizer.py +6 -0
  37. invarlock/reporting/policy_utils.py +38 -36
  38. invarlock/reporting/primary_metric_utils.py +71 -17
  39. invarlock/reporting/render.py +61 -0
  40. invarlock/reporting/report.py +1 -1
  41. invarlock/reporting/report_types.py +5 -2
  42. invarlock/reporting/validate.py +1 -18
  43. {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
  44. {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
  45. {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
  46. {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
  47. {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
  48. {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
@@ -108,6 +108,64 @@ def _coerce_mapping(obj: object) -> dict[str, Any]:
108
108
  return {}
109
109
 
110
110
 
111
+ def _prune_none_values(value: Any) -> Any:
112
+ """Recursively drop keys/items whose value is None.
113
+
114
+ Used when serializing dataclass-style config sections that define many optional
115
+ fields defaulting to None; those should behave as "unset" rather than explicit
116
+ policy overrides.
117
+ """
118
+
119
+ if isinstance(value, dict):
120
+ return {
121
+ key: _prune_none_values(val)
122
+ for key, val in value.items()
123
+ if val is not None
124
+ }
125
+ if isinstance(value, list):
126
+ return [_prune_none_values(item) for item in value if item is not None]
127
+ if isinstance(value, tuple):
128
+ return tuple(_prune_none_values(item) for item in value if item is not None)
129
+ return value
130
+
131
+
132
+ def _to_serialisable_dict(section: object) -> dict[str, Any]:
133
+ """Coerce config fragments to plain dicts.
134
+
135
+ Handles InvarLockConfig sections (which wrap dicts in a private `_Obj` with
136
+ `_data`) so downstream components (core.runner) see canonical mappings,
137
+ e.g. `eval.bootstrap.replicates`.
138
+ """
139
+
140
+ # Prefer native dump methods
141
+ if hasattr(section, "model_dump"):
142
+ return section.model_dump() # type: ignore[return-value]
143
+ if hasattr(section, "dict"):
144
+ try:
145
+ return section.dict() # type: ignore[return-value]
146
+ except Exception:
147
+ pass
148
+ # Unwrap CLI _Obj wrapper used by InvarLockConfig for attribute access
149
+ try:
150
+ raw = getattr(section, "_data", None)
151
+ if isinstance(raw, dict):
152
+ return raw
153
+ except Exception:
154
+ pass
155
+ # Already a mapping
156
+ if isinstance(section, dict):
157
+ return section
158
+ # Best-effort attribute dump (prune None so "unset" does not override tier defaults)
159
+ try:
160
+ data = vars(section)
161
+ # Common case: {'_data': {...}}
162
+ if isinstance(data, dict) and isinstance(data.get("_data"), dict):
163
+ return data["_data"]
164
+ return _prune_none_values(data) # type: ignore[return-value]
165
+ except TypeError:
166
+ return {}
167
+
168
+
111
169
  def _resolve_pm_acceptance_range(
112
170
  cfg: InvarLockConfig | dict[str, Any] | None,
113
171
  ) -> dict[str, float]:
@@ -296,7 +354,7 @@ def _resolve_exit_code(exc: Exception, *, profile: str | None) -> int:
296
354
  return 1
297
355
 
298
356
 
299
- ## NOTE: Deprecated legacy helper `_check_pairability_or_abort` was removed.
357
+ ## NOTE: Deprecated helper `_check_pairability_or_abort` was removed.
300
358
  ## Provider parity and pairing guarantees are enforced via guard digests and
301
359
  ## invariant checks during run execution.
302
360
 
@@ -700,7 +758,7 @@ def _prepare_config_for_run(
700
758
  cfg = _load_config(config_path)
701
759
 
702
760
  # Apply profile if specified (dev is a no-op)
703
- if profile and str(profile).lower() in {"ci", "release"}:
761
+ if profile and str(profile).lower() not in {"dev"}:
704
762
  console.print(f"🎯 Applying profile: {profile}")
705
763
  try:
706
764
  cfg = _apply_profile(cfg, profile)
@@ -814,17 +872,14 @@ def _resolve_device_and_output(
814
872
  console.print(f"[red]❌ Device validation failed: {error_msg}[/red]")
815
873
  raise typer.Exit(1)
816
874
 
817
- # Determine output directory (support both 'output.dir' and legacy 'out.dir')
875
+ # Determine output directory
818
876
  if out:
819
877
  output_dir = Path(out)
820
878
  else:
821
879
  try:
822
880
  output_dir = Path(cfg.output.dir)
823
881
  except Exception:
824
- try:
825
- output_dir = Path(cfg.out.dir) # type: ignore[attr-defined]
826
- except Exception:
827
- output_dir = Path("runs")
882
+ output_dir = Path("runs")
828
883
  output_dir.mkdir(parents=True, exist_ok=True)
829
884
  return str(resolved_device), output_dir
830
885
 
@@ -1297,7 +1352,7 @@ def _validate_and_harvest_baseline_schedule(
1297
1352
  _fail_schedule(f"{label} input_ids empty at index {idx}")
1298
1353
  seqs.append(seq_ints)
1299
1354
 
1300
- # attention_masks are required for pairing, but legacy baselines may omit them.
1355
+ # attention_masks are required for pairing, but some baselines may omit them.
1301
1356
  # When absent, default to all-ones masks (cannot infer padding reliably).
1302
1357
  masks_rows: list[list[int]] = []
1303
1358
  masks_missing = masks is None or masks == []
@@ -1610,7 +1665,7 @@ def _resolve_metric_and_provider(
1610
1665
  ) -> tuple[str, str, dict[str, float]]:
1611
1666
  """Resolve metric kind, provider kind, and metric options from config with precedence.
1612
1667
 
1613
- Precedence: CLI args (not handled here) → config → ModelProfile defaults → legacy fallback.
1668
+ Precedence: CLI args (not handled here) → config → ModelProfile defaults → fallback.
1614
1669
  Primary metric (metric‑v1) is canonical in dev‑phase; no env flag toggles.
1615
1670
  """
1616
1671
  # Provider kind
@@ -1684,11 +1739,11 @@ def _resolve_metric_and_provider(
1684
1739
  else:
1685
1740
  metric_kind = None
1686
1741
 
1687
- # Fallback to model profile default or legacy resolution by loss type
1742
+ # Fallback to model profile default or loss-type mapping
1688
1743
  if not metric_kind and hasattr(model_profile, "default_metric"):
1689
1744
  metric_kind = model_profile.default_metric
1690
1745
  if not metric_kind:
1691
- # Legacy: map from loss kind
1746
+ # Map from loss kind
1692
1747
  lk = (resolved_loss_type or "causal").lower()
1693
1748
  if lk == "mlm":
1694
1749
  metric_kind = "ppl_mlm"
@@ -1832,7 +1887,9 @@ def run_command(
1832
1887
  None, "--device", help="Device override (auto|cuda|mps|cpu)"
1833
1888
  ),
1834
1889
  profile: str | None = typer.Option(
1835
- None, "--profile", help="Profile to apply (ci|release)"
1890
+ None,
1891
+ "--profile",
1892
+ help="Profile to apply (e.g. ci, release, ci_cpu; dev is a no-op)",
1836
1893
  ),
1837
1894
  out: str | None = typer.Option(None, "--out", help="Output directory override"),
1838
1895
  edit: str | None = typer.Option(None, "--edit", help="Edit kind (quant|mixed)"),
@@ -2099,11 +2156,28 @@ def run_command(
2099
2156
  if pairing_schedule:
2100
2157
  # Normalize baseline report in-memory so downstream digest/parity
2101
2158
  # computations see a consistent window_id + mask shape even for
2102
- # legacy baselines missing some fields.
2159
+ # baselines missing some fields.
2103
2160
  try:
2104
- baseline_report_data["evaluation_windows"] = (
2105
- pairing_schedule
2106
- )
2161
+ ew = baseline_report_data.get("evaluation_windows")
2162
+ if not isinstance(ew, dict):
2163
+ ew = {}
2164
+ baseline_report_data["evaluation_windows"] = ew
2165
+ # Merge the sanitized pairing schedule into existing
2166
+ # evaluation_windows without discarding logloss/token_counts.
2167
+ for arm in ("preview", "final"):
2168
+ src = (
2169
+ pairing_schedule.get(arm)
2170
+ if isinstance(pairing_schedule, dict)
2171
+ else None
2172
+ )
2173
+ if not isinstance(src, dict):
2174
+ continue
2175
+ dst = ew.get(arm)
2176
+ if not isinstance(dst, dict):
2177
+ ew[arm] = dict(src)
2178
+ continue
2179
+ for key, value in src.items():
2180
+ dst[key] = value
2107
2181
  except Exception:
2108
2182
  pass
2109
2183
  # Harvest tokenizer hash provenance from baseline when present.
@@ -2226,50 +2300,11 @@ def run_command(
2226
2300
  console.print(f"🔌 Adapter: {adapter.name}")
2227
2301
 
2228
2302
  # Create run configuration
2229
- def _to_serialisable_dict(section: object) -> dict[str, Any]:
2230
- """Coerce config fragments to plain dicts.
2231
-
2232
- Handles InvarLockConfig sections (which wrap dicts in a private `_Obj` with
2233
- `_data`) so downstream components (core.runner) see canonical mappings,
2234
- e.g. `eval.bootstrap.replicates`.
2235
- """
2236
- # Prefer native dump methods
2237
- if hasattr(section, "model_dump"):
2238
- return section.model_dump() # type: ignore[return-value]
2239
- if hasattr(section, "dict"):
2240
- try:
2241
- return section.dict() # type: ignore[return-value]
2242
- except Exception:
2243
- pass
2244
- # Unwrap CLI _Obj wrapper used by InvarLockConfig for attribute access
2245
- try:
2246
- raw = getattr(section, "_data", None)
2247
- if isinstance(raw, dict):
2248
- return raw
2249
- except Exception:
2250
- pass
2251
- # Already a mapping
2252
- if isinstance(section, dict):
2253
- return section
2254
- # Best-effort attribute dump
2255
- try:
2256
- data = vars(section)
2257
- # Common case: {'_data': {...}}
2258
- if isinstance(data, dict) and isinstance(data.get("_data"), dict):
2259
- return data["_data"]
2260
- return data # type: ignore[return-value]
2261
- except TypeError:
2262
- return {}
2263
-
2264
- def _dump_guard(section: object) -> dict[str, Any]:
2265
- data = _to_serialisable_dict(section)
2266
- return data if isinstance(data, dict) else {}
2267
-
2268
2303
  guard_overrides = {
2269
- "spectral": _dump_guard(getattr(cfg.guards, "spectral", {})),
2270
- "rmt": _dump_guard(getattr(cfg.guards, "rmt", {})),
2271
- "variance": _dump_guard(getattr(cfg.guards, "variance", {})),
2272
- "invariants": _dump_guard(getattr(cfg.guards, "invariants", {})),
2304
+ "spectral": _to_serialisable_dict(getattr(cfg.guards, "spectral", {})),
2305
+ "rmt": _to_serialisable_dict(getattr(cfg.guards, "rmt", {})),
2306
+ "variance": _to_serialisable_dict(getattr(cfg.guards, "variance", {})),
2307
+ "invariants": _to_serialisable_dict(getattr(cfg.guards, "invariants", {})),
2273
2308
  }
2274
2309
 
2275
2310
  if model_profile.invariants:
@@ -2297,6 +2332,31 @@ def run_command(
2297
2332
  "plugins": plugin_provenance,
2298
2333
  "run_id": run_id,
2299
2334
  }
2335
+ # Provide baseline per-window logloss to the CoreRunner for paired tail
2336
+ # evidence and (optionally) fail/rollback enforcement.
2337
+ try:
2338
+ if isinstance(baseline_report_data, dict):
2339
+ ew = baseline_report_data.get("evaluation_windows")
2340
+ if isinstance(ew, dict):
2341
+ final = ew.get("final")
2342
+ if (
2343
+ isinstance(final, dict)
2344
+ and isinstance(final.get("window_ids"), list)
2345
+ and isinstance(final.get("logloss"), list)
2346
+ ):
2347
+ base_eval: dict[str, Any] = {
2348
+ "final": {
2349
+ "window_ids": list(final.get("window_ids") or []),
2350
+ "logloss": list(final.get("logloss") or []),
2351
+ }
2352
+ }
2353
+ if isinstance(final.get("token_counts"), list):
2354
+ base_eval["final"]["token_counts"] = list(
2355
+ final.get("token_counts") or []
2356
+ )
2357
+ run_context["baseline_eval_windows"] = base_eval
2358
+ except Exception:
2359
+ pass
2300
2360
  run_context.setdefault("primary_metric", {})["acceptance_range"] = (
2301
2361
  pm_acceptance_range
2302
2362
  )
@@ -3461,6 +3521,16 @@ def run_command(
3461
3521
  # Convert CoreRunner report to evaluation report
3462
3522
  report = create_empty_report()
3463
3523
 
3524
+ # Persist minimal run context for certificate/report provenance.
3525
+ try:
3526
+ report["context"] = {
3527
+ "profile": profile_normalized,
3528
+ "auto": dict(auto_config),
3529
+ "assurance": dict(run_context.get("assurance") or {}),
3530
+ }
3531
+ except Exception:
3532
+ pass
3533
+
3464
3534
  # Code provenance: commit hash and InvarLock version
3465
3535
  commit_value = (
3466
3536
  getattr(cfg.meta, "commit", "") if hasattr(cfg, "meta") else ""
@@ -3696,6 +3766,7 @@ def run_command(
3696
3766
  "window_pairing_final",
3697
3767
  "paired_windows",
3698
3768
  "paired_delta_summary",
3769
+ "primary_metric_tail",
3699
3770
  "preview_total_tokens",
3700
3771
  "final_total_tokens",
3701
3772
  "masked_tokens_total",
@@ -4313,6 +4384,12 @@ def run_command(
4313
4384
  pm = compute_primary_metric_from_report(
4314
4385
  report, kind=metric_kind_resolved, baseline=baseline_report_data
4315
4386
  )
4387
+ core_primary_metric = None
4388
+ if hasattr(core_report, "metrics") and isinstance(
4389
+ core_report.metrics, dict
4390
+ ):
4391
+ core_primary_metric = core_report.metrics.get("primary_metric")
4392
+ pm = _merge_primary_metric_health(pm, core_primary_metric)
4316
4393
  report.setdefault("metrics", {})["primary_metric"] = pm
4317
4394
  # Attach configured reps/ci_level when provided
4318
4395
  if metric_opts:
@@ -4327,7 +4404,7 @@ def run_command(
4327
4404
  ) # type: ignore[index]
4328
4405
  except Exception:
4329
4406
  pass
4330
- # Shadow parity check against legacy ppl fields (best-effort)
4407
+ # Shadow parity check against ppl_* fields (best-effort)
4331
4408
  try:
4332
4409
  pm_blk = report.get("metrics", {}).get("primary_metric", {})
4333
4410
  ppl_final_v1 = float(pm_blk.get("final"))
@@ -4626,12 +4703,33 @@ def run_command(
4626
4703
  pass
4627
4704
 
4628
4705
 
4706
+ def _merge_primary_metric_health(
4707
+ primary_metric: dict[str, Any] | None,
4708
+ core_primary_metric: dict[str, Any] | None,
4709
+ ) -> dict[str, Any]:
4710
+ if not isinstance(primary_metric, dict):
4711
+ return {}
4712
+ merged = dict(primary_metric)
4713
+ if not isinstance(core_primary_metric, dict):
4714
+ return merged
4715
+ if core_primary_metric.get("invalid") is True:
4716
+ merged["invalid"] = True
4717
+ merged["degraded"] = True
4718
+ if core_primary_metric.get("degraded") is True:
4719
+ merged["degraded"] = True
4720
+ core_reason = core_primary_metric.get("degraded_reason")
4721
+ if isinstance(core_reason, str) and core_reason:
4722
+ merged["degraded_reason"] = core_reason
4723
+ merged["degraded"] = True
4724
+ return merged
4725
+
4726
+
4629
4727
  def _format_debug_metric_diffs(
4630
4728
  pm: dict[str, float] | None,
4631
4729
  metrics: dict[str, float] | None,
4632
4730
  baseline_report_data: dict | None,
4633
4731
  ) -> str:
4634
- """Build a compact DEBUG_METRIC_DIFFS line comparing current snapshot vs legacy ppl_*.
4732
+ """Build a compact DEBUG_METRIC_DIFFS line comparing current snapshot vs ppl_*.
4635
4733
 
4636
4734
  Returns a semicolon-separated string of deltas like
4637
4735
  "final: v1-v1 = +0.000000000; Δlog(final): +0.000000000; ...". Safe to call with
@@ -10,6 +10,7 @@ overlap=0.0).
10
10
 
11
11
  from __future__ import annotations
12
12
 
13
+ import hashlib
13
14
  import json
14
15
  import math
15
16
  from pathlib import Path
@@ -220,9 +221,8 @@ def _validate_drift_band(certificate: dict[str, Any]) -> list[str]:
220
221
  def _validate_tokenizer_hash(certificate: dict[str, Any]) -> list[str]:
221
222
  """Validate tokenizer hash consistency between baseline and edited runs.
222
223
 
223
- The check is enforced only when both hashes are present to preserve
224
- compatibility with legacy certificates. When present and different,
225
- the verification fails.
224
+ The check is enforced only when both hashes are present. When present and
225
+ different, the verification fails.
226
226
  """
227
227
  errors: list[str] = []
228
228
  meta = certificate.get("meta", {}) or {}
@@ -244,7 +244,7 @@ def _validate_tokenizer_hash(certificate: dict[str, Any]) -> list[str]:
244
244
  if isinstance(edited_hash, str) and isinstance(baseline_hash, str):
245
245
  if edited_hash and baseline_hash and edited_hash != baseline_hash:
246
246
  errors.append("Tokenizer hash mismatch between baseline and edited runs.")
247
- # If either hash is missing, skip the check for backward compatibility
247
+ # If either hash is missing, skip the check
248
248
  return errors
249
249
 
250
250
 
@@ -259,6 +259,74 @@ def _resolve_path(payload: Any, path: str) -> Any:
259
259
  return current
260
260
 
261
261
 
262
+ def _measurement_contract_digest(contract: Any) -> str | None:
263
+ if not isinstance(contract, dict) or not contract:
264
+ return None
265
+ try:
266
+ canonical = json.dumps(contract, sort_keys=True, default=str)
267
+ except Exception:
268
+ return None
269
+ return hashlib.sha256(canonical.encode()).hexdigest()[:16]
270
+
271
+
272
+ def _validate_measurement_contracts(
273
+ certificate: dict[str, Any], *, profile: str
274
+ ) -> list[str]:
275
+ """Enforce measurement-contract presence and baseline pairing for guards."""
276
+ errors: list[str] = []
277
+ prof = (profile or "").strip().lower()
278
+ resolved_policy = certificate.get("resolved_policy") or {}
279
+
280
+ for guard_key in ("spectral", "rmt"):
281
+ block = certificate.get(guard_key) or {}
282
+ if not isinstance(block, dict):
283
+ continue
284
+ evaluated = bool(block.get("evaluated", True))
285
+ if not evaluated:
286
+ continue
287
+
288
+ mc = block.get("measurement_contract")
289
+ mc_hash = _measurement_contract_digest(mc)
290
+ expected_hash = block.get("measurement_contract_hash")
291
+ if not isinstance(mc, dict) or not mc:
292
+ errors.append(f"Certificate missing {guard_key}.measurement_contract.")
293
+ elif isinstance(expected_hash, str) and expected_hash:
294
+ if mc_hash and mc_hash != expected_hash:
295
+ errors.append(
296
+ f"{guard_key}.measurement_contract_hash mismatch: expected={expected_hash}, computed={mc_hash}."
297
+ )
298
+ else:
299
+ errors.append(f"Certificate missing {guard_key}.measurement_contract_hash.")
300
+
301
+ rp_guard = (
302
+ resolved_policy.get(guard_key)
303
+ if isinstance(resolved_policy, dict)
304
+ else None
305
+ )
306
+ rp_mc = (
307
+ rp_guard.get("measurement_contract") if isinstance(rp_guard, dict) else None
308
+ )
309
+ rp_hash = _measurement_contract_digest(rp_mc)
310
+ if not isinstance(rp_mc, dict) or not rp_mc:
311
+ errors.append(
312
+ f"Certificate missing resolved_policy.{guard_key}.measurement_contract."
313
+ )
314
+ elif mc_hash and rp_hash and mc_hash != rp_hash:
315
+ errors.append(
316
+ f"{guard_key} measurement_contract differs between analysis and resolved_policy "
317
+ f"(analysis={mc_hash}, resolved_policy={rp_hash})."
318
+ )
319
+
320
+ if prof in {"ci", "release"}:
321
+ match = block.get("measurement_contract_match")
322
+ if match is not True:
323
+ errors.append(
324
+ f"{guard_key} measurement contract must match baseline for {prof} profile."
325
+ )
326
+
327
+ return errors
328
+
329
+
262
330
  def _apply_profile_lints(certificate: dict[str, Any]) -> list[str]:
263
331
  """Apply model-profile specific lint rules embedded in the certificate."""
264
332
  errors: list[str] = []
@@ -343,6 +411,12 @@ def _validate_certificate_payload(
343
411
  errors.extend(_validate_drift_band(certificate))
344
412
  errors.extend(_apply_profile_lints(certificate))
345
413
  errors.extend(_validate_tokenizer_hash(certificate))
414
+ if prof in {"ci", "release"}:
415
+ errors.extend(_validate_measurement_contracts(certificate, profile=prof))
416
+
417
+ # strict/fast assurance mode checks were removed; verification gates rely on
418
+ # structural schema + guard metric contracts instead.
419
+
346
420
  # Release-only enforcement: guard overhead must be measured or explicitly skipped.
347
421
  if prof == "release":
348
422
  go = certificate.get("guard_overhead")
invarlock/cli/config.py CHANGED
@@ -131,14 +131,9 @@ class EvalBootstrapConfig:
131
131
  @dataclass
132
132
  class SpectralGuardConfig:
133
133
  sigma_quantile: float | None = None
134
- contraction: float | None = None
135
134
  family_caps: dict[str, Any] = field(default_factory=dict)
136
135
 
137
136
  def __post_init__(self) -> None:
138
- # contraction is an alias for sigma_quantile
139
- if self.contraction is not None and self.sigma_quantile is None:
140
- self.sigma_quantile = float(self.contraction)
141
- self.contraction = None
142
137
  # normalize family_caps: scalar → {"kappa": value}
143
138
  caps = {}
144
139
  for k, v in (self.family_caps or {}).items():
@@ -244,6 +239,27 @@ def load_config(path: str | Path) -> InvarLockConfig:
244
239
  raise ValueError("defaults must be a mapping when present")
245
240
  if isinstance(defaults, dict):
246
241
  raw = _deep_merge(defaults, raw)
242
+
243
+ # "assurance" (strict/fast) was removed in the GPU/MPS-first measurement-contract
244
+ # world. Fail closed so outdated configs are updated explicitly.
245
+ if raw.get("assurance") is not None:
246
+ raise ValueError(
247
+ "assurance.* is deprecated; configure measurement contracts under guards.* "
248
+ "(e.g., guards.spectral.estimator, guards.rmt.activation.sampling)."
249
+ )
250
+
251
+ # Per-guard strict/fast mode overrides were also removed. Fail closed to avoid
252
+ # silently accepting configs that no longer apply.
253
+ guards_block = raw.get("guards")
254
+ if isinstance(guards_block, dict):
255
+ for guard_name in ("spectral", "rmt"):
256
+ node = guards_block.get(guard_name)
257
+ if isinstance(node, dict) and "mode" in node:
258
+ raise ValueError(
259
+ f"guards.{guard_name}.mode is deprecated; remove it and configure "
260
+ "measurement-contract knobs under guard policy fields instead."
261
+ )
262
+
247
263
  # Coerce known guard configs for friendlier attribute access
248
264
  guards = raw.get("guards")
249
265
  if isinstance(guards, dict):
invarlock/core/api.py CHANGED
@@ -17,7 +17,7 @@ from __future__ import annotations
17
17
  from abc import ABC, abstractmethod
18
18
  from dataclasses import dataclass, field
19
19
  from pathlib import Path
20
- from typing import Any
20
+ from typing import Any, Protocol, runtime_checkable
21
21
 
22
22
 
23
23
  class ModelAdapter(ABC):
@@ -88,6 +88,15 @@ class ModelEdit(ABC):
88
88
  pass
89
89
 
90
90
 
91
+ @runtime_checkable
92
+ class EditLike(Protocol):
93
+ name: str
94
+
95
+ def can_edit(self, model_desc: dict[str, Any]) -> bool: ...
96
+
97
+ def apply(self, model: Any, adapter: ModelAdapter, **kwargs) -> dict[str, Any]: ...
98
+
99
+
91
100
  class Guard(ABC):
92
101
  """
93
102
  Abstract interface for safety guards.
@@ -116,6 +125,37 @@ class Guard(ABC):
116
125
  pass
117
126
 
118
127
 
128
+ @runtime_checkable
129
+ class GuardWithContext(Protocol):
130
+ def set_run_context(self, report: Any) -> None: ...
131
+
132
+
133
+ @runtime_checkable
134
+ class GuardWithPrepare(Protocol):
135
+ def prepare(
136
+ self,
137
+ model: Any,
138
+ adapter: ModelAdapter,
139
+ calib: Any,
140
+ policy_config: dict[str, Any],
141
+ ) -> dict[str, Any]: ...
142
+
143
+
144
+ @runtime_checkable
145
+ class GuardWithBeforeEdit(Protocol):
146
+ def before_edit(self, model: Any) -> Any: ...
147
+
148
+
149
+ @runtime_checkable
150
+ class GuardWithAfterEdit(Protocol):
151
+ def after_edit(self, model: Any) -> Any: ...
152
+
153
+
154
+ @runtime_checkable
155
+ class GuardWithFinalize(Protocol):
156
+ def finalize(self, model: Any) -> Any: ...
157
+
158
+
119
159
  class GuardChain:
120
160
  """
121
161
  Manages a chain of guards with policy-based execution.
@@ -145,7 +185,7 @@ class GuardChain:
145
185
  """Prepare all guards."""
146
186
  results = {}
147
187
  for guard in self.guards:
148
- if hasattr(guard, "prepare"):
188
+ if isinstance(guard, GuardWithPrepare):
149
189
  results[guard.name] = guard.prepare(
150
190
  model, adapter, calib, policy_config
151
191
  )
@@ -157,7 +197,7 @@ class GuardChain:
157
197
  """Execute before_edit on all guards."""
158
198
  results = []
159
199
  for guard in self.guards:
160
- if hasattr(guard, "before_edit"):
200
+ if isinstance(guard, GuardWithBeforeEdit):
161
201
  result = guard.before_edit(model)
162
202
  if result is not None:
163
203
  results.append(result)
@@ -167,7 +207,7 @@ class GuardChain:
167
207
  """Execute after_edit on all guards."""
168
208
  results = []
169
209
  for guard in self.guards:
170
- if hasattr(guard, "after_edit"):
210
+ if isinstance(guard, GuardWithAfterEdit):
171
211
  result = guard.after_edit(model)
172
212
  if result is not None:
173
213
  results.append(result)
@@ -177,7 +217,7 @@ class GuardChain:
177
217
  """Finalize all guards and return outcomes."""
178
218
  results = []
179
219
  for guard in self.guards:
180
- if hasattr(guard, "finalize"):
220
+ if isinstance(guard, GuardWithFinalize):
181
221
  result = guard.finalize(model)
182
222
  results.append(result)
183
223
  return results