invarlock 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +1 -1
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +58 -39
- invarlock/cli/commands/doctor.py +3 -1
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/report.py +1 -1
- invarlock/cli/commands/run.py +159 -61
- invarlock/cli/commands/verify.py +78 -4
- invarlock/cli/config.py +21 -5
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +2 -2
- invarlock/core/runner.py +314 -50
- invarlock/eval/bench.py +0 -13
- invarlock/eval/data.py +14 -28
- invarlock/eval/metrics.py +4 -1
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +625 -544
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +5 -29
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +42 -15
- invarlock/reporting/certificate.py +225 -46
- invarlock/reporting/certificate_schema.py +2 -1
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +197 -274
- invarlock/reporting/normalizer.py +6 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +61 -0
- invarlock/reporting/report.py +1 -1
- invarlock/reporting/report_types.py +5 -2
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/guards/rmt.py
CHANGED
|
@@ -22,7 +22,6 @@ import torch
|
|
|
22
22
|
import torch.linalg as tla
|
|
23
23
|
import torch.nn as nn
|
|
24
24
|
|
|
25
|
-
from invarlock.cli._evidence import maybe_dump_guard_evidence
|
|
26
25
|
from invarlock.core.api import Guard
|
|
27
26
|
|
|
28
27
|
from ._contracts import guard_assert
|
|
@@ -278,7 +277,9 @@ def layer_svd_stats(
|
|
|
278
277
|
return result
|
|
279
278
|
|
|
280
279
|
|
|
281
|
-
def capture_baseline_mp_stats(
|
|
280
|
+
def capture_baseline_mp_stats(
|
|
281
|
+
model: nn.Module, *, allowed_module_names: list[str] | None = None
|
|
282
|
+
) -> dict[str, dict[str, float]]:
|
|
282
283
|
"""
|
|
283
284
|
Capture baseline MP statistics for linear layers only.
|
|
284
285
|
|
|
@@ -321,9 +322,14 @@ def capture_baseline_mp_stats(model: nn.Module) -> dict[str, dict[str, float]]:
|
|
|
321
322
|
|
|
322
323
|
# Define allowlist for RMT analysis - only linear layers where MP makes sense
|
|
323
324
|
allowed_suffixes = [".attn.c_attn", ".attn.c_proj", ".mlp.c_fc", ".mlp.c_proj"]
|
|
325
|
+
allowed_set = None
|
|
326
|
+
if isinstance(allowed_module_names, list) and allowed_module_names:
|
|
327
|
+
allowed_set = {str(name).strip() for name in allowed_module_names if name}
|
|
324
328
|
|
|
325
329
|
for name, module in model.named_modules():
|
|
326
330
|
if isinstance(module, module_types) and hasattr(module, "weight"):
|
|
331
|
+
if allowed_set is not None and name not in allowed_set:
|
|
332
|
+
continue
|
|
327
333
|
# CRITICAL: Restrict to only linear layers where MP analysis is meaningful
|
|
328
334
|
# Skip embeddings, LM head, layer norms - MP heuristics don't apply there
|
|
329
335
|
if any(name.endswith(suffix) for suffix in allowed_suffixes):
|
|
@@ -412,6 +418,7 @@ def rmt_detect(
|
|
|
412
418
|
correction_factor: float | None = None,
|
|
413
419
|
layer_indices: list[int] | None = None,
|
|
414
420
|
target_layers: list[str] | None = None, # Alternative layer specification
|
|
421
|
+
allowed_module_names: list[str] | None = None, # Exact module allowlist
|
|
415
422
|
verbose: bool = False,
|
|
416
423
|
max_iterations: int = 2, # Add iteration guard
|
|
417
424
|
baseline_sigmas: dict[str, float]
|
|
@@ -431,6 +438,7 @@ def rmt_detect(
|
|
|
431
438
|
correction_factor: Factor to apply for correction (if not detect_only)
|
|
432
439
|
layer_indices: Specific layers to analyze by index (None = all)
|
|
433
440
|
target_layers: Specific layers to analyze by name (None = all)
|
|
441
|
+
allowed_module_names: Exact module names to analyze (None = derived default scope)
|
|
434
442
|
verbose: Whether to print warnings and details
|
|
435
443
|
max_iterations: Maximum iterations for correction (default 2)
|
|
436
444
|
baseline_sigmas: Baseline sigmas for baseline-aware checking
|
|
@@ -473,9 +481,14 @@ def rmt_detect(
|
|
|
473
481
|
else:
|
|
474
482
|
# CRITICAL: Only analyze modules where MP analysis makes sense
|
|
475
483
|
# Exclude embeddings, LM head, layer norms - they have different spectral properties
|
|
484
|
+
allowed_set = None
|
|
485
|
+
if isinstance(allowed_module_names, list) and allowed_module_names:
|
|
486
|
+
allowed_set = {str(name).strip() for name in allowed_module_names if name}
|
|
476
487
|
for name, module in model.named_modules():
|
|
477
488
|
# Check if this is an allowed module type with 2D weights
|
|
478
489
|
if any(name.endswith(suffix) for suffix in allowed_suffixes):
|
|
490
|
+
if allowed_set is not None and name not in allowed_set:
|
|
491
|
+
continue
|
|
479
492
|
has_2d_weights = any(
|
|
480
493
|
param.ndim == 2 and "weight" in param_name
|
|
481
494
|
for param_name, param in module.named_parameters(recurse=False)
|
|
@@ -671,7 +684,6 @@ def rmt_detect(
|
|
|
671
684
|
return {
|
|
672
685
|
"has_outliers": has_outliers,
|
|
673
686
|
"n_layers_flagged": n_outliers,
|
|
674
|
-
"outlier_count": n_outliers, # Alias for compatibility
|
|
675
687
|
"max_ratio": max_ratio,
|
|
676
688
|
"threshold": threshold,
|
|
677
689
|
"correction_iterations": correction_iterations,
|
|
@@ -702,8 +714,6 @@ def rmt_detect_report(
|
|
|
702
714
|
"has_outliers": result["has_outliers"],
|
|
703
715
|
"n_layers_flagged": result["n_layers_flagged"],
|
|
704
716
|
"max_ratio": result["max_ratio"],
|
|
705
|
-
"rmt_max_ratio": result["max_ratio"], # Alias for compatibility
|
|
706
|
-
"rmt_has_outliers": result["has_outliers"], # Alias
|
|
707
717
|
}
|
|
708
718
|
|
|
709
719
|
return summary, result["per_layer"]
|
|
@@ -819,7 +829,6 @@ def rmt_detect_with_names(
|
|
|
819
829
|
return {
|
|
820
830
|
"has_outliers": has_outliers,
|
|
821
831
|
"n_layers_flagged": n_outliers,
|
|
822
|
-
"outlier_count": n_outliers,
|
|
823
832
|
"max_ratio": max_ratio,
|
|
824
833
|
"threshold": threshold,
|
|
825
834
|
"per_layer": per_layer,
|
|
@@ -1104,14 +1113,18 @@ class RMTPolicy:
|
|
|
1104
1113
|
correct: bool = True # Enable automatic correction
|
|
1105
1114
|
|
|
1106
1115
|
|
|
1107
|
-
class RMTPolicyDict(TypedDict):
|
|
1108
|
-
"""TypedDict version of
|
|
1116
|
+
class RMTPolicyDict(TypedDict, total=False):
|
|
1117
|
+
"""TypedDict version of the RMT guard policy."""
|
|
1109
1118
|
|
|
1110
1119
|
q: float | Literal["auto"]
|
|
1111
1120
|
deadband: float
|
|
1112
1121
|
margin: float
|
|
1113
1122
|
correct: bool
|
|
1114
|
-
|
|
1123
|
+
epsilon_default: float
|
|
1124
|
+
epsilon_by_family: dict[str, float]
|
|
1125
|
+
activation_required: bool
|
|
1126
|
+
estimator: dict[str, Any]
|
|
1127
|
+
activation: dict[str, Any]
|
|
1115
1128
|
|
|
1116
1129
|
|
|
1117
1130
|
class RMTGuard(Guard):
|
|
@@ -1144,7 +1157,9 @@ class RMTGuard(Guard):
|
|
|
1144
1157
|
deadband: float = 0.10,
|
|
1145
1158
|
margin: float = 1.5,
|
|
1146
1159
|
correct: bool = True,
|
|
1147
|
-
|
|
1160
|
+
*,
|
|
1161
|
+
epsilon_default: float = 0.10,
|
|
1162
|
+
epsilon_by_family: dict[str, float] | None = None,
|
|
1148
1163
|
):
|
|
1149
1164
|
"""
|
|
1150
1165
|
Initialize RMT Guard.
|
|
@@ -1159,13 +1174,23 @@ class RMTGuard(Guard):
|
|
|
1159
1174
|
self.deadband = deadband
|
|
1160
1175
|
self.margin = margin
|
|
1161
1176
|
self.correct = correct
|
|
1162
|
-
self.epsilon_default =
|
|
1177
|
+
self.epsilon_default = float(epsilon_default)
|
|
1163
1178
|
self.epsilon_by_family: dict[str, float] = {}
|
|
1164
|
-
self.
|
|
1179
|
+
self._set_epsilon_by_family(epsilon_by_family)
|
|
1165
1180
|
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1166
1181
|
self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
|
|
1167
1182
|
|
|
1168
|
-
#
|
|
1183
|
+
# Measurement contract knobs (vNext)
|
|
1184
|
+
self.estimator: dict[str, Any] = {
|
|
1185
|
+
"type": "power_iter",
|
|
1186
|
+
"iters": 3,
|
|
1187
|
+
"init": "ones",
|
|
1188
|
+
}
|
|
1189
|
+
self.activation_sampling: dict[str, Any] = {
|
|
1190
|
+
"windows": {"count": 8, "indices_policy": "evenly_spaced"}
|
|
1191
|
+
}
|
|
1192
|
+
|
|
1193
|
+
# Internal state (activation edge-risk scoring)
|
|
1169
1194
|
self._calibration_batches: list[Any] = []
|
|
1170
1195
|
self._activation_ready = False
|
|
1171
1196
|
self._require_activation = False
|
|
@@ -1173,8 +1198,6 @@ class RMTGuard(Guard):
|
|
|
1173
1198
|
self._activation_required_reason: str | None = None
|
|
1174
1199
|
self._run_profile: str | None = None
|
|
1175
1200
|
self._run_tier: str | None = None
|
|
1176
|
-
self.baseline_mp_stats: dict[str, dict[str, float]] | None = None
|
|
1177
|
-
self.baseline_sigmas: dict[str, float] | None = None
|
|
1178
1201
|
self.prepared = False
|
|
1179
1202
|
self.events: list[dict[str, Any]] = []
|
|
1180
1203
|
self._last_result: dict[str, Any] | None = None
|
|
@@ -1187,10 +1210,10 @@ class RMTGuard(Guard):
|
|
|
1187
1210
|
".mlp.c_fc",
|
|
1188
1211
|
".mlp.c_proj",
|
|
1189
1212
|
]
|
|
1190
|
-
self.
|
|
1191
|
-
self.
|
|
1192
|
-
self.
|
|
1193
|
-
self.
|
|
1213
|
+
self.baseline_edge_risk_by_family: dict[str, float] = {}
|
|
1214
|
+
self.baseline_edge_risk_by_module: dict[str, float] = {}
|
|
1215
|
+
self.edge_risk_by_family: dict[str, float] = {}
|
|
1216
|
+
self.edge_risk_by_module: dict[str, float] = {}
|
|
1194
1217
|
self.epsilon_violations: list[dict[str, Any]] = []
|
|
1195
1218
|
|
|
1196
1219
|
def _log_event(
|
|
@@ -1219,45 +1242,47 @@ class RMTGuard(Guard):
|
|
|
1219
1242
|
tier = str(auto.get("tier", tier) or tier).strip().lower()
|
|
1220
1243
|
self._run_profile = profile or None
|
|
1221
1244
|
self._run_tier = tier or None
|
|
1222
|
-
self._require_activation = bool(
|
|
1223
|
-
profile in {"ci", "release"} and tier in {"balanced", "conservative"}
|
|
1224
|
-
)
|
|
1245
|
+
self._require_activation = bool(profile in {"ci", "release"})
|
|
1225
1246
|
|
|
1226
|
-
def
|
|
1227
|
-
"""
|
|
1228
|
-
if
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1247
|
+
def _set_epsilon_default(self, epsilon: Any) -> None:
|
|
1248
|
+
"""Set the default ε used when a family value is missing."""
|
|
1249
|
+
if epsilon is None:
|
|
1250
|
+
return
|
|
1251
|
+
try:
|
|
1252
|
+
eps = float(epsilon)
|
|
1253
|
+
except (TypeError, ValueError):
|
|
1254
|
+
return
|
|
1255
|
+
if eps >= 0.0 and math.isfinite(eps):
|
|
1256
|
+
self.epsilon_default = eps
|
|
1257
|
+
|
|
1258
|
+
def _set_epsilon_by_family(self, epsilon: Any) -> None:
|
|
1259
|
+
"""Set per-family ε values."""
|
|
1260
|
+
if not isinstance(epsilon, dict):
|
|
1261
|
+
return
|
|
1262
|
+
for family, value in epsilon.items():
|
|
1263
|
+
try:
|
|
1264
|
+
eps = float(value)
|
|
1265
|
+
except (TypeError, ValueError):
|
|
1266
|
+
continue
|
|
1267
|
+
if eps >= 0.0 and math.isfinite(eps):
|
|
1268
|
+
self.epsilon_by_family[str(family)] = eps
|
|
1243
1269
|
|
|
1244
1270
|
@staticmethod
|
|
1245
1271
|
def _classify_family(module_name: str) -> str:
|
|
1246
|
-
"""Classify module name into a guard family."""
|
|
1272
|
+
"""Classify module name into a guard family (vNext: {attn, ffn, embed, other})."""
|
|
1247
1273
|
lower = module_name.lower()
|
|
1248
|
-
|
|
1274
|
+
if any(tok in lower for tok in ("attn", "attention", "self_attn")):
|
|
1275
|
+
return "attn"
|
|
1249
1276
|
if any(
|
|
1250
1277
|
tok in lower
|
|
1251
1278
|
for tok in ("router", "routing", "gate", "gating", "dispatch", "switch")
|
|
1252
1279
|
):
|
|
1253
|
-
return "
|
|
1280
|
+
return "ffn"
|
|
1254
1281
|
if any(
|
|
1255
1282
|
tok in lower for tok in ("experts", "expert", "moe", "mixture_of_experts")
|
|
1256
1283
|
):
|
|
1257
|
-
return "
|
|
1258
|
-
if
|
|
1259
|
-
return "attn"
|
|
1260
|
-
if ".mlp." in lower or "ffn" in lower or ".c_fc" in lower:
|
|
1284
|
+
return "ffn"
|
|
1285
|
+
if any(tok in lower for tok in ("mlp", "ffn", "c_fc", "feed_forward")):
|
|
1261
1286
|
return "ffn"
|
|
1262
1287
|
if "embed" in lower or "wte" in lower or "wpe" in lower:
|
|
1263
1288
|
return "embed"
|
|
@@ -1287,24 +1312,28 @@ class RMTGuard(Guard):
|
|
|
1287
1312
|
return counts
|
|
1288
1313
|
|
|
1289
1314
|
def _compute_epsilon_violations(self) -> list[dict[str, Any]]:
|
|
1290
|
-
"""Compute
|
|
1315
|
+
"""Compute ε-band violations per family on activation edge-risk scores."""
|
|
1291
1316
|
violations: list[dict[str, Any]] = []
|
|
1292
|
-
families = set(self.
|
|
1293
|
-
self.
|
|
1317
|
+
families = set(self.edge_risk_by_family) | set(
|
|
1318
|
+
self.baseline_edge_risk_by_family
|
|
1294
1319
|
)
|
|
1295
1320
|
for family in families:
|
|
1296
|
-
|
|
1297
|
-
|
|
1321
|
+
base = float(self.baseline_edge_risk_by_family.get(family, 0.0) or 0.0)
|
|
1322
|
+
cur = float(self.edge_risk_by_family.get(family, 0.0) or 0.0)
|
|
1323
|
+
if base <= 0.0:
|
|
1324
|
+
continue
|
|
1298
1325
|
epsilon_val = float(
|
|
1299
1326
|
self.epsilon_by_family.get(family, self.epsilon_default)
|
|
1300
1327
|
)
|
|
1301
|
-
allowed =
|
|
1302
|
-
if
|
|
1328
|
+
allowed = (1.0 + epsilon_val) * base
|
|
1329
|
+
if cur > allowed:
|
|
1330
|
+
delta = (cur / base) - 1.0 if base > 0 else float("inf")
|
|
1303
1331
|
violations.append(
|
|
1304
1332
|
{
|
|
1305
1333
|
"family": family,
|
|
1306
|
-
"
|
|
1307
|
-
"
|
|
1334
|
+
"edge_base": base,
|
|
1335
|
+
"edge_cur": cur,
|
|
1336
|
+
"delta": float(delta),
|
|
1308
1337
|
"allowed": allowed,
|
|
1309
1338
|
"epsilon": epsilon_val,
|
|
1310
1339
|
}
|
|
@@ -1321,8 +1350,6 @@ class RMTGuard(Guard):
|
|
|
1321
1350
|
Returns:
|
|
1322
1351
|
List of (name, module) tuples for linear layers in scope
|
|
1323
1352
|
"""
|
|
1324
|
-
modules = []
|
|
1325
|
-
|
|
1326
1353
|
# Get module types
|
|
1327
1354
|
try:
|
|
1328
1355
|
from transformers.pytorch_utils import Conv1D
|
|
@@ -1338,20 +1365,61 @@ class RMTGuard(Guard):
|
|
|
1338
1365
|
)
|
|
1339
1366
|
module_types = module_types_without_conv1d_2
|
|
1340
1367
|
|
|
1341
|
-
|
|
1368
|
+
candidates: list[tuple[str, nn.Module]] = []
|
|
1342
1369
|
for name, module in model.named_modules():
|
|
1343
|
-
if isinstance(module, module_types) and hasattr(module, "weight"):
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1370
|
+
if not (isinstance(module, module_types) and hasattr(module, "weight")):
|
|
1371
|
+
continue
|
|
1372
|
+
# Strict scope enforcement - only allowed linear layers
|
|
1373
|
+
if any(name.endswith(suffix) for suffix in self.allowed_suffixes):
|
|
1374
|
+
candidates.append((name, module))
|
|
1375
|
+
candidates.sort(key=lambda t: t[0])
|
|
1376
|
+
return candidates
|
|
1349
1377
|
|
|
1350
1378
|
def _collect_calibration_batches(self, calib: Any, max_windows: int) -> list[Any]:
|
|
1351
1379
|
"""Collect a deterministic slice of calibration batches."""
|
|
1352
1380
|
if calib is None or max_windows <= 0:
|
|
1353
1381
|
return []
|
|
1354
1382
|
source = getattr(calib, "dataloader", None) or calib
|
|
1383
|
+
# Prefer index-based selection when possible so we can support simple
|
|
1384
|
+
# deterministic policies (first/last/evenly_spaced) without consuming
|
|
1385
|
+
# the entire iterator.
|
|
1386
|
+
try:
|
|
1387
|
+
if hasattr(source, "__len__") and hasattr(source, "__getitem__"):
|
|
1388
|
+
n = int(len(source)) # type: ignore[arg-type]
|
|
1389
|
+
if n <= 0:
|
|
1390
|
+
return []
|
|
1391
|
+
count = min(int(max_windows), n)
|
|
1392
|
+
policy = (
|
|
1393
|
+
(self.activation_sampling.get("windows") or {}).get(
|
|
1394
|
+
"indices_policy", "evenly_spaced"
|
|
1395
|
+
)
|
|
1396
|
+
if isinstance(self.activation_sampling, dict)
|
|
1397
|
+
else "evenly_spaced"
|
|
1398
|
+
)
|
|
1399
|
+
policy = str(policy or "evenly_spaced").strip().lower()
|
|
1400
|
+
if policy == "last":
|
|
1401
|
+
idxs = list(range(max(0, n - count), n))
|
|
1402
|
+
elif policy == "evenly_spaced":
|
|
1403
|
+
if count <= 1:
|
|
1404
|
+
idxs = [0]
|
|
1405
|
+
else:
|
|
1406
|
+
idxs = [
|
|
1407
|
+
int(round(i * (n - 1) / float(count - 1)))
|
|
1408
|
+
for i in range(count)
|
|
1409
|
+
]
|
|
1410
|
+
else:
|
|
1411
|
+
idxs = list(range(count))
|
|
1412
|
+
batches: list[Any] = []
|
|
1413
|
+
for idx in idxs:
|
|
1414
|
+
try:
|
|
1415
|
+
batches.append(source[idx]) # type: ignore[index]
|
|
1416
|
+
except Exception:
|
|
1417
|
+
continue
|
|
1418
|
+
return batches
|
|
1419
|
+
except Exception:
|
|
1420
|
+
pass
|
|
1421
|
+
|
|
1422
|
+
# Iterable fallback: first-N only.
|
|
1355
1423
|
try:
|
|
1356
1424
|
iterator = iter(source)
|
|
1357
1425
|
except TypeError:
|
|
@@ -1459,8 +1527,210 @@ class RMTGuard(Guard):
|
|
|
1459
1527
|
):
|
|
1460
1528
|
modules.append((name, module))
|
|
1461
1529
|
|
|
1530
|
+
modules.sort(key=lambda t: t[0])
|
|
1462
1531
|
return modules
|
|
1463
1532
|
|
|
1533
|
+
def _activation_edge_risk(
|
|
1534
|
+
self, activations: Any
|
|
1535
|
+
) -> tuple[float, float, float] | None:
|
|
1536
|
+
"""Compute activation edge-risk score r = σ̂max(A') / σ_MP(m,n).
|
|
1537
|
+
|
|
1538
|
+
A' is a centered + standardised view of the activation matrix. The σ̂max
|
|
1539
|
+
estimator is matvec-based and avoids full SVD.
|
|
1540
|
+
"""
|
|
1541
|
+
if isinstance(activations, tuple | list):
|
|
1542
|
+
activations = activations[0] if activations else None
|
|
1543
|
+
if not isinstance(activations, torch.Tensor):
|
|
1544
|
+
return None
|
|
1545
|
+
if activations.dim() < 2:
|
|
1546
|
+
return None
|
|
1547
|
+
if activations.dim() > 2:
|
|
1548
|
+
activations = activations.reshape(-1, activations.shape[-1])
|
|
1549
|
+
if activations.numel() == 0:
|
|
1550
|
+
return None
|
|
1551
|
+
|
|
1552
|
+
mat = activations.detach()
|
|
1553
|
+
if mat.shape[0] <= 0 or mat.shape[1] <= 0:
|
|
1554
|
+
return None
|
|
1555
|
+
if not torch.isfinite(mat).all():
|
|
1556
|
+
return None
|
|
1557
|
+
|
|
1558
|
+
eps = 1e-12
|
|
1559
|
+
try:
|
|
1560
|
+
mu = mat.mean(dtype=torch.float32)
|
|
1561
|
+
norm = torch.linalg.vector_norm(mat.reshape(-1), ord=2, dtype=torch.float32)
|
|
1562
|
+
mean_sq = (norm * norm) / float(mat.numel())
|
|
1563
|
+
var = mean_sq - (mu * mu)
|
|
1564
|
+
std = torch.sqrt(var.clamp_min(eps))
|
|
1565
|
+
except Exception:
|
|
1566
|
+
return None
|
|
1567
|
+
if not torch.isfinite(mu) or not torch.isfinite(std):
|
|
1568
|
+
return None
|
|
1569
|
+
std_val = float(std.item())
|
|
1570
|
+
if not math.isfinite(std_val) or std_val <= 0.0:
|
|
1571
|
+
return None
|
|
1572
|
+
|
|
1573
|
+
m, n = int(mat.shape[0]), int(mat.shape[1])
|
|
1574
|
+
mp_edge_val = mp_bulk_edge(m, n, whitened=False)
|
|
1575
|
+
if not (math.isfinite(mp_edge_val) and mp_edge_val > 0.0):
|
|
1576
|
+
return None
|
|
1577
|
+
|
|
1578
|
+
try:
|
|
1579
|
+
iters = int((self.estimator or {}).get("iters", 3) or 3)
|
|
1580
|
+
except Exception:
|
|
1581
|
+
iters = 3
|
|
1582
|
+
if iters < 1:
|
|
1583
|
+
iters = 1
|
|
1584
|
+
init = str((self.estimator or {}).get("init", "ones") or "ones").strip().lower()
|
|
1585
|
+
if init not in {"ones", "e0"}:
|
|
1586
|
+
init = "ones"
|
|
1587
|
+
|
|
1588
|
+
device = mat.device
|
|
1589
|
+
dtype = mat.dtype
|
|
1590
|
+
|
|
1591
|
+
with torch.no_grad():
|
|
1592
|
+
if init == "ones":
|
|
1593
|
+
v = torch.ones((n,), device=device, dtype=dtype)
|
|
1594
|
+
else:
|
|
1595
|
+
v = torch.zeros((n,), device=device, dtype=dtype)
|
|
1596
|
+
v[0] = 1
|
|
1597
|
+
v = v / torch.linalg.vector_norm(v.float()).clamp_min(eps).to(dtype)
|
|
1598
|
+
|
|
1599
|
+
mu_d = mu.to(dtype)
|
|
1600
|
+
inv_std_d = (1.0 / std.clamp_min(eps)).to(dtype)
|
|
1601
|
+
ones_n = torch.ones((n,), device=device, dtype=dtype)
|
|
1602
|
+
|
|
1603
|
+
sigma = 0.0
|
|
1604
|
+
for _ in range(iters):
|
|
1605
|
+
v_sum = torch.sum(v.float())
|
|
1606
|
+
u = mat @ v
|
|
1607
|
+
u = (u - (mu_d * v_sum.to(dtype))) * inv_std_d
|
|
1608
|
+
u_norm = torch.linalg.vector_norm(u.float()).clamp_min(eps)
|
|
1609
|
+
sigma_val = float(u_norm.item())
|
|
1610
|
+
if not math.isfinite(sigma_val):
|
|
1611
|
+
return None
|
|
1612
|
+
u = u / u_norm.to(dtype)
|
|
1613
|
+
|
|
1614
|
+
u_sum = torch.sum(u.float())
|
|
1615
|
+
v = mat.T @ u
|
|
1616
|
+
v = (v - (mu_d * u_sum.to(dtype) * ones_n)) * inv_std_d
|
|
1617
|
+
v_norm = torch.linalg.vector_norm(v.float()).clamp_min(eps)
|
|
1618
|
+
v = v / v_norm.to(dtype)
|
|
1619
|
+
sigma = sigma_val
|
|
1620
|
+
|
|
1621
|
+
risk = float(sigma) / max(float(mp_edge_val), eps)
|
|
1622
|
+
return float(risk), float(sigma), float(mp_edge_val)
|
|
1623
|
+
|
|
1624
|
+
def _compute_activation_edge_risk(
|
|
1625
|
+
self, model: nn.Module, batches: list[Any]
|
|
1626
|
+
) -> dict[str, Any] | None:
|
|
1627
|
+
"""Compute token-weighted activation edge-risk scores per module/family."""
|
|
1628
|
+
if not batches:
|
|
1629
|
+
return None
|
|
1630
|
+
|
|
1631
|
+
modules = self._get_activation_modules(model)
|
|
1632
|
+
if not modules:
|
|
1633
|
+
return None
|
|
1634
|
+
|
|
1635
|
+
acc: dict[str, dict[str, float]] = {}
|
|
1636
|
+
for name, _module in modules:
|
|
1637
|
+
acc[name] = {"weighted_sum": 0.0, "weight": 0.0, "max_risk": 0.0}
|
|
1638
|
+
|
|
1639
|
+
batch_weight_holder = {"weight": 1}
|
|
1640
|
+
handles: list[Any] = []
|
|
1641
|
+
|
|
1642
|
+
def _make_hook(name: str):
|
|
1643
|
+
def _hook(_module: nn.Module, _inputs: tuple[Any, ...], output: Any):
|
|
1644
|
+
out = self._activation_edge_risk(output)
|
|
1645
|
+
if out is None:
|
|
1646
|
+
return
|
|
1647
|
+
risk, _sigma, _edge = out
|
|
1648
|
+
try:
|
|
1649
|
+
weight = int(batch_weight_holder.get("weight", 1) or 1)
|
|
1650
|
+
except Exception:
|
|
1651
|
+
weight = 1
|
|
1652
|
+
row = acc.get(name)
|
|
1653
|
+
if row is None:
|
|
1654
|
+
return
|
|
1655
|
+
row["weighted_sum"] = float(row.get("weighted_sum", 0.0)) + float(
|
|
1656
|
+
risk
|
|
1657
|
+
) * float(weight)
|
|
1658
|
+
row["weight"] = float(row.get("weight", 0.0)) + float(weight)
|
|
1659
|
+
row["max_risk"] = max(float(row.get("max_risk", 0.0)), float(risk))
|
|
1660
|
+
|
|
1661
|
+
return _hook
|
|
1662
|
+
|
|
1663
|
+
for name, module in modules:
|
|
1664
|
+
try:
|
|
1665
|
+
handles.append(module.register_forward_hook(_make_hook(name)))
|
|
1666
|
+
except Exception:
|
|
1667
|
+
continue
|
|
1668
|
+
|
|
1669
|
+
model_was_training = model.training
|
|
1670
|
+
model.eval()
|
|
1671
|
+
device = next(model.parameters()).device
|
|
1672
|
+
batches_used = 0
|
|
1673
|
+
token_weight_total = 0
|
|
1674
|
+
|
|
1675
|
+
try:
|
|
1676
|
+
with torch.no_grad():
|
|
1677
|
+
for batch in batches:
|
|
1678
|
+
inputs, attention_mask = self._prepare_activation_inputs(
|
|
1679
|
+
batch, device
|
|
1680
|
+
)
|
|
1681
|
+
if inputs is None:
|
|
1682
|
+
continue
|
|
1683
|
+
batch_weight = self._batch_token_weight(inputs, attention_mask)
|
|
1684
|
+
batch_weight_holder["weight"] = batch_weight
|
|
1685
|
+
try:
|
|
1686
|
+
if attention_mask is not None:
|
|
1687
|
+
model(inputs, attention_mask=attention_mask)
|
|
1688
|
+
else:
|
|
1689
|
+
model(inputs)
|
|
1690
|
+
except TypeError:
|
|
1691
|
+
model(inputs)
|
|
1692
|
+
batches_used += 1
|
|
1693
|
+
token_weight_total += batch_weight
|
|
1694
|
+
finally:
|
|
1695
|
+
for handle in handles:
|
|
1696
|
+
try:
|
|
1697
|
+
handle.remove()
|
|
1698
|
+
except Exception:
|
|
1699
|
+
pass
|
|
1700
|
+
if model_was_training:
|
|
1701
|
+
model.train()
|
|
1702
|
+
|
|
1703
|
+
if batches_used <= 0:
|
|
1704
|
+
return None
|
|
1705
|
+
|
|
1706
|
+
edge_risk_by_module: dict[str, float] = {}
|
|
1707
|
+
for name, row in acc.items():
|
|
1708
|
+
w = float(row.get("weight", 0.0) or 0.0)
|
|
1709
|
+
if w <= 0.0:
|
|
1710
|
+
continue
|
|
1711
|
+
edge_risk_by_module[name] = float(row.get("weighted_sum", 0.0) or 0.0) / w
|
|
1712
|
+
|
|
1713
|
+
if not edge_risk_by_module:
|
|
1714
|
+
return None
|
|
1715
|
+
|
|
1716
|
+
edge_risk_by_family: dict[str, float] = {}
|
|
1717
|
+
for name, risk in edge_risk_by_module.items():
|
|
1718
|
+
family = self._classify_family(name)
|
|
1719
|
+
edge_risk_by_family[family] = max(
|
|
1720
|
+
float(edge_risk_by_family.get(family, 0.0)), float(risk)
|
|
1721
|
+
)
|
|
1722
|
+
|
|
1723
|
+
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1724
|
+
edge_risk_by_family.setdefault(family_key, 0.0)
|
|
1725
|
+
|
|
1726
|
+
return {
|
|
1727
|
+
"analysis_source": "activations_edge_risk",
|
|
1728
|
+
"edge_risk_by_module": edge_risk_by_module,
|
|
1729
|
+
"edge_risk_by_family": edge_risk_by_family,
|
|
1730
|
+
"token_weight_total": int(token_weight_total),
|
|
1731
|
+
"batches_used": int(batches_used),
|
|
1732
|
+
}
|
|
1733
|
+
|
|
1464
1734
|
def _activation_svd_outliers(
|
|
1465
1735
|
self, activations: Any, margin: float, deadband: float
|
|
1466
1736
|
) -> tuple[int, float, float]:
|
|
@@ -1780,150 +2050,175 @@ class RMTGuard(Guard):
|
|
|
1780
2050
|
calib=None,
|
|
1781
2051
|
policy: dict[str, Any] | None = None,
|
|
1782
2052
|
) -> dict[str, Any]:
|
|
1783
|
-
"""
|
|
1784
|
-
Prepare RMT guard by capturing baseline MP statistics.
|
|
1785
|
-
|
|
1786
|
-
Args:
|
|
1787
|
-
model: The model that will be edited
|
|
1788
|
-
adapter: ModelAdapter (optional, for tying map access)
|
|
1789
|
-
calib: Calibration data for activation-based outlier counting
|
|
1790
|
-
policy: Guard policy parameters (optional)
|
|
1791
|
-
|
|
1792
|
-
Returns:
|
|
1793
|
-
Dictionary with preparation results and baseline metrics
|
|
1794
|
-
"""
|
|
2053
|
+
"""Prepare RMT guard by capturing baseline activation edge-risk scores."""
|
|
1795
2054
|
import time
|
|
1796
2055
|
|
|
1797
2056
|
start_time = time.time()
|
|
1798
2057
|
self._activation_required_failed = False
|
|
1799
2058
|
self._activation_required_reason = None
|
|
1800
2059
|
|
|
1801
|
-
# Store adapter for tying map access
|
|
2060
|
+
# Store adapter for tying map access (if used by downstream code)
|
|
1802
2061
|
self.adapter = adapter
|
|
1803
2062
|
|
|
1804
|
-
#
|
|
1805
|
-
if policy:
|
|
1806
|
-
self.q = policy.get("q", self.q)
|
|
1807
|
-
self.deadband = policy.get("deadband", self.deadband)
|
|
1808
|
-
self.margin = policy.get("margin", self.margin)
|
|
1809
|
-
self.correct = policy.get("correct", self.correct)
|
|
2063
|
+
# Policy overrides (vNext contract)
|
|
2064
|
+
if isinstance(policy, dict) and policy:
|
|
1810
2065
|
if "epsilon" in policy:
|
|
1811
|
-
|
|
2066
|
+
from invarlock.core.exceptions import ValidationError
|
|
2067
|
+
|
|
2068
|
+
raise ValidationError(
|
|
2069
|
+
code="E501",
|
|
2070
|
+
message="POLICY-PARAM-INVALID",
|
|
2071
|
+
details={
|
|
2072
|
+
"param": "epsilon",
|
|
2073
|
+
"hint": "Use rmt.epsilon_default and rmt.epsilon_by_family instead.",
|
|
2074
|
+
},
|
|
2075
|
+
)
|
|
2076
|
+
if "q" in policy:
|
|
2077
|
+
q_val = policy.get("q")
|
|
2078
|
+
if q_val == "auto":
|
|
2079
|
+
self.q = "auto"
|
|
2080
|
+
else:
|
|
2081
|
+
try:
|
|
2082
|
+
self.q = float(q_val)
|
|
2083
|
+
except (TypeError, ValueError):
|
|
2084
|
+
self.q = "auto"
|
|
2085
|
+
if "deadband" in policy:
|
|
2086
|
+
self.deadband = float(policy.get("deadband", self.deadband))
|
|
2087
|
+
if "margin" in policy:
|
|
2088
|
+
try:
|
|
2089
|
+
self.margin = float(policy.get("margin", self.margin))
|
|
2090
|
+
except (TypeError, ValueError):
|
|
2091
|
+
pass
|
|
2092
|
+
if "correct" in policy:
|
|
2093
|
+
self.correct = bool(policy.get("correct"))
|
|
1812
2094
|
if "epsilon_by_family" in policy:
|
|
1813
|
-
self.
|
|
2095
|
+
self._set_epsilon_by_family(policy["epsilon_by_family"])
|
|
2096
|
+
if "epsilon_default" in policy:
|
|
2097
|
+
self._set_epsilon_default(policy["epsilon_default"])
|
|
2098
|
+
for family_key in ("attn", "ffn", "embed", "other"):
|
|
2099
|
+
self.epsilon_by_family.setdefault(family_key, self.epsilon_default)
|
|
1814
2100
|
if "activation_required" in policy:
|
|
1815
2101
|
self._require_activation = bool(policy.get("activation_required"))
|
|
1816
2102
|
|
|
2103
|
+
estimator_policy = policy.get("estimator")
|
|
2104
|
+
if isinstance(estimator_policy, dict):
|
|
2105
|
+
try:
|
|
2106
|
+
iters = int(estimator_policy.get("iters", 3) or 3)
|
|
2107
|
+
except Exception:
|
|
2108
|
+
iters = 3
|
|
2109
|
+
if iters < 1:
|
|
2110
|
+
iters = 1
|
|
2111
|
+
init = (
|
|
2112
|
+
str(estimator_policy.get("init", "ones") or "ones").strip().lower()
|
|
2113
|
+
)
|
|
2114
|
+
if init not in {"ones", "e0"}:
|
|
2115
|
+
init = "ones"
|
|
2116
|
+
self.estimator = {"type": "power_iter", "iters": iters, "init": init}
|
|
2117
|
+
|
|
2118
|
+
activation_policy = policy.get("activation")
|
|
2119
|
+
if isinstance(activation_policy, dict):
|
|
2120
|
+
sampling = activation_policy.get("sampling")
|
|
2121
|
+
if isinstance(sampling, dict):
|
|
2122
|
+
windows = sampling.get("windows")
|
|
2123
|
+
if isinstance(windows, dict):
|
|
2124
|
+
cfg = dict(self.activation_sampling.get("windows") or {})
|
|
2125
|
+
if windows.get("count") is not None:
|
|
2126
|
+
try:
|
|
2127
|
+
cfg["count"] = int(windows.get("count") or 0)
|
|
2128
|
+
except Exception:
|
|
2129
|
+
pass
|
|
2130
|
+
if windows.get("indices_policy") is not None:
|
|
2131
|
+
cfg["indices_policy"] = str(
|
|
2132
|
+
windows.get("indices_policy")
|
|
2133
|
+
or cfg.get("indices_policy")
|
|
2134
|
+
)
|
|
2135
|
+
self.activation_sampling["windows"] = cfg
|
|
2136
|
+
|
|
1817
2137
|
self._log_event(
|
|
1818
2138
|
"prepare",
|
|
1819
|
-
message=
|
|
2139
|
+
message="Preparing RMT guard baseline activation edge-risk metrics",
|
|
1820
2140
|
)
|
|
1821
2141
|
|
|
1822
2142
|
try:
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
self.
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
self._activation_ready = False
|
|
1834
|
-
self._calibration_batches = []
|
|
1835
|
-
max_windows = 0
|
|
1836
|
-
if calib is not None:
|
|
1837
|
-
source = getattr(calib, "dataloader", None) or calib
|
|
1838
|
-
if hasattr(source, "__len__"):
|
|
1839
|
-
try:
|
|
1840
|
-
max_windows = int(len(source))
|
|
1841
|
-
except Exception:
|
|
1842
|
-
max_windows = 0
|
|
1843
|
-
self._calibration_batches = self._collect_calibration_batches(
|
|
1844
|
-
calib, max_windows
|
|
1845
|
-
)
|
|
2143
|
+
windows_cfg = self.activation_sampling.get("windows") or {}
|
|
2144
|
+
try:
|
|
2145
|
+
window_count = int(windows_cfg.get("count", 0) or 0)
|
|
2146
|
+
except Exception:
|
|
2147
|
+
window_count = 0
|
|
2148
|
+
self._calibration_batches = (
|
|
2149
|
+
self._collect_calibration_batches(calib, window_count)
|
|
2150
|
+
if calib is not None and window_count > 0
|
|
2151
|
+
else []
|
|
2152
|
+
)
|
|
1846
2153
|
|
|
1847
|
-
|
|
1848
|
-
|
|
1849
|
-
|
|
1850
|
-
|
|
1851
|
-
|
|
2154
|
+
self.baseline_edge_risk_by_family = {}
|
|
2155
|
+
self.baseline_edge_risk_by_module = {}
|
|
2156
|
+
self.edge_risk_by_family = {}
|
|
2157
|
+
self.edge_risk_by_module = {}
|
|
2158
|
+
self.epsilon_violations = []
|
|
1852
2159
|
|
|
1853
|
-
if
|
|
1854
|
-
self.
|
|
1855
|
-
self.
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
2160
|
+
if self._require_activation and not self._calibration_batches:
|
|
2161
|
+
self._activation_required_failed = True
|
|
2162
|
+
self._activation_required_reason = "activation_required"
|
|
2163
|
+
self._activation_ready = False
|
|
2164
|
+
self.prepared = False
|
|
2165
|
+
return {
|
|
2166
|
+
"ready": False,
|
|
2167
|
+
"baseline_metrics": {},
|
|
2168
|
+
"policy_applied": policy or {},
|
|
2169
|
+
"preparation_time": time.time() - start_time,
|
|
2170
|
+
"error": "Activation batches required but unavailable",
|
|
2171
|
+
}
|
|
2172
|
+
|
|
2173
|
+
baseline = (
|
|
2174
|
+
self._compute_activation_edge_risk(model, self._calibration_batches)
|
|
2175
|
+
if self._calibration_batches
|
|
2176
|
+
else None
|
|
2177
|
+
)
|
|
2178
|
+
if baseline is None:
|
|
1862
2179
|
if self._require_activation:
|
|
1863
2180
|
self._activation_required_failed = True
|
|
1864
2181
|
self._activation_required_reason = "activation_baseline_unavailable"
|
|
1865
|
-
self.
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
"
|
|
1882
|
-
|
|
1883
|
-
self.baseline_outliers_per_family = self._count_outliers_per_family(
|
|
1884
|
-
baseline_detection.get("per_layer", [])
|
|
1885
|
-
)
|
|
1886
|
-
for family_key in ("attn", "ffn", "embed", "other"):
|
|
1887
|
-
self.baseline_outliers_per_family.setdefault(family_key, 0)
|
|
1888
|
-
self.outliers_per_family = {}
|
|
1889
|
-
self.outliers_total = 0
|
|
1890
|
-
self.epsilon_violations = []
|
|
1891
|
-
|
|
1892
|
-
self.prepared = True
|
|
1893
|
-
preparation_time = time.time() - start_time
|
|
2182
|
+
self._activation_ready = False
|
|
2183
|
+
self.prepared = False
|
|
2184
|
+
return {
|
|
2185
|
+
"ready": False,
|
|
2186
|
+
"baseline_metrics": {},
|
|
2187
|
+
"policy_applied": policy or {},
|
|
2188
|
+
"preparation_time": time.time() - start_time,
|
|
2189
|
+
"error": "Activation baseline unavailable",
|
|
2190
|
+
}
|
|
2191
|
+
# Non-required: treat as not ready and allow pipeline to continue.
|
|
2192
|
+
self._activation_ready = False
|
|
2193
|
+
self.prepared = True
|
|
2194
|
+
return {
|
|
2195
|
+
"ready": True,
|
|
2196
|
+
"baseline_metrics": {},
|
|
2197
|
+
"policy_applied": policy or {},
|
|
2198
|
+
"preparation_time": time.time() - start_time,
|
|
2199
|
+
}
|
|
1894
2200
|
|
|
1895
|
-
self.
|
|
1896
|
-
"
|
|
1897
|
-
message=f"Captured {len(self.baseline_mp_stats)} baseline MP statistics",
|
|
1898
|
-
baseline_count=len(self.baseline_mp_stats),
|
|
1899
|
-
linear_modules_count=len(linear_modules),
|
|
1900
|
-
preparation_time=preparation_time,
|
|
2201
|
+
self.baseline_edge_risk_by_module = dict(
|
|
2202
|
+
baseline.get("edge_risk_by_module") or {}
|
|
1901
2203
|
)
|
|
2204
|
+
self.baseline_edge_risk_by_family = dict(
|
|
2205
|
+
baseline.get("edge_risk_by_family") or {}
|
|
2206
|
+
)
|
|
2207
|
+
self._activation_ready = True
|
|
2208
|
+
self.prepared = True
|
|
1902
2209
|
|
|
2210
|
+
preparation_time = time.time() - start_time
|
|
1903
2211
|
return {
|
|
2212
|
+
"ready": True,
|
|
1904
2213
|
"baseline_metrics": {
|
|
1905
|
-
"
|
|
1906
|
-
"
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
list(self.baseline_sigmas.values())
|
|
1911
|
-
),
|
|
1912
|
-
"max_baseline_sigma": max(self.baseline_sigmas.values())
|
|
1913
|
-
if self.baseline_sigmas
|
|
1914
|
-
else 0.0,
|
|
1915
|
-
"min_baseline_sigma": min(self.baseline_sigmas.values())
|
|
1916
|
-
if self.baseline_sigmas
|
|
1917
|
-
else 0.0,
|
|
1918
|
-
},
|
|
1919
|
-
"policy_applied": {
|
|
1920
|
-
"q": self.q,
|
|
1921
|
-
"deadband": self.deadband,
|
|
1922
|
-
"margin": self.margin,
|
|
1923
|
-
"correct": self.correct,
|
|
2214
|
+
"edge_risk_by_family": dict(self.baseline_edge_risk_by_family),
|
|
2215
|
+
"measurement_contract": {
|
|
2216
|
+
"estimator": self.estimator,
|
|
2217
|
+
"activation_sampling": self.activation_sampling,
|
|
2218
|
+
},
|
|
1924
2219
|
},
|
|
2220
|
+
"policy_applied": policy or {},
|
|
1925
2221
|
"preparation_time": preparation_time,
|
|
1926
|
-
"ready": True,
|
|
1927
2222
|
}
|
|
1928
2223
|
|
|
1929
2224
|
except Exception as e:
|
|
@@ -1957,12 +2252,7 @@ class RMTGuard(Guard):
|
|
|
1957
2252
|
)
|
|
1958
2253
|
|
|
1959
2254
|
def after_edit(self, model: nn.Module) -> None:
|
|
1960
|
-
"""
|
|
1961
|
-
Execute after edit - perform RMT detection and optional correction.
|
|
1962
|
-
|
|
1963
|
-
Args:
|
|
1964
|
-
model: The model that was just edited
|
|
1965
|
-
"""
|
|
2255
|
+
"""Execute after edit: compute activation edge-risk on sampled batches."""
|
|
1966
2256
|
if not self.prepared:
|
|
1967
2257
|
self._log_event(
|
|
1968
2258
|
"after_edit_skipped",
|
|
@@ -1971,137 +2261,40 @@ class RMTGuard(Guard):
|
|
|
1971
2261
|
)
|
|
1972
2262
|
return
|
|
1973
2263
|
|
|
1974
|
-
self._log_event("after_edit", message="Applying RMT detection and correction")
|
|
1975
|
-
|
|
1976
2264
|
try:
|
|
1977
|
-
|
|
1978
|
-
corrected_layers = 0
|
|
1979
|
-
correction_iterations = 0
|
|
1980
|
-
use_activation = bool(self._activation_ready and self._calibration_batches)
|
|
1981
|
-
|
|
1982
|
-
if self._require_activation and not use_activation:
|
|
2265
|
+
if self._require_activation and not self._calibration_batches:
|
|
1983
2266
|
self._activation_required_failed = True
|
|
1984
2267
|
self._activation_required_reason = "activation_unavailable"
|
|
1985
2268
|
self._last_result = {
|
|
1986
|
-
"
|
|
1987
|
-
"
|
|
1988
|
-
"
|
|
1989
|
-
"max_ratio": 0.0,
|
|
1990
|
-
"analysis_source": "activations",
|
|
2269
|
+
"analysis_source": "activations_edge_risk",
|
|
2270
|
+
"edge_risk_by_module": {},
|
|
2271
|
+
"edge_risk_by_family": {},
|
|
1991
2272
|
}
|
|
1992
|
-
self._log_event(
|
|
1993
|
-
"activation_required_missing",
|
|
1994
|
-
level="ERROR",
|
|
1995
|
-
message="Activation outlier analysis required but unavailable",
|
|
1996
|
-
profile=self._run_profile,
|
|
1997
|
-
tier=self._run_tier,
|
|
1998
|
-
)
|
|
1999
2273
|
return
|
|
2000
2274
|
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
model, self._calibration_batches
|
|
2012
|
-
)
|
|
2013
|
-
if not detection_result:
|
|
2014
|
-
if self._require_activation:
|
|
2015
|
-
self._activation_required_failed = True
|
|
2016
|
-
self._activation_required_reason = (
|
|
2017
|
-
"activation_outliers_unavailable"
|
|
2018
|
-
)
|
|
2019
|
-
self._last_result = {
|
|
2020
|
-
"has_outliers": False,
|
|
2021
|
-
"n_layers_flagged": 0,
|
|
2022
|
-
"per_layer": [],
|
|
2023
|
-
"max_ratio": 0.0,
|
|
2024
|
-
"analysis_source": "activations",
|
|
2025
|
-
}
|
|
2026
|
-
self._log_event(
|
|
2027
|
-
"activation_required_missing",
|
|
2028
|
-
level="ERROR",
|
|
2029
|
-
message="Activation outlier analysis failed",
|
|
2030
|
-
profile=self._run_profile,
|
|
2031
|
-
tier=self._run_tier,
|
|
2032
|
-
)
|
|
2033
|
-
return
|
|
2034
|
-
use_activation = False
|
|
2035
|
-
|
|
2036
|
-
if not use_activation:
|
|
2037
|
-
if self.correct:
|
|
2038
|
-
# Apply correction using enhanced logic with adapter support
|
|
2039
|
-
detection_result = self._apply_rmt_detection_and_correction(model)
|
|
2040
|
-
else:
|
|
2041
|
-
# Detection only
|
|
2042
|
-
detection_result = rmt_detect(
|
|
2043
|
-
model=model,
|
|
2044
|
-
threshold=self.margin, # Use margin as threshold
|
|
2045
|
-
detect_only=True,
|
|
2046
|
-
verbose=False,
|
|
2047
|
-
baseline_sigmas=self.baseline_sigmas,
|
|
2048
|
-
baseline_mp_stats=self.baseline_mp_stats,
|
|
2049
|
-
deadband=self.deadband,
|
|
2275
|
+
current = (
|
|
2276
|
+
self._compute_activation_edge_risk(model, self._calibration_batches)
|
|
2277
|
+
if self._calibration_batches
|
|
2278
|
+
else None
|
|
2279
|
+
)
|
|
2280
|
+
if current is None:
|
|
2281
|
+
if self._require_activation:
|
|
2282
|
+
self._activation_required_failed = True
|
|
2283
|
+
self._activation_required_reason = (
|
|
2284
|
+
"activation_edge_risk_unavailable"
|
|
2050
2285
|
)
|
|
2286
|
+
self._last_result = {
|
|
2287
|
+
"analysis_source": "activations_edge_risk",
|
|
2288
|
+
"edge_risk_by_module": {},
|
|
2289
|
+
"edge_risk_by_family": {},
|
|
2290
|
+
}
|
|
2291
|
+
return
|
|
2051
2292
|
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
if use_activation:
|
|
2056
|
-
detection_result["correction_iterations"] = correction_iterations
|
|
2057
|
-
detection_result["corrected_layers"] = corrected_layers
|
|
2058
|
-
|
|
2059
|
-
# Store results
|
|
2060
|
-
self._last_result = detection_result
|
|
2061
|
-
self.outliers_per_family = self._count_outliers_per_family(
|
|
2062
|
-
detection_result.get("per_layer", [])
|
|
2063
|
-
)
|
|
2064
|
-
for family_key in ("attn", "ffn", "embed", "other"):
|
|
2065
|
-
self.outliers_per_family.setdefault(family_key, 0)
|
|
2066
|
-
outlier_total = detection_result.get("outlier_count")
|
|
2067
|
-
if outlier_total is None:
|
|
2068
|
-
outlier_total = detection_result.get(
|
|
2069
|
-
"n_layers_flagged", len(self.outliers_per_family)
|
|
2070
|
-
)
|
|
2071
|
-
self.outliers_total = int(outlier_total or 0)
|
|
2293
|
+
self.edge_risk_by_module = dict(current.get("edge_risk_by_module") or {})
|
|
2294
|
+
self.edge_risk_by_family = dict(current.get("edge_risk_by_family") or {})
|
|
2295
|
+
self._last_result = dict(current)
|
|
2072
2296
|
self.epsilon_violations = self._compute_epsilon_violations()
|
|
2073
2297
|
|
|
2074
|
-
flagged_layers = detection_result.get("n_layers_flagged", 0)
|
|
2075
|
-
corrected_layers = detection_result.get("correction_iterations", 0)
|
|
2076
|
-
|
|
2077
|
-
self._log_event(
|
|
2078
|
-
"rmt_detection_complete",
|
|
2079
|
-
message=f"Detected {flagged_layers} outlier layers, correction enabled: {self.correct}",
|
|
2080
|
-
layers_flagged=flagged_layers,
|
|
2081
|
-
correction_iterations=corrected_layers,
|
|
2082
|
-
has_outliers=detection_result.get("has_outliers", False),
|
|
2083
|
-
max_ratio=detection_result.get("max_ratio", 0.0),
|
|
2084
|
-
)
|
|
2085
|
-
|
|
2086
|
-
# Log individual layer results
|
|
2087
|
-
for layer_info in detection_result.get("per_layer", []):
|
|
2088
|
-
if layer_info.get("has_outlier", False):
|
|
2089
|
-
self._log_event(
|
|
2090
|
-
"outlier_detected",
|
|
2091
|
-
message=f"Outlier detected in {layer_info.get('module_name', 'unknown')}",
|
|
2092
|
-
layer_name=layer_info.get("module_name"),
|
|
2093
|
-
ratio=layer_info.get("worst_ratio", 0.0),
|
|
2094
|
-
sigma_max=layer_info.get("sigma_max", 0.0),
|
|
2095
|
-
corrected=self.correct,
|
|
2096
|
-
)
|
|
2097
|
-
elif layer_info.get("skip_reason"):
|
|
2098
|
-
self._log_event(
|
|
2099
|
-
"layer_skipped",
|
|
2100
|
-
message=f"Layer {layer_info.get('module_name', 'unknown')} skipped: {layer_info.get('skip_reason')}",
|
|
2101
|
-
layer_name=layer_info.get("module_name"),
|
|
2102
|
-
skip_reason=layer_info.get("skip_reason"),
|
|
2103
|
-
)
|
|
2104
|
-
|
|
2105
2298
|
except Exception as e:
|
|
2106
2299
|
self._log_event(
|
|
2107
2300
|
"after_edit_failed",
|
|
@@ -2109,15 +2302,11 @@ class RMTGuard(Guard):
|
|
|
2109
2302
|
message=f"RMT detection failed: {str(e)}",
|
|
2110
2303
|
error=str(e),
|
|
2111
2304
|
)
|
|
2112
|
-
# Store empty result for finalize
|
|
2113
2305
|
self._last_result = {
|
|
2114
|
-
"
|
|
2115
|
-
"
|
|
2116
|
-
"
|
|
2117
|
-
"max_ratio": 0.0,
|
|
2306
|
+
"analysis_source": "activations_edge_risk",
|
|
2307
|
+
"edge_risk_by_module": {},
|
|
2308
|
+
"edge_risk_by_family": {},
|
|
2118
2309
|
}
|
|
2119
|
-
self.outliers_per_family = {}
|
|
2120
|
-
self.outliers_total = 0
|
|
2121
2310
|
self.epsilon_violations = []
|
|
2122
2311
|
|
|
2123
2312
|
def validate(
|
|
@@ -2163,27 +2352,13 @@ class RMTGuard(Guard):
|
|
|
2163
2352
|
}
|
|
2164
2353
|
|
|
2165
2354
|
def finalize(self, model: nn.Module, adapter=None) -> GuardOutcome | dict[str, Any]:
|
|
2166
|
-
"""
|
|
2167
|
-
Finalize RMT guard and return comprehensive results.
|
|
2168
|
-
|
|
2169
|
-
Args:
|
|
2170
|
-
model: The final edited model
|
|
2171
|
-
adapter: Optional adapter for tying map access
|
|
2172
|
-
|
|
2173
|
-
Returns:
|
|
2174
|
-
GuardOutcome or dict with RMT detection and correction results
|
|
2175
|
-
"""
|
|
2355
|
+
"""Finalize RMT guard and return activation edge-risk ε-band outcome."""
|
|
2176
2356
|
import time
|
|
2177
2357
|
|
|
2178
2358
|
start_time = time.time()
|
|
2359
|
+
_ = adapter
|
|
2179
2360
|
|
|
2180
2361
|
if not self.prepared:
|
|
2181
|
-
self._log_event(
|
|
2182
|
-
"finalize_failed",
|
|
2183
|
-
level="ERROR",
|
|
2184
|
-
message="RMT guard not properly prepared",
|
|
2185
|
-
)
|
|
2186
|
-
|
|
2187
2362
|
if HAS_GUARD_OUTCOME:
|
|
2188
2363
|
return GuardOutcome(
|
|
2189
2364
|
name=self.name,
|
|
@@ -2202,35 +2377,28 @@ class RMTGuard(Guard):
|
|
|
2202
2377
|
"finalize_time": time.time() - start_time,
|
|
2203
2378
|
},
|
|
2204
2379
|
)
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
"
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
"errors": ["Preparation failed or baseline MP stats not captured"],
|
|
2214
|
-
"events": self.events,
|
|
2215
|
-
}
|
|
2380
|
+
return {
|
|
2381
|
+
"passed": False,
|
|
2382
|
+
"metrics": {
|
|
2383
|
+
"prepared": False,
|
|
2384
|
+
"finalize_time": time.time() - start_time,
|
|
2385
|
+
},
|
|
2386
|
+
"errors": ["RMT guard not properly prepared"],
|
|
2387
|
+
}
|
|
2216
2388
|
|
|
2217
|
-
|
|
2218
|
-
if self._require_activation and (
|
|
2219
|
-
self._activation_required_failed or not self._activation_ready
|
|
2220
|
-
):
|
|
2389
|
+
if self._require_activation and self._activation_required_failed:
|
|
2221
2390
|
reason = self._activation_required_reason or "activation_required"
|
|
2222
|
-
message = "Activation outlier analysis required but unavailable"
|
|
2223
2391
|
finalize_time = time.time() - start_time
|
|
2224
2392
|
if HAS_GUARD_OUTCOME:
|
|
2225
2393
|
return GuardOutcome(
|
|
2226
2394
|
name=self.name,
|
|
2227
2395
|
passed=False,
|
|
2228
|
-
action="
|
|
2396
|
+
action="abort",
|
|
2229
2397
|
violations=[
|
|
2230
2398
|
{
|
|
2231
2399
|
"type": "activation_required",
|
|
2232
2400
|
"severity": "error",
|
|
2233
|
-
"message":
|
|
2401
|
+
"message": "Activation edge-risk analysis required but unavailable",
|
|
2234
2402
|
"module_name": None,
|
|
2235
2403
|
"reason": reason,
|
|
2236
2404
|
}
|
|
@@ -2252,218 +2420,74 @@ class RMTGuard(Guard):
|
|
|
2252
2420
|
"activation_reason": reason,
|
|
2253
2421
|
"finalize_time": finalize_time,
|
|
2254
2422
|
},
|
|
2255
|
-
"
|
|
2256
|
-
"errors": [message],
|
|
2257
|
-
"violations": [],
|
|
2258
|
-
"events": self.events,
|
|
2423
|
+
"errors": ["Activation edge-risk analysis required but unavailable"],
|
|
2259
2424
|
}
|
|
2260
2425
|
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
"per_layer": [],
|
|
2265
|
-
"max_ratio": 0.0,
|
|
2266
|
-
}
|
|
2267
|
-
|
|
2268
|
-
if result and not self.outliers_per_family:
|
|
2269
|
-
self.outliers_per_family = self._count_outliers_per_family(
|
|
2270
|
-
result.get("per_layer", [])
|
|
2426
|
+
if not self.edge_risk_by_family and self._calibration_batches:
|
|
2427
|
+
current = self._compute_activation_edge_risk(
|
|
2428
|
+
model, self._calibration_batches
|
|
2271
2429
|
)
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2430
|
+
if current is not None:
|
|
2431
|
+
self.edge_risk_by_family = dict(
|
|
2432
|
+
current.get("edge_risk_by_family") or {}
|
|
2433
|
+
)
|
|
2434
|
+
self.edge_risk_by_module = dict(
|
|
2435
|
+
current.get("edge_risk_by_module") or {}
|
|
2436
|
+
)
|
|
2437
|
+
self._last_result = dict(current)
|
|
2438
|
+
|
|
2279
2439
|
self.epsilon_violations = self._compute_epsilon_violations()
|
|
2280
|
-
# Contracts: epsilon non-negative, counts non-negative
|
|
2281
2440
|
for fam, eps in self.epsilon_by_family.items():
|
|
2282
2441
|
guard_assert(eps >= 0.0, f"rmt.epsilon[{fam}] must be >= 0")
|
|
2283
|
-
for fam in set(self.outliers_per_family) | set(
|
|
2284
|
-
self.baseline_outliers_per_family
|
|
2285
|
-
):
|
|
2286
|
-
guard_assert(
|
|
2287
|
-
self.outliers_per_family.get(fam, 0) >= 0,
|
|
2288
|
-
"rmt.outliers_per_family negative",
|
|
2289
|
-
)
|
|
2290
|
-
guard_assert(
|
|
2291
|
-
self.baseline_outliers_per_family.get(fam, 0) >= 0,
|
|
2292
|
-
"rmt.baseline_outliers negative",
|
|
2293
|
-
)
|
|
2294
|
-
|
|
2295
|
-
# Calculate metrics
|
|
2296
|
-
flagged_layers = result.get("n_layers_flagged", 0)
|
|
2297
|
-
total_layers = (
|
|
2298
|
-
len(result.get("per_layer", []))
|
|
2299
|
-
if result.get("per_layer")
|
|
2300
|
-
else len(self.baseline_mp_stats)
|
|
2301
|
-
if self.baseline_mp_stats
|
|
2302
|
-
else 0
|
|
2303
|
-
)
|
|
2304
|
-
flagged_rate = flagged_layers / total_layers if total_layers > 0 else 0.0
|
|
2305
|
-
|
|
2306
|
-
# Acceptance gate: pass when epsilon-rule holds per family.
|
|
2307
|
-
passed = not bool(self.epsilon_violations)
|
|
2308
|
-
|
|
2309
|
-
# Generate violations for GuardOutcome
|
|
2310
|
-
violations = []
|
|
2311
|
-
warnings = []
|
|
2312
|
-
errors = []
|
|
2313
|
-
|
|
2314
|
-
# Create violations for each flagged layer
|
|
2315
|
-
for layer_info in result.get("per_layer", []):
|
|
2316
|
-
if layer_info.get("has_outlier", False):
|
|
2317
|
-
violations.append(
|
|
2318
|
-
{
|
|
2319
|
-
"type": "rmt_outlier",
|
|
2320
|
-
"severity": "warning" if self.correct else "error",
|
|
2321
|
-
"message": f"RMT outlier detected: ratio={layer_info.get('worst_ratio', 0.0):.2f}",
|
|
2322
|
-
"module_name": layer_info.get("module_name"),
|
|
2323
|
-
"ratio": layer_info.get("worst_ratio", 0.0),
|
|
2324
|
-
"threshold": (1.0 + self.deadband) * self.margin,
|
|
2325
|
-
"corrected": self.correct,
|
|
2326
|
-
}
|
|
2327
|
-
)
|
|
2328
|
-
|
|
2329
|
-
if flagged_rate > 0.3: # Warning threshold at 30%
|
|
2330
|
-
warnings.append(
|
|
2331
|
-
f"High RMT outlier rate: {flagged_layers}/{total_layers} layers flagged ({flagged_rate:.1%})"
|
|
2332
|
-
)
|
|
2333
|
-
|
|
2334
|
-
if flagged_rate > 0.7: # Escalate to warning for unusually high rates
|
|
2335
|
-
warnings.append(
|
|
2336
|
-
f"Excessive RMT outliers: {flagged_layers}/{total_layers} layers flagged"
|
|
2337
|
-
)
|
|
2338
|
-
|
|
2339
|
-
if self.epsilon_violations:
|
|
2340
|
-
passed = False
|
|
2341
|
-
for failure in self.epsilon_violations:
|
|
2342
|
-
errors.append(
|
|
2343
|
-
"RMT ε-rule violation: "
|
|
2344
|
-
f"{failure['family']} bare={failure['bare']} "
|
|
2345
|
-
f"guarded={failure['guarded']} allowed={failure['allowed']} "
|
|
2346
|
-
f"(ε={failure['epsilon']:.3f})"
|
|
2347
|
-
)
|
|
2348
2442
|
|
|
2443
|
+
stable = not self.epsilon_violations
|
|
2444
|
+
action = "continue" if stable else "abort"
|
|
2349
2445
|
finalize_time = time.time() - start_time
|
|
2350
2446
|
|
|
2351
|
-
|
|
2352
|
-
|
|
2353
|
-
"
|
|
2354
|
-
"
|
|
2355
|
-
"
|
|
2356
|
-
"
|
|
2357
|
-
"
|
|
2358
|
-
"
|
|
2359
|
-
|
|
2360
|
-
|
|
2361
|
-
"deadband_used": self.deadband,
|
|
2362
|
-
"margin_used": self.margin,
|
|
2363
|
-
"detection_threshold": (1.0 + self.deadband) * self.margin,
|
|
2364
|
-
"baseline_layers_captured": len(self.baseline_mp_stats)
|
|
2365
|
-
if self.baseline_mp_stats
|
|
2366
|
-
else 0,
|
|
2367
|
-
"finalize_time": finalize_time,
|
|
2368
|
-
"activation_required": bool(self._require_activation),
|
|
2369
|
-
"activation_ready": bool(self._activation_ready),
|
|
2370
|
-
"analysis_source": result.get("analysis_source"),
|
|
2371
|
-
"token_weight_total": result.get("token_weight_total"),
|
|
2372
|
-
"baseline_outliers_per_family": {
|
|
2373
|
-
k: int(v) for k, v in self.baseline_outliers_per_family.items()
|
|
2447
|
+
metrics: dict[str, Any] = {
|
|
2448
|
+
"prepared": True,
|
|
2449
|
+
"stable": stable,
|
|
2450
|
+
"edge_risk_by_family_base": dict(self.baseline_edge_risk_by_family),
|
|
2451
|
+
"edge_risk_by_family": dict(self.edge_risk_by_family),
|
|
2452
|
+
"epsilon_by_family": dict(self.epsilon_by_family),
|
|
2453
|
+
"epsilon_violations": list(self.epsilon_violations),
|
|
2454
|
+
"measurement_contract": {
|
|
2455
|
+
"estimator": self.estimator,
|
|
2456
|
+
"activation_sampling": self.activation_sampling,
|
|
2374
2457
|
},
|
|
2375
|
-
"
|
|
2376
|
-
k: int(v) for k, v in self.outliers_per_family.items()
|
|
2377
|
-
},
|
|
2378
|
-
"baseline_outliers_total": int(self.baseline_total_outliers),
|
|
2379
|
-
"outliers_total": int(self.outliers_total),
|
|
2380
|
-
"epsilon_by_family": {
|
|
2381
|
-
k: float(v) for k, v in self.epsilon_by_family.items()
|
|
2382
|
-
},
|
|
2383
|
-
"epsilon_default": float(self.epsilon_default),
|
|
2384
|
-
"epsilon_violations": self.epsilon_violations,
|
|
2458
|
+
"finalize_time": finalize_time,
|
|
2385
2459
|
}
|
|
2386
2460
|
|
|
2387
|
-
|
|
2388
|
-
|
|
2389
|
-
|
|
2390
|
-
passed=passed,
|
|
2391
|
-
flagged_rate=flagged_rate,
|
|
2392
|
-
finalize_time=finalize_time,
|
|
2393
|
-
)
|
|
2394
|
-
|
|
2395
|
-
# Return GuardOutcome if available, otherwise legacy dict
|
|
2396
|
-
# Env-gated tiny evidence dump for auditors
|
|
2397
|
-
try:
|
|
2398
|
-
payload = {
|
|
2399
|
-
"rmt": {
|
|
2400
|
-
"epsilon_by_family": {
|
|
2401
|
-
k: float(v) for k, v in self.epsilon_by_family.items()
|
|
2402
|
-
},
|
|
2403
|
-
"deadband": float(self.deadband),
|
|
2404
|
-
"margin": float(self.margin),
|
|
2405
|
-
"evaluated": True,
|
|
2406
|
-
}
|
|
2407
|
-
}
|
|
2408
|
-
maybe_dump_guard_evidence(".", payload)
|
|
2409
|
-
except Exception:
|
|
2410
|
-
pass
|
|
2411
|
-
|
|
2412
|
-
if HAS_GUARD_OUTCOME:
|
|
2413
|
-
# Add details to metrics since GuardOutcome doesn't have a details field
|
|
2414
|
-
final_metrics.update(
|
|
2461
|
+
violations: list[dict[str, Any]] = []
|
|
2462
|
+
for v in self.epsilon_violations:
|
|
2463
|
+
violations.append(
|
|
2415
2464
|
{
|
|
2416
|
-
"
|
|
2417
|
-
"
|
|
2418
|
-
"
|
|
2419
|
-
|
|
2420
|
-
|
|
2421
|
-
"
|
|
2422
|
-
"
|
|
2423
|
-
"
|
|
2424
|
-
|
|
2425
|
-
"deadband": self.deadband,
|
|
2426
|
-
"margin": self.margin,
|
|
2427
|
-
"correct": self.correct,
|
|
2428
|
-
"epsilon": self.epsilon_by_family.copy(),
|
|
2429
|
-
},
|
|
2430
|
-
"scope_suffixes": self.allowed_suffixes,
|
|
2465
|
+
"type": "epsilon_band",
|
|
2466
|
+
"severity": "error",
|
|
2467
|
+
"family": v.get("family"),
|
|
2468
|
+
"edge_base": v.get("edge_base"),
|
|
2469
|
+
"edge_cur": v.get("edge_cur"),
|
|
2470
|
+
"allowed": v.get("allowed"),
|
|
2471
|
+
"epsilon": v.get("epsilon"),
|
|
2472
|
+
"delta": v.get("delta"),
|
|
2473
|
+
"message": f"ε-band violation in {v.get('family')}",
|
|
2431
2474
|
}
|
|
2432
2475
|
)
|
|
2433
2476
|
|
|
2477
|
+
if HAS_GUARD_OUTCOME:
|
|
2434
2478
|
return GuardOutcome(
|
|
2435
2479
|
name=self.name,
|
|
2436
|
-
passed=
|
|
2437
|
-
action=
|
|
2480
|
+
passed=stable,
|
|
2481
|
+
action=action,
|
|
2438
2482
|
violations=violations,
|
|
2439
|
-
metrics=
|
|
2483
|
+
metrics=metrics,
|
|
2440
2484
|
)
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
|
|
2446
|
-
|
|
2447
|
-
"violations": violations,
|
|
2448
|
-
"events": self.events,
|
|
2449
|
-
"details": {
|
|
2450
|
-
"guard_type": "rmt",
|
|
2451
|
-
"baseline_captured": self.baseline_mp_stats is not None,
|
|
2452
|
-
"baseline_count": len(self.baseline_mp_stats)
|
|
2453
|
-
if self.baseline_mp_stats
|
|
2454
|
-
else 0,
|
|
2455
|
-
"flagged_layer_names": [v["module_name"] for v in violations],
|
|
2456
|
-
"per_layer_results": result.get("per_layer", []),
|
|
2457
|
-
"policy": {
|
|
2458
|
-
"q": self.q,
|
|
2459
|
-
"deadband": self.deadband,
|
|
2460
|
-
"margin": self.margin,
|
|
2461
|
-
"correct": self.correct,
|
|
2462
|
-
"epsilon": self.epsilon_by_family.copy(),
|
|
2463
|
-
},
|
|
2464
|
-
"scope_suffixes": self.allowed_suffixes,
|
|
2465
|
-
},
|
|
2466
|
-
}
|
|
2485
|
+
return {
|
|
2486
|
+
"passed": stable,
|
|
2487
|
+
"action": action,
|
|
2488
|
+
"metrics": metrics,
|
|
2489
|
+
"violations": violations,
|
|
2490
|
+
}
|
|
2467
2491
|
|
|
2468
2492
|
def policy(self) -> RMTPolicyDict:
|
|
2469
2493
|
"""
|
|
@@ -2477,7 +2501,8 @@ class RMTGuard(Guard):
|
|
|
2477
2501
|
deadband=self.deadband,
|
|
2478
2502
|
margin=self.margin,
|
|
2479
2503
|
correct=self.correct,
|
|
2480
|
-
|
|
2504
|
+
epsilon_default=float(self.epsilon_default),
|
|
2505
|
+
epsilon_by_family=self.epsilon_by_family.copy(),
|
|
2481
2506
|
)
|
|
2482
2507
|
|
|
2483
2508
|
|
|
@@ -2494,28 +2519,31 @@ def get_rmt_policy(name: str = "balanced") -> RMTPolicyDict:
|
|
|
2494
2519
|
Returns:
|
|
2495
2520
|
RMTPolicyDict configuration
|
|
2496
2521
|
"""
|
|
2497
|
-
# Per-family ε values match tiers.yaml
|
|
2522
|
+
# Per-family ε values match runtime tiers.yaml.
|
|
2498
2523
|
policies = {
|
|
2499
2524
|
"conservative": RMTPolicyDict(
|
|
2500
2525
|
q="auto",
|
|
2501
2526
|
deadband=0.05,
|
|
2502
2527
|
margin=1.3,
|
|
2503
2528
|
correct=True,
|
|
2504
|
-
|
|
2529
|
+
epsilon_default=0.06,
|
|
2530
|
+
epsilon_by_family={"ffn": 0.06, "attn": 0.05, "embed": 0.07, "other": 0.07},
|
|
2505
2531
|
),
|
|
2506
2532
|
"balanced": RMTPolicyDict(
|
|
2507
2533
|
q="auto",
|
|
2508
2534
|
deadband=0.10,
|
|
2509
2535
|
margin=1.5,
|
|
2510
2536
|
correct=True,
|
|
2511
|
-
|
|
2537
|
+
epsilon_default=0.10,
|
|
2538
|
+
epsilon_by_family={"ffn": 0.10, "attn": 0.08, "embed": 0.12, "other": 0.12},
|
|
2512
2539
|
),
|
|
2513
2540
|
"aggressive": RMTPolicyDict(
|
|
2514
2541
|
q="auto",
|
|
2515
2542
|
deadband=0.15,
|
|
2516
2543
|
margin=1.8,
|
|
2517
2544
|
correct=True,
|
|
2518
|
-
|
|
2545
|
+
epsilon_default=0.15,
|
|
2546
|
+
epsilon_by_family={"ffn": 0.15, "attn": 0.15, "embed": 0.15, "other": 0.15},
|
|
2519
2547
|
),
|
|
2520
2548
|
}
|
|
2521
2549
|
|
|
@@ -2537,7 +2565,9 @@ def create_custom_rmt_policy(
|
|
|
2537
2565
|
deadband: float = 0.10,
|
|
2538
2566
|
margin: float = 1.5,
|
|
2539
2567
|
correct: bool = True,
|
|
2540
|
-
|
|
2568
|
+
*,
|
|
2569
|
+
epsilon_default: float = 0.1,
|
|
2570
|
+
epsilon_by_family: dict[str, float] | None = None,
|
|
2541
2571
|
) -> RMTPolicyDict:
|
|
2542
2572
|
"""
|
|
2543
2573
|
Create a custom RMT policy.
|
|
@@ -2578,10 +2608,61 @@ def create_custom_rmt_policy(
|
|
|
2578
2608
|
details={"param": "margin", "value": margin},
|
|
2579
2609
|
)
|
|
2580
2610
|
|
|
2611
|
+
from invarlock.core.exceptions import ValidationError
|
|
2612
|
+
|
|
2613
|
+
try:
|
|
2614
|
+
eps_default_val = float(epsilon_default)
|
|
2615
|
+
except (TypeError, ValueError) as exc:
|
|
2616
|
+
raise ValidationError(
|
|
2617
|
+
code="E501",
|
|
2618
|
+
message="POLICY-PARAM-INVALID",
|
|
2619
|
+
details={"param": "epsilon_default", "value": epsilon_default},
|
|
2620
|
+
) from exc
|
|
2621
|
+
if not (math.isfinite(eps_default_val) and eps_default_val >= 0.0):
|
|
2622
|
+
raise ValidationError(
|
|
2623
|
+
code="E501",
|
|
2624
|
+
message="POLICY-PARAM-INVALID",
|
|
2625
|
+
details={"param": "epsilon_default", "value": epsilon_default},
|
|
2626
|
+
)
|
|
2627
|
+
|
|
2628
|
+
eps_by_family: dict[str, float] = {}
|
|
2629
|
+
if epsilon_by_family is not None:
|
|
2630
|
+
if not isinstance(epsilon_by_family, dict):
|
|
2631
|
+
raise ValidationError(
|
|
2632
|
+
code="E501",
|
|
2633
|
+
message="POLICY-PARAM-INVALID",
|
|
2634
|
+
details={"param": "epsilon_by_family", "value": epsilon_by_family},
|
|
2635
|
+
)
|
|
2636
|
+
for family, value in epsilon_by_family.items():
|
|
2637
|
+
try:
|
|
2638
|
+
eps_val = float(value)
|
|
2639
|
+
except (TypeError, ValueError) as exc:
|
|
2640
|
+
raise ValidationError(
|
|
2641
|
+
code="E501",
|
|
2642
|
+
message="POLICY-PARAM-INVALID",
|
|
2643
|
+
details={
|
|
2644
|
+
"param": "epsilon_by_family",
|
|
2645
|
+
"family": str(family),
|
|
2646
|
+
"value": value,
|
|
2647
|
+
},
|
|
2648
|
+
) from exc
|
|
2649
|
+
if not (math.isfinite(eps_val) and eps_val >= 0.0):
|
|
2650
|
+
raise ValidationError(
|
|
2651
|
+
code="E501",
|
|
2652
|
+
message="POLICY-PARAM-INVALID",
|
|
2653
|
+
details={
|
|
2654
|
+
"param": "epsilon_by_family",
|
|
2655
|
+
"family": str(family),
|
|
2656
|
+
"value": value,
|
|
2657
|
+
},
|
|
2658
|
+
)
|
|
2659
|
+
eps_by_family[str(family)] = eps_val
|
|
2660
|
+
|
|
2581
2661
|
return RMTPolicyDict(
|
|
2582
2662
|
q=q,
|
|
2583
2663
|
deadband=deadband,
|
|
2584
2664
|
margin=margin,
|
|
2585
2665
|
correct=correct,
|
|
2586
|
-
|
|
2666
|
+
epsilon_default=eps_default_val,
|
|
2667
|
+
epsilon_by_family=eps_by_family,
|
|
2587
2668
|
)
|