invarlock 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/adapters/auto.py +4 -4
- invarlock/adapters/hf_bert.py +6 -5
- invarlock/adapters/hf_gpt2.py +5 -4
- invarlock/adapters/hf_llama.py +4 -2
- invarlock/adapters/hf_mixin.py +88 -9
- invarlock/adapters/hf_t5.py +5 -3
- invarlock/cli/commands/run.py +566 -141
- invarlock/cli/commands/verify.py +12 -0
- invarlock/cli/config.py +11 -1
- invarlock/cli/determinism.py +16 -1
- invarlock/core/bootstrap.py +137 -5
- invarlock/core/runner.py +345 -50
- invarlock/eval/bench_regression.py +1 -1
- invarlock/eval/bootstrap.py +3 -1
- invarlock/eval/data.py +11 -0
- invarlock/eval/primary_metric.py +20 -5
- invarlock/guards/rmt.py +536 -46
- invarlock/guards/spectral.py +1 -1
- invarlock/guards/variance.py +122 -43
- invarlock/reporting/certificate.py +258 -12
- invarlock/reporting/normalizer.py +3 -0
- invarlock/reporting/policy_utils.py +1 -3
- invarlock/reporting/primary_metric_utils.py +17 -0
- invarlock/reporting/validate.py +10 -10
- {invarlock-0.3.2.dist-info → invarlock-0.3.4.dist-info}/METADATA +2 -2
- {invarlock-0.3.2.dist-info → invarlock-0.3.4.dist-info}/RECORD +31 -31
- {invarlock-0.3.2.dist-info → invarlock-0.3.4.dist-info}/WHEEL +0 -0
- {invarlock-0.3.2.dist-info → invarlock-0.3.4.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.2.dist-info → invarlock-0.3.4.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.2.dist-info → invarlock-0.3.4.dist-info}/top_level.txt +0 -0
invarlock/guards/rmt.py
CHANGED
|
@@ -11,6 +11,7 @@ to training instability.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
+
import itertools
|
|
14
15
|
import math
|
|
15
16
|
from dataclasses import dataclass
|
|
16
17
|
from datetime import datetime
|
|
@@ -1117,8 +1118,9 @@ class RMTGuard(Guard):
|
|
|
1117
1118
|
"""
|
|
1118
1119
|
Standalone RMT Guard for baseline-aware outlier detection and correction.
|
|
1119
1120
|
|
|
1120
|
-
Implements Marchenko-Pastur theory-based
|
|
1121
|
-
-
|
|
1121
|
+
Implements Marchenko-Pastur theory-based outlier tracking with:
|
|
1122
|
+
- Activation-based outlier counts from calibration batches
|
|
1123
|
+
- Baseline capture of MP bulk edges for linear layers (correction/fallback)
|
|
1122
1124
|
- Conservative outlier detection with deadband support
|
|
1123
1125
|
- Optional in-place correction preserving weight tying
|
|
1124
1126
|
- Comprehensive event logging and metrics
|
|
@@ -1129,7 +1131,7 @@ class RMTGuard(Guard):
|
|
|
1129
1131
|
- margin: RMT threshold ratio (default 1.5)
|
|
1130
1132
|
- correct: Enable automatic correction (default True)
|
|
1131
1133
|
|
|
1132
|
-
Linear Layer Scope (
|
|
1134
|
+
Linear Layer Scope (correction/fallback):
|
|
1133
1135
|
- attn.c_attn, attn.c_proj, mlp.c_fc, mlp.c_proj
|
|
1134
1136
|
- Excludes: embeddings, LM head, layer norms, biases
|
|
1135
1137
|
"""
|
|
@@ -1164,6 +1166,13 @@ class RMTGuard(Guard):
|
|
|
1164
1166
|
self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
|
|
1165
1167
|
|
|
1166
1168
|
# Internal state
|
|
1169
|
+
self._calibration_batches: list[Any] = []
|
|
1170
|
+
self._activation_ready = False
|
|
1171
|
+
self._require_activation = False
|
|
1172
|
+
self._activation_required_failed = False
|
|
1173
|
+
self._activation_required_reason: str | None = None
|
|
1174
|
+
self._run_profile: str | None = None
|
|
1175
|
+
self._run_tier: str | None = None
|
|
1167
1176
|
self.baseline_mp_stats: dict[str, dict[str, float]] | None = None
|
|
1168
1177
|
self.baseline_sigmas: dict[str, float] | None = None
|
|
1169
1178
|
self.prepared = False
|
|
@@ -1198,6 +1207,22 @@ class RMTGuard(Guard):
|
|
|
1198
1207
|
}
|
|
1199
1208
|
self.events.append(event)
|
|
1200
1209
|
|
|
1210
|
+
def set_run_context(self, report: Any) -> None:
|
|
1211
|
+
"""Capture tier/profile context for activation requirements."""
|
|
1212
|
+
ctx = getattr(report, "context", {}) or {}
|
|
1213
|
+
profile = ""
|
|
1214
|
+
tier = "balanced"
|
|
1215
|
+
if isinstance(ctx, dict):
|
|
1216
|
+
profile = str(ctx.get("profile", "") or "").strip().lower()
|
|
1217
|
+
auto = ctx.get("auto")
|
|
1218
|
+
if isinstance(auto, dict):
|
|
1219
|
+
tier = str(auto.get("tier", tier) or tier).strip().lower()
|
|
1220
|
+
self._run_profile = profile or None
|
|
1221
|
+
self._run_tier = tier or None
|
|
1222
|
+
self._require_activation = bool(
|
|
1223
|
+
profile in {"ci", "release"} and tier in {"balanced", "conservative"}
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1201
1226
|
def _set_epsilon(self, epsilon: float | dict[str, float] | None) -> None:
|
|
1202
1227
|
"""Configure epsilon defaults and per-family overrides."""
|
|
1203
1228
|
if isinstance(epsilon, dict):
|
|
@@ -1244,11 +1269,21 @@ class RMTGuard(Guard):
|
|
|
1244
1269
|
"""Count outliers grouped by family."""
|
|
1245
1270
|
counts: dict[str, int] = {}
|
|
1246
1271
|
for layer_info in per_layer:
|
|
1247
|
-
|
|
1248
|
-
|
|
1272
|
+
outlier_count = layer_info.get("outlier_count")
|
|
1273
|
+
if outlier_count is None:
|
|
1274
|
+
if not layer_info.get("has_outlier"):
|
|
1275
|
+
continue
|
|
1276
|
+
increment = 1
|
|
1277
|
+
else:
|
|
1278
|
+
try:
|
|
1279
|
+
increment = int(outlier_count)
|
|
1280
|
+
except (TypeError, ValueError):
|
|
1281
|
+
continue
|
|
1282
|
+
if increment <= 0:
|
|
1283
|
+
continue
|
|
1249
1284
|
module_name = layer_info.get("module_name", "")
|
|
1250
1285
|
family = self._classify_family(module_name)
|
|
1251
|
-
counts[family] = counts.get(family, 0) +
|
|
1286
|
+
counts[family] = counts.get(family, 0) + increment
|
|
1252
1287
|
return counts
|
|
1253
1288
|
|
|
1254
1289
|
def _compute_epsilon_violations(self) -> list[dict[str, Any]]:
|
|
@@ -1312,6 +1347,295 @@ class RMTGuard(Guard):
|
|
|
1312
1347
|
|
|
1313
1348
|
return modules
|
|
1314
1349
|
|
|
1350
|
+
def _collect_calibration_batches(self, calib: Any, max_windows: int) -> list[Any]:
|
|
1351
|
+
"""Collect a deterministic slice of calibration batches."""
|
|
1352
|
+
if calib is None or max_windows <= 0:
|
|
1353
|
+
return []
|
|
1354
|
+
source = getattr(calib, "dataloader", None) or calib
|
|
1355
|
+
try:
|
|
1356
|
+
iterator = iter(source)
|
|
1357
|
+
except TypeError:
|
|
1358
|
+
return []
|
|
1359
|
+
return list(itertools.islice(iterator, max_windows))
|
|
1360
|
+
|
|
1361
|
+
def _prepare_activation_inputs(
|
|
1362
|
+
self, batch: Any, device: torch.device
|
|
1363
|
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
|
1364
|
+
"""Normalize batch inputs to tensors on the target device."""
|
|
1365
|
+
if isinstance(batch, dict):
|
|
1366
|
+
input_ids = batch.get("input_ids", batch.get("inputs"))
|
|
1367
|
+
attention_mask = batch.get("attention_mask")
|
|
1368
|
+
elif isinstance(batch, tuple | list) and batch:
|
|
1369
|
+
input_ids = batch[0]
|
|
1370
|
+
attention_mask = batch[1] if len(batch) > 1 else None
|
|
1371
|
+
else:
|
|
1372
|
+
input_ids = batch
|
|
1373
|
+
attention_mask = None
|
|
1374
|
+
|
|
1375
|
+
if input_ids is None:
|
|
1376
|
+
return None, None
|
|
1377
|
+
|
|
1378
|
+
if not isinstance(input_ids, torch.Tensor):
|
|
1379
|
+
input_ids = torch.as_tensor(input_ids)
|
|
1380
|
+
if input_ids.dim() == 1:
|
|
1381
|
+
input_ids = input_ids.unsqueeze(0)
|
|
1382
|
+
try:
|
|
1383
|
+
input_ids = input_ids.to(device)
|
|
1384
|
+
except Exception:
|
|
1385
|
+
input_ids = input_ids.clone()
|
|
1386
|
+
|
|
1387
|
+
if attention_mask is not None:
|
|
1388
|
+
if not isinstance(attention_mask, torch.Tensor):
|
|
1389
|
+
attention_mask = torch.as_tensor(attention_mask)
|
|
1390
|
+
if attention_mask.dim() == 1:
|
|
1391
|
+
attention_mask = attention_mask.unsqueeze(0)
|
|
1392
|
+
try:
|
|
1393
|
+
attention_mask = attention_mask.to(device)
|
|
1394
|
+
except Exception:
|
|
1395
|
+
attention_mask = attention_mask.clone()
|
|
1396
|
+
|
|
1397
|
+
return input_ids, attention_mask
|
|
1398
|
+
|
|
1399
|
+
@staticmethod
|
|
1400
|
+
def _batch_token_weight(
|
|
1401
|
+
input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None
|
|
1402
|
+
) -> int:
|
|
1403
|
+
"""Compute token-weight for a batch (used for activation outlier weighting)."""
|
|
1404
|
+
weight = 0
|
|
1405
|
+
if isinstance(attention_mask, torch.Tensor):
|
|
1406
|
+
try:
|
|
1407
|
+
weight = int(attention_mask.sum().item())
|
|
1408
|
+
except Exception:
|
|
1409
|
+
weight = 0
|
|
1410
|
+
if weight <= 0 and isinstance(input_ids, torch.Tensor):
|
|
1411
|
+
try:
|
|
1412
|
+
weight = int(input_ids.numel())
|
|
1413
|
+
except Exception:
|
|
1414
|
+
weight = 0
|
|
1415
|
+
return max(weight, 1)
|
|
1416
|
+
|
|
1417
|
+
def _get_activation_modules(self, model: nn.Module) -> list[tuple[str, nn.Module]]:
|
|
1418
|
+
"""Return modules to analyze for activation-based RMT."""
|
|
1419
|
+
modules: list[tuple[str, nn.Module]] = []
|
|
1420
|
+
try:
|
|
1421
|
+
from transformers.pytorch_utils import Conv1D
|
|
1422
|
+
|
|
1423
|
+
module_types_with_conv1d: tuple[
|
|
1424
|
+
type[nn.Linear], type[nn.Conv1d], type[Conv1D]
|
|
1425
|
+
] = (nn.Linear, nn.Conv1d, Conv1D)
|
|
1426
|
+
module_types = module_types_with_conv1d
|
|
1427
|
+
except ImportError:
|
|
1428
|
+
module_types_without_conv1d: tuple[type[nn.Linear], type[nn.Conv1d]] = (
|
|
1429
|
+
nn.Linear,
|
|
1430
|
+
nn.Conv1d,
|
|
1431
|
+
)
|
|
1432
|
+
module_types = module_types_without_conv1d
|
|
1433
|
+
|
|
1434
|
+
for name, module in model.named_modules():
|
|
1435
|
+
if isinstance(module, nn.Embedding):
|
|
1436
|
+
modules.append((name, module))
|
|
1437
|
+
continue
|
|
1438
|
+
if isinstance(module, nn.LayerNorm):
|
|
1439
|
+
modules.append((name, module))
|
|
1440
|
+
continue
|
|
1441
|
+
if isinstance(module, module_types) and hasattr(module, "weight"):
|
|
1442
|
+
name_lower = name.lower()
|
|
1443
|
+
if any(
|
|
1444
|
+
name.endswith(suffix) for suffix in self.allowed_suffixes
|
|
1445
|
+
) or any(
|
|
1446
|
+
tok in name_lower
|
|
1447
|
+
for tok in (
|
|
1448
|
+
"attn",
|
|
1449
|
+
"attention",
|
|
1450
|
+
"mlp",
|
|
1451
|
+
"ffn",
|
|
1452
|
+
"router",
|
|
1453
|
+
"expert",
|
|
1454
|
+
"moe",
|
|
1455
|
+
"gate",
|
|
1456
|
+
"gating",
|
|
1457
|
+
"switch",
|
|
1458
|
+
)
|
|
1459
|
+
):
|
|
1460
|
+
modules.append((name, module))
|
|
1461
|
+
|
|
1462
|
+
return modules
|
|
1463
|
+
|
|
1464
|
+
def _activation_svd_outliers(
|
|
1465
|
+
self, activations: Any, margin: float, deadband: float
|
|
1466
|
+
) -> tuple[int, float, float]:
|
|
1467
|
+
"""Count activation singular values beyond the MP edge."""
|
|
1468
|
+
if isinstance(activations, tuple | list):
|
|
1469
|
+
activations = activations[0] if activations else None
|
|
1470
|
+
if not isinstance(activations, torch.Tensor):
|
|
1471
|
+
return 0, 0.0, 0.0
|
|
1472
|
+
|
|
1473
|
+
if activations.dim() < 2:
|
|
1474
|
+
return 0, 0.0, 0.0
|
|
1475
|
+
|
|
1476
|
+
if activations.dim() > 2:
|
|
1477
|
+
activations = activations.reshape(-1, activations.shape[-1])
|
|
1478
|
+
|
|
1479
|
+
if activations.numel() == 0:
|
|
1480
|
+
return 0, 0.0, 0.0
|
|
1481
|
+
|
|
1482
|
+
try:
|
|
1483
|
+
mat = activations.detach().float().cpu()
|
|
1484
|
+
except Exception:
|
|
1485
|
+
return 0, 0.0, 0.0
|
|
1486
|
+
|
|
1487
|
+
if not torch.isfinite(mat).all():
|
|
1488
|
+
return 0, 0.0, 0.0
|
|
1489
|
+
|
|
1490
|
+
mat = mat - mat.mean()
|
|
1491
|
+
std = float(mat.std().item())
|
|
1492
|
+
if not math.isfinite(std) or std <= 0.0:
|
|
1493
|
+
return 0, 0.0, 0.0
|
|
1494
|
+
|
|
1495
|
+
mat = mat / std
|
|
1496
|
+
m, n = mat.shape
|
|
1497
|
+
mp_edge_val = mp_bulk_edge(m, n, whitened=False)
|
|
1498
|
+
threshold = mp_edge_val * (1.0 + deadband) * margin
|
|
1499
|
+
|
|
1500
|
+
try:
|
|
1501
|
+
s_vals = torch.linalg.svdvals(mat)
|
|
1502
|
+
except (RuntimeError, torch.linalg.LinAlgError):
|
|
1503
|
+
return 0, 0.0, 0.0
|
|
1504
|
+
|
|
1505
|
+
if s_vals.numel() == 0:
|
|
1506
|
+
return 0, 0.0, 0.0
|
|
1507
|
+
|
|
1508
|
+
sigma_max = float(s_vals.max().item())
|
|
1509
|
+
max_ratio = sigma_max / max(mp_edge_val, 1e-12)
|
|
1510
|
+
outlier_count = int((s_vals > threshold).sum().item())
|
|
1511
|
+
return outlier_count, float(max_ratio), sigma_max
|
|
1512
|
+
|
|
1513
|
+
def _compute_activation_outliers(
|
|
1514
|
+
self, model: nn.Module, batches: list[Any]
|
|
1515
|
+
) -> dict[str, Any] | None:
|
|
1516
|
+
"""Compute activation-based RMT outlier counts."""
|
|
1517
|
+
if not batches:
|
|
1518
|
+
return None
|
|
1519
|
+
|
|
1520
|
+
modules = self._get_activation_modules(model)
|
|
1521
|
+
if not modules:
|
|
1522
|
+
return None
|
|
1523
|
+
|
|
1524
|
+
per_layer_map: dict[str, dict[str, Any]] = {}
|
|
1525
|
+
batch_weight_holder = {"weight": 1}
|
|
1526
|
+
for idx, (module_name, _module) in enumerate(modules):
|
|
1527
|
+
per_layer_map[module_name] = {
|
|
1528
|
+
"layer": idx,
|
|
1529
|
+
"module_name": module_name,
|
|
1530
|
+
"sigma_max": 0.0,
|
|
1531
|
+
"worst_ratio": 0.0,
|
|
1532
|
+
"outlier_count": 0,
|
|
1533
|
+
"has_outlier": False,
|
|
1534
|
+
}
|
|
1535
|
+
|
|
1536
|
+
handles: list[Any] = []
|
|
1537
|
+
|
|
1538
|
+
def _make_hook(name: str):
|
|
1539
|
+
def _hook(_module: nn.Module, _inputs: tuple[Any, ...], output: Any):
|
|
1540
|
+
try:
|
|
1541
|
+
outliers, max_ratio, sigma_max = self._activation_svd_outliers(
|
|
1542
|
+
output, self.margin, self.deadband
|
|
1543
|
+
)
|
|
1544
|
+
except Exception:
|
|
1545
|
+
return
|
|
1546
|
+
stats = per_layer_map.get(name)
|
|
1547
|
+
if stats is None:
|
|
1548
|
+
return
|
|
1549
|
+
weight = int(batch_weight_holder.get("weight", 1) or 1)
|
|
1550
|
+
if outliers > 0:
|
|
1551
|
+
increment = int(outliers) * weight
|
|
1552
|
+
stats["outlier_count"] = (
|
|
1553
|
+
int(stats.get("outlier_count", 0)) + increment
|
|
1554
|
+
)
|
|
1555
|
+
stats["has_outlier"] = True
|
|
1556
|
+
stats["worst_ratio"] = max(
|
|
1557
|
+
float(stats.get("worst_ratio", 0.0)), float(max_ratio)
|
|
1558
|
+
)
|
|
1559
|
+
stats["sigma_max"] = max(
|
|
1560
|
+
float(stats.get("sigma_max", 0.0)), float(sigma_max)
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1563
|
+
return _hook
|
|
1564
|
+
|
|
1565
|
+
for name, module in modules:
|
|
1566
|
+
try:
|
|
1567
|
+
handles.append(module.register_forward_hook(_make_hook(name)))
|
|
1568
|
+
except Exception:
|
|
1569
|
+
continue
|
|
1570
|
+
|
|
1571
|
+
model_was_training = model.training
|
|
1572
|
+
model.eval()
|
|
1573
|
+
device = next(model.parameters()).device
|
|
1574
|
+
batches_used = 0
|
|
1575
|
+
token_weight_total = 0
|
|
1576
|
+
|
|
1577
|
+
try:
|
|
1578
|
+
with torch.no_grad():
|
|
1579
|
+
for batch in batches:
|
|
1580
|
+
inputs, attention_mask = self._prepare_activation_inputs(
|
|
1581
|
+
batch, device
|
|
1582
|
+
)
|
|
1583
|
+
if inputs is None:
|
|
1584
|
+
continue
|
|
1585
|
+
batch_weight = self._batch_token_weight(inputs, attention_mask)
|
|
1586
|
+
batch_weight_holder["weight"] = batch_weight
|
|
1587
|
+
try:
|
|
1588
|
+
if attention_mask is not None:
|
|
1589
|
+
model(inputs, attention_mask=attention_mask)
|
|
1590
|
+
else:
|
|
1591
|
+
model(inputs)
|
|
1592
|
+
batches_used += 1
|
|
1593
|
+
token_weight_total += batch_weight
|
|
1594
|
+
except TypeError:
|
|
1595
|
+
try:
|
|
1596
|
+
model(inputs)
|
|
1597
|
+
batches_used += 1
|
|
1598
|
+
token_weight_total += batch_weight
|
|
1599
|
+
except Exception:
|
|
1600
|
+
continue
|
|
1601
|
+
except Exception:
|
|
1602
|
+
continue
|
|
1603
|
+
finally:
|
|
1604
|
+
for handle in handles:
|
|
1605
|
+
try:
|
|
1606
|
+
handle.remove()
|
|
1607
|
+
except Exception:
|
|
1608
|
+
pass
|
|
1609
|
+
if model_was_training:
|
|
1610
|
+
model.train()
|
|
1611
|
+
|
|
1612
|
+
if batches_used == 0:
|
|
1613
|
+
return None
|
|
1614
|
+
|
|
1615
|
+
per_layer = [per_layer_map[name] for name, _module in modules]
|
|
1616
|
+
flagged_layers = [
|
|
1617
|
+
info["layer"] for info in per_layer if info.get("has_outlier")
|
|
1618
|
+
]
|
|
1619
|
+
outlier_total = sum(
|
|
1620
|
+
int(info.get("outlier_count", 0) or 0) for info in per_layer
|
|
1621
|
+
)
|
|
1622
|
+
max_ratio = max(
|
|
1623
|
+
(float(info.get("worst_ratio", 0.0)) for info in per_layer), default=0.0
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
return {
|
|
1627
|
+
"has_outliers": bool(flagged_layers),
|
|
1628
|
+
"n_layers_flagged": len(flagged_layers),
|
|
1629
|
+
"outlier_count": outlier_total,
|
|
1630
|
+
"max_ratio": max_ratio,
|
|
1631
|
+
"threshold": (1.0 + self.deadband) * self.margin,
|
|
1632
|
+
"per_layer": per_layer,
|
|
1633
|
+
"flagged_layers": flagged_layers,
|
|
1634
|
+
"analysis_source": "activations",
|
|
1635
|
+
"token_weight_total": int(token_weight_total),
|
|
1636
|
+
"token_weighted": True,
|
|
1637
|
+
}
|
|
1638
|
+
|
|
1315
1639
|
def _apply_rmt_detection_and_correction(self, model: nn.Module) -> dict[str, Any]:
|
|
1316
1640
|
"""
|
|
1317
1641
|
Apply Step 5 RMT detection and correction with adapter support.
|
|
@@ -1462,7 +1786,7 @@ class RMTGuard(Guard):
|
|
|
1462
1786
|
Args:
|
|
1463
1787
|
model: The model that will be edited
|
|
1464
1788
|
adapter: ModelAdapter (optional, for tying map access)
|
|
1465
|
-
calib: Calibration data
|
|
1789
|
+
calib: Calibration data for activation-based outlier counting
|
|
1466
1790
|
policy: Guard policy parameters (optional)
|
|
1467
1791
|
|
|
1468
1792
|
Returns:
|
|
@@ -1471,6 +1795,8 @@ class RMTGuard(Guard):
|
|
|
1471
1795
|
import time
|
|
1472
1796
|
|
|
1473
1797
|
start_time = time.time()
|
|
1798
|
+
self._activation_required_failed = False
|
|
1799
|
+
self._activation_required_reason = None
|
|
1474
1800
|
|
|
1475
1801
|
# Store adapter for tying map access during correction
|
|
1476
1802
|
self.adapter = adapter
|
|
@@ -1485,6 +1811,8 @@ class RMTGuard(Guard):
|
|
|
1485
1811
|
self._set_epsilon(policy["epsilon"])
|
|
1486
1812
|
if "epsilon_by_family" in policy:
|
|
1487
1813
|
self._set_epsilon(policy["epsilon_by_family"])
|
|
1814
|
+
if "activation_required" in policy:
|
|
1815
|
+
self._require_activation = bool(policy.get("activation_required"))
|
|
1488
1816
|
|
|
1489
1817
|
self._log_event(
|
|
1490
1818
|
"prepare",
|
|
@@ -1502,19 +1830,59 @@ class RMTGuard(Guard):
|
|
|
1502
1830
|
|
|
1503
1831
|
# Get linear modules in scope
|
|
1504
1832
|
linear_modules = self._get_linear_modules(model)
|
|
1833
|
+
self._activation_ready = False
|
|
1834
|
+
self._calibration_batches = []
|
|
1835
|
+
max_windows = 0
|
|
1836
|
+
if calib is not None:
|
|
1837
|
+
source = getattr(calib, "dataloader", None) or calib
|
|
1838
|
+
if hasattr(source, "__len__"):
|
|
1839
|
+
try:
|
|
1840
|
+
max_windows = int(len(source))
|
|
1841
|
+
except Exception:
|
|
1842
|
+
max_windows = 0
|
|
1843
|
+
self._calibration_batches = self._collect_calibration_batches(
|
|
1844
|
+
calib, max_windows
|
|
1845
|
+
)
|
|
1505
1846
|
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1847
|
+
activation_baseline = None
|
|
1848
|
+
if self._calibration_batches:
|
|
1849
|
+
activation_baseline = self._compute_activation_outliers(
|
|
1850
|
+
model, self._calibration_batches
|
|
1851
|
+
)
|
|
1852
|
+
|
|
1853
|
+
if activation_baseline:
|
|
1854
|
+
self._activation_ready = True
|
|
1855
|
+
self.baseline_total_outliers = int(
|
|
1856
|
+
activation_baseline.get("outlier_count", 0) or 0
|
|
1857
|
+
)
|
|
1858
|
+
self.baseline_outliers_per_family = self._count_outliers_per_family(
|
|
1859
|
+
activation_baseline.get("per_layer", [])
|
|
1860
|
+
)
|
|
1861
|
+
else:
|
|
1862
|
+
if self._require_activation:
|
|
1863
|
+
self._activation_required_failed = True
|
|
1864
|
+
self._activation_required_reason = "activation_baseline_unavailable"
|
|
1865
|
+
self._log_event(
|
|
1866
|
+
"activation_required_missing",
|
|
1867
|
+
level="ERROR",
|
|
1868
|
+
message="Activation baseline unavailable for RMT",
|
|
1869
|
+
profile=self._run_profile,
|
|
1870
|
+
tier=self._run_tier,
|
|
1871
|
+
)
|
|
1872
|
+
baseline_detection = rmt_detect(
|
|
1873
|
+
model=model,
|
|
1874
|
+
threshold=self.margin,
|
|
1875
|
+
detect_only=True,
|
|
1876
|
+
baseline_sigmas=self.baseline_sigmas,
|
|
1877
|
+
baseline_mp_stats=self.baseline_mp_stats,
|
|
1878
|
+
deadband=self.deadband,
|
|
1879
|
+
)
|
|
1880
|
+
self.baseline_total_outliers = baseline_detection.get(
|
|
1881
|
+
"n_layers_flagged", 0
|
|
1882
|
+
)
|
|
1883
|
+
self.baseline_outliers_per_family = self._count_outliers_per_family(
|
|
1884
|
+
baseline_detection.get("per_layer", [])
|
|
1885
|
+
)
|
|
1518
1886
|
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1519
1887
|
self.baseline_outliers_per_family.setdefault(family_key, 0)
|
|
1520
1888
|
self.outliers_per_family = {}
|
|
@@ -1595,7 +1963,7 @@ class RMTGuard(Guard):
|
|
|
1595
1963
|
Args:
|
|
1596
1964
|
model: The model that was just edited
|
|
1597
1965
|
"""
|
|
1598
|
-
if not self.prepared
|
|
1966
|
+
if not self.prepared:
|
|
1599
1967
|
self._log_event(
|
|
1600
1968
|
"after_edit_skipped",
|
|
1601
1969
|
level="WARN",
|
|
@@ -1606,22 +1974,87 @@ class RMTGuard(Guard):
|
|
|
1606
1974
|
self._log_event("after_edit", message="Applying RMT detection and correction")
|
|
1607
1975
|
|
|
1608
1976
|
try:
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
|
|
1623
|
-
|
|
1977
|
+
detection_result: dict[str, Any] | None = None
|
|
1978
|
+
corrected_layers = 0
|
|
1979
|
+
correction_iterations = 0
|
|
1980
|
+
use_activation = bool(self._activation_ready and self._calibration_batches)
|
|
1981
|
+
|
|
1982
|
+
if self._require_activation and not use_activation:
|
|
1983
|
+
self._activation_required_failed = True
|
|
1984
|
+
self._activation_required_reason = "activation_unavailable"
|
|
1985
|
+
self._last_result = {
|
|
1986
|
+
"has_outliers": False,
|
|
1987
|
+
"n_layers_flagged": 0,
|
|
1988
|
+
"per_layer": [],
|
|
1989
|
+
"max_ratio": 0.0,
|
|
1990
|
+
"analysis_source": "activations",
|
|
1991
|
+
}
|
|
1992
|
+
self._log_event(
|
|
1993
|
+
"activation_required_missing",
|
|
1994
|
+
level="ERROR",
|
|
1995
|
+
message="Activation outlier analysis required but unavailable",
|
|
1996
|
+
profile=self._run_profile,
|
|
1997
|
+
tier=self._run_tier,
|
|
1624
1998
|
)
|
|
1999
|
+
return
|
|
2000
|
+
|
|
2001
|
+
if use_activation:
|
|
2002
|
+
if self.correct:
|
|
2003
|
+
correction_result = self._apply_rmt_detection_and_correction(model)
|
|
2004
|
+
correction_iterations = int(
|
|
2005
|
+
correction_result.get("correction_iterations", 0) or 0
|
|
2006
|
+
)
|
|
2007
|
+
corrected_layers = int(
|
|
2008
|
+
correction_result.get("corrected_layers", 0) or 0
|
|
2009
|
+
)
|
|
2010
|
+
detection_result = self._compute_activation_outliers(
|
|
2011
|
+
model, self._calibration_batches
|
|
2012
|
+
)
|
|
2013
|
+
if not detection_result:
|
|
2014
|
+
if self._require_activation:
|
|
2015
|
+
self._activation_required_failed = True
|
|
2016
|
+
self._activation_required_reason = (
|
|
2017
|
+
"activation_outliers_unavailable"
|
|
2018
|
+
)
|
|
2019
|
+
self._last_result = {
|
|
2020
|
+
"has_outliers": False,
|
|
2021
|
+
"n_layers_flagged": 0,
|
|
2022
|
+
"per_layer": [],
|
|
2023
|
+
"max_ratio": 0.0,
|
|
2024
|
+
"analysis_source": "activations",
|
|
2025
|
+
}
|
|
2026
|
+
self._log_event(
|
|
2027
|
+
"activation_required_missing",
|
|
2028
|
+
level="ERROR",
|
|
2029
|
+
message="Activation outlier analysis failed",
|
|
2030
|
+
profile=self._run_profile,
|
|
2031
|
+
tier=self._run_tier,
|
|
2032
|
+
)
|
|
2033
|
+
return
|
|
2034
|
+
use_activation = False
|
|
2035
|
+
|
|
2036
|
+
if not use_activation:
|
|
2037
|
+
if self.correct:
|
|
2038
|
+
# Apply correction using enhanced logic with adapter support
|
|
2039
|
+
detection_result = self._apply_rmt_detection_and_correction(model)
|
|
2040
|
+
else:
|
|
2041
|
+
# Detection only
|
|
2042
|
+
detection_result = rmt_detect(
|
|
2043
|
+
model=model,
|
|
2044
|
+
threshold=self.margin, # Use margin as threshold
|
|
2045
|
+
detect_only=True,
|
|
2046
|
+
verbose=False,
|
|
2047
|
+
baseline_sigmas=self.baseline_sigmas,
|
|
2048
|
+
baseline_mp_stats=self.baseline_mp_stats,
|
|
2049
|
+
deadband=self.deadband,
|
|
2050
|
+
)
|
|
2051
|
+
|
|
2052
|
+
if detection_result is None:
|
|
2053
|
+
raise RuntimeError("RMT detection failed to produce results")
|
|
2054
|
+
|
|
2055
|
+
if use_activation:
|
|
2056
|
+
detection_result["correction_iterations"] = correction_iterations
|
|
2057
|
+
detection_result["corrected_layers"] = corrected_layers
|
|
1625
2058
|
|
|
1626
2059
|
# Store results
|
|
1627
2060
|
self._last_result = detection_result
|
|
@@ -1630,9 +2063,12 @@ class RMTGuard(Guard):
|
|
|
1630
2063
|
)
|
|
1631
2064
|
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1632
2065
|
self.outliers_per_family.setdefault(family_key, 0)
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
2066
|
+
outlier_total = detection_result.get("outlier_count")
|
|
2067
|
+
if outlier_total is None:
|
|
2068
|
+
outlier_total = detection_result.get(
|
|
2069
|
+
"n_layers_flagged", len(self.outliers_per_family)
|
|
2070
|
+
)
|
|
2071
|
+
self.outliers_total = int(outlier_total or 0)
|
|
1636
2072
|
self.epsilon_violations = self._compute_epsilon_violations()
|
|
1637
2073
|
|
|
1638
2074
|
flagged_layers = detection_result.get("n_layers_flagged", 0)
|
|
@@ -1779,6 +2215,49 @@ class RMTGuard(Guard):
|
|
|
1779
2215
|
}
|
|
1780
2216
|
|
|
1781
2217
|
# Get results from after_edit
|
|
2218
|
+
if self._require_activation and (
|
|
2219
|
+
self._activation_required_failed or not self._activation_ready
|
|
2220
|
+
):
|
|
2221
|
+
reason = self._activation_required_reason or "activation_required"
|
|
2222
|
+
message = "Activation outlier analysis required but unavailable"
|
|
2223
|
+
finalize_time = time.time() - start_time
|
|
2224
|
+
if HAS_GUARD_OUTCOME:
|
|
2225
|
+
return GuardOutcome(
|
|
2226
|
+
name=self.name,
|
|
2227
|
+
passed=False,
|
|
2228
|
+
action="rollback",
|
|
2229
|
+
violations=[
|
|
2230
|
+
{
|
|
2231
|
+
"type": "activation_required",
|
|
2232
|
+
"severity": "error",
|
|
2233
|
+
"message": message,
|
|
2234
|
+
"module_name": None,
|
|
2235
|
+
"reason": reason,
|
|
2236
|
+
}
|
|
2237
|
+
],
|
|
2238
|
+
metrics={
|
|
2239
|
+
"prepared": True,
|
|
2240
|
+
"activation_required": True,
|
|
2241
|
+
"activation_ready": False,
|
|
2242
|
+
"activation_reason": reason,
|
|
2243
|
+
"finalize_time": finalize_time,
|
|
2244
|
+
},
|
|
2245
|
+
)
|
|
2246
|
+
return {
|
|
2247
|
+
"passed": False,
|
|
2248
|
+
"metrics": {
|
|
2249
|
+
"prepared": True,
|
|
2250
|
+
"activation_required": True,
|
|
2251
|
+
"activation_ready": False,
|
|
2252
|
+
"activation_reason": reason,
|
|
2253
|
+
"finalize_time": finalize_time,
|
|
2254
|
+
},
|
|
2255
|
+
"warnings": [],
|
|
2256
|
+
"errors": [message],
|
|
2257
|
+
"violations": [],
|
|
2258
|
+
"events": self.events,
|
|
2259
|
+
}
|
|
2260
|
+
|
|
1782
2261
|
result = self._last_result or {
|
|
1783
2262
|
"has_outliers": False,
|
|
1784
2263
|
"n_layers_flagged": 0,
|
|
@@ -1793,7 +2272,10 @@ class RMTGuard(Guard):
|
|
|
1793
2272
|
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1794
2273
|
self.outliers_per_family.setdefault(family_key, 0)
|
|
1795
2274
|
self.baseline_outliers_per_family.setdefault(family_key, 0)
|
|
1796
|
-
|
|
2275
|
+
outlier_total = result.get("outlier_count")
|
|
2276
|
+
if outlier_total is None:
|
|
2277
|
+
outlier_total = result.get("n_layers_flagged", self.outliers_total or 0)
|
|
2278
|
+
self.outliers_total = int(outlier_total or 0)
|
|
1797
2279
|
self.epsilon_violations = self._compute_epsilon_violations()
|
|
1798
2280
|
# Contracts: epsilon non-negative, counts non-negative
|
|
1799
2281
|
for fam, eps in self.epsilon_by_family.items():
|
|
@@ -1812,12 +2294,17 @@ class RMTGuard(Guard):
|
|
|
1812
2294
|
|
|
1813
2295
|
# Calculate metrics
|
|
1814
2296
|
flagged_layers = result.get("n_layers_flagged", 0)
|
|
1815
|
-
total_layers =
|
|
2297
|
+
total_layers = (
|
|
2298
|
+
len(result.get("per_layer", []))
|
|
2299
|
+
if result.get("per_layer")
|
|
2300
|
+
else len(self.baseline_mp_stats)
|
|
2301
|
+
if self.baseline_mp_stats
|
|
2302
|
+
else 0
|
|
2303
|
+
)
|
|
1816
2304
|
flagged_rate = flagged_layers / total_layers if total_layers > 0 else 0.0
|
|
1817
2305
|
|
|
1818
|
-
#
|
|
1819
|
-
|
|
1820
|
-
passed = flagged_rate <= 0.5 # Allow up to 50% flagged for conservative gate
|
|
2306
|
+
# Acceptance gate: pass when epsilon-rule holds per family.
|
|
2307
|
+
passed = not bool(self.epsilon_violations)
|
|
1821
2308
|
|
|
1822
2309
|
# Generate violations for GuardOutcome
|
|
1823
2310
|
violations = []
|
|
@@ -1844,11 +2331,10 @@ class RMTGuard(Guard):
|
|
|
1844
2331
|
f"High RMT outlier rate: {flagged_layers}/{total_layers} layers flagged ({flagged_rate:.1%})"
|
|
1845
2332
|
)
|
|
1846
2333
|
|
|
1847
|
-
if flagged_rate > 0.7: #
|
|
1848
|
-
|
|
2334
|
+
if flagged_rate > 0.7: # Escalate to warning for unusually high rates
|
|
2335
|
+
warnings.append(
|
|
1849
2336
|
f"Excessive RMT outliers: {flagged_layers}/{total_layers} layers flagged"
|
|
1850
2337
|
)
|
|
1851
|
-
passed = False
|
|
1852
2338
|
|
|
1853
2339
|
if self.epsilon_violations:
|
|
1854
2340
|
passed = False
|
|
@@ -1879,6 +2365,10 @@ class RMTGuard(Guard):
|
|
|
1879
2365
|
if self.baseline_mp_stats
|
|
1880
2366
|
else 0,
|
|
1881
2367
|
"finalize_time": finalize_time,
|
|
2368
|
+
"activation_required": bool(self._require_activation),
|
|
2369
|
+
"activation_ready": bool(self._activation_ready),
|
|
2370
|
+
"analysis_source": result.get("analysis_source"),
|
|
2371
|
+
"token_weight_total": result.get("token_weight_total"),
|
|
1882
2372
|
"baseline_outliers_per_family": {
|
|
1883
2373
|
k: int(v) for k, v in self.baseline_outliers_per_family.items()
|
|
1884
2374
|
},
|