invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,640 @@
1
+ """
2
+ InvarLock Guards - Invariants
3
+ =========================
4
+
5
+ Invariant checking for model edits to ensure structural integrity.
6
+ """
7
+
8
+ from typing import Any
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from invarlock.core.api import Guard
14
+ from invarlock.core.types import GuardOutcome
15
+
16
+
17
+ class InvariantsGuard(Guard):
18
+ """
19
+ Guard for checking model invariants and structural integrity.
20
+ """
21
+
22
+ name = "invariants"
23
+
24
+ def __init__(self, strict_mode: bool = False, on_fail: str = "warn"):
25
+ """
26
+ Initialize invariants guard.
27
+
28
+ Args:
29
+ strict_mode: Whether to use strict validation
30
+ on_fail: Action to take on failure ("warn", "rollback", "abort")
31
+ """
32
+ self.strict_mode = strict_mode
33
+ self.on_fail = on_fail
34
+ self.prepared = False
35
+ self.baseline_checks: dict[str, Any] = {}
36
+ self.profile_checks: tuple[str, ...] = ()
37
+
38
+ def prepare(
39
+ self, model: Any, adapter: Any, calib: Any, policy: dict[str, Any]
40
+ ) -> dict[str, Any]:
41
+ """
42
+ Prepare invariants guard by capturing baseline state.
43
+
44
+ Args:
45
+ model: Model to prepare for
46
+ adapter: ModelAdapter instance
47
+ calib: Calibration data (unused)
48
+ policy: Policy configuration
49
+
50
+ Returns:
51
+ Preparation results
52
+ """
53
+ self.prepared = True
54
+
55
+ profile_checks = (
56
+ policy.get("profile_checks") if isinstance(policy, dict) else None
57
+ )
58
+ if isinstance(profile_checks, list | tuple | set):
59
+ self.profile_checks = tuple(str(check) for check in profile_checks)
60
+ else:
61
+ self.profile_checks = ()
62
+
63
+ # Capture baseline invariants
64
+ self.baseline_checks = self._capture_invariants(model, adapter)
65
+
66
+ return {
67
+ "ready": True,
68
+ "baseline_checks": len(self.baseline_checks),
69
+ "strict_mode": self.strict_mode,
70
+ }
71
+
72
+ def before_edit(self, model: Any) -> None:
73
+ """Execute before edit (no action needed for invariants)."""
74
+ pass
75
+
76
+ def after_edit(self, model: Any) -> None:
77
+ """Execute after edit (no action needed for invariants)."""
78
+ pass
79
+
80
+ def validate(
81
+ self, model: Any, adapter: Any, context: dict[str, Any]
82
+ ) -> dict[str, Any]:
83
+ """
84
+ Validate model invariants (Guard ABC interface).
85
+
86
+ Args:
87
+ model: Model to validate
88
+ adapter: ModelAdapter instance
89
+ context: Validation context
90
+
91
+ Returns:
92
+ Dictionary with validation results
93
+ """
94
+ if not self.prepared:
95
+ # Auto-prepare if not already done
96
+ self.prepare(model, adapter, None, {})
97
+
98
+ outcome = self.finalize(model)
99
+
100
+ return {
101
+ "passed": outcome.passed,
102
+ "action": outcome.action,
103
+ "violations": outcome.violations,
104
+ "metrics": outcome.metrics,
105
+ }
106
+
107
+ def finalize(self, model: Any) -> GuardOutcome:
108
+ """
109
+ Finalize invariants guard by checking for violations.
110
+
111
+ Args:
112
+ model: Model to validate
113
+
114
+ Returns:
115
+ GuardOutcome with validation results
116
+ """
117
+ if not self.prepared:
118
+ return GuardOutcome(
119
+ name=self.name,
120
+ passed=False,
121
+ action="warn",
122
+ violations=[{"type": "not_prepared", "message": "Guard not prepared"}],
123
+ metrics={},
124
+ )
125
+
126
+ # Check current invariants
127
+ current_checks = self._capture_invariants(model, None)
128
+ violations: list[dict[str, Any]] = []
129
+ tokenizer_mismatches: list[dict[str, Any]] = []
130
+
131
+ # Non-finite detection
132
+ non_finite_locations = self._detect_non_finite(model)
133
+ if non_finite_locations:
134
+ violations.append(
135
+ {
136
+ "type": "non_finite_tensor",
137
+ "locations": non_finite_locations,
138
+ "message": "Non-finite parameter or buffer values detected",
139
+ }
140
+ )
141
+
142
+ # LayerNorm coverage check
143
+ baseline_layer_norms = set(self.baseline_checks.get("layer_norm_paths", ()))
144
+ current_layer_norms = set(current_checks.get("layer_norm_paths", ()))
145
+ missing_layer_norms = sorted(baseline_layer_norms - current_layer_norms)
146
+ if missing_layer_norms:
147
+ violations.append(
148
+ {
149
+ "type": "layer_norm_missing",
150
+ "missing": missing_layer_norms,
151
+ "message": "Expected LayerNorm modules are missing after edit",
152
+ }
153
+ )
154
+
155
+ # Tokenizer / vocab alignment
156
+ baseline_vocab_sizes = self.baseline_checks.get("embedding_vocab_sizes")
157
+ current_vocab_sizes = current_checks.get("embedding_vocab_sizes")
158
+ if isinstance(baseline_vocab_sizes, dict):
159
+ for module_name, baseline_size in baseline_vocab_sizes.items():
160
+ current_size = None
161
+ if isinstance(current_vocab_sizes, dict):
162
+ current_size = current_vocab_sizes.get(module_name)
163
+ if current_size is None or int(current_size) != int(baseline_size):
164
+ mismatch = {
165
+ "module": module_name,
166
+ "baseline": int(baseline_size),
167
+ "current": None if current_size is None else int(current_size),
168
+ }
169
+ tokenizer_mismatches.append(mismatch)
170
+ violations.append(
171
+ {
172
+ "type": "tokenizer_mismatch",
173
+ "message": "Embedding vocabulary size changed",
174
+ **mismatch,
175
+ }
176
+ )
177
+
178
+ # Compare remaining invariants with baseline
179
+ handled_keys = {
180
+ "layer_norm_paths",
181
+ "embedding_vocab_sizes",
182
+ "config_vocab_size",
183
+ }
184
+
185
+ for check_name, baseline_value in self.baseline_checks.items():
186
+ if check_name in handled_keys:
187
+ continue
188
+
189
+ current_value = current_checks.get(check_name)
190
+
191
+ if current_value != baseline_value:
192
+ violations.append(
193
+ {
194
+ "type": "invariant_violation",
195
+ "check": check_name,
196
+ "baseline": baseline_value,
197
+ "current": current_value,
198
+ "message": f"Invariant {check_name} changed from {baseline_value} to {current_value}",
199
+ }
200
+ )
201
+
202
+ # Classify violations by severity
203
+ fatal_violation_types = {"non_finite_tensor", "tokenizer_mismatch"}
204
+ if self.strict_mode:
205
+ fatal_violation_types.update({"layer_norm_missing", "invariant_violation"})
206
+
207
+ fatal_violations: list[dict[str, Any]] = []
208
+ warning_violations: list[dict[str, Any]] = []
209
+
210
+ for violation in violations:
211
+ violation_type = violation.get("type")
212
+ severity = "fatal" if violation_type in fatal_violation_types else "warning"
213
+ annotated = violation.copy()
214
+ annotated.setdefault("severity", severity)
215
+ if severity == "fatal":
216
+ fatal_violations.append(annotated)
217
+ else:
218
+ warning_violations.append(annotated)
219
+
220
+ annotated_violations = fatal_violations + warning_violations
221
+
222
+ # Determine if passed based on fatal violations and configured action
223
+ fatal_count = len(fatal_violations)
224
+ warning_count = len(warning_violations)
225
+
226
+ if fatal_count:
227
+ passed = False
228
+ if self.on_fail in {"abort", "rollback"}:
229
+ action = self.on_fail
230
+ else:
231
+ action = "abort"
232
+ elif warning_count:
233
+ if self.on_fail in {"abort", "rollback"}:
234
+ passed = False
235
+ action = self.on_fail
236
+ else:
237
+ passed = True
238
+ action = "warn"
239
+ else:
240
+ passed = True
241
+ action = "none"
242
+
243
+ metrics: dict[str, Any] = {
244
+ "checks_performed": len(self.baseline_checks),
245
+ "violations_found": len(annotated_violations),
246
+ "fatal_violations": fatal_count,
247
+ "warning_violations": warning_count,
248
+ }
249
+ if non_finite_locations:
250
+ metrics["non_finite_found"] = len(non_finite_locations)
251
+ if missing_layer_norms:
252
+ metrics["layer_norm_missing"] = missing_layer_norms
253
+ if tokenizer_mismatches:
254
+ metrics["tokenizer_mismatches"] = tokenizer_mismatches
255
+
256
+ return GuardOutcome(
257
+ name=self.name,
258
+ passed=passed,
259
+ action=action,
260
+ violations=annotated_violations,
261
+ metrics=metrics,
262
+ )
263
+
264
+ def _capture_invariants(self, model: Any, adapter: Any | None) -> dict[str, Any]:
265
+ """
266
+ Capture model invariants for comparison.
267
+
268
+ Args:
269
+ model: Model to analyze
270
+ adapter: ModelAdapter (optional)
271
+
272
+ Returns:
273
+ Dictionary of invariant checks
274
+ """
275
+ checks = {}
276
+
277
+ # Check parameter count
278
+ try:
279
+ param_count = sum(p.numel() for p in model.parameters())
280
+ checks["parameter_count"] = param_count
281
+ except Exception:
282
+ checks["parameter_count"] = -1
283
+
284
+ # Record LayerNorm module paths for later comparison
285
+ layer_norm_paths: list[str] = []
286
+ try:
287
+ for name, module in model.named_modules():
288
+ if isinstance(module, nn.LayerNorm):
289
+ layer_norm_paths.append(name)
290
+ except Exception:
291
+ layer_norm_paths = []
292
+ checks["layer_norm_paths"] = tuple(layer_norm_paths)
293
+
294
+ # Capture embedding vocab sizes (num_embeddings) for tokenizer alignment
295
+ embedding_vocab_sizes: dict[str, int] = {}
296
+ try:
297
+ for name, module in model.named_modules():
298
+ if isinstance(module, nn.Embedding):
299
+ try:
300
+ embedding_vocab_sizes[name] = int(module.num_embeddings)
301
+ except Exception:
302
+ weight = getattr(module, "weight", None)
303
+ if getattr(weight, "shape", None):
304
+ embedding_vocab_sizes[name] = int(weight.shape[0])
305
+ except Exception:
306
+ embedding_vocab_sizes = {}
307
+ if embedding_vocab_sizes:
308
+ checks["embedding_vocab_sizes"] = embedding_vocab_sizes
309
+
310
+ config = getattr(model, "config", None)
311
+ config_vocab = getattr(config, "vocab_size", None)
312
+ try:
313
+ if config_vocab is not None:
314
+ checks["config_vocab_size"] = int(config_vocab)
315
+ except Exception:
316
+ pass
317
+
318
+ # Check weight tying (for language models)
319
+ weight_tying_flags: dict[str, bool] = {}
320
+
321
+ def _is_tied(left: Any, right: Any) -> bool:
322
+ if left is None or right is None:
323
+ return False
324
+ try:
325
+ return left.data_ptr() == right.data_ptr()
326
+ except Exception:
327
+ return False
328
+
329
+ # GPT-2 style (transformer.wte <-> lm_head)
330
+ try:
331
+ transformer = getattr(model, "transformer", None)
332
+ lm_head = getattr(model, "lm_head", None)
333
+ embed_weight = getattr(getattr(transformer, "wte", None), "weight", None)
334
+ head_weight = getattr(lm_head, "weight", None)
335
+ if embed_weight is not None and head_weight is not None:
336
+ weight_tying_flags["gpt2"] = _is_tied(embed_weight, head_weight)
337
+ except Exception:
338
+ pass
339
+
340
+ # BERT style (bert.embeddings.word_embeddings <-> cls.predictions.decoder)
341
+ try:
342
+ bert = getattr(model, "bert", None)
343
+ embeddings = getattr(bert, "embeddings", None)
344
+ word_embeddings = getattr(embeddings, "word_embeddings", None)
345
+ decoder = getattr(
346
+ getattr(getattr(model, "cls", None), "predictions", None),
347
+ "decoder",
348
+ None,
349
+ )
350
+ embed_weight = getattr(word_embeddings, "weight", None)
351
+ decoder_weight = getattr(decoder, "weight", None)
352
+ if embed_weight is not None and decoder_weight is not None:
353
+ weight_tying_flags["bert"] = _is_tied(embed_weight, decoder_weight)
354
+ except Exception:
355
+ pass
356
+
357
+ # LLaMA style (model.embed_tokens <-> lm_head)
358
+ try:
359
+ llama_model = getattr(model, "model", None)
360
+ embed_tokens = getattr(llama_model, "embed_tokens", None)
361
+ embed_weight = getattr(embed_tokens, "weight", None)
362
+ llama_head_weight = getattr(getattr(model, "lm_head", None), "weight", None)
363
+ if embed_weight is not None and llama_head_weight is not None:
364
+ weight_tying_flags["llama"] = _is_tied(embed_weight, llama_head_weight)
365
+ except Exception:
366
+ pass
367
+
368
+ if weight_tying_flags:
369
+ checks["weight_tying"] = all(weight_tying_flags.values())
370
+ checks["weight_tying_arches"] = weight_tying_flags
371
+ else:
372
+ checks["weight_tying"] = None
373
+
374
+ # Check model structure hash (basic)
375
+ try:
376
+ structure_items = []
377
+ for name, module in model.named_modules():
378
+ structure_items.append(f"{name}:{type(module).__name__}")
379
+ structure_hash = hash(tuple(structure_items))
380
+ checks["structure_hash"] = structure_hash
381
+ except Exception:
382
+ checks["structure_hash"] = 0
383
+
384
+ # Profile-specific invariants
385
+ if getattr(self, "profile_checks", None):
386
+ for name in self.profile_checks:
387
+ checks[f"profile::{name}"] = self._evaluate_profile_check(model, name)
388
+
389
+ return checks
390
+
391
+ def _detect_non_finite(self, model: Any) -> list[str]:
392
+ """Detect parameters or buffers containing non-finite values."""
393
+ locations: list[str] = []
394
+ try:
395
+ for name, param in model.named_parameters():
396
+ try:
397
+ if not torch.isfinite(param).all():
398
+ locations.append(f"parameter::{name}")
399
+ except Exception:
400
+ continue
401
+ for name, buffer in model.named_buffers():
402
+ try:
403
+ if not torch.isfinite(buffer).all():
404
+ locations.append(f"buffer::{name}")
405
+ except Exception:
406
+ continue
407
+ except Exception:
408
+ return locations
409
+ return locations
410
+
411
+ def _evaluate_profile_check(self, model: Any, name: str) -> bool:
412
+ name = str(name).lower()
413
+
414
+ if name == "mlm_mask_alignment":
415
+ config = getattr(model, "config", None)
416
+ model_type = getattr(config, "model_type", "") if config else ""
417
+ has_cls_decoder = bool(
418
+ getattr(
419
+ getattr(getattr(model, "cls", None), "predictions", None),
420
+ "decoder",
421
+ None,
422
+ )
423
+ )
424
+ return "bert" in model_type or has_cls_decoder
425
+
426
+ if name in {"rope_rotary_embedding", "rotary_embedding"}:
427
+ # Detect rotary embeddings used by LLaMA-style models
428
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
429
+ first_layer = model.model.layers[0] if model.model.layers else None
430
+ else:
431
+ first_layer = None
432
+ rotary = None
433
+ if first_layer is not None:
434
+ rotary = getattr(
435
+ getattr(first_layer, "self_attn", None), "rotary_emb", None
436
+ )
437
+ return rotary is not None
438
+
439
+ if name in {"causal_masking", "causal"}:
440
+ config = getattr(model, "config", None)
441
+ if config and getattr(config, "is_decoder", False):
442
+ return True
443
+ model_type = getattr(config, "model_type", "") if config else ""
444
+ return any(
445
+ keyword in model_type
446
+ for keyword in ("gpt", "llama", "mistral", "opt", "phi")
447
+ )
448
+
449
+ return True
450
+
451
+
452
+ def check_adapter_aware_invariants(
453
+ model: Any, verbose: bool = False
454
+ ) -> tuple[bool, dict[str, Any]]:
455
+ """
456
+ Check model invariants with adapter awareness.
457
+
458
+ Args:
459
+ model: Model to check
460
+ verbose: Whether to print detailed information
461
+
462
+ Returns:
463
+ (all_passed, results) tuple
464
+ """
465
+ results: dict[str, Any] = {"adapter_type": "none", "checks": {}, "violations": []}
466
+ all_passed = True
467
+ # Standard model checks only
468
+ standard_checks: dict[str, dict[str, Any]] = _check_standard_invariants(model)
469
+ results["checks"].update(standard_checks)
470
+ for check_name, check_result in standard_checks.items():
471
+ if not check_result.get("passed", True):
472
+ all_passed = False
473
+ results["violations"].append(
474
+ {
475
+ "type": "standard_violation",
476
+ "check": check_name,
477
+ "message": check_result.get("message", "Check failed"),
478
+ }
479
+ )
480
+ return all_passed, results
481
+
482
+
483
+ def _detect_adapter_type(model: Any) -> str:
484
+ """Detect adapter type (disabled). Always returns 'none'."""
485
+ return "none"
486
+
487
+
488
+ def _check_standard_invariants(model: Any) -> dict[str, dict[str, Any]]:
489
+ """Check standard model invariants."""
490
+ checks: dict[str, dict[str, Any]] = {}
491
+
492
+ # Check parameter count is reasonable
493
+ try:
494
+ param_count = sum(p.numel() for p in model.parameters())
495
+ checks["parameter_count"] = {
496
+ "passed": param_count > 0,
497
+ "count": param_count,
498
+ "message": f"Parameter count: {param_count}",
499
+ }
500
+ except Exception as e:
501
+ checks["parameter_count"] = {
502
+ "passed": False,
503
+ "message": f"Could not count parameters: {e}",
504
+ }
505
+
506
+ # Check for NaN parameters
507
+ try:
508
+ has_nan = False
509
+ for param in model.parameters():
510
+ if hasattr(param, "isnan") and param.isnan().any():
511
+ has_nan = True
512
+ break
513
+
514
+ checks["no_nan_parameters"] = {
515
+ "passed": not has_nan,
516
+ "message": "NaN parameters detected" if has_nan else "No NaN parameters",
517
+ }
518
+ except Exception as e:
519
+ checks["no_nan_parameters"] = {
520
+ "passed": False,
521
+ "message": f"Could not check for NaN: {e}",
522
+ }
523
+
524
+ return checks
525
+
526
+
527
+ def check_all_invariants(model: Any, threshold: float = 1e-6) -> GuardOutcome:
528
+ """
529
+ Check all basic model invariants.
530
+
531
+ Args:
532
+ model: PyTorch model to check
533
+ threshold: Numerical threshold for invariant checks
534
+
535
+ Returns:
536
+ GuardOutcome: Result of invariant checking
537
+ """
538
+ violations = []
539
+
540
+ # Basic model structure checks
541
+ if not hasattr(model, "named_parameters"):
542
+ violations.append(
543
+ {
544
+ "type": "structure_violation",
545
+ "message": "Model missing named_parameters method",
546
+ }
547
+ )
548
+ return GuardOutcome(
549
+ name="check_all_invariants",
550
+ passed=False,
551
+ action="reject",
552
+ violations=violations,
553
+ metrics={},
554
+ )
555
+
556
+ # Check for NaN/Inf in parameters
557
+ for name, param in model.named_parameters():
558
+ if hasattr(param.data, "isnan") and param.data.isnan().any():
559
+ violations.append(
560
+ {
561
+ "type": "nan_violation",
562
+ "parameter": name,
563
+ "message": f"NaN detected in parameter {name}",
564
+ }
565
+ )
566
+ if hasattr(param.data, "isinf") and param.data.isinf().any():
567
+ violations.append(
568
+ {
569
+ "type": "inf_violation",
570
+ "parameter": name,
571
+ "message": f"Inf detected in parameter {name}",
572
+ }
573
+ )
574
+
575
+ # Check parameter ranges are reasonable
576
+ for name, param in model.named_parameters():
577
+ if hasattr(param.data, "abs") and hasattr(param.data, "max"):
578
+ max_val = param.data.abs().max()
579
+ if hasattr(max_val, "item"):
580
+ max_val = max_val.item()
581
+
582
+ if max_val > 1000:
583
+ violations.append(
584
+ {
585
+ "type": "range_violation",
586
+ "parameter": name,
587
+ "max_value": max_val,
588
+ "message": f"Parameter {name} has unusually large values (max: {max_val})",
589
+ }
590
+ )
591
+ if max_val < threshold:
592
+ violations.append(
593
+ {
594
+ "type": "range_violation",
595
+ "parameter": name,
596
+ "max_value": max_val,
597
+ "message": f"Parameter {name} has unusually small values (max: {max_val})",
598
+ }
599
+ )
600
+
601
+ passed = len(violations) == 0
602
+ action = "continue" if passed else "reject"
603
+
604
+ return GuardOutcome(
605
+ name="check_all_invariants",
606
+ passed=passed,
607
+ action=action,
608
+ violations=violations,
609
+ metrics={
610
+ "parameters_checked": sum(1 for _ in model.named_parameters()),
611
+ "violations_found": len(violations),
612
+ },
613
+ )
614
+
615
+
616
+ def assert_invariants(model: Any, threshold: float = 1e-6) -> None:
617
+ """
618
+ Assert that all model invariants hold, raising exception if not.
619
+
620
+ Args:
621
+ model: PyTorch model to check
622
+ threshold: Numerical threshold for invariant checks
623
+
624
+ Raises:
625
+ AssertionError: If any invariants are violated
626
+ """
627
+ result = check_all_invariants(model, threshold)
628
+ if not result.passed:
629
+ violation_messages = [v.get("message", str(v)) for v in result.violations or []]
630
+ raise AssertionError(
631
+ f"Model invariants violated: {'; '.join(violation_messages)}"
632
+ )
633
+
634
+
635
+ __all__ = [
636
+ "InvariantsGuard",
637
+ "check_adapter_aware_invariants",
638
+ "check_all_invariants",
639
+ "assert_invariants",
640
+ ]