invarlock 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +1 -1
  4. invarlock/calibration/spectral_null.py +15 -10
  5. invarlock/calibration/variance_ve.py +0 -2
  6. invarlock/cli/commands/calibrate.py +6 -2
  7. invarlock/cli/commands/certify.py +58 -39
  8. invarlock/cli/commands/doctor.py +3 -1
  9. invarlock/cli/commands/explain_gates.py +57 -8
  10. invarlock/cli/commands/report.py +1 -1
  11. invarlock/cli/commands/run.py +159 -61
  12. invarlock/cli/commands/verify.py +78 -4
  13. invarlock/cli/config.py +21 -5
  14. invarlock/core/api.py +45 -5
  15. invarlock/core/auto_tuning.py +65 -20
  16. invarlock/core/contracts.py +7 -1
  17. invarlock/core/registry.py +2 -2
  18. invarlock/core/runner.py +314 -50
  19. invarlock/eval/bench.py +0 -13
  20. invarlock/eval/data.py +73 -283
  21. invarlock/eval/metrics.py +134 -4
  22. invarlock/eval/primary_metric.py +23 -0
  23. invarlock/eval/tail_stats.py +230 -0
  24. invarlock/guards/_estimators.py +154 -0
  25. invarlock/guards/policies.py +16 -6
  26. invarlock/guards/rmt.py +625 -544
  27. invarlock/guards/spectral.py +348 -110
  28. invarlock/guards/tier_config.py +32 -30
  29. invarlock/guards/variance.py +5 -29
  30. invarlock/guards_ref/rmt_ref.py +23 -23
  31. invarlock/model_profile.py +42 -15
  32. invarlock/reporting/certificate.py +225 -46
  33. invarlock/reporting/certificate_schema.py +2 -1
  34. invarlock/reporting/dataset_hashing.py +15 -2
  35. invarlock/reporting/guards_analysis.py +197 -274
  36. invarlock/reporting/normalizer.py +6 -0
  37. invarlock/reporting/policy_utils.py +38 -36
  38. invarlock/reporting/primary_metric_utils.py +71 -17
  39. invarlock/reporting/render.py +61 -0
  40. invarlock/reporting/report.py +1 -1
  41. invarlock/reporting/report_types.py +5 -2
  42. invarlock/reporting/validate.py +1 -18
  43. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
  44. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
  45. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
  46. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
  47. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
  48. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/guards/rmt.py CHANGED
@@ -22,7 +22,6 @@ import torch
22
22
  import torch.linalg as tla
23
23
  import torch.nn as nn
24
24
 
25
- from invarlock.cli._evidence import maybe_dump_guard_evidence
26
25
  from invarlock.core.api import Guard
27
26
 
28
27
  from ._contracts import guard_assert
@@ -278,7 +277,9 @@ def layer_svd_stats(
278
277
  return result
279
278
 
280
279
 
281
- def capture_baseline_mp_stats(model: nn.Module) -> dict[str, dict[str, float]]:
280
+ def capture_baseline_mp_stats(
281
+ model: nn.Module, *, allowed_module_names: list[str] | None = None
282
+ ) -> dict[str, dict[str, float]]:
282
283
  """
283
284
  Capture baseline MP statistics for linear layers only.
284
285
 
@@ -321,9 +322,14 @@ def capture_baseline_mp_stats(model: nn.Module) -> dict[str, dict[str, float]]:
321
322
 
322
323
  # Define allowlist for RMT analysis - only linear layers where MP makes sense
323
324
  allowed_suffixes = [".attn.c_attn", ".attn.c_proj", ".mlp.c_fc", ".mlp.c_proj"]
325
+ allowed_set = None
326
+ if isinstance(allowed_module_names, list) and allowed_module_names:
327
+ allowed_set = {str(name).strip() for name in allowed_module_names if name}
324
328
 
325
329
  for name, module in model.named_modules():
326
330
  if isinstance(module, module_types) and hasattr(module, "weight"):
331
+ if allowed_set is not None and name not in allowed_set:
332
+ continue
327
333
  # CRITICAL: Restrict to only linear layers where MP analysis is meaningful
328
334
  # Skip embeddings, LM head, layer norms - MP heuristics don't apply there
329
335
  if any(name.endswith(suffix) for suffix in allowed_suffixes):
@@ -412,6 +418,7 @@ def rmt_detect(
412
418
  correction_factor: float | None = None,
413
419
  layer_indices: list[int] | None = None,
414
420
  target_layers: list[str] | None = None, # Alternative layer specification
421
+ allowed_module_names: list[str] | None = None, # Exact module allowlist
415
422
  verbose: bool = False,
416
423
  max_iterations: int = 2, # Add iteration guard
417
424
  baseline_sigmas: dict[str, float]
@@ -431,6 +438,7 @@ def rmt_detect(
431
438
  correction_factor: Factor to apply for correction (if not detect_only)
432
439
  layer_indices: Specific layers to analyze by index (None = all)
433
440
  target_layers: Specific layers to analyze by name (None = all)
441
+ allowed_module_names: Exact module names to analyze (None = derived default scope)
434
442
  verbose: Whether to print warnings and details
435
443
  max_iterations: Maximum iterations for correction (default 2)
436
444
  baseline_sigmas: Baseline sigmas for baseline-aware checking
@@ -473,9 +481,14 @@ def rmt_detect(
473
481
  else:
474
482
  # CRITICAL: Only analyze modules where MP analysis makes sense
475
483
  # Exclude embeddings, LM head, layer norms - they have different spectral properties
484
+ allowed_set = None
485
+ if isinstance(allowed_module_names, list) and allowed_module_names:
486
+ allowed_set = {str(name).strip() for name in allowed_module_names if name}
476
487
  for name, module in model.named_modules():
477
488
  # Check if this is an allowed module type with 2D weights
478
489
  if any(name.endswith(suffix) for suffix in allowed_suffixes):
490
+ if allowed_set is not None and name not in allowed_set:
491
+ continue
479
492
  has_2d_weights = any(
480
493
  param.ndim == 2 and "weight" in param_name
481
494
  for param_name, param in module.named_parameters(recurse=False)
@@ -671,7 +684,6 @@ def rmt_detect(
671
684
  return {
672
685
  "has_outliers": has_outliers,
673
686
  "n_layers_flagged": n_outliers,
674
- "outlier_count": n_outliers, # Alias for compatibility
675
687
  "max_ratio": max_ratio,
676
688
  "threshold": threshold,
677
689
  "correction_iterations": correction_iterations,
@@ -702,8 +714,6 @@ def rmt_detect_report(
702
714
  "has_outliers": result["has_outliers"],
703
715
  "n_layers_flagged": result["n_layers_flagged"],
704
716
  "max_ratio": result["max_ratio"],
705
- "rmt_max_ratio": result["max_ratio"], # Alias for compatibility
706
- "rmt_has_outliers": result["has_outliers"], # Alias
707
717
  }
708
718
 
709
719
  return summary, result["per_layer"]
@@ -819,7 +829,6 @@ def rmt_detect_with_names(
819
829
  return {
820
830
  "has_outliers": has_outliers,
821
831
  "n_layers_flagged": n_outliers,
822
- "outlier_count": n_outliers,
823
832
  "max_ratio": max_ratio,
824
833
  "threshold": threshold,
825
834
  "per_layer": per_layer,
@@ -1104,14 +1113,18 @@ class RMTPolicy:
1104
1113
  correct: bool = True # Enable automatic correction
1105
1114
 
1106
1115
 
1107
- class RMTPolicyDict(TypedDict):
1108
- """TypedDict version of RMTPolicy for compatibility."""
1116
+ class RMTPolicyDict(TypedDict, total=False):
1117
+ """TypedDict version of the RMT guard policy."""
1109
1118
 
1110
1119
  q: float | Literal["auto"]
1111
1120
  deadband: float
1112
1121
  margin: float
1113
1122
  correct: bool
1114
- epsilon: float | dict[str, float] | None
1123
+ epsilon_default: float
1124
+ epsilon_by_family: dict[str, float]
1125
+ activation_required: bool
1126
+ estimator: dict[str, Any]
1127
+ activation: dict[str, Any]
1115
1128
 
1116
1129
 
1117
1130
  class RMTGuard(Guard):
@@ -1144,7 +1157,9 @@ class RMTGuard(Guard):
1144
1157
  deadband: float = 0.10,
1145
1158
  margin: float = 1.5,
1146
1159
  correct: bool = True,
1147
- epsilon: float | dict[str, float] | None = None,
1160
+ *,
1161
+ epsilon_default: float = 0.10,
1162
+ epsilon_by_family: dict[str, float] | None = None,
1148
1163
  ):
1149
1164
  """
1150
1165
  Initialize RMT Guard.
@@ -1159,13 +1174,23 @@ class RMTGuard(Guard):
1159
1174
  self.deadband = deadband
1160
1175
  self.margin = margin
1161
1176
  self.correct = correct
1162
- self.epsilon_default = 0.10
1177
+ self.epsilon_default = float(epsilon_default)
1163
1178
  self.epsilon_by_family: dict[str, float] = {}
1164
- self._set_epsilon(epsilon)
1179
+ self._set_epsilon_by_family(epsilon_by_family)
1165
1180
  for family_key in ("attn", "ffn", "embed", "other"):
1166
1181
  self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
1167
1182
 
1168
- # Internal state
1183
+ # Measurement contract knobs (vNext)
1184
+ self.estimator: dict[str, Any] = {
1185
+ "type": "power_iter",
1186
+ "iters": 3,
1187
+ "init": "ones",
1188
+ }
1189
+ self.activation_sampling: dict[str, Any] = {
1190
+ "windows": {"count": 8, "indices_policy": "evenly_spaced"}
1191
+ }
1192
+
1193
+ # Internal state (activation edge-risk scoring)
1169
1194
  self._calibration_batches: list[Any] = []
1170
1195
  self._activation_ready = False
1171
1196
  self._require_activation = False
@@ -1173,8 +1198,6 @@ class RMTGuard(Guard):
1173
1198
  self._activation_required_reason: str | None = None
1174
1199
  self._run_profile: str | None = None
1175
1200
  self._run_tier: str | None = None
1176
- self.baseline_mp_stats: dict[str, dict[str, float]] | None = None
1177
- self.baseline_sigmas: dict[str, float] | None = None
1178
1201
  self.prepared = False
1179
1202
  self.events: list[dict[str, Any]] = []
1180
1203
  self._last_result: dict[str, Any] | None = None
@@ -1187,10 +1210,10 @@ class RMTGuard(Guard):
1187
1210
  ".mlp.c_fc",
1188
1211
  ".mlp.c_proj",
1189
1212
  ]
1190
- self.baseline_outliers_per_family: dict[str, int] = {}
1191
- self.baseline_total_outliers: int = 0
1192
- self.outliers_per_family: dict[str, int] = {}
1193
- self.outliers_total: int = 0
1213
+ self.baseline_edge_risk_by_family: dict[str, float] = {}
1214
+ self.baseline_edge_risk_by_module: dict[str, float] = {}
1215
+ self.edge_risk_by_family: dict[str, float] = {}
1216
+ self.edge_risk_by_module: dict[str, float] = {}
1194
1217
  self.epsilon_violations: list[dict[str, Any]] = []
1195
1218
 
1196
1219
  def _log_event(
@@ -1219,45 +1242,47 @@ class RMTGuard(Guard):
1219
1242
  tier = str(auto.get("tier", tier) or tier).strip().lower()
1220
1243
  self._run_profile = profile or None
1221
1244
  self._run_tier = tier or None
1222
- self._require_activation = bool(
1223
- profile in {"ci", "release"} and tier in {"balanced", "conservative"}
1224
- )
1245
+ self._require_activation = bool(profile in {"ci", "release"})
1225
1246
 
1226
- def _set_epsilon(self, epsilon: float | dict[str, float] | None) -> None:
1227
- """Configure epsilon defaults and per-family overrides."""
1228
- if isinstance(epsilon, dict):
1229
- mapped: dict[str, float] = {}
1230
- for family, value in epsilon.items():
1231
- try:
1232
- mapped[str(family)] = float(value)
1233
- except (TypeError, ValueError):
1234
- continue
1235
- if mapped:
1236
- self.epsilon_by_family.update(mapped)
1237
- self.epsilon_default = max(mapped.values())
1238
- elif isinstance(epsilon, int | float):
1239
- self.epsilon_default = float(epsilon)
1240
- if self.epsilon_by_family:
1241
- for family in list(self.epsilon_by_family):
1242
- self.epsilon_by_family[family] = self.epsilon_default
1247
+ def _set_epsilon_default(self, epsilon: Any) -> None:
1248
+ """Set the default ε used when a family value is missing."""
1249
+ if epsilon is None:
1250
+ return
1251
+ try:
1252
+ eps = float(epsilon)
1253
+ except (TypeError, ValueError):
1254
+ return
1255
+ if eps >= 0.0 and math.isfinite(eps):
1256
+ self.epsilon_default = eps
1257
+
1258
+ def _set_epsilon_by_family(self, epsilon: Any) -> None:
1259
+ """Set per-family ε values."""
1260
+ if not isinstance(epsilon, dict):
1261
+ return
1262
+ for family, value in epsilon.items():
1263
+ try:
1264
+ eps = float(value)
1265
+ except (TypeError, ValueError):
1266
+ continue
1267
+ if eps >= 0.0 and math.isfinite(eps):
1268
+ self.epsilon_by_family[str(family)] = eps
1243
1269
 
1244
1270
  @staticmethod
1245
1271
  def _classify_family(module_name: str) -> str:
1246
- """Classify module name into a guard family."""
1272
+ """Classify module name into a guard family (vNext: {attn, ffn, embed, other})."""
1247
1273
  lower = module_name.lower()
1248
- # MoE
1274
+ if any(tok in lower for tok in ("attn", "attention", "self_attn")):
1275
+ return "attn"
1249
1276
  if any(
1250
1277
  tok in lower
1251
1278
  for tok in ("router", "routing", "gate", "gating", "dispatch", "switch")
1252
1279
  ):
1253
- return "router"
1280
+ return "ffn"
1254
1281
  if any(
1255
1282
  tok in lower for tok in ("experts", "expert", "moe", "mixture_of_experts")
1256
1283
  ):
1257
- return "expert_ffn"
1258
- if ".attn." in lower or "attention" in lower:
1259
- return "attn"
1260
- if ".mlp." in lower or "ffn" in lower or ".c_fc" in lower:
1284
+ return "ffn"
1285
+ if any(tok in lower for tok in ("mlp", "ffn", "c_fc", "feed_forward")):
1261
1286
  return "ffn"
1262
1287
  if "embed" in lower or "wte" in lower or "wpe" in lower:
1263
1288
  return "embed"
@@ -1287,24 +1312,28 @@ class RMTGuard(Guard):
1287
1312
  return counts
1288
1313
 
1289
1314
  def _compute_epsilon_violations(self) -> list[dict[str, Any]]:
1290
- """Compute epsilon-rule violations per family."""
1315
+ """Compute ε-band violations per family on activation edge-risk scores."""
1291
1316
  violations: list[dict[str, Any]] = []
1292
- families = set(self.outliers_per_family) | set(
1293
- self.baseline_outliers_per_family
1317
+ families = set(self.edge_risk_by_family) | set(
1318
+ self.baseline_edge_risk_by_family
1294
1319
  )
1295
1320
  for family in families:
1296
- bare = int(self.baseline_outliers_per_family.get(family, 0) or 0)
1297
- guarded = int(self.outliers_per_family.get(family, 0) or 0)
1321
+ base = float(self.baseline_edge_risk_by_family.get(family, 0.0) or 0.0)
1322
+ cur = float(self.edge_risk_by_family.get(family, 0.0) or 0.0)
1323
+ if base <= 0.0:
1324
+ continue
1298
1325
  epsilon_val = float(
1299
1326
  self.epsilon_by_family.get(family, self.epsilon_default)
1300
1327
  )
1301
- allowed = math.ceil(bare * (1 + epsilon_val))
1302
- if guarded > allowed:
1328
+ allowed = (1.0 + epsilon_val) * base
1329
+ if cur > allowed:
1330
+ delta = (cur / base) - 1.0 if base > 0 else float("inf")
1303
1331
  violations.append(
1304
1332
  {
1305
1333
  "family": family,
1306
- "bare": bare,
1307
- "guarded": guarded,
1334
+ "edge_base": base,
1335
+ "edge_cur": cur,
1336
+ "delta": float(delta),
1308
1337
  "allowed": allowed,
1309
1338
  "epsilon": epsilon_val,
1310
1339
  }
@@ -1321,8 +1350,6 @@ class RMTGuard(Guard):
1321
1350
  Returns:
1322
1351
  List of (name, module) tuples for linear layers in scope
1323
1352
  """
1324
- modules = []
1325
-
1326
1353
  # Get module types
1327
1354
  try:
1328
1355
  from transformers.pytorch_utils import Conv1D
@@ -1338,20 +1365,61 @@ class RMTGuard(Guard):
1338
1365
  )
1339
1366
  module_types = module_types_without_conv1d_2
1340
1367
 
1341
- modules: list[tuple[str, nn.Module]] = []
1368
+ candidates: list[tuple[str, nn.Module]] = []
1342
1369
  for name, module in model.named_modules():
1343
- if isinstance(module, module_types) and hasattr(module, "weight"):
1344
- # Strict scope enforcement - only allowed linear layers
1345
- if any(name.endswith(suffix) for suffix in self.allowed_suffixes):
1346
- modules.append((name, module))
1347
-
1348
- return modules
1370
+ if not (isinstance(module, module_types) and hasattr(module, "weight")):
1371
+ continue
1372
+ # Strict scope enforcement - only allowed linear layers
1373
+ if any(name.endswith(suffix) for suffix in self.allowed_suffixes):
1374
+ candidates.append((name, module))
1375
+ candidates.sort(key=lambda t: t[0])
1376
+ return candidates
1349
1377
 
1350
1378
  def _collect_calibration_batches(self, calib: Any, max_windows: int) -> list[Any]:
1351
1379
  """Collect a deterministic slice of calibration batches."""
1352
1380
  if calib is None or max_windows <= 0:
1353
1381
  return []
1354
1382
  source = getattr(calib, "dataloader", None) or calib
1383
+ # Prefer index-based selection when possible so we can support simple
1384
+ # deterministic policies (first/last/evenly_spaced) without consuming
1385
+ # the entire iterator.
1386
+ try:
1387
+ if hasattr(source, "__len__") and hasattr(source, "__getitem__"):
1388
+ n = int(len(source)) # type: ignore[arg-type]
1389
+ if n <= 0:
1390
+ return []
1391
+ count = min(int(max_windows), n)
1392
+ policy = (
1393
+ (self.activation_sampling.get("windows") or {}).get(
1394
+ "indices_policy", "evenly_spaced"
1395
+ )
1396
+ if isinstance(self.activation_sampling, dict)
1397
+ else "evenly_spaced"
1398
+ )
1399
+ policy = str(policy or "evenly_spaced").strip().lower()
1400
+ if policy == "last":
1401
+ idxs = list(range(max(0, n - count), n))
1402
+ elif policy == "evenly_spaced":
1403
+ if count <= 1:
1404
+ idxs = [0]
1405
+ else:
1406
+ idxs = [
1407
+ int(round(i * (n - 1) / float(count - 1)))
1408
+ for i in range(count)
1409
+ ]
1410
+ else:
1411
+ idxs = list(range(count))
1412
+ batches: list[Any] = []
1413
+ for idx in idxs:
1414
+ try:
1415
+ batches.append(source[idx]) # type: ignore[index]
1416
+ except Exception:
1417
+ continue
1418
+ return batches
1419
+ except Exception:
1420
+ pass
1421
+
1422
+ # Iterable fallback: first-N only.
1355
1423
  try:
1356
1424
  iterator = iter(source)
1357
1425
  except TypeError:
@@ -1459,8 +1527,210 @@ class RMTGuard(Guard):
1459
1527
  ):
1460
1528
  modules.append((name, module))
1461
1529
 
1530
+ modules.sort(key=lambda t: t[0])
1462
1531
  return modules
1463
1532
 
1533
+ def _activation_edge_risk(
1534
+ self, activations: Any
1535
+ ) -> tuple[float, float, float] | None:
1536
+ """Compute activation edge-risk score r = σ̂max(A') / σ_MP(m,n).
1537
+
1538
+ A' is a centered + standardised view of the activation matrix. The σ̂max
1539
+ estimator is matvec-based and avoids full SVD.
1540
+ """
1541
+ if isinstance(activations, tuple | list):
1542
+ activations = activations[0] if activations else None
1543
+ if not isinstance(activations, torch.Tensor):
1544
+ return None
1545
+ if activations.dim() < 2:
1546
+ return None
1547
+ if activations.dim() > 2:
1548
+ activations = activations.reshape(-1, activations.shape[-1])
1549
+ if activations.numel() == 0:
1550
+ return None
1551
+
1552
+ mat = activations.detach()
1553
+ if mat.shape[0] <= 0 or mat.shape[1] <= 0:
1554
+ return None
1555
+ if not torch.isfinite(mat).all():
1556
+ return None
1557
+
1558
+ eps = 1e-12
1559
+ try:
1560
+ mu = mat.mean(dtype=torch.float32)
1561
+ norm = torch.linalg.vector_norm(mat.reshape(-1), ord=2, dtype=torch.float32)
1562
+ mean_sq = (norm * norm) / float(mat.numel())
1563
+ var = mean_sq - (mu * mu)
1564
+ std = torch.sqrt(var.clamp_min(eps))
1565
+ except Exception:
1566
+ return None
1567
+ if not torch.isfinite(mu) or not torch.isfinite(std):
1568
+ return None
1569
+ std_val = float(std.item())
1570
+ if not math.isfinite(std_val) or std_val <= 0.0:
1571
+ return None
1572
+
1573
+ m, n = int(mat.shape[0]), int(mat.shape[1])
1574
+ mp_edge_val = mp_bulk_edge(m, n, whitened=False)
1575
+ if not (math.isfinite(mp_edge_val) and mp_edge_val > 0.0):
1576
+ return None
1577
+
1578
+ try:
1579
+ iters = int((self.estimator or {}).get("iters", 3) or 3)
1580
+ except Exception:
1581
+ iters = 3
1582
+ if iters < 1:
1583
+ iters = 1
1584
+ init = str((self.estimator or {}).get("init", "ones") or "ones").strip().lower()
1585
+ if init not in {"ones", "e0"}:
1586
+ init = "ones"
1587
+
1588
+ device = mat.device
1589
+ dtype = mat.dtype
1590
+
1591
+ with torch.no_grad():
1592
+ if init == "ones":
1593
+ v = torch.ones((n,), device=device, dtype=dtype)
1594
+ else:
1595
+ v = torch.zeros((n,), device=device, dtype=dtype)
1596
+ v[0] = 1
1597
+ v = v / torch.linalg.vector_norm(v.float()).clamp_min(eps).to(dtype)
1598
+
1599
+ mu_d = mu.to(dtype)
1600
+ inv_std_d = (1.0 / std.clamp_min(eps)).to(dtype)
1601
+ ones_n = torch.ones((n,), device=device, dtype=dtype)
1602
+
1603
+ sigma = 0.0
1604
+ for _ in range(iters):
1605
+ v_sum = torch.sum(v.float())
1606
+ u = mat @ v
1607
+ u = (u - (mu_d * v_sum.to(dtype))) * inv_std_d
1608
+ u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
1609
+ sigma_val = float(u_norm.item())
1610
+ if not math.isfinite(sigma_val):
1611
+ return None
1612
+ u = u / u_norm.to(dtype)
1613
+
1614
+ u_sum = torch.sum(u.float())
1615
+ v = mat.T @ u
1616
+ v = (v - (mu_d * u_sum.to(dtype) * ones_n)) * inv_std_d
1617
+ v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
1618
+ v = v / v_norm.to(dtype)
1619
+ sigma = sigma_val
1620
+
1621
+ risk = float(sigma) / max(float(mp_edge_val), eps)
1622
+ return float(risk), float(sigma), float(mp_edge_val)
1623
+
1624
+ def _compute_activation_edge_risk(
1625
+ self, model: nn.Module, batches: list[Any]
1626
+ ) -> dict[str, Any] | None:
1627
+ """Compute token-weighted activation edge-risk scores per module/family."""
1628
+ if not batches:
1629
+ return None
1630
+
1631
+ modules = self._get_activation_modules(model)
1632
+ if not modules:
1633
+ return None
1634
+
1635
+ acc: dict[str, dict[str, float]] = {}
1636
+ for name, _module in modules:
1637
+ acc[name] = {"weighted_sum": 0.0, "weight": 0.0, "max_risk": 0.0}
1638
+
1639
+ batch_weight_holder = {"weight": 1}
1640
+ handles: list[Any] = []
1641
+
1642
+ def _make_hook(name: str):
1643
+ def _hook(_module: nn.Module, _inputs: tuple[Any, ...], output: Any):
1644
+ out = self._activation_edge_risk(output)
1645
+ if out is None:
1646
+ return
1647
+ risk, _sigma, _edge = out
1648
+ try:
1649
+ weight = int(batch_weight_holder.get("weight", 1) or 1)
1650
+ except Exception:
1651
+ weight = 1
1652
+ row = acc.get(name)
1653
+ if row is None:
1654
+ return
1655
+ row["weighted_sum"] = float(row.get("weighted_sum", 0.0)) + float(
1656
+ risk
1657
+ ) * float(weight)
1658
+ row["weight"] = float(row.get("weight", 0.0)) + float(weight)
1659
+ row["max_risk"] = max(float(row.get("max_risk", 0.0)), float(risk))
1660
+
1661
+ return _hook
1662
+
1663
+ for name, module in modules:
1664
+ try:
1665
+ handles.append(module.register_forward_hook(_make_hook(name)))
1666
+ except Exception:
1667
+ continue
1668
+
1669
+ model_was_training = model.training
1670
+ model.eval()
1671
+ device = next(model.parameters()).device
1672
+ batches_used = 0
1673
+ token_weight_total = 0
1674
+
1675
+ try:
1676
+ with torch.no_grad():
1677
+ for batch in batches:
1678
+ inputs, attention_mask = self._prepare_activation_inputs(
1679
+ batch, device
1680
+ )
1681
+ if inputs is None:
1682
+ continue
1683
+ batch_weight = self._batch_token_weight(inputs, attention_mask)
1684
+ batch_weight_holder["weight"] = batch_weight
1685
+ try:
1686
+ if attention_mask is not None:
1687
+ model(inputs, attention_mask=attention_mask)
1688
+ else:
1689
+ model(inputs)
1690
+ except TypeError:
1691
+ model(inputs)
1692
+ batches_used += 1
1693
+ token_weight_total += batch_weight
1694
+ finally:
1695
+ for handle in handles:
1696
+ try:
1697
+ handle.remove()
1698
+ except Exception:
1699
+ pass
1700
+ if model_was_training:
1701
+ model.train()
1702
+
1703
+ if batches_used <= 0:
1704
+ return None
1705
+
1706
+ edge_risk_by_module: dict[str, float] = {}
1707
+ for name, row in acc.items():
1708
+ w = float(row.get("weight", 0.0) or 0.0)
1709
+ if w <= 0.0:
1710
+ continue
1711
+ edge_risk_by_module[name] = float(row.get("weighted_sum", 0.0) or 0.0) / w
1712
+
1713
+ if not edge_risk_by_module:
1714
+ return None
1715
+
1716
+ edge_risk_by_family: dict[str, float] = {}
1717
+ for name, risk in edge_risk_by_module.items():
1718
+ family = self._classify_family(name)
1719
+ edge_risk_by_family[family] = max(
1720
+ float(edge_risk_by_family.get(family, 0.0)), float(risk)
1721
+ )
1722
+
1723
+ for family_key in ("attn", "ffn", "embed", "other"):
1724
+ edge_risk_by_family.setdefault(family_key, 0.0)
1725
+
1726
+ return {
1727
+ "analysis_source": "activations_edge_risk",
1728
+ "edge_risk_by_module": edge_risk_by_module,
1729
+ "edge_risk_by_family": edge_risk_by_family,
1730
+ "token_weight_total": int(token_weight_total),
1731
+ "batches_used": int(batches_used),
1732
+ }
1733
+
1464
1734
  def _activation_svd_outliers(
1465
1735
  self, activations: Any, margin: float, deadband: float
1466
1736
  ) -> tuple[int, float, float]:
@@ -1780,150 +2050,175 @@ class RMTGuard(Guard):
1780
2050
  calib=None,
1781
2051
  policy: dict[str, Any] | None = None,
1782
2052
  ) -> dict[str, Any]:
1783
- """
1784
- Prepare RMT guard by capturing baseline MP statistics.
1785
-
1786
- Args:
1787
- model: The model that will be edited
1788
- adapter: ModelAdapter (optional, for tying map access)
1789
- calib: Calibration data for activation-based outlier counting
1790
- policy: Guard policy parameters (optional)
1791
-
1792
- Returns:
1793
- Dictionary with preparation results and baseline metrics
1794
- """
2053
+ """Prepare RMT guard by capturing baseline activation edge-risk scores."""
1795
2054
  import time
1796
2055
 
1797
2056
  start_time = time.time()
1798
2057
  self._activation_required_failed = False
1799
2058
  self._activation_required_reason = None
1800
2059
 
1801
- # Store adapter for tying map access during correction
2060
+ # Store adapter for tying map access (if used by downstream code)
1802
2061
  self.adapter = adapter
1803
2062
 
1804
- # Update parameters from policy if provided
1805
- if policy:
1806
- self.q = policy.get("q", self.q)
1807
- self.deadband = policy.get("deadband", self.deadband)
1808
- self.margin = policy.get("margin", self.margin)
1809
- self.correct = policy.get("correct", self.correct)
2063
+ # Policy overrides (vNext contract)
2064
+ if isinstance(policy, dict) and policy:
1810
2065
  if "epsilon" in policy:
1811
- self._set_epsilon(policy["epsilon"])
2066
+ from invarlock.core.exceptions import ValidationError
2067
+
2068
+ raise ValidationError(
2069
+ code="E501",
2070
+ message="POLICY-PARAM-INVALID",
2071
+ details={
2072
+ "param": "epsilon",
2073
+ "hint": "Use rmt.epsilon_default and rmt.epsilon_by_family instead.",
2074
+ },
2075
+ )
2076
+ if "q" in policy:
2077
+ q_val = policy.get("q")
2078
+ if q_val == "auto":
2079
+ self.q = "auto"
2080
+ else:
2081
+ try:
2082
+ self.q = float(q_val)
2083
+ except (TypeError, ValueError):
2084
+ self.q = "auto"
2085
+ if "deadband" in policy:
2086
+ self.deadband = float(policy.get("deadband", self.deadband))
2087
+ if "margin" in policy:
2088
+ try:
2089
+ self.margin = float(policy.get("margin", self.margin))
2090
+ except (TypeError, ValueError):
2091
+ pass
2092
+ if "correct" in policy:
2093
+ self.correct = bool(policy.get("correct"))
1812
2094
  if "epsilon_by_family" in policy:
1813
- self._set_epsilon(policy["epsilon_by_family"])
2095
+ self._set_epsilon_by_family(policy["epsilon_by_family"])
2096
+ if "epsilon_default" in policy:
2097
+ self._set_epsilon_default(policy["epsilon_default"])
2098
+ for family_key in ("attn", "ffn", "embed", "other"):
2099
+ self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
1814
2100
  if "activation_required" in policy:
1815
2101
  self._require_activation = bool(policy.get("activation_required"))
1816
2102
 
2103
+ estimator_policy = policy.get("estimator")
2104
+ if isinstance(estimator_policy, dict):
2105
+ try:
2106
+ iters = int(estimator_policy.get("iters", 3) or 3)
2107
+ except Exception:
2108
+ iters = 3
2109
+ if iters < 1:
2110
+ iters = 1
2111
+ init = (
2112
+ str(estimator_policy.get("init", "ones") or "ones").strip().lower()
2113
+ )
2114
+ if init not in {"ones", "e0"}:
2115
+ init = "ones"
2116
+ self.estimator = {"type": "power_iter", "iters": iters, "init": init}
2117
+
2118
+ activation_policy = policy.get("activation")
2119
+ if isinstance(activation_policy, dict):
2120
+ sampling = activation_policy.get("sampling")
2121
+ if isinstance(sampling, dict):
2122
+ windows = sampling.get("windows")
2123
+ if isinstance(windows, dict):
2124
+ cfg = dict(self.activation_sampling.get("windows") or {})
2125
+ if windows.get("count") is not None:
2126
+ try:
2127
+ cfg["count"] = int(windows.get("count") or 0)
2128
+ except Exception:
2129
+ pass
2130
+ if windows.get("indices_policy") is not None:
2131
+ cfg["indices_policy"] = str(
2132
+ windows.get("indices_policy")
2133
+ or cfg.get("indices_policy")
2134
+ )
2135
+ self.activation_sampling["windows"] = cfg
2136
+
1817
2137
  self._log_event(
1818
2138
  "prepare",
1819
- message=f"Preparing RMT guard with q={self.q}, deadband={self.deadband}, margin={self.margin}, correct={self.correct}",
2139
+ message="Preparing RMT guard baseline activation edge-risk metrics",
1820
2140
  )
1821
2141
 
1822
2142
  try:
1823
- # Capture baseline MP statistics for linear layers
1824
- self.baseline_mp_stats = capture_baseline_mp_stats(model)
1825
-
1826
- # Extract baseline sigmas for compatibility with existing detection
1827
- self.baseline_sigmas = {}
1828
- for name, stats in self.baseline_mp_stats.items():
1829
- self.baseline_sigmas[name] = stats.get("sigma_base", 0.0)
1830
-
1831
- # Get linear modules in scope
1832
- linear_modules = self._get_linear_modules(model)
1833
- self._activation_ready = False
1834
- self._calibration_batches = []
1835
- max_windows = 0
1836
- if calib is not None:
1837
- source = getattr(calib, "dataloader", None) or calib
1838
- if hasattr(source, "__len__"):
1839
- try:
1840
- max_windows = int(len(source))
1841
- except Exception:
1842
- max_windows = 0
1843
- self._calibration_batches = self._collect_calibration_batches(
1844
- calib, max_windows
1845
- )
2143
+ windows_cfg = self.activation_sampling.get("windows") or {}
2144
+ try:
2145
+ window_count = int(windows_cfg.get("count", 0) or 0)
2146
+ except Exception:
2147
+ window_count = 0
2148
+ self._calibration_batches = (
2149
+ self._collect_calibration_batches(calib, window_count)
2150
+ if calib is not None and window_count > 0
2151
+ else []
2152
+ )
1846
2153
 
1847
- activation_baseline = None
1848
- if self._calibration_batches:
1849
- activation_baseline = self._compute_activation_outliers(
1850
- model, self._calibration_batches
1851
- )
2154
+ self.baseline_edge_risk_by_family = {}
2155
+ self.baseline_edge_risk_by_module = {}
2156
+ self.edge_risk_by_family = {}
2157
+ self.edge_risk_by_module = {}
2158
+ self.epsilon_violations = []
1852
2159
 
1853
- if activation_baseline:
1854
- self._activation_ready = True
1855
- self.baseline_total_outliers = int(
1856
- activation_baseline.get("outlier_count", 0) or 0
1857
- )
1858
- self.baseline_outliers_per_family = self._count_outliers_per_family(
1859
- activation_baseline.get("per_layer", [])
1860
- )
1861
- else:
2160
+ if self._require_activation and not self._calibration_batches:
2161
+ self._activation_required_failed = True
2162
+ self._activation_required_reason = "activation_required"
2163
+ self._activation_ready = False
2164
+ self.prepared = False
2165
+ return {
2166
+ "ready": False,
2167
+ "baseline_metrics": {},
2168
+ "policy_applied": policy or {},
2169
+ "preparation_time": time.time() - start_time,
2170
+ "error": "Activation batches required but unavailable",
2171
+ }
2172
+
2173
+ baseline = (
2174
+ self._compute_activation_edge_risk(model, self._calibration_batches)
2175
+ if self._calibration_batches
2176
+ else None
2177
+ )
2178
+ if baseline is None:
1862
2179
  if self._require_activation:
1863
2180
  self._activation_required_failed = True
1864
2181
  self._activation_required_reason = "activation_baseline_unavailable"
1865
- self._log_event(
1866
- "activation_required_missing",
1867
- level="ERROR",
1868
- message="Activation baseline unavailable for RMT",
1869
- profile=self._run_profile,
1870
- tier=self._run_tier,
1871
- )
1872
- baseline_detection = rmt_detect(
1873
- model=model,
1874
- threshold=self.margin,
1875
- detect_only=True,
1876
- baseline_sigmas=self.baseline_sigmas,
1877
- baseline_mp_stats=self.baseline_mp_stats,
1878
- deadband=self.deadband,
1879
- )
1880
- self.baseline_total_outliers = baseline_detection.get(
1881
- "n_layers_flagged", 0
1882
- )
1883
- self.baseline_outliers_per_family = self._count_outliers_per_family(
1884
- baseline_detection.get("per_layer", [])
1885
- )
1886
- for family_key in ("attn", "ffn", "embed", "other"):
1887
- self.baseline_outliers_per_family.setdefault(family_key, 0)
1888
- self.outliers_per_family = {}
1889
- self.outliers_total = 0
1890
- self.epsilon_violations = []
1891
-
1892
- self.prepared = True
1893
- preparation_time = time.time() - start_time
2182
+ self._activation_ready = False
2183
+ self.prepared = False
2184
+ return {
2185
+ "ready": False,
2186
+ "baseline_metrics": {},
2187
+ "policy_applied": policy or {},
2188
+ "preparation_time": time.time() - start_time,
2189
+ "error": "Activation baseline unavailable",
2190
+ }
2191
+ # Non-required: treat as not ready and allow pipeline to continue.
2192
+ self._activation_ready = False
2193
+ self.prepared = True
2194
+ return {
2195
+ "ready": True,
2196
+ "baseline_metrics": {},
2197
+ "policy_applied": policy or {},
2198
+ "preparation_time": time.time() - start_time,
2199
+ }
1894
2200
 
1895
- self._log_event(
1896
- "prepare_success",
1897
- message=f"Captured {len(self.baseline_mp_stats)} baseline MP statistics",
1898
- baseline_count=len(self.baseline_mp_stats),
1899
- linear_modules_count=len(linear_modules),
1900
- preparation_time=preparation_time,
2201
+ self.baseline_edge_risk_by_module = dict(
2202
+ baseline.get("edge_risk_by_module") or {}
1901
2203
  )
2204
+ self.baseline_edge_risk_by_family = dict(
2205
+ baseline.get("edge_risk_by_family") or {}
2206
+ )
2207
+ self._activation_ready = True
2208
+ self.prepared = True
1902
2209
 
2210
+ preparation_time = time.time() - start_time
1903
2211
  return {
2212
+ "ready": True,
1904
2213
  "baseline_metrics": {
1905
- "mp_stats_sample": dict(list(self.baseline_mp_stats.items())[:3]),
1906
- "total_layers": len(self.baseline_mp_stats),
1907
- "linear_modules_in_scope": len(linear_modules),
1908
- "scope_suffixes": self.allowed_suffixes,
1909
- "average_baseline_sigma": np.mean(
1910
- list(self.baseline_sigmas.values())
1911
- ),
1912
- "max_baseline_sigma": max(self.baseline_sigmas.values())
1913
- if self.baseline_sigmas
1914
- else 0.0,
1915
- "min_baseline_sigma": min(self.baseline_sigmas.values())
1916
- if self.baseline_sigmas
1917
- else 0.0,
1918
- },
1919
- "policy_applied": {
1920
- "q": self.q,
1921
- "deadband": self.deadband,
1922
- "margin": self.margin,
1923
- "correct": self.correct,
2214
+ "edge_risk_by_family": dict(self.baseline_edge_risk_by_family),
2215
+ "measurement_contract": {
2216
+ "estimator": self.estimator,
2217
+ "activation_sampling": self.activation_sampling,
2218
+ },
1924
2219
  },
2220
+ "policy_applied": policy or {},
1925
2221
  "preparation_time": preparation_time,
1926
- "ready": True,
1927
2222
  }
1928
2223
 
1929
2224
  except Exception as e:
@@ -1957,12 +2252,7 @@ class RMTGuard(Guard):
1957
2252
  )
1958
2253
 
1959
2254
  def after_edit(self, model: nn.Module) -> None:
1960
- """
1961
- Execute after edit - perform RMT detection and optional correction.
1962
-
1963
- Args:
1964
- model: The model that was just edited
1965
- """
2255
+ """Execute after edit: compute activation edge-risk on sampled batches."""
1966
2256
  if not self.prepared:
1967
2257
  self._log_event(
1968
2258
  "after_edit_skipped",
@@ -1971,137 +2261,40 @@ class RMTGuard(Guard):
1971
2261
  )
1972
2262
  return
1973
2263
 
1974
- self._log_event("after_edit", message="Applying RMT detection and correction")
1975
-
1976
2264
  try:
1977
- detection_result: dict[str, Any] | None = None
1978
- corrected_layers = 0
1979
- correction_iterations = 0
1980
- use_activation = bool(self._activation_ready and self._calibration_batches)
1981
-
1982
- if self._require_activation and not use_activation:
2265
+ if self._require_activation and not self._calibration_batches:
1983
2266
  self._activation_required_failed = True
1984
2267
  self._activation_required_reason = "activation_unavailable"
1985
2268
  self._last_result = {
1986
- "has_outliers": False,
1987
- "n_layers_flagged": 0,
1988
- "per_layer": [],
1989
- "max_ratio": 0.0,
1990
- "analysis_source": "activations",
2269
+ "analysis_source": "activations_edge_risk",
2270
+ "edge_risk_by_module": {},
2271
+ "edge_risk_by_family": {},
1991
2272
  }
1992
- self._log_event(
1993
- "activation_required_missing",
1994
- level="ERROR",
1995
- message="Activation outlier analysis required but unavailable",
1996
- profile=self._run_profile,
1997
- tier=self._run_tier,
1998
- )
1999
2273
  return
2000
2274
 
2001
- if use_activation:
2002
- if self.correct:
2003
- correction_result = self._apply_rmt_detection_and_correction(model)
2004
- correction_iterations = int(
2005
- correction_result.get("correction_iterations", 0) or 0
2006
- )
2007
- corrected_layers = int(
2008
- correction_result.get("corrected_layers", 0) or 0
2009
- )
2010
- detection_result = self._compute_activation_outliers(
2011
- model, self._calibration_batches
2012
- )
2013
- if not detection_result:
2014
- if self._require_activation:
2015
- self._activation_required_failed = True
2016
- self._activation_required_reason = (
2017
- "activation_outliers_unavailable"
2018
- )
2019
- self._last_result = {
2020
- "has_outliers": False,
2021
- "n_layers_flagged": 0,
2022
- "per_layer": [],
2023
- "max_ratio": 0.0,
2024
- "analysis_source": "activations",
2025
- }
2026
- self._log_event(
2027
- "activation_required_missing",
2028
- level="ERROR",
2029
- message="Activation outlier analysis failed",
2030
- profile=self._run_profile,
2031
- tier=self._run_tier,
2032
- )
2033
- return
2034
- use_activation = False
2035
-
2036
- if not use_activation:
2037
- if self.correct:
2038
- # Apply correction using enhanced logic with adapter support
2039
- detection_result = self._apply_rmt_detection_and_correction(model)
2040
- else:
2041
- # Detection only
2042
- detection_result = rmt_detect(
2043
- model=model,
2044
- threshold=self.margin, # Use margin as threshold
2045
- detect_only=True,
2046
- verbose=False,
2047
- baseline_sigmas=self.baseline_sigmas,
2048
- baseline_mp_stats=self.baseline_mp_stats,
2049
- deadband=self.deadband,
2275
+ current = (
2276
+ self._compute_activation_edge_risk(model, self._calibration_batches)
2277
+ if self._calibration_batches
2278
+ else None
2279
+ )
2280
+ if current is None:
2281
+ if self._require_activation:
2282
+ self._activation_required_failed = True
2283
+ self._activation_required_reason = (
2284
+ "activation_edge_risk_unavailable"
2050
2285
  )
2286
+ self._last_result = {
2287
+ "analysis_source": "activations_edge_risk",
2288
+ "edge_risk_by_module": {},
2289
+ "edge_risk_by_family": {},
2290
+ }
2291
+ return
2051
2292
 
2052
- if detection_result is None:
2053
- raise RuntimeError("RMT detection failed to produce results")
2054
-
2055
- if use_activation:
2056
- detection_result["correction_iterations"] = correction_iterations
2057
- detection_result["corrected_layers"] = corrected_layers
2058
-
2059
- # Store results
2060
- self._last_result = detection_result
2061
- self.outliers_per_family = self._count_outliers_per_family(
2062
- detection_result.get("per_layer", [])
2063
- )
2064
- for family_key in ("attn", "ffn", "embed", "other"):
2065
- self.outliers_per_family.setdefault(family_key, 0)
2066
- outlier_total = detection_result.get("outlier_count")
2067
- if outlier_total is None:
2068
- outlier_total = detection_result.get(
2069
- "n_layers_flagged", len(self.outliers_per_family)
2070
- )
2071
- self.outliers_total = int(outlier_total or 0)
2293
+ self.edge_risk_by_module = dict(current.get("edge_risk_by_module") or {})
2294
+ self.edge_risk_by_family = dict(current.get("edge_risk_by_family") or {})
2295
+ self._last_result = dict(current)
2072
2296
  self.epsilon_violations = self._compute_epsilon_violations()
2073
2297
 
2074
- flagged_layers = detection_result.get("n_layers_flagged", 0)
2075
- corrected_layers = detection_result.get("correction_iterations", 0)
2076
-
2077
- self._log_event(
2078
- "rmt_detection_complete",
2079
- message=f"Detected {flagged_layers} outlier layers, correction enabled: {self.correct}",
2080
- layers_flagged=flagged_layers,
2081
- correction_iterations=corrected_layers,
2082
- has_outliers=detection_result.get("has_outliers", False),
2083
- max_ratio=detection_result.get("max_ratio", 0.0),
2084
- )
2085
-
2086
- # Log individual layer results
2087
- for layer_info in detection_result.get("per_layer", []):
2088
- if layer_info.get("has_outlier", False):
2089
- self._log_event(
2090
- "outlier_detected",
2091
- message=f"Outlier detected in {layer_info.get('module_name', 'unknown')}",
2092
- layer_name=layer_info.get("module_name"),
2093
- ratio=layer_info.get("worst_ratio", 0.0),
2094
- sigma_max=layer_info.get("sigma_max", 0.0),
2095
- corrected=self.correct,
2096
- )
2097
- elif layer_info.get("skip_reason"):
2098
- self._log_event(
2099
- "layer_skipped",
2100
- message=f"Layer {layer_info.get('module_name', 'unknown')} skipped: {layer_info.get('skip_reason')}",
2101
- layer_name=layer_info.get("module_name"),
2102
- skip_reason=layer_info.get("skip_reason"),
2103
- )
2104
-
2105
2298
  except Exception as e:
2106
2299
  self._log_event(
2107
2300
  "after_edit_failed",
@@ -2109,15 +2302,11 @@ class RMTGuard(Guard):
2109
2302
  message=f"RMT detection failed: {str(e)}",
2110
2303
  error=str(e),
2111
2304
  )
2112
- # Store empty result for finalize
2113
2305
  self._last_result = {
2114
- "has_outliers": False,
2115
- "n_layers_flagged": 0,
2116
- "per_layer": [],
2117
- "max_ratio": 0.0,
2306
+ "analysis_source": "activations_edge_risk",
2307
+ "edge_risk_by_module": {},
2308
+ "edge_risk_by_family": {},
2118
2309
  }
2119
- self.outliers_per_family = {}
2120
- self.outliers_total = 0
2121
2310
  self.epsilon_violations = []
2122
2311
 
2123
2312
  def validate(
@@ -2163,27 +2352,13 @@ class RMTGuard(Guard):
2163
2352
  }
2164
2353
 
2165
2354
  def finalize(self, model: nn.Module, adapter=None) -> GuardOutcome | dict[str, Any]:
2166
- """
2167
- Finalize RMT guard and return comprehensive results.
2168
-
2169
- Args:
2170
- model: The final edited model
2171
- adapter: Optional adapter for tying map access
2172
-
2173
- Returns:
2174
- GuardOutcome or dict with RMT detection and correction results
2175
- """
2355
+ """Finalize RMT guard and return activation edge-risk ε-band outcome."""
2176
2356
  import time
2177
2357
 
2178
2358
  start_time = time.time()
2359
+ _ = adapter
2179
2360
 
2180
2361
  if not self.prepared:
2181
- self._log_event(
2182
- "finalize_failed",
2183
- level="ERROR",
2184
- message="RMT guard not properly prepared",
2185
- )
2186
-
2187
2362
  if HAS_GUARD_OUTCOME:
2188
2363
  return GuardOutcome(
2189
2364
  name=self.name,
@@ -2202,35 +2377,28 @@ class RMTGuard(Guard):
2202
2377
  "finalize_time": time.time() - start_time,
2203
2378
  },
2204
2379
  )
2205
- else:
2206
- return {
2207
- "passed": False,
2208
- "metrics": {
2209
- "prepared": False,
2210
- "finalize_time": time.time() - start_time,
2211
- },
2212
- "warnings": ["RMT guard not properly prepared"],
2213
- "errors": ["Preparation failed or baseline MP stats not captured"],
2214
- "events": self.events,
2215
- }
2380
+ return {
2381
+ "passed": False,
2382
+ "metrics": {
2383
+ "prepared": False,
2384
+ "finalize_time": time.time() - start_time,
2385
+ },
2386
+ "errors": ["RMT guard not properly prepared"],
2387
+ }
2216
2388
 
2217
- # Get results from after_edit
2218
- if self._require_activation and (
2219
- self._activation_required_failed or not self._activation_ready
2220
- ):
2389
+ if self._require_activation and self._activation_required_failed:
2221
2390
  reason = self._activation_required_reason or "activation_required"
2222
- message = "Activation outlier analysis required but unavailable"
2223
2391
  finalize_time = time.time() - start_time
2224
2392
  if HAS_GUARD_OUTCOME:
2225
2393
  return GuardOutcome(
2226
2394
  name=self.name,
2227
2395
  passed=False,
2228
- action="rollback",
2396
+ action="abort",
2229
2397
  violations=[
2230
2398
  {
2231
2399
  "type": "activation_required",
2232
2400
  "severity": "error",
2233
- "message": message,
2401
+ "message": "Activation edge-risk analysis required but unavailable",
2234
2402
  "module_name": None,
2235
2403
  "reason": reason,
2236
2404
  }
@@ -2252,218 +2420,74 @@ class RMTGuard(Guard):
2252
2420
  "activation_reason": reason,
2253
2421
  "finalize_time": finalize_time,
2254
2422
  },
2255
- "warnings": [],
2256
- "errors": [message],
2257
- "violations": [],
2258
- "events": self.events,
2423
+ "errors": ["Activation edge-risk analysis required but unavailable"],
2259
2424
  }
2260
2425
 
2261
- result = self._last_result or {
2262
- "has_outliers": False,
2263
- "n_layers_flagged": 0,
2264
- "per_layer": [],
2265
- "max_ratio": 0.0,
2266
- }
2267
-
2268
- if result and not self.outliers_per_family:
2269
- self.outliers_per_family = self._count_outliers_per_family(
2270
- result.get("per_layer", [])
2426
+ if not self.edge_risk_by_family and self._calibration_batches:
2427
+ current = self._compute_activation_edge_risk(
2428
+ model, self._calibration_batches
2271
2429
  )
2272
- for family_key in ("attn", "ffn", "embed", "other"):
2273
- self.outliers_per_family.setdefault(family_key, 0)
2274
- self.baseline_outliers_per_family.setdefault(family_key, 0)
2275
- outlier_total = result.get("outlier_count")
2276
- if outlier_total is None:
2277
- outlier_total = result.get("n_layers_flagged", self.outliers_total or 0)
2278
- self.outliers_total = int(outlier_total or 0)
2430
+ if current is not None:
2431
+ self.edge_risk_by_family = dict(
2432
+ current.get("edge_risk_by_family") or {}
2433
+ )
2434
+ self.edge_risk_by_module = dict(
2435
+ current.get("edge_risk_by_module") or {}
2436
+ )
2437
+ self._last_result = dict(current)
2438
+
2279
2439
  self.epsilon_violations = self._compute_epsilon_violations()
2280
- # Contracts: epsilon non-negative, counts non-negative
2281
2440
  for fam, eps in self.epsilon_by_family.items():
2282
2441
  guard_assert(eps >= 0.0, f"rmt.epsilon[{fam}] must be >= 0")
2283
- for fam in set(self.outliers_per_family) | set(
2284
- self.baseline_outliers_per_family
2285
- ):
2286
- guard_assert(
2287
- self.outliers_per_family.get(fam, 0) >= 0,
2288
- "rmt.outliers_per_family negative",
2289
- )
2290
- guard_assert(
2291
- self.baseline_outliers_per_family.get(fam, 0) >= 0,
2292
- "rmt.baseline_outliers negative",
2293
- )
2294
-
2295
- # Calculate metrics
2296
- flagged_layers = result.get("n_layers_flagged", 0)
2297
- total_layers = (
2298
- len(result.get("per_layer", []))
2299
- if result.get("per_layer")
2300
- else len(self.baseline_mp_stats)
2301
- if self.baseline_mp_stats
2302
- else 0
2303
- )
2304
- flagged_rate = flagged_layers / total_layers if total_layers > 0 else 0.0
2305
-
2306
- # Acceptance gate: pass when epsilon-rule holds per family.
2307
- passed = not bool(self.epsilon_violations)
2308
-
2309
- # Generate violations for GuardOutcome
2310
- violations = []
2311
- warnings = []
2312
- errors = []
2313
-
2314
- # Create violations for each flagged layer
2315
- for layer_info in result.get("per_layer", []):
2316
- if layer_info.get("has_outlier", False):
2317
- violations.append(
2318
- {
2319
- "type": "rmt_outlier",
2320
- "severity": "warning" if self.correct else "error",
2321
- "message": f"RMT outlier detected: ratio={layer_info.get('worst_ratio', 0.0):.2f}",
2322
- "module_name": layer_info.get("module_name"),
2323
- "ratio": layer_info.get("worst_ratio", 0.0),
2324
- "threshold": (1.0 + self.deadband) * self.margin,
2325
- "corrected": self.correct,
2326
- }
2327
- )
2328
-
2329
- if flagged_rate > 0.3: # Warning threshold at 30%
2330
- warnings.append(
2331
- f"High RMT outlier rate: {flagged_layers}/{total_layers} layers flagged ({flagged_rate:.1%})"
2332
- )
2333
-
2334
- if flagged_rate > 0.7: # Escalate to warning for unusually high rates
2335
- warnings.append(
2336
- f"Excessive RMT outliers: {flagged_layers}/{total_layers} layers flagged"
2337
- )
2338
-
2339
- if self.epsilon_violations:
2340
- passed = False
2341
- for failure in self.epsilon_violations:
2342
- errors.append(
2343
- "RMT ε-rule violation: "
2344
- f"{failure['family']} bare={failure['bare']} "
2345
- f"guarded={failure['guarded']} allowed={failure['allowed']} "
2346
- f"(ε={failure['epsilon']:.3f})"
2347
- )
2348
2442
 
2443
+ stable = not self.epsilon_violations
2444
+ action = "continue" if stable else "abort"
2349
2445
  finalize_time = time.time() - start_time
2350
2446
 
2351
- # Final metrics
2352
- final_metrics = {
2353
- "layers_flagged": flagged_layers,
2354
- "total_layers": total_layers,
2355
- "flagged_rate": flagged_rate,
2356
- "rmt_outliers": flagged_layers,
2357
- "max_ratio": result.get("max_ratio", 0.0),
2358
- "correction_enabled": self.correct,
2359
- "correction_iterations": result.get("correction_iterations", 0),
2360
- "q_used": self.q,
2361
- "deadband_used": self.deadband,
2362
- "margin_used": self.margin,
2363
- "detection_threshold": (1.0 + self.deadband) * self.margin,
2364
- "baseline_layers_captured": len(self.baseline_mp_stats)
2365
- if self.baseline_mp_stats
2366
- else 0,
2367
- "finalize_time": finalize_time,
2368
- "activation_required": bool(self._require_activation),
2369
- "activation_ready": bool(self._activation_ready),
2370
- "analysis_source": result.get("analysis_source"),
2371
- "token_weight_total": result.get("token_weight_total"),
2372
- "baseline_outliers_per_family": {
2373
- k: int(v) for k, v in self.baseline_outliers_per_family.items()
2447
+ metrics: dict[str, Any] = {
2448
+ "prepared": True,
2449
+ "stable": stable,
2450
+ "edge_risk_by_family_base": dict(self.baseline_edge_risk_by_family),
2451
+ "edge_risk_by_family": dict(self.edge_risk_by_family),
2452
+ "epsilon_by_family": dict(self.epsilon_by_family),
2453
+ "epsilon_violations": list(self.epsilon_violations),
2454
+ "measurement_contract": {
2455
+ "estimator": self.estimator,
2456
+ "activation_sampling": self.activation_sampling,
2374
2457
  },
2375
- "outliers_per_family": {
2376
- k: int(v) for k, v in self.outliers_per_family.items()
2377
- },
2378
- "baseline_outliers_total": int(self.baseline_total_outliers),
2379
- "outliers_total": int(self.outliers_total),
2380
- "epsilon_by_family": {
2381
- k: float(v) for k, v in self.epsilon_by_family.items()
2382
- },
2383
- "epsilon_default": float(self.epsilon_default),
2384
- "epsilon_violations": self.epsilon_violations,
2458
+ "finalize_time": finalize_time,
2385
2459
  }
2386
2460
 
2387
- self._log_event(
2388
- "finalize_complete",
2389
- message=f"RMT guard finalized - {'PASSED' if passed else 'FAILED'}",
2390
- passed=passed,
2391
- flagged_rate=flagged_rate,
2392
- finalize_time=finalize_time,
2393
- )
2394
-
2395
- # Return GuardOutcome if available, otherwise legacy dict
2396
- # Env-gated tiny evidence dump for auditors
2397
- try:
2398
- payload = {
2399
- "rmt": {
2400
- "epsilon_by_family": {
2401
- k: float(v) for k, v in self.epsilon_by_family.items()
2402
- },
2403
- "deadband": float(self.deadband),
2404
- "margin": float(self.margin),
2405
- "evaluated": True,
2406
- }
2407
- }
2408
- maybe_dump_guard_evidence(".", payload)
2409
- except Exception:
2410
- pass
2411
-
2412
- if HAS_GUARD_OUTCOME:
2413
- # Add details to metrics since GuardOutcome doesn't have a details field
2414
- final_metrics.update(
2461
+ violations: list[dict[str, Any]] = []
2462
+ for v in self.epsilon_violations:
2463
+ violations.append(
2415
2464
  {
2416
- "guard_type": "rmt",
2417
- "baseline_captured": self.baseline_mp_stats is not None,
2418
- "baseline_count": len(self.baseline_mp_stats)
2419
- if self.baseline_mp_stats
2420
- else 0,
2421
- "flagged_layer_names": [v["module_name"] for v in violations],
2422
- "per_layer_results": result.get("per_layer", []),
2423
- "policy": {
2424
- "q": self.q,
2425
- "deadband": self.deadband,
2426
- "margin": self.margin,
2427
- "correct": self.correct,
2428
- "epsilon": self.epsilon_by_family.copy(),
2429
- },
2430
- "scope_suffixes": self.allowed_suffixes,
2465
+ "type": "epsilon_band",
2466
+ "severity": "error",
2467
+ "family": v.get("family"),
2468
+ "edge_base": v.get("edge_base"),
2469
+ "edge_cur": v.get("edge_cur"),
2470
+ "allowed": v.get("allowed"),
2471
+ "epsilon": v.get("epsilon"),
2472
+ "delta": v.get("delta"),
2473
+ "message": f"ε-band violation in {v.get('family')}",
2431
2474
  }
2432
2475
  )
2433
2476
 
2477
+ if HAS_GUARD_OUTCOME:
2434
2478
  return GuardOutcome(
2435
2479
  name=self.name,
2436
- passed=passed,
2437
- action="none" if passed else "rollback",
2480
+ passed=stable,
2481
+ action=action,
2438
2482
  violations=violations,
2439
- metrics=final_metrics,
2483
+ metrics=metrics,
2440
2484
  )
2441
- else:
2442
- return {
2443
- "passed": passed,
2444
- "metrics": final_metrics,
2445
- "warnings": warnings,
2446
- "errors": errors,
2447
- "violations": violations,
2448
- "events": self.events,
2449
- "details": {
2450
- "guard_type": "rmt",
2451
- "baseline_captured": self.baseline_mp_stats is not None,
2452
- "baseline_count": len(self.baseline_mp_stats)
2453
- if self.baseline_mp_stats
2454
- else 0,
2455
- "flagged_layer_names": [v["module_name"] for v in violations],
2456
- "per_layer_results": result.get("per_layer", []),
2457
- "policy": {
2458
- "q": self.q,
2459
- "deadband": self.deadband,
2460
- "margin": self.margin,
2461
- "correct": self.correct,
2462
- "epsilon": self.epsilon_by_family.copy(),
2463
- },
2464
- "scope_suffixes": self.allowed_suffixes,
2465
- },
2466
- }
2485
+ return {
2486
+ "passed": stable,
2487
+ "action": action,
2488
+ "metrics": metrics,
2489
+ "violations": violations,
2490
+ }
2467
2491
 
2468
2492
  def policy(self) -> RMTPolicyDict:
2469
2493
  """
@@ -2477,7 +2501,8 @@ class RMTGuard(Guard):
2477
2501
  deadband=self.deadband,
2478
2502
  margin=self.margin,
2479
2503
  correct=self.correct,
2480
- epsilon=self.epsilon_by_family.copy(),
2504
+ epsilon_default=float(self.epsilon_default),
2505
+ epsilon_by_family=self.epsilon_by_family.copy(),
2481
2506
  )
2482
2507
 
2483
2508
 
@@ -2494,28 +2519,31 @@ def get_rmt_policy(name: str = "balanced") -> RMTPolicyDict:
2494
2519
  Returns:
2495
2520
  RMTPolicyDict configuration
2496
2521
  """
2497
- # Per-family ε values match tiers.yaml (November 2025 calibration)
2522
+ # Per-family ε values match runtime tiers.yaml.
2498
2523
  policies = {
2499
2524
  "conservative": RMTPolicyDict(
2500
2525
  q="auto",
2501
2526
  deadband=0.05,
2502
2527
  margin=1.3,
2503
2528
  correct=True,
2504
- epsilon={"ffn": 0.06, "attn": 0.05, "embed": 0.07, "other": 0.07},
2529
+ epsilon_default=0.06,
2530
+ epsilon_by_family={"ffn": 0.06, "attn": 0.05, "embed": 0.07, "other": 0.07},
2505
2531
  ),
2506
2532
  "balanced": RMTPolicyDict(
2507
2533
  q="auto",
2508
2534
  deadband=0.10,
2509
2535
  margin=1.5,
2510
2536
  correct=True,
2511
- epsilon={"ffn": 0.10, "attn": 0.08, "embed": 0.12, "other": 0.12},
2537
+ epsilon_default=0.10,
2538
+ epsilon_by_family={"ffn": 0.10, "attn": 0.08, "embed": 0.12, "other": 0.12},
2512
2539
  ),
2513
2540
  "aggressive": RMTPolicyDict(
2514
2541
  q="auto",
2515
2542
  deadband=0.15,
2516
2543
  margin=1.8,
2517
2544
  correct=True,
2518
- epsilon={"ffn": 0.14, "attn": 0.12, "embed": 0.18, "other": 0.18},
2545
+ epsilon_default=0.15,
2546
+ epsilon_by_family={"ffn": 0.15, "attn": 0.15, "embed": 0.15, "other": 0.15},
2519
2547
  ),
2520
2548
  }
2521
2549
 
@@ -2537,7 +2565,9 @@ def create_custom_rmt_policy(
2537
2565
  deadband: float = 0.10,
2538
2566
  margin: float = 1.5,
2539
2567
  correct: bool = True,
2540
- epsilon: float | dict[str, float] | None = None,
2568
+ *,
2569
+ epsilon_default: float = 0.1,
2570
+ epsilon_by_family: dict[str, float] | None = None,
2541
2571
  ) -> RMTPolicyDict:
2542
2572
  """
2543
2573
  Create a custom RMT policy.
@@ -2578,10 +2608,61 @@ def create_custom_rmt_policy(
2578
2608
  details={"param": "margin", "value": margin},
2579
2609
  )
2580
2610
 
2611
+ from invarlock.core.exceptions import ValidationError
2612
+
2613
+ try:
2614
+ eps_default_val = float(epsilon_default)
2615
+ except (TypeError, ValueError) as exc:
2616
+ raise ValidationError(
2617
+ code="E501",
2618
+ message="POLICY-PARAM-INVALID",
2619
+ details={"param": "epsilon_default", "value": epsilon_default},
2620
+ ) from exc
2621
+ if not (math.isfinite(eps_default_val) and eps_default_val >= 0.0):
2622
+ raise ValidationError(
2623
+ code="E501",
2624
+ message="POLICY-PARAM-INVALID",
2625
+ details={"param": "epsilon_default", "value": epsilon_default},
2626
+ )
2627
+
2628
+ eps_by_family: dict[str, float] = {}
2629
+ if epsilon_by_family is not None:
2630
+ if not isinstance(epsilon_by_family, dict):
2631
+ raise ValidationError(
2632
+ code="E501",
2633
+ message="POLICY-PARAM-INVALID",
2634
+ details={"param": "epsilon_by_family", "value": epsilon_by_family},
2635
+ )
2636
+ for family, value in epsilon_by_family.items():
2637
+ try:
2638
+ eps_val = float(value)
2639
+ except (TypeError, ValueError) as exc:
2640
+ raise ValidationError(
2641
+ code="E501",
2642
+ message="POLICY-PARAM-INVALID",
2643
+ details={
2644
+ "param": "epsilon_by_family",
2645
+ "family": str(family),
2646
+ "value": value,
2647
+ },
2648
+ ) from exc
2649
+ if not (math.isfinite(eps_val) and eps_val >= 0.0):
2650
+ raise ValidationError(
2651
+ code="E501",
2652
+ message="POLICY-PARAM-INVALID",
2653
+ details={
2654
+ "param": "epsilon_by_family",
2655
+ "family": str(family),
2656
+ "value": value,
2657
+ },
2658
+ )
2659
+ eps_by_family[str(family)] = eps_val
2660
+
2581
2661
  return RMTPolicyDict(
2582
2662
  q=q,
2583
2663
  deadband=deadband,
2584
2664
  margin=margin,
2585
2665
  correct=correct,
2586
- epsilon=epsilon,
2666
+ epsilon_default=eps_default_val,
2667
+ epsilon_by_family=eps_by_family,
2587
2668
  )