invarlock 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invarlock/guards/rmt.py CHANGED
@@ -11,6 +11,7 @@ to training instability.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ import itertools
14
15
  import math
15
16
  from dataclasses import dataclass
16
17
  from datetime import datetime
@@ -1117,8 +1118,9 @@ class RMTGuard(Guard):
1117
1118
  """
1118
1119
  Standalone RMT Guard for baseline-aware outlier detection and correction.
1119
1120
 
1120
- Implements Marchenko-Pastur theory-based spectral health checking with:
1121
- - Baseline capture of MP bulk edges for linear layers
1121
+ Implements Marchenko-Pastur theory-based outlier tracking with:
1122
+ - Activation-based outlier counts from calibration batches
1123
+ - Baseline capture of MP bulk edges for linear layers (correction/fallback)
1122
1124
  - Conservative outlier detection with deadband support
1123
1125
  - Optional in-place correction preserving weight tying
1124
1126
  - Comprehensive event logging and metrics
@@ -1129,7 +1131,7 @@ class RMTGuard(Guard):
1129
1131
  - margin: RMT threshold ratio (default 1.5)
1130
1132
  - correct: Enable automatic correction (default True)
1131
1133
 
1132
- Linear Layer Scope (enforced):
1134
+ Linear Layer Scope (correction/fallback):
1133
1135
  - attn.c_attn, attn.c_proj, mlp.c_fc, mlp.c_proj
1134
1136
  - Excludes: embeddings, LM head, layer norms, biases
1135
1137
  """
@@ -1164,6 +1166,13 @@ class RMTGuard(Guard):
1164
1166
  self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
1165
1167
 
1166
1168
  # Internal state
1169
+ self._calibration_batches: list[Any] = []
1170
+ self._activation_ready = False
1171
+ self._require_activation = False
1172
+ self._activation_required_failed = False
1173
+ self._activation_required_reason: str | None = None
1174
+ self._run_profile: str | None = None
1175
+ self._run_tier: str | None = None
1167
1176
  self.baseline_mp_stats: dict[str, dict[str, float]] | None = None
1168
1177
  self.baseline_sigmas: dict[str, float] | None = None
1169
1178
  self.prepared = False
@@ -1198,6 +1207,22 @@ class RMTGuard(Guard):
1198
1207
  }
1199
1208
  self.events.append(event)
1200
1209
 
1210
+ def set_run_context(self, report: Any) -> None:
1211
+ """Capture tier/profile context for activation requirements."""
1212
+ ctx = getattr(report, "context", {}) or {}
1213
+ profile = ""
1214
+ tier = "balanced"
1215
+ if isinstance(ctx, dict):
1216
+ profile = str(ctx.get("profile", "") or "").strip().lower()
1217
+ auto = ctx.get("auto")
1218
+ if isinstance(auto, dict):
1219
+ tier = str(auto.get("tier", tier) or tier).strip().lower()
1220
+ self._run_profile = profile or None
1221
+ self._run_tier = tier or None
1222
+ self._require_activation = bool(
1223
+ profile in {"ci", "release"} and tier in {"balanced", "conservative"}
1224
+ )
1225
+
1201
1226
  def _set_epsilon(self, epsilon: float | dict[str, float] | None) -> None:
1202
1227
  """Configure epsilon defaults and per-family overrides."""
1203
1228
  if isinstance(epsilon, dict):
@@ -1244,11 +1269,21 @@ class RMTGuard(Guard):
1244
1269
  """Count outliers grouped by family."""
1245
1270
  counts: dict[str, int] = {}
1246
1271
  for layer_info in per_layer:
1247
- if not layer_info.get("has_outlier"):
1248
- continue
1272
+ outlier_count = layer_info.get("outlier_count")
1273
+ if outlier_count is None:
1274
+ if not layer_info.get("has_outlier"):
1275
+ continue
1276
+ increment = 1
1277
+ else:
1278
+ try:
1279
+ increment = int(outlier_count)
1280
+ except (TypeError, ValueError):
1281
+ continue
1282
+ if increment <= 0:
1283
+ continue
1249
1284
  module_name = layer_info.get("module_name", "")
1250
1285
  family = self._classify_family(module_name)
1251
- counts[family] = counts.get(family, 0) + 1
1286
+ counts[family] = counts.get(family, 0) + increment
1252
1287
  return counts
1253
1288
 
1254
1289
  def _compute_epsilon_violations(self) -> list[dict[str, Any]]:
@@ -1312,6 +1347,295 @@ class RMTGuard(Guard):
1312
1347
 
1313
1348
  return modules
1314
1349
 
1350
+ def _collect_calibration_batches(self, calib: Any, max_windows: int) -> list[Any]:
1351
+ """Collect a deterministic slice of calibration batches."""
1352
+ if calib is None or max_windows <= 0:
1353
+ return []
1354
+ source = getattr(calib, "dataloader", None) or calib
1355
+ try:
1356
+ iterator = iter(source)
1357
+ except TypeError:
1358
+ return []
1359
+ return list(itertools.islice(iterator, max_windows))
1360
+
1361
+ def _prepare_activation_inputs(
1362
+ self, batch: Any, device: torch.device
1363
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
1364
+ """Normalize batch inputs to tensors on the target device."""
1365
+ if isinstance(batch, dict):
1366
+ input_ids = batch.get("input_ids", batch.get("inputs"))
1367
+ attention_mask = batch.get("attention_mask")
1368
+ elif isinstance(batch, tuple | list) and batch:
1369
+ input_ids = batch[0]
1370
+ attention_mask = batch[1] if len(batch) > 1 else None
1371
+ else:
1372
+ input_ids = batch
1373
+ attention_mask = None
1374
+
1375
+ if input_ids is None:
1376
+ return None, None
1377
+
1378
+ if not isinstance(input_ids, torch.Tensor):
1379
+ input_ids = torch.as_tensor(input_ids)
1380
+ if input_ids.dim() == 1:
1381
+ input_ids = input_ids.unsqueeze(0)
1382
+ try:
1383
+ input_ids = input_ids.to(device)
1384
+ except Exception:
1385
+ input_ids = input_ids.clone()
1386
+
1387
+ if attention_mask is not None:
1388
+ if not isinstance(attention_mask, torch.Tensor):
1389
+ attention_mask = torch.as_tensor(attention_mask)
1390
+ if attention_mask.dim() == 1:
1391
+ attention_mask = attention_mask.unsqueeze(0)
1392
+ try:
1393
+ attention_mask = attention_mask.to(device)
1394
+ except Exception:
1395
+ attention_mask = attention_mask.clone()
1396
+
1397
+ return input_ids, attention_mask
1398
+
1399
+ @staticmethod
1400
+ def _batch_token_weight(
1401
+ input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None
1402
+ ) -> int:
1403
+ """Compute token-weight for a batch (used for activation outlier weighting)."""
1404
+ weight = 0
1405
+ if isinstance(attention_mask, torch.Tensor):
1406
+ try:
1407
+ weight = int(attention_mask.sum().item())
1408
+ except Exception:
1409
+ weight = 0
1410
+ if weight <= 0 and isinstance(input_ids, torch.Tensor):
1411
+ try:
1412
+ weight = int(input_ids.numel())
1413
+ except Exception:
1414
+ weight = 0
1415
+ return max(weight, 1)
1416
+
1417
+ def _get_activation_modules(self, model: nn.Module) -> list[tuple[str, nn.Module]]:
1418
+ """Return modules to analyze for activation-based RMT."""
1419
+ modules: list[tuple[str, nn.Module]] = []
1420
+ try:
1421
+ from transformers.pytorch_utils import Conv1D
1422
+
1423
+ module_types_with_conv1d: tuple[
1424
+ type[nn.Linear], type[nn.Conv1d], type[Conv1D]
1425
+ ] = (nn.Linear, nn.Conv1d, Conv1D)
1426
+ module_types = module_types_with_conv1d
1427
+ except ImportError:
1428
+ module_types_without_conv1d: tuple[type[nn.Linear], type[nn.Conv1d]] = (
1429
+ nn.Linear,
1430
+ nn.Conv1d,
1431
+ )
1432
+ module_types = module_types_without_conv1d
1433
+
1434
+ for name, module in model.named_modules():
1435
+ if isinstance(module, nn.Embedding):
1436
+ modules.append((name, module))
1437
+ continue
1438
+ if isinstance(module, nn.LayerNorm):
1439
+ modules.append((name, module))
1440
+ continue
1441
+ if isinstance(module, module_types) and hasattr(module, "weight"):
1442
+ name_lower = name.lower()
1443
+ if any(
1444
+ name.endswith(suffix) for suffix in self.allowed_suffixes
1445
+ ) or any(
1446
+ tok in name_lower
1447
+ for tok in (
1448
+ "attn",
1449
+ "attention",
1450
+ "mlp",
1451
+ "ffn",
1452
+ "router",
1453
+ "expert",
1454
+ "moe",
1455
+ "gate",
1456
+ "gating",
1457
+ "switch",
1458
+ )
1459
+ ):
1460
+ modules.append((name, module))
1461
+
1462
+ return modules
1463
+
1464
+ def _activation_svd_outliers(
1465
+ self, activations: Any, margin: float, deadband: float
1466
+ ) -> tuple[int, float, float]:
1467
+ """Count activation singular values beyond the MP edge."""
1468
+ if isinstance(activations, tuple | list):
1469
+ activations = activations[0] if activations else None
1470
+ if not isinstance(activations, torch.Tensor):
1471
+ return 0, 0.0, 0.0
1472
+
1473
+ if activations.dim() < 2:
1474
+ return 0, 0.0, 0.0
1475
+
1476
+ if activations.dim() > 2:
1477
+ activations = activations.reshape(-1, activations.shape[-1])
1478
+
1479
+ if activations.numel() == 0:
1480
+ return 0, 0.0, 0.0
1481
+
1482
+ try:
1483
+ mat = activations.detach().float().cpu()
1484
+ except Exception:
1485
+ return 0, 0.0, 0.0
1486
+
1487
+ if not torch.isfinite(mat).all():
1488
+ return 0, 0.0, 0.0
1489
+
1490
+ mat = mat - mat.mean()
1491
+ std = float(mat.std().item())
1492
+ if not math.isfinite(std) or std <= 0.0:
1493
+ return 0, 0.0, 0.0
1494
+
1495
+ mat = mat / std
1496
+ m, n = mat.shape
1497
+ mp_edge_val = mp_bulk_edge(m, n, whitened=False)
1498
+ threshold = mp_edge_val * (1.0 + deadband) * margin
1499
+
1500
+ try:
1501
+ s_vals = torch.linalg.svdvals(mat)
1502
+ except (RuntimeError, torch.linalg.LinAlgError):
1503
+ return 0, 0.0, 0.0
1504
+
1505
+ if s_vals.numel() == 0:
1506
+ return 0, 0.0, 0.0
1507
+
1508
+ sigma_max = float(s_vals.max().item())
1509
+ max_ratio = sigma_max / max(mp_edge_val, 1e-12)
1510
+ outlier_count = int((s_vals > threshold).sum().item())
1511
+ return outlier_count, float(max_ratio), sigma_max
1512
+
1513
+ def _compute_activation_outliers(
1514
+ self, model: nn.Module, batches: list[Any]
1515
+ ) -> dict[str, Any] | None:
1516
+ """Compute activation-based RMT outlier counts."""
1517
+ if not batches:
1518
+ return None
1519
+
1520
+ modules = self._get_activation_modules(model)
1521
+ if not modules:
1522
+ return None
1523
+
1524
+ per_layer_map: dict[str, dict[str, Any]] = {}
1525
+ batch_weight_holder = {"weight": 1}
1526
+ for idx, (module_name, _module) in enumerate(modules):
1527
+ per_layer_map[module_name] = {
1528
+ "layer": idx,
1529
+ "module_name": module_name,
1530
+ "sigma_max": 0.0,
1531
+ "worst_ratio": 0.0,
1532
+ "outlier_count": 0,
1533
+ "has_outlier": False,
1534
+ }
1535
+
1536
+ handles: list[Any] = []
1537
+
1538
+ def _make_hook(name: str):
1539
+ def _hook(_module: nn.Module, _inputs: tuple[Any, ...], output: Any):
1540
+ try:
1541
+ outliers, max_ratio, sigma_max = self._activation_svd_outliers(
1542
+ output, self.margin, self.deadband
1543
+ )
1544
+ except Exception:
1545
+ return
1546
+ stats = per_layer_map.get(name)
1547
+ if stats is None:
1548
+ return
1549
+ weight = int(batch_weight_holder.get("weight", 1) or 1)
1550
+ if outliers > 0:
1551
+ increment = int(outliers) * weight
1552
+ stats["outlier_count"] = (
1553
+ int(stats.get("outlier_count", 0)) + increment
1554
+ )
1555
+ stats["has_outlier"] = True
1556
+ stats["worst_ratio"] = max(
1557
+ float(stats.get("worst_ratio", 0.0)), float(max_ratio)
1558
+ )
1559
+ stats["sigma_max"] = max(
1560
+ float(stats.get("sigma_max", 0.0)), float(sigma_max)
1561
+ )
1562
+
1563
+ return _hook
1564
+
1565
+ for name, module in modules:
1566
+ try:
1567
+ handles.append(module.register_forward_hook(_make_hook(name)))
1568
+ except Exception:
1569
+ continue
1570
+
1571
+ model_was_training = model.training
1572
+ model.eval()
1573
+ device = next(model.parameters()).device
1574
+ batches_used = 0
1575
+ token_weight_total = 0
1576
+
1577
+ try:
1578
+ with torch.no_grad():
1579
+ for batch in batches:
1580
+ inputs, attention_mask = self._prepare_activation_inputs(
1581
+ batch, device
1582
+ )
1583
+ if inputs is None:
1584
+ continue
1585
+ batch_weight = self._batch_token_weight(inputs, attention_mask)
1586
+ batch_weight_holder["weight"] = batch_weight
1587
+ try:
1588
+ if attention_mask is not None:
1589
+ model(inputs, attention_mask=attention_mask)
1590
+ else:
1591
+ model(inputs)
1592
+ batches_used += 1
1593
+ token_weight_total += batch_weight
1594
+ except TypeError:
1595
+ try:
1596
+ model(inputs)
1597
+ batches_used += 1
1598
+ token_weight_total += batch_weight
1599
+ except Exception:
1600
+ continue
1601
+ except Exception:
1602
+ continue
1603
+ finally:
1604
+ for handle in handles:
1605
+ try:
1606
+ handle.remove()
1607
+ except Exception:
1608
+ pass
1609
+ if model_was_training:
1610
+ model.train()
1611
+
1612
+ if batches_used == 0:
1613
+ return None
1614
+
1615
+ per_layer = [per_layer_map[name] for name, _module in modules]
1616
+ flagged_layers = [
1617
+ info["layer"] for info in per_layer if info.get("has_outlier")
1618
+ ]
1619
+ outlier_total = sum(
1620
+ int(info.get("outlier_count", 0) or 0) for info in per_layer
1621
+ )
1622
+ max_ratio = max(
1623
+ (float(info.get("worst_ratio", 0.0)) for info in per_layer), default=0.0
1624
+ )
1625
+
1626
+ return {
1627
+ "has_outliers": bool(flagged_layers),
1628
+ "n_layers_flagged": len(flagged_layers),
1629
+ "outlier_count": outlier_total,
1630
+ "max_ratio": max_ratio,
1631
+ "threshold": (1.0 + self.deadband) * self.margin,
1632
+ "per_layer": per_layer,
1633
+ "flagged_layers": flagged_layers,
1634
+ "analysis_source": "activations",
1635
+ "token_weight_total": int(token_weight_total),
1636
+ "token_weighted": True,
1637
+ }
1638
+
1315
1639
  def _apply_rmt_detection_and_correction(self, model: nn.Module) -> dict[str, Any]:
1316
1640
  """
1317
1641
  Apply Step 5 RMT detection and correction with adapter support.
@@ -1462,7 +1786,7 @@ class RMTGuard(Guard):
1462
1786
  Args:
1463
1787
  model: The model that will be edited
1464
1788
  adapter: ModelAdapter (optional, for tying map access)
1465
- calib: Calibration data (unused for RMT)
1789
+ calib: Calibration data for activation-based outlier counting
1466
1790
  policy: Guard policy parameters (optional)
1467
1791
 
1468
1792
  Returns:
@@ -1471,6 +1795,8 @@ class RMTGuard(Guard):
1471
1795
  import time
1472
1796
 
1473
1797
  start_time = time.time()
1798
+ self._activation_required_failed = False
1799
+ self._activation_required_reason = None
1474
1800
 
1475
1801
  # Store adapter for tying map access during correction
1476
1802
  self.adapter = adapter
@@ -1485,6 +1811,8 @@ class RMTGuard(Guard):
1485
1811
  self._set_epsilon(policy["epsilon"])
1486
1812
  if "epsilon_by_family" in policy:
1487
1813
  self._set_epsilon(policy["epsilon_by_family"])
1814
+ if "activation_required" in policy:
1815
+ self._require_activation = bool(policy.get("activation_required"))
1488
1816
 
1489
1817
  self._log_event(
1490
1818
  "prepare",
@@ -1502,19 +1830,59 @@ class RMTGuard(Guard):
1502
1830
 
1503
1831
  # Get linear modules in scope
1504
1832
  linear_modules = self._get_linear_modules(model)
1833
+ self._activation_ready = False
1834
+ self._calibration_batches = []
1835
+ max_windows = 0
1836
+ if calib is not None:
1837
+ source = getattr(calib, "dataloader", None) or calib
1838
+ if hasattr(source, "__len__"):
1839
+ try:
1840
+ max_windows = int(len(source))
1841
+ except Exception:
1842
+ max_windows = 0
1843
+ self._calibration_batches = self._collect_calibration_batches(
1844
+ calib, max_windows
1845
+ )
1505
1846
 
1506
- baseline_detection = rmt_detect(
1507
- model=model,
1508
- threshold=self.margin,
1509
- detect_only=True,
1510
- baseline_sigmas=self.baseline_sigmas,
1511
- baseline_mp_stats=self.baseline_mp_stats,
1512
- deadband=self.deadband,
1513
- )
1514
- self.baseline_total_outliers = baseline_detection.get("n_layers_flagged", 0)
1515
- self.baseline_outliers_per_family = self._count_outliers_per_family(
1516
- baseline_detection.get("per_layer", [])
1517
- )
1847
+ activation_baseline = None
1848
+ if self._calibration_batches:
1849
+ activation_baseline = self._compute_activation_outliers(
1850
+ model, self._calibration_batches
1851
+ )
1852
+
1853
+ if activation_baseline:
1854
+ self._activation_ready = True
1855
+ self.baseline_total_outliers = int(
1856
+ activation_baseline.get("outlier_count", 0) or 0
1857
+ )
1858
+ self.baseline_outliers_per_family = self._count_outliers_per_family(
1859
+ activation_baseline.get("per_layer", [])
1860
+ )
1861
+ else:
1862
+ if self._require_activation:
1863
+ self._activation_required_failed = True
1864
+ self._activation_required_reason = "activation_baseline_unavailable"
1865
+ self._log_event(
1866
+ "activation_required_missing",
1867
+ level="ERROR",
1868
+ message="Activation baseline unavailable for RMT",
1869
+ profile=self._run_profile,
1870
+ tier=self._run_tier,
1871
+ )
1872
+ baseline_detection = rmt_detect(
1873
+ model=model,
1874
+ threshold=self.margin,
1875
+ detect_only=True,
1876
+ baseline_sigmas=self.baseline_sigmas,
1877
+ baseline_mp_stats=self.baseline_mp_stats,
1878
+ deadband=self.deadband,
1879
+ )
1880
+ self.baseline_total_outliers = baseline_detection.get(
1881
+ "n_layers_flagged", 0
1882
+ )
1883
+ self.baseline_outliers_per_family = self._count_outliers_per_family(
1884
+ baseline_detection.get("per_layer", [])
1885
+ )
1518
1886
  for family_key in ("attn", "ffn", "embed", "other"):
1519
1887
  self.baseline_outliers_per_family.setdefault(family_key, 0)
1520
1888
  self.outliers_per_family = {}
@@ -1595,7 +1963,7 @@ class RMTGuard(Guard):
1595
1963
  Args:
1596
1964
  model: The model that was just edited
1597
1965
  """
1598
- if not self.prepared or not self.baseline_mp_stats:
1966
+ if not self.prepared:
1599
1967
  self._log_event(
1600
1968
  "after_edit_skipped",
1601
1969
  level="WARN",
@@ -1606,22 +1974,87 @@ class RMTGuard(Guard):
1606
1974
  self._log_event("after_edit", message="Applying RMT detection and correction")
1607
1975
 
1608
1976
  try:
1609
- # Perform RMT detection with baseline awareness
1610
- # Create custom detection with proper adapter support
1611
- if self.correct:
1612
- # Apply correction using enhanced logic with adapter support
1613
- detection_result = self._apply_rmt_detection_and_correction(model)
1614
- else:
1615
- # Detection only
1616
- detection_result = rmt_detect(
1617
- model=model,
1618
- threshold=self.margin, # Use margin as threshold
1619
- detect_only=True,
1620
- verbose=False,
1621
- baseline_sigmas=self.baseline_sigmas,
1622
- baseline_mp_stats=self.baseline_mp_stats,
1623
- deadband=self.deadband,
1977
+ detection_result: dict[str, Any] | None = None
1978
+ corrected_layers = 0
1979
+ correction_iterations = 0
1980
+ use_activation = bool(self._activation_ready and self._calibration_batches)
1981
+
1982
+ if self._require_activation and not use_activation:
1983
+ self._activation_required_failed = True
1984
+ self._activation_required_reason = "activation_unavailable"
1985
+ self._last_result = {
1986
+ "has_outliers": False,
1987
+ "n_layers_flagged": 0,
1988
+ "per_layer": [],
1989
+ "max_ratio": 0.0,
1990
+ "analysis_source": "activations",
1991
+ }
1992
+ self._log_event(
1993
+ "activation_required_missing",
1994
+ level="ERROR",
1995
+ message="Activation outlier analysis required but unavailable",
1996
+ profile=self._run_profile,
1997
+ tier=self._run_tier,
1624
1998
  )
1999
+ return
2000
+
2001
+ if use_activation:
2002
+ if self.correct:
2003
+ correction_result = self._apply_rmt_detection_and_correction(model)
2004
+ correction_iterations = int(
2005
+ correction_result.get("correction_iterations", 0) or 0
2006
+ )
2007
+ corrected_layers = int(
2008
+ correction_result.get("corrected_layers", 0) or 0
2009
+ )
2010
+ detection_result = self._compute_activation_outliers(
2011
+ model, self._calibration_batches
2012
+ )
2013
+ if not detection_result:
2014
+ if self._require_activation:
2015
+ self._activation_required_failed = True
2016
+ self._activation_required_reason = (
2017
+ "activation_outliers_unavailable"
2018
+ )
2019
+ self._last_result = {
2020
+ "has_outliers": False,
2021
+ "n_layers_flagged": 0,
2022
+ "per_layer": [],
2023
+ "max_ratio": 0.0,
2024
+ "analysis_source": "activations",
2025
+ }
2026
+ self._log_event(
2027
+ "activation_required_missing",
2028
+ level="ERROR",
2029
+ message="Activation outlier analysis failed",
2030
+ profile=self._run_profile,
2031
+ tier=self._run_tier,
2032
+ )
2033
+ return
2034
+ use_activation = False
2035
+
2036
+ if not use_activation:
2037
+ if self.correct:
2038
+ # Apply correction using enhanced logic with adapter support
2039
+ detection_result = self._apply_rmt_detection_and_correction(model)
2040
+ else:
2041
+ # Detection only
2042
+ detection_result = rmt_detect(
2043
+ model=model,
2044
+ threshold=self.margin, # Use margin as threshold
2045
+ detect_only=True,
2046
+ verbose=False,
2047
+ baseline_sigmas=self.baseline_sigmas,
2048
+ baseline_mp_stats=self.baseline_mp_stats,
2049
+ deadband=self.deadband,
2050
+ )
2051
+
2052
+ if detection_result is None:
2053
+ raise RuntimeError("RMT detection failed to produce results")
2054
+
2055
+ if use_activation:
2056
+ detection_result["correction_iterations"] = correction_iterations
2057
+ detection_result["corrected_layers"] = corrected_layers
1625
2058
 
1626
2059
  # Store results
1627
2060
  self._last_result = detection_result
@@ -1630,9 +2063,12 @@ class RMTGuard(Guard):
1630
2063
  )
1631
2064
  for family_key in ("attn", "ffn", "embed", "other"):
1632
2065
  self.outliers_per_family.setdefault(family_key, 0)
1633
- self.outliers_total = detection_result.get(
1634
- "n_layers_flagged", len(self.outliers_per_family)
1635
- )
2066
+ outlier_total = detection_result.get("outlier_count")
2067
+ if outlier_total is None:
2068
+ outlier_total = detection_result.get(
2069
+ "n_layers_flagged", len(self.outliers_per_family)
2070
+ )
2071
+ self.outliers_total = int(outlier_total or 0)
1636
2072
  self.epsilon_violations = self._compute_epsilon_violations()
1637
2073
 
1638
2074
  flagged_layers = detection_result.get("n_layers_flagged", 0)
@@ -1779,6 +2215,49 @@ class RMTGuard(Guard):
1779
2215
  }
1780
2216
 
1781
2217
  # Get results from after_edit
2218
+ if self._require_activation and (
2219
+ self._activation_required_failed or not self._activation_ready
2220
+ ):
2221
+ reason = self._activation_required_reason or "activation_required"
2222
+ message = "Activation outlier analysis required but unavailable"
2223
+ finalize_time = time.time() - start_time
2224
+ if HAS_GUARD_OUTCOME:
2225
+ return GuardOutcome(
2226
+ name=self.name,
2227
+ passed=False,
2228
+ action="rollback",
2229
+ violations=[
2230
+ {
2231
+ "type": "activation_required",
2232
+ "severity": "error",
2233
+ "message": message,
2234
+ "module_name": None,
2235
+ "reason": reason,
2236
+ }
2237
+ ],
2238
+ metrics={
2239
+ "prepared": True,
2240
+ "activation_required": True,
2241
+ "activation_ready": False,
2242
+ "activation_reason": reason,
2243
+ "finalize_time": finalize_time,
2244
+ },
2245
+ )
2246
+ return {
2247
+ "passed": False,
2248
+ "metrics": {
2249
+ "prepared": True,
2250
+ "activation_required": True,
2251
+ "activation_ready": False,
2252
+ "activation_reason": reason,
2253
+ "finalize_time": finalize_time,
2254
+ },
2255
+ "warnings": [],
2256
+ "errors": [message],
2257
+ "violations": [],
2258
+ "events": self.events,
2259
+ }
2260
+
1782
2261
  result = self._last_result or {
1783
2262
  "has_outliers": False,
1784
2263
  "n_layers_flagged": 0,
@@ -1793,7 +2272,10 @@ class RMTGuard(Guard):
1793
2272
  for family_key in ("attn", "ffn", "embed", "other"):
1794
2273
  self.outliers_per_family.setdefault(family_key, 0)
1795
2274
  self.baseline_outliers_per_family.setdefault(family_key, 0)
1796
- self.outliers_total = result.get("n_layers_flagged", self.outliers_total or 0)
2275
+ outlier_total = result.get("outlier_count")
2276
+ if outlier_total is None:
2277
+ outlier_total = result.get("n_layers_flagged", self.outliers_total or 0)
2278
+ self.outliers_total = int(outlier_total or 0)
1797
2279
  self.epsilon_violations = self._compute_epsilon_violations()
1798
2280
  # Contracts: epsilon non-negative, counts non-negative
1799
2281
  for fam, eps in self.epsilon_by_family.items():
@@ -1812,12 +2294,17 @@ class RMTGuard(Guard):
1812
2294
 
1813
2295
  # Calculate metrics
1814
2296
  flagged_layers = result.get("n_layers_flagged", 0)
1815
- total_layers = len(self.baseline_mp_stats) if self.baseline_mp_stats else 0
2297
+ total_layers = (
2298
+ len(result.get("per_layer", []))
2299
+ if result.get("per_layer")
2300
+ else len(self.baseline_mp_stats)
2301
+ if self.baseline_mp_stats
2302
+ else 0
2303
+ )
1816
2304
  flagged_rate = flagged_layers / total_layers if total_layers > 0 else 0.0
1817
2305
 
1818
- # Step 5 validation gate: no increase in outliers vs bare edit, ≤1% primary-metric cost
1819
- # For now, use flagged rate as proxy (will be enhanced with PM checking)
1820
- passed = flagged_rate <= 0.5 # Allow up to 50% flagged for conservative gate
2306
+ # Acceptance gate: pass when epsilon-rule holds per family.
2307
+ passed = not bool(self.epsilon_violations)
1821
2308
 
1822
2309
  # Generate violations for GuardOutcome
1823
2310
  violations = []
@@ -1844,11 +2331,10 @@ class RMTGuard(Guard):
1844
2331
  f"High RMT outlier rate: {flagged_layers}/{total_layers} layers flagged ({flagged_rate:.1%})"
1845
2332
  )
1846
2333
 
1847
- if flagged_rate > 0.7: # Error threshold at 70%
1848
- errors.append(
2334
+ if flagged_rate > 0.7: # Escalate to warning for unusually high rates
2335
+ warnings.append(
1849
2336
  f"Excessive RMT outliers: {flagged_layers}/{total_layers} layers flagged"
1850
2337
  )
1851
- passed = False
1852
2338
 
1853
2339
  if self.epsilon_violations:
1854
2340
  passed = False
@@ -1879,6 +2365,10 @@ class RMTGuard(Guard):
1879
2365
  if self.baseline_mp_stats
1880
2366
  else 0,
1881
2367
  "finalize_time": finalize_time,
2368
+ "activation_required": bool(self._require_activation),
2369
+ "activation_ready": bool(self._activation_ready),
2370
+ "analysis_source": result.get("analysis_source"),
2371
+ "token_weight_total": result.get("token_weight_total"),
1882
2372
  "baseline_outliers_per_family": {
1883
2373
  k: int(v) for k, v in self.baseline_outliers_per_family.items()
1884
2374
  },