invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,805 @@
1
+ """
2
+ InvarLock Guards - Default Policy Presets
3
+ ====================================
4
+
5
+ Default policy configurations for various guard types and use cases.
6
+ Provides sensible defaults for different model architectures and safety requirements.
7
+
8
+ Policy values are loaded from tiers.yaml (the calibrated source of truth) with
9
+ hardcoded fallbacks for robustness. Use check_policy_drift() to verify that
10
+ code and config are synchronized.
11
+ """
12
+
13
+ import math
14
+ from typing import Any, Literal
15
+
16
+ try: # Python 3.12+
17
+ from typing import NotRequired, TypedDict
18
+ except ImportError: # Legacy fallback
19
+ from typing import NotRequired
20
+
21
+ from typing_extensions import TypedDict
22
+
23
+ from invarlock.core.exceptions import (
24
+ GuardError,
25
+ PolicyViolationError,
26
+ ValidationError,
27
+ )
28
+
29
+ from .rmt import RMTPolicyDict
30
+ from .spectral import SpectralPolicy
31
+ from .tier_config import check_drift as check_tier_drift
32
+ from .tier_config import get_tier_guard_config
33
+
34
+ # === Spectral Guard Policies ===
35
+
36
+ # Conservative policy - tight control for production use
37
+ SPECTRAL_CONSERVATIVE: SpectralPolicy = {
38
+ "sigma_quantile": 0.90, # Allow only 90% of baseline spectral norm
39
+ "deadband": 0.05, # 5% deadband - strict threshold
40
+ "scope": "ffn", # FFN layers only (safest)
41
+ "correction_enabled": True,
42
+ "max_caps": 3,
43
+ "multiple_testing": {"method": "bonferroni", "alpha": 0.02, "m": 4},
44
+ }
45
+
46
+ # Balanced policy - good for most use cases
47
+ SPECTRAL_BALANCED: SpectralPolicy = {
48
+ "sigma_quantile": 0.95, # Allow 95% of baseline spectral norm
49
+ "deadband": 0.10, # 10% deadband - reasonable tolerance
50
+ "scope": "ffn", # FFN layers only
51
+ "correction_enabled": False,
52
+ "max_caps": 5,
53
+ "multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
54
+ }
55
+
56
+ # Aggressive policy - for research/experimental use
57
+ SPECTRAL_AGGRESSIVE: SpectralPolicy = {
58
+ "sigma_quantile": 0.98, # Allow 98% of baseline spectral norm
59
+ "deadband": 0.15, # 15% deadband - more permissive
60
+ "scope": "all", # All layers including attention
61
+ "correction_enabled": True,
62
+ "max_caps": 8,
63
+ "multiple_testing": {"method": "bh", "alpha": 0.1, "m": 4},
64
+ }
65
+
66
+ # Attention-aware policy - includes attention projections
67
+ SPECTRAL_ATTN_AWARE: SpectralPolicy = {
68
+ "sigma_quantile": 0.95, # Standard scaling factor
69
+ "deadband": 0.10, # Standard deadband
70
+ "scope": "attn", # Attention layers only
71
+ "correction_enabled": False,
72
+ "max_caps": 5,
73
+ "multiple_testing": {"method": "bh", "alpha": 0.05, "m": 4},
74
+ }
75
+
76
+ # === RMT Guard Policies ===
77
+
78
+ # Conservative RMT policy - tight control for production use
79
+ RMT_CONSERVATIVE: RMTPolicyDict = {
80
+ "q": "auto", # Auto-derive MP aspect ratio from weight shapes
81
+ "deadband": 0.05, # 5% deadband - strict threshold
82
+ "margin": 1.3, # Lower margin for conservative detection
83
+ "correct": True, # Enable automatic correction
84
+ "epsilon": {"attn": 0.05, "ffn": 0.06, "embed": 0.07, "other": 0.07},
85
+ }
86
+
87
+ # Balanced RMT policy - good for most use cases
88
+ RMT_BALANCED: RMTPolicyDict = {
89
+ "q": "auto", # Auto-derive MP aspect ratio from weight shapes
90
+ "deadband": 0.10, # 10% deadband - reasonable tolerance
91
+ "margin": 1.5, # Standard margin for outlier detection
92
+ "correct": False, # Monitor-only by default
93
+ "epsilon": {"attn": 0.08, "ffn": 0.10, "embed": 0.12, "other": 0.12},
94
+ }
95
+
96
+ # Aggressive RMT policy - for research/experimental use
97
+ RMT_AGGRESSIVE: RMTPolicyDict = {
98
+ "q": "auto", # Auto-derive MP aspect ratio from weight shapes
99
+ "deadband": 0.15, # 15% deadband - more permissive
100
+ "margin": 1.8, # Higher margin allows more deviation
101
+ "correct": True, # Enable automatic correction
102
+ "epsilon": {"attn": 0.15, "ffn": 0.15, "embed": 0.15, "other": 0.15},
103
+ }
104
+
105
+ # === Variance Guard Policies ===
106
+
107
+
108
+ class VariancePolicyRequired(TypedDict):
109
+ """TypedDict for variance guard policy configuration."""
110
+
111
+ min_gain: float
112
+ max_calib: int
113
+ scope: Literal["ffn", "attn", "both"]
114
+ clamp: tuple[float, float]
115
+ deadband: float
116
+ seed: int
117
+ mode: Literal["delta", "ci"]
118
+ min_rel_gain: float
119
+ alpha: float
120
+
121
+
122
+ class VariancePolicyDict(VariancePolicyRequired, total=False):
123
+ """Extended variance policy allowing optional calibration overrides."""
124
+
125
+ calibration: dict[str, Any]
126
+ tie_breaker_deadband: NotRequired[float]
127
+ min_effect_lognll: NotRequired[float]
128
+ min_abs_adjust: NotRequired[float]
129
+ max_scale_step: NotRequired[float]
130
+ topk_backstop: NotRequired[int]
131
+ predictive_gate: NotRequired[bool]
132
+ monitor_only: NotRequired[bool]
133
+ target_modules: NotRequired[list[str]]
134
+ tap: NotRequired[str | list[str]]
135
+
136
+
137
+ # Conservative variance policy - strict A/B gate for production use
138
+ VARIANCE_CONSERVATIVE: VariancePolicyDict = {
139
+ "min_gain": 0.02,
140
+ "max_calib": 160,
141
+ "scope": "ffn",
142
+ "clamp": (0.85, 1.12),
143
+ "deadband": 0.03,
144
+ "seed": 42,
145
+ "mode": "ci",
146
+ "min_rel_gain": 0.002,
147
+ "alpha": 0.05,
148
+ "tie_breaker_deadband": 0.005,
149
+ "min_effect_lognll": 0.0018,
150
+ "min_abs_adjust": 0.02,
151
+ "max_scale_step": 0.015,
152
+ "topk_backstop": 0,
153
+ "predictive_gate": True,
154
+ "tap": "transformer.h.*.mlp.c_proj",
155
+ "calibration": {
156
+ "windows": 16,
157
+ "min_coverage": 12,
158
+ "seed": 42,
159
+ },
160
+ }
161
+
162
+ # Balanced variance policy - good for most use cases
163
+ VARIANCE_BALANCED: VariancePolicyDict = {
164
+ "min_gain": 0.0,
165
+ "max_calib": 160,
166
+ "scope": "ffn",
167
+ "clamp": (0.85, 1.12),
168
+ "deadband": 0.02,
169
+ "seed": 123,
170
+ "mode": "ci",
171
+ "min_rel_gain": 0.001,
172
+ "alpha": 0.05,
173
+ "tie_breaker_deadband": 0.001,
174
+ "min_effect_lognll": 0.0009,
175
+ "min_abs_adjust": 0.012,
176
+ "max_scale_step": 0.03,
177
+ "topk_backstop": 1,
178
+ "predictive_gate": True,
179
+ "tap": "transformer.h.*.mlp.c_proj",
180
+ "calibration": {
181
+ "windows": 12,
182
+ "min_coverage": 10,
183
+ "seed": 123,
184
+ },
185
+ }
186
+
187
+ # Aggressive variance policy - for research/experimental use
188
+ VARIANCE_AGGRESSIVE: VariancePolicyDict = {
189
+ "min_gain": 0.0,
190
+ "max_calib": 240,
191
+ "scope": "both",
192
+ "clamp": (0.3, 3.0),
193
+ "deadband": 0.12,
194
+ "seed": 456,
195
+ "mode": "ci",
196
+ "min_rel_gain": 0.0025,
197
+ "alpha": 0.05,
198
+ "tie_breaker_deadband": 0.0005,
199
+ "min_effect_lognll": 0.0005,
200
+ "calibration": {
201
+ "windows": 6,
202
+ "min_coverage": 4,
203
+ "seed": 456,
204
+ },
205
+ }
206
+
207
+ # === Policy Collections ===
208
+
209
+ DEFAULT_SPECTRAL_POLICIES: dict[str, SpectralPolicy] = {
210
+ "conservative": SPECTRAL_CONSERVATIVE,
211
+ "balanced": SPECTRAL_BALANCED,
212
+ "aggressive": SPECTRAL_AGGRESSIVE,
213
+ "attn_aware": SPECTRAL_ATTN_AWARE,
214
+ }
215
+
216
+ # === RMT Policy Collections ===
217
+
218
+ DEFAULT_RMT_POLICIES: dict[str, RMTPolicyDict] = {
219
+ "conservative": RMT_CONSERVATIVE,
220
+ "balanced": RMT_BALANCED,
221
+ "aggressive": RMT_AGGRESSIVE,
222
+ }
223
+
224
+ # === Variance Policy Collections ===
225
+
226
+ DEFAULT_VARIANCE_POLICIES: dict[str, VariancePolicyDict] = {
227
+ "conservative": VARIANCE_CONSERVATIVE,
228
+ "balanced": VARIANCE_BALANCED,
229
+ "aggressive": VARIANCE_AGGRESSIVE,
230
+ }
231
+
232
+ # === Utility Functions ===
233
+
234
+
235
+ def get_spectral_policy(
236
+ name: str = "balanced", *, use_yaml: bool = True
237
+ ) -> SpectralPolicy:
238
+ """
239
+ Get a spectral policy by name.
240
+
241
+ Loads values from tiers.yaml (calibrated source of truth) when available,
242
+ falling back to hardcoded defaults for robustness.
243
+
244
+ Args:
245
+ name: Policy name ("conservative", "balanced", "aggressive", "attn_aware")
246
+ use_yaml: If True, attempt to load calibrated values from tiers.yaml
247
+
248
+ Returns:
249
+ SpectralPolicy configuration with calibrated thresholds
250
+
251
+ Raises:
252
+ GuardError(E502): If policy name not found
253
+ """
254
+ if name not in DEFAULT_SPECTRAL_POLICIES:
255
+ available = list(DEFAULT_SPECTRAL_POLICIES.keys())
256
+ raise GuardError(
257
+ code="E502",
258
+ message="POLICY-NOT-FOUND",
259
+ details={"name": name, "available": available},
260
+ )
261
+
262
+ # Start with hardcoded defaults
263
+ policy = DEFAULT_SPECTRAL_POLICIES[name].copy()
264
+
265
+ # Overlay calibrated values from tiers.yaml if available
266
+ if use_yaml and name in ("balanced", "conservative", "aggressive"):
267
+ try:
268
+ tier_config = get_tier_guard_config(name, "spectral_guard") # type: ignore[arg-type]
269
+ if tier_config:
270
+ # Update with calibrated values
271
+ if "sigma_quantile" in tier_config:
272
+ policy["sigma_quantile"] = tier_config["sigma_quantile"]
273
+ if "deadband" in tier_config:
274
+ policy["deadband"] = tier_config["deadband"]
275
+ if "scope" in tier_config:
276
+ policy["scope"] = tier_config["scope"]
277
+ if "max_caps" in tier_config:
278
+ policy["max_caps"] = tier_config["max_caps"]
279
+ if "family_caps" in tier_config:
280
+ policy["family_caps"] = tier_config["family_caps"]
281
+ if "multiple_testing" in tier_config:
282
+ policy["multiple_testing"] = tier_config["multiple_testing"]
283
+ except Exception:
284
+ # Fallback to hardcoded values on any error
285
+ pass
286
+
287
+ return policy
288
+
289
+
290
+ def create_custom_spectral_policy(
291
+ sigma_quantile: float = 0.95,
292
+ deadband: float = 0.10,
293
+ scope: str = "ffn",
294
+ ) -> SpectralPolicy:
295
+ """
296
+ Create a custom spectral policy.
297
+
298
+ Args:
299
+ sigma_quantile: Baseline spectral percentile (0.0-1.0)
300
+ deadband: Tolerance margin (0.0-0.5)
301
+ scope: Module scope ("ffn", "attn", "all")
302
+
303
+ Returns:
304
+ Custom SpectralPolicy configuration
305
+
306
+ Raises:
307
+ ValidationError(E501): If parameters are out of valid ranges
308
+ """
309
+ if not 0.0 <= sigma_quantile <= 1.0:
310
+ raise ValidationError(
311
+ code="E501",
312
+ message="POLICY-PARAM-INVALID",
313
+ details={"param": "sigma_quantile", "value": sigma_quantile},
314
+ )
315
+
316
+ if not 0.0 <= deadband <= 0.5:
317
+ raise ValidationError(
318
+ code="E501",
319
+ message="POLICY-PARAM-INVALID",
320
+ details={"param": "deadband", "value": deadband},
321
+ )
322
+
323
+ if scope not in ["ffn", "attn", "all"]:
324
+ raise ValidationError(
325
+ code="E501",
326
+ message="POLICY-PARAM-INVALID",
327
+ details={"param": "scope", "value": scope},
328
+ )
329
+
330
+ return SpectralPolicy(
331
+ sigma_quantile=sigma_quantile,
332
+ deadband=deadband,
333
+ scope=scope,
334
+ )
335
+
336
+
337
+ def get_policy_for_model_size(param_count: int) -> SpectralPolicy:
338
+ """
339
+ Get recommended spectral policy based on model size.
340
+
341
+ Args:
342
+ param_count: Number of model parameters
343
+
344
+ Returns:
345
+ Recommended SpectralPolicy
346
+ """
347
+ if param_count < 100_000_000: # < 100M params
348
+ return get_spectral_policy("aggressive")
349
+ elif param_count < 1_000_000_000: # < 1B params
350
+ return get_spectral_policy("balanced")
351
+ else: # >= 1B params
352
+ return get_spectral_policy("conservative")
353
+
354
+
355
+ def get_rmt_policy(name: str = "balanced", *, use_yaml: bool = True) -> RMTPolicyDict:
356
+ """
357
+ Get a RMT policy by name.
358
+
359
+ Loads values from tiers.yaml (calibrated source of truth) when available,
360
+ falling back to hardcoded defaults for robustness.
361
+
362
+ Args:
363
+ name: Policy name ("conservative", "balanced", "aggressive")
364
+ use_yaml: If True, attempt to load calibrated values from tiers.yaml
365
+
366
+ Returns:
367
+ RMTPolicyDict configuration with calibrated epsilon values
368
+
369
+ Raises:
370
+ GuardError(E502): If policy name not found
371
+ """
372
+ if name not in DEFAULT_RMT_POLICIES:
373
+ available = list(DEFAULT_RMT_POLICIES.keys())
374
+ raise GuardError(
375
+ code="E502",
376
+ message="POLICY-NOT-FOUND",
377
+ details={"name": name, "available": available},
378
+ )
379
+
380
+ # Start with hardcoded defaults
381
+ policy = DEFAULT_RMT_POLICIES[name].copy()
382
+
383
+ # Overlay calibrated values from tiers.yaml if available
384
+ if use_yaml and name in ("balanced", "conservative", "aggressive"):
385
+ try:
386
+ tier_config = get_tier_guard_config(name, "rmt_guard") # type: ignore[arg-type]
387
+ if tier_config:
388
+ # Update with calibrated values
389
+ if "deadband" in tier_config:
390
+ policy["deadband"] = tier_config["deadband"]
391
+ if "margin" in tier_config:
392
+ policy["margin"] = tier_config["margin"]
393
+ # Use epsilon_by_family as the epsilon dict
394
+ if "epsilon_by_family" in tier_config:
395
+ policy["epsilon"] = tier_config["epsilon_by_family"]
396
+ except Exception:
397
+ # Fallback to hardcoded values on any error
398
+ pass
399
+
400
+ return policy
401
+
402
+
403
+ def create_custom_rmt_policy(
404
+ q: float | Literal["auto"] = "auto",
405
+ deadband: float = 0.10,
406
+ margin: float = 1.5,
407
+ correct: bool = True,
408
+ ) -> RMTPolicyDict:
409
+ """
410
+ Create a custom RMT policy.
411
+
412
+ Args:
413
+ q: MP aspect ratio (auto-derived or manual, 0.1-10.0)
414
+ deadband: Tolerance margin (0.0-0.5)
415
+ margin: RMT threshold ratio (>= 1.0)
416
+ correct: Enable automatic correction
417
+
418
+ Returns:
419
+ Custom RMTPolicyDict configuration
420
+
421
+ Raises:
422
+ ValidationError(E501): If parameters are out of valid ranges
423
+ """
424
+ if isinstance(q, float) and not 0.1 <= q <= 10.0:
425
+ raise ValidationError(
426
+ code="E501",
427
+ message="POLICY-PARAM-INVALID",
428
+ details={"param": "q", "value": q},
429
+ )
430
+
431
+ if not 0.0 <= deadband <= 0.5:
432
+ raise ValidationError(
433
+ code="E501",
434
+ message="POLICY-PARAM-INVALID",
435
+ details={"param": "deadband", "value": deadband},
436
+ )
437
+
438
+ if not margin >= 1.0:
439
+ raise ValidationError(
440
+ code="E501",
441
+ message="POLICY-PARAM-INVALID",
442
+ details={"param": "margin", "value": margin},
443
+ )
444
+
445
+ return RMTPolicyDict(q=q, deadband=deadband, margin=margin, correct=correct)
446
+
447
+
448
+ def get_rmt_policy_for_model_size(param_count: int) -> RMTPolicyDict:
449
+ """
450
+ Get recommended RMT policy based on model size.
451
+
452
+ Args:
453
+ param_count: Number of model parameters
454
+
455
+ Returns:
456
+ Recommended RMTPolicyDict
457
+ """
458
+ if param_count < 100_000_000: # < 100M params
459
+ return get_rmt_policy("aggressive")
460
+ elif param_count < 1_000_000_000: # < 1B params
461
+ return get_rmt_policy("balanced")
462
+ else: # >= 1B params
463
+ return get_rmt_policy("conservative")
464
+
465
+
466
+ def get_variance_policy(
467
+ name: str = "balanced", *, use_yaml: bool = True
468
+ ) -> VariancePolicyDict:
469
+ """
470
+ Get a variance policy by name.
471
+
472
+ Loads values from tiers.yaml (calibrated source of truth) when available,
473
+ falling back to hardcoded defaults for robustness.
474
+
475
+ Args:
476
+ name: Policy name ("conservative", "balanced", "aggressive")
477
+ use_yaml: If True, attempt to load calibrated values from tiers.yaml
478
+
479
+ Returns:
480
+ VariancePolicyDict configuration with calibrated thresholds
481
+
482
+ Raises:
483
+ GuardError(E502): If policy name not found
484
+ """
485
+ if name not in DEFAULT_VARIANCE_POLICIES:
486
+ available = list(DEFAULT_VARIANCE_POLICIES.keys())
487
+ raise GuardError(
488
+ code="E502",
489
+ message="POLICY-NOT-FOUND",
490
+ details={"name": name, "available": available},
491
+ )
492
+
493
+ # Start with hardcoded defaults
494
+ policy = DEFAULT_VARIANCE_POLICIES[name].copy()
495
+
496
+ # Overlay calibrated values from tiers.yaml if available
497
+ if use_yaml and name in ("balanced", "conservative", "aggressive"):
498
+ try:
499
+ tier_config = get_tier_guard_config(name, "variance_guard") # type: ignore[arg-type]
500
+ if tier_config:
501
+ # Update with calibrated values
502
+ if "deadband" in tier_config:
503
+ policy["deadband"] = tier_config["deadband"]
504
+ if "min_effect_lognll" in tier_config:
505
+ policy["min_effect_lognll"] = tier_config["min_effect_lognll"]
506
+ if "min_abs_adjust" in tier_config:
507
+ policy["min_abs_adjust"] = tier_config["min_abs_adjust"]
508
+ if "max_scale_step" in tier_config:
509
+ policy["max_scale_step"] = tier_config["max_scale_step"]
510
+ if "topk_backstop" in tier_config:
511
+ policy["topk_backstop"] = tier_config["topk_backstop"]
512
+ if "predictive_one_sided" in tier_config:
513
+ # Map predictive_one_sided to predictive_gate behavior
514
+ pass # This is handled elsewhere in variance guard
515
+ except Exception:
516
+ # Fallback to hardcoded values on any error
517
+ pass
518
+
519
+ return policy
520
+
521
+
522
+ def create_custom_variance_policy(
523
+ min_gain: float = 0.30,
524
+ max_calib: int = 200,
525
+ scope: Literal["ffn", "attn", "both"] = "both",
526
+ clamp: tuple[float, float] = (0.5, 2.0),
527
+ deadband: float = 0.10,
528
+ seed: int = 123,
529
+ mode: Literal["delta", "ci"] = "ci",
530
+ min_rel_gain: float = 0.005,
531
+ alpha: float = 0.05,
532
+ ) -> VariancePolicyDict:
533
+ """
534
+ Create a custom variance policy.
535
+
536
+ Args:
537
+ min_gain: Minimum primary-metric improvement to enable VE (0.0-1.0)
538
+ max_calib: Maximum calibration samples (50-1000)
539
+ scope: Module scope ("ffn", "attn", "both")
540
+ clamp: Scaling factor limits (min, max) where 0.1 <= min < max <= 5.0
541
+ deadband: Tolerance margin (0.0-0.5)
542
+ seed: Random seed for deterministic evaluation
543
+ mode: Gate mode (\"ci\" or \"delta\")
544
+ min_rel_gain: Minimum relative gain required under CI mode
545
+ alpha: Confidence interval significance level
546
+
547
+ Returns:
548
+ Custom VariancePolicyDict configuration
549
+
550
+ Raises:
551
+ ValidationError(E501): If parameters are out of valid ranges
552
+ """
553
+ if not 0.0 <= min_gain <= 1.0:
554
+ raise ValidationError(
555
+ code="E501",
556
+ message="POLICY-PARAM-INVALID",
557
+ details={"param": "min_gain", "value": min_gain},
558
+ )
559
+
560
+ if not 50 <= max_calib <= 1000:
561
+ raise ValidationError(
562
+ code="E501",
563
+ message="POLICY-PARAM-INVALID",
564
+ details={"param": "max_calib", "value": max_calib},
565
+ )
566
+
567
+ if scope not in ["ffn", "attn", "both"]:
568
+ raise ValidationError(
569
+ code="E501",
570
+ message="POLICY-PARAM-INVALID",
571
+ details={"param": "scope", "value": scope},
572
+ )
573
+
574
+ clamp_min, clamp_max = clamp
575
+ if not (0.1 <= clamp_min < clamp_max <= 5.0):
576
+ raise ValidationError(
577
+ code="E501",
578
+ message="POLICY-PARAM-INVALID",
579
+ details={"param": "clamp", "value": clamp},
580
+ )
581
+
582
+ if not 0.0 <= deadband <= 0.5:
583
+ raise ValidationError(
584
+ code="E501",
585
+ message="POLICY-PARAM-INVALID",
586
+ details={"param": "deadband", "value": deadband},
587
+ )
588
+
589
+ if mode not in {"delta", "ci"}:
590
+ raise ValidationError(
591
+ code="E501",
592
+ message="POLICY-PARAM-INVALID",
593
+ details={"param": "mode", "value": mode},
594
+ )
595
+
596
+ if not 0.0 <= min_rel_gain < 1.0:
597
+ raise ValidationError(
598
+ code="E501",
599
+ message="POLICY-PARAM-INVALID",
600
+ details={"param": "min_rel_gain", "value": min_rel_gain},
601
+ )
602
+
603
+ if not 0.0 < alpha < 1.0:
604
+ raise ValidationError(
605
+ code="E501",
606
+ message="POLICY-PARAM-INVALID",
607
+ details={"param": "alpha", "value": alpha},
608
+ )
609
+
610
+ return VariancePolicyDict(
611
+ min_gain=min_gain,
612
+ max_calib=max_calib,
613
+ scope=scope,
614
+ clamp=clamp,
615
+ deadband=deadband,
616
+ seed=seed,
617
+ mode=mode,
618
+ min_rel_gain=min_rel_gain,
619
+ alpha=alpha,
620
+ )
621
+
622
+
623
+ def get_variance_policy_for_model_size(param_count: int) -> VariancePolicyDict:
624
+ """
625
+ Get recommended variance policy based on model size.
626
+
627
+ Args:
628
+ param_count: Number of model parameters
629
+
630
+ Returns:
631
+ Recommended VariancePolicyDict
632
+ """
633
+ if param_count < 100_000_000: # < 100M params
634
+ return get_variance_policy("aggressive")
635
+ elif param_count < 1_000_000_000: # < 1B params
636
+ return get_variance_policy("balanced")
637
+ else: # >= 1B params
638
+ return get_variance_policy("conservative")
639
+
640
+
641
+ # === Validation Gate Presets ===
642
+
643
+ VALIDATION_GATE_STRICT: dict[str, Any] = {
644
+ "max_capping_rate": 0.3, # Max 30% of layers can be capped
645
+ "max_ppl_degradation": 0.01, # Max 1% primary-metric degradation (ppl-like)
646
+ "require_branch_balance": True,
647
+ }
648
+
649
+ VALIDATION_GATE_STANDARD: dict[str, Any] = {
650
+ "max_capping_rate": 0.5, # Max 50% of layers can be capped
651
+ "max_ppl_degradation": 0.02, # Max 2% primary-metric degradation (ppl-like)
652
+ "require_branch_balance": True,
653
+ }
654
+
655
+ VALIDATION_GATE_PERMISSIVE: dict[str, Any] = {
656
+ "max_capping_rate": 0.7, # Max 70% of layers can be capped
657
+ "max_ppl_degradation": 0.05, # Max 5% primary-metric degradation (ppl-like)
658
+ "require_branch_balance": False,
659
+ }
660
+
661
+ DEFAULT_VALIDATION_GATES: dict[str, dict[str, Any]] = {
662
+ "strict": VALIDATION_GATE_STRICT,
663
+ "standard": VALIDATION_GATE_STANDARD,
664
+ "permissive": VALIDATION_GATE_PERMISSIVE,
665
+ }
666
+
667
+
668
+ def get_validation_gate(name: str = "standard") -> dict[str, Any]:
669
+ """
670
+ Get validation gate configuration by name.
671
+
672
+ Args:
673
+ name: Gate configuration name
674
+
675
+ Returns:
676
+ Validation gate configuration
677
+ """
678
+ if name not in DEFAULT_VALIDATION_GATES:
679
+ available = list(DEFAULT_VALIDATION_GATES.keys())
680
+ raise GuardError(
681
+ code="E502",
682
+ message="POLICY-NOT-FOUND",
683
+ details={"name": name, "available": available},
684
+ )
685
+
686
+ return DEFAULT_VALIDATION_GATES[name].copy()
687
+
688
+
689
+ def enforce_validation_gate(metrics: dict[str, Any], gate: dict[str, Any]) -> None:
690
+ """Enforce validation gate thresholds.
691
+
692
+ Raises PolicyViolationError(E503) with a 'violations' list in details
693
+ when one or more constraints are exceeded.
694
+ """
695
+ violations: list[dict[str, Any]] = []
696
+
697
+ try:
698
+ caps = float(metrics.get("caps_applied", 0))
699
+ total = float(metrics.get("total_layers", 0))
700
+ if total > 0:
701
+ rate = caps / total
702
+ limit = float(gate.get("max_capping_rate", 1.0))
703
+ if rate > limit:
704
+ violations.append(
705
+ {
706
+ "type": "capping_rate",
707
+ "actual": rate,
708
+ "limit": limit,
709
+ }
710
+ )
711
+ except Exception:
712
+ # Ignore malformed metrics here; gating purely best-effort
713
+ pass
714
+
715
+ try:
716
+ ratio = metrics.get("primary_metric_ratio")
717
+ if isinstance(ratio, int | float) and math.isfinite(float(ratio)):
718
+ limit = float(gate.get("max_ppl_degradation", 1.0))
719
+ # ppl-like ratio: degradation ~ ratio-1; gate on allowed extra
720
+ if ratio - 1.0 > limit:
721
+ violations.append(
722
+ {
723
+ "type": "primary_metric_degradation",
724
+ "actual": float(ratio - 1.0),
725
+ "limit": limit,
726
+ }
727
+ )
728
+ except Exception:
729
+ pass
730
+
731
+ if isinstance(gate.get("require_branch_balance"), bool) and gate.get(
732
+ "require_branch_balance"
733
+ ):
734
+ if metrics.get("branch_balance_ok") is False:
735
+ violations.append(
736
+ {"type": "branch_balance", "actual": False, "limit": True}
737
+ )
738
+
739
+ if violations:
740
+ raise PolicyViolationError(
741
+ code="E503",
742
+ message="VALIDATION-GATE-FAILED",
743
+ details={"violations": violations, "metrics": metrics, "gate": gate},
744
+ )
745
+
746
+
747
+ def check_policy_drift(*, silent: bool = False) -> dict[str, list[str]]:
748
+ """
749
+ Check for drift between tiers.yaml and hardcoded policy fallbacks.
750
+
751
+ This helps detect when tiers.yaml has been updated but hardcoded
752
+ fallbacks in this module haven't been synchronized.
753
+
754
+ Args:
755
+ silent: If True, don't emit warnings (just return drift info)
756
+
757
+ Returns:
758
+ Dict of tier -> list of drift descriptions.
759
+ Empty dict means no drift detected.
760
+
761
+ Example:
762
+ >>> drift = check_policy_drift()
763
+ >>> if drift:
764
+ ... print("Policy drift detected:", drift)
765
+ ... print("Consider updating hardcoded defaults in policies.py")
766
+ """
767
+ return check_tier_drift(silent=silent)
768
+
769
+
770
+ __all__ = [
771
+ # Spectral policy constants
772
+ "SPECTRAL_CONSERVATIVE",
773
+ "SPECTRAL_BALANCED",
774
+ "SPECTRAL_AGGRESSIVE",
775
+ "SPECTRAL_ATTN_AWARE",
776
+ "DEFAULT_SPECTRAL_POLICIES",
777
+ # RMT policy constants
778
+ "RMT_CONSERVATIVE",
779
+ "RMT_BALANCED",
780
+ "RMT_AGGRESSIVE",
781
+ "DEFAULT_RMT_POLICIES",
782
+ # Variance policy constants
783
+ "VariancePolicyDict",
784
+ "VARIANCE_CONSERVATIVE",
785
+ "VARIANCE_BALANCED",
786
+ "VARIANCE_AGGRESSIVE",
787
+ "DEFAULT_VARIANCE_POLICIES",
788
+ # Validation gate constants
789
+ "VALIDATION_GATE_STRICT",
790
+ "VALIDATION_GATE_STANDARD",
791
+ "VALIDATION_GATE_PERMISSIVE",
792
+ "DEFAULT_VALIDATION_GATES",
793
+ # Utility functions
794
+ "get_spectral_policy",
795
+ "create_custom_spectral_policy",
796
+ "get_policy_for_model_size",
797
+ "get_rmt_policy",
798
+ "create_custom_rmt_policy",
799
+ "get_rmt_policy_for_model_size",
800
+ "get_variance_policy",
801
+ "create_custom_variance_policy",
802
+ "get_variance_policy_for_model_size",
803
+ "get_validation_gate",
804
+ "check_policy_drift",
805
+ ]