invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,640 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Guards - Invariants
|
|
3
|
+
=========================
|
|
4
|
+
|
|
5
|
+
Invariant checking for model edits to ensure structural integrity.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
|
|
13
|
+
from invarlock.core.api import Guard
|
|
14
|
+
from invarlock.core.types import GuardOutcome
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InvariantsGuard(Guard):
|
|
18
|
+
"""
|
|
19
|
+
Guard for checking model invariants and structural integrity.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
name = "invariants"
|
|
23
|
+
|
|
24
|
+
def __init__(self, strict_mode: bool = False, on_fail: str = "warn"):
|
|
25
|
+
"""
|
|
26
|
+
Initialize invariants guard.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
strict_mode: Whether to use strict validation
|
|
30
|
+
on_fail: Action to take on failure ("warn", "rollback", "abort")
|
|
31
|
+
"""
|
|
32
|
+
self.strict_mode = strict_mode
|
|
33
|
+
self.on_fail = on_fail
|
|
34
|
+
self.prepared = False
|
|
35
|
+
self.baseline_checks: dict[str, Any] = {}
|
|
36
|
+
self.profile_checks: tuple[str, ...] = ()
|
|
37
|
+
|
|
38
|
+
def prepare(
|
|
39
|
+
self, model: Any, adapter: Any, calib: Any, policy: dict[str, Any]
|
|
40
|
+
) -> dict[str, Any]:
|
|
41
|
+
"""
|
|
42
|
+
Prepare invariants guard by capturing baseline state.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
model: Model to prepare for
|
|
46
|
+
adapter: ModelAdapter instance
|
|
47
|
+
calib: Calibration data (unused)
|
|
48
|
+
policy: Policy configuration
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Preparation results
|
|
52
|
+
"""
|
|
53
|
+
self.prepared = True
|
|
54
|
+
|
|
55
|
+
profile_checks = (
|
|
56
|
+
policy.get("profile_checks") if isinstance(policy, dict) else None
|
|
57
|
+
)
|
|
58
|
+
if isinstance(profile_checks, list | tuple | set):
|
|
59
|
+
self.profile_checks = tuple(str(check) for check in profile_checks)
|
|
60
|
+
else:
|
|
61
|
+
self.profile_checks = ()
|
|
62
|
+
|
|
63
|
+
# Capture baseline invariants
|
|
64
|
+
self.baseline_checks = self._capture_invariants(model, adapter)
|
|
65
|
+
|
|
66
|
+
return {
|
|
67
|
+
"ready": True,
|
|
68
|
+
"baseline_checks": len(self.baseline_checks),
|
|
69
|
+
"strict_mode": self.strict_mode,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
def before_edit(self, model: Any) -> None:
|
|
73
|
+
"""Execute before edit (no action needed for invariants)."""
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
def after_edit(self, model: Any) -> None:
|
|
77
|
+
"""Execute after edit (no action needed for invariants)."""
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
def validate(
|
|
81
|
+
self, model: Any, adapter: Any, context: dict[str, Any]
|
|
82
|
+
) -> dict[str, Any]:
|
|
83
|
+
"""
|
|
84
|
+
Validate model invariants (Guard ABC interface).
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model: Model to validate
|
|
88
|
+
adapter: ModelAdapter instance
|
|
89
|
+
context: Validation context
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Dictionary with validation results
|
|
93
|
+
"""
|
|
94
|
+
if not self.prepared:
|
|
95
|
+
# Auto-prepare if not already done
|
|
96
|
+
self.prepare(model, adapter, None, {})
|
|
97
|
+
|
|
98
|
+
outcome = self.finalize(model)
|
|
99
|
+
|
|
100
|
+
return {
|
|
101
|
+
"passed": outcome.passed,
|
|
102
|
+
"action": outcome.action,
|
|
103
|
+
"violations": outcome.violations,
|
|
104
|
+
"metrics": outcome.metrics,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def finalize(self, model: Any) -> GuardOutcome:
|
|
108
|
+
"""
|
|
109
|
+
Finalize invariants guard by checking for violations.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model: Model to validate
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
GuardOutcome with validation results
|
|
116
|
+
"""
|
|
117
|
+
if not self.prepared:
|
|
118
|
+
return GuardOutcome(
|
|
119
|
+
name=self.name,
|
|
120
|
+
passed=False,
|
|
121
|
+
action="warn",
|
|
122
|
+
violations=[{"type": "not_prepared", "message": "Guard not prepared"}],
|
|
123
|
+
metrics={},
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Check current invariants
|
|
127
|
+
current_checks = self._capture_invariants(model, None)
|
|
128
|
+
violations: list[dict[str, Any]] = []
|
|
129
|
+
tokenizer_mismatches: list[dict[str, Any]] = []
|
|
130
|
+
|
|
131
|
+
# Non-finite detection
|
|
132
|
+
non_finite_locations = self._detect_non_finite(model)
|
|
133
|
+
if non_finite_locations:
|
|
134
|
+
violations.append(
|
|
135
|
+
{
|
|
136
|
+
"type": "non_finite_tensor",
|
|
137
|
+
"locations": non_finite_locations,
|
|
138
|
+
"message": "Non-finite parameter or buffer values detected",
|
|
139
|
+
}
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# LayerNorm coverage check
|
|
143
|
+
baseline_layer_norms = set(self.baseline_checks.get("layer_norm_paths", ()))
|
|
144
|
+
current_layer_norms = set(current_checks.get("layer_norm_paths", ()))
|
|
145
|
+
missing_layer_norms = sorted(baseline_layer_norms - current_layer_norms)
|
|
146
|
+
if missing_layer_norms:
|
|
147
|
+
violations.append(
|
|
148
|
+
{
|
|
149
|
+
"type": "layer_norm_missing",
|
|
150
|
+
"missing": missing_layer_norms,
|
|
151
|
+
"message": "Expected LayerNorm modules are missing after edit",
|
|
152
|
+
}
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Tokenizer / vocab alignment
|
|
156
|
+
baseline_vocab_sizes = self.baseline_checks.get("embedding_vocab_sizes")
|
|
157
|
+
current_vocab_sizes = current_checks.get("embedding_vocab_sizes")
|
|
158
|
+
if isinstance(baseline_vocab_sizes, dict):
|
|
159
|
+
for module_name, baseline_size in baseline_vocab_sizes.items():
|
|
160
|
+
current_size = None
|
|
161
|
+
if isinstance(current_vocab_sizes, dict):
|
|
162
|
+
current_size = current_vocab_sizes.get(module_name)
|
|
163
|
+
if current_size is None or int(current_size) != int(baseline_size):
|
|
164
|
+
mismatch = {
|
|
165
|
+
"module": module_name,
|
|
166
|
+
"baseline": int(baseline_size),
|
|
167
|
+
"current": None if current_size is None else int(current_size),
|
|
168
|
+
}
|
|
169
|
+
tokenizer_mismatches.append(mismatch)
|
|
170
|
+
violations.append(
|
|
171
|
+
{
|
|
172
|
+
"type": "tokenizer_mismatch",
|
|
173
|
+
"message": "Embedding vocabulary size changed",
|
|
174
|
+
**mismatch,
|
|
175
|
+
}
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Compare remaining invariants with baseline
|
|
179
|
+
handled_keys = {
|
|
180
|
+
"layer_norm_paths",
|
|
181
|
+
"embedding_vocab_sizes",
|
|
182
|
+
"config_vocab_size",
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
for check_name, baseline_value in self.baseline_checks.items():
|
|
186
|
+
if check_name in handled_keys:
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
current_value = current_checks.get(check_name)
|
|
190
|
+
|
|
191
|
+
if current_value != baseline_value:
|
|
192
|
+
violations.append(
|
|
193
|
+
{
|
|
194
|
+
"type": "invariant_violation",
|
|
195
|
+
"check": check_name,
|
|
196
|
+
"baseline": baseline_value,
|
|
197
|
+
"current": current_value,
|
|
198
|
+
"message": f"Invariant {check_name} changed from {baseline_value} to {current_value}",
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Classify violations by severity
|
|
203
|
+
fatal_violation_types = {"non_finite_tensor", "tokenizer_mismatch"}
|
|
204
|
+
if self.strict_mode:
|
|
205
|
+
fatal_violation_types.update({"layer_norm_missing", "invariant_violation"})
|
|
206
|
+
|
|
207
|
+
fatal_violations: list[dict[str, Any]] = []
|
|
208
|
+
warning_violations: list[dict[str, Any]] = []
|
|
209
|
+
|
|
210
|
+
for violation in violations:
|
|
211
|
+
violation_type = violation.get("type")
|
|
212
|
+
severity = "fatal" if violation_type in fatal_violation_types else "warning"
|
|
213
|
+
annotated = violation.copy()
|
|
214
|
+
annotated.setdefault("severity", severity)
|
|
215
|
+
if severity == "fatal":
|
|
216
|
+
fatal_violations.append(annotated)
|
|
217
|
+
else:
|
|
218
|
+
warning_violations.append(annotated)
|
|
219
|
+
|
|
220
|
+
annotated_violations = fatal_violations + warning_violations
|
|
221
|
+
|
|
222
|
+
# Determine if passed based on fatal violations and configured action
|
|
223
|
+
fatal_count = len(fatal_violations)
|
|
224
|
+
warning_count = len(warning_violations)
|
|
225
|
+
|
|
226
|
+
if fatal_count:
|
|
227
|
+
passed = False
|
|
228
|
+
if self.on_fail in {"abort", "rollback"}:
|
|
229
|
+
action = self.on_fail
|
|
230
|
+
else:
|
|
231
|
+
action = "abort"
|
|
232
|
+
elif warning_count:
|
|
233
|
+
if self.on_fail in {"abort", "rollback"}:
|
|
234
|
+
passed = False
|
|
235
|
+
action = self.on_fail
|
|
236
|
+
else:
|
|
237
|
+
passed = True
|
|
238
|
+
action = "warn"
|
|
239
|
+
else:
|
|
240
|
+
passed = True
|
|
241
|
+
action = "none"
|
|
242
|
+
|
|
243
|
+
metrics: dict[str, Any] = {
|
|
244
|
+
"checks_performed": len(self.baseline_checks),
|
|
245
|
+
"violations_found": len(annotated_violations),
|
|
246
|
+
"fatal_violations": fatal_count,
|
|
247
|
+
"warning_violations": warning_count,
|
|
248
|
+
}
|
|
249
|
+
if non_finite_locations:
|
|
250
|
+
metrics["non_finite_found"] = len(non_finite_locations)
|
|
251
|
+
if missing_layer_norms:
|
|
252
|
+
metrics["layer_norm_missing"] = missing_layer_norms
|
|
253
|
+
if tokenizer_mismatches:
|
|
254
|
+
metrics["tokenizer_mismatches"] = tokenizer_mismatches
|
|
255
|
+
|
|
256
|
+
return GuardOutcome(
|
|
257
|
+
name=self.name,
|
|
258
|
+
passed=passed,
|
|
259
|
+
action=action,
|
|
260
|
+
violations=annotated_violations,
|
|
261
|
+
metrics=metrics,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def _capture_invariants(self, model: Any, adapter: Any | None) -> dict[str, Any]:
|
|
265
|
+
"""
|
|
266
|
+
Capture model invariants for comparison.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
model: Model to analyze
|
|
270
|
+
adapter: ModelAdapter (optional)
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Dictionary of invariant checks
|
|
274
|
+
"""
|
|
275
|
+
checks = {}
|
|
276
|
+
|
|
277
|
+
# Check parameter count
|
|
278
|
+
try:
|
|
279
|
+
param_count = sum(p.numel() for p in model.parameters())
|
|
280
|
+
checks["parameter_count"] = param_count
|
|
281
|
+
except Exception:
|
|
282
|
+
checks["parameter_count"] = -1
|
|
283
|
+
|
|
284
|
+
# Record LayerNorm module paths for later comparison
|
|
285
|
+
layer_norm_paths: list[str] = []
|
|
286
|
+
try:
|
|
287
|
+
for name, module in model.named_modules():
|
|
288
|
+
if isinstance(module, nn.LayerNorm):
|
|
289
|
+
layer_norm_paths.append(name)
|
|
290
|
+
except Exception:
|
|
291
|
+
layer_norm_paths = []
|
|
292
|
+
checks["layer_norm_paths"] = tuple(layer_norm_paths)
|
|
293
|
+
|
|
294
|
+
# Capture embedding vocab sizes (num_embeddings) for tokenizer alignment
|
|
295
|
+
embedding_vocab_sizes: dict[str, int] = {}
|
|
296
|
+
try:
|
|
297
|
+
for name, module in model.named_modules():
|
|
298
|
+
if isinstance(module, nn.Embedding):
|
|
299
|
+
try:
|
|
300
|
+
embedding_vocab_sizes[name] = int(module.num_embeddings)
|
|
301
|
+
except Exception:
|
|
302
|
+
weight = getattr(module, "weight", None)
|
|
303
|
+
if getattr(weight, "shape", None):
|
|
304
|
+
embedding_vocab_sizes[name] = int(weight.shape[0])
|
|
305
|
+
except Exception:
|
|
306
|
+
embedding_vocab_sizes = {}
|
|
307
|
+
if embedding_vocab_sizes:
|
|
308
|
+
checks["embedding_vocab_sizes"] = embedding_vocab_sizes
|
|
309
|
+
|
|
310
|
+
config = getattr(model, "config", None)
|
|
311
|
+
config_vocab = getattr(config, "vocab_size", None)
|
|
312
|
+
try:
|
|
313
|
+
if config_vocab is not None:
|
|
314
|
+
checks["config_vocab_size"] = int(config_vocab)
|
|
315
|
+
except Exception:
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
# Check weight tying (for language models)
|
|
319
|
+
weight_tying_flags: dict[str, bool] = {}
|
|
320
|
+
|
|
321
|
+
def _is_tied(left: Any, right: Any) -> bool:
|
|
322
|
+
if left is None or right is None:
|
|
323
|
+
return False
|
|
324
|
+
try:
|
|
325
|
+
return left.data_ptr() == right.data_ptr()
|
|
326
|
+
except Exception:
|
|
327
|
+
return False
|
|
328
|
+
|
|
329
|
+
# GPT-2 style (transformer.wte <-> lm_head)
|
|
330
|
+
try:
|
|
331
|
+
transformer = getattr(model, "transformer", None)
|
|
332
|
+
lm_head = getattr(model, "lm_head", None)
|
|
333
|
+
embed_weight = getattr(getattr(transformer, "wte", None), "weight", None)
|
|
334
|
+
head_weight = getattr(lm_head, "weight", None)
|
|
335
|
+
if embed_weight is not None and head_weight is not None:
|
|
336
|
+
weight_tying_flags["gpt2"] = _is_tied(embed_weight, head_weight)
|
|
337
|
+
except Exception:
|
|
338
|
+
pass
|
|
339
|
+
|
|
340
|
+
# BERT style (bert.embeddings.word_embeddings <-> cls.predictions.decoder)
|
|
341
|
+
try:
|
|
342
|
+
bert = getattr(model, "bert", None)
|
|
343
|
+
embeddings = getattr(bert, "embeddings", None)
|
|
344
|
+
word_embeddings = getattr(embeddings, "word_embeddings", None)
|
|
345
|
+
decoder = getattr(
|
|
346
|
+
getattr(getattr(model, "cls", None), "predictions", None),
|
|
347
|
+
"decoder",
|
|
348
|
+
None,
|
|
349
|
+
)
|
|
350
|
+
embed_weight = getattr(word_embeddings, "weight", None)
|
|
351
|
+
decoder_weight = getattr(decoder, "weight", None)
|
|
352
|
+
if embed_weight is not None and decoder_weight is not None:
|
|
353
|
+
weight_tying_flags["bert"] = _is_tied(embed_weight, decoder_weight)
|
|
354
|
+
except Exception:
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
# LLaMA style (model.embed_tokens <-> lm_head)
|
|
358
|
+
try:
|
|
359
|
+
llama_model = getattr(model, "model", None)
|
|
360
|
+
embed_tokens = getattr(llama_model, "embed_tokens", None)
|
|
361
|
+
embed_weight = getattr(embed_tokens, "weight", None)
|
|
362
|
+
llama_head_weight = getattr(getattr(model, "lm_head", None), "weight", None)
|
|
363
|
+
if embed_weight is not None and llama_head_weight is not None:
|
|
364
|
+
weight_tying_flags["llama"] = _is_tied(embed_weight, llama_head_weight)
|
|
365
|
+
except Exception:
|
|
366
|
+
pass
|
|
367
|
+
|
|
368
|
+
if weight_tying_flags:
|
|
369
|
+
checks["weight_tying"] = all(weight_tying_flags.values())
|
|
370
|
+
checks["weight_tying_arches"] = weight_tying_flags
|
|
371
|
+
else:
|
|
372
|
+
checks["weight_tying"] = None
|
|
373
|
+
|
|
374
|
+
# Check model structure hash (basic)
|
|
375
|
+
try:
|
|
376
|
+
structure_items = []
|
|
377
|
+
for name, module in model.named_modules():
|
|
378
|
+
structure_items.append(f"{name}:{type(module).__name__}")
|
|
379
|
+
structure_hash = hash(tuple(structure_items))
|
|
380
|
+
checks["structure_hash"] = structure_hash
|
|
381
|
+
except Exception:
|
|
382
|
+
checks["structure_hash"] = 0
|
|
383
|
+
|
|
384
|
+
# Profile-specific invariants
|
|
385
|
+
if getattr(self, "profile_checks", None):
|
|
386
|
+
for name in self.profile_checks:
|
|
387
|
+
checks[f"profile::{name}"] = self._evaluate_profile_check(model, name)
|
|
388
|
+
|
|
389
|
+
return checks
|
|
390
|
+
|
|
391
|
+
def _detect_non_finite(self, model: Any) -> list[str]:
|
|
392
|
+
"""Detect parameters or buffers containing non-finite values."""
|
|
393
|
+
locations: list[str] = []
|
|
394
|
+
try:
|
|
395
|
+
for name, param in model.named_parameters():
|
|
396
|
+
try:
|
|
397
|
+
if not torch.isfinite(param).all():
|
|
398
|
+
locations.append(f"parameter::{name}")
|
|
399
|
+
except Exception:
|
|
400
|
+
continue
|
|
401
|
+
for name, buffer in model.named_buffers():
|
|
402
|
+
try:
|
|
403
|
+
if not torch.isfinite(buffer).all():
|
|
404
|
+
locations.append(f"buffer::{name}")
|
|
405
|
+
except Exception:
|
|
406
|
+
continue
|
|
407
|
+
except Exception:
|
|
408
|
+
return locations
|
|
409
|
+
return locations
|
|
410
|
+
|
|
411
|
+
def _evaluate_profile_check(self, model: Any, name: str) -> bool:
|
|
412
|
+
name = str(name).lower()
|
|
413
|
+
|
|
414
|
+
if name == "mlm_mask_alignment":
|
|
415
|
+
config = getattr(model, "config", None)
|
|
416
|
+
model_type = getattr(config, "model_type", "") if config else ""
|
|
417
|
+
has_cls_decoder = bool(
|
|
418
|
+
getattr(
|
|
419
|
+
getattr(getattr(model, "cls", None), "predictions", None),
|
|
420
|
+
"decoder",
|
|
421
|
+
None,
|
|
422
|
+
)
|
|
423
|
+
)
|
|
424
|
+
return "bert" in model_type or has_cls_decoder
|
|
425
|
+
|
|
426
|
+
if name in {"rope_rotary_embedding", "rotary_embedding"}:
|
|
427
|
+
# Detect rotary embeddings used by LLaMA-style models
|
|
428
|
+
if hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
429
|
+
first_layer = model.model.layers[0] if model.model.layers else None
|
|
430
|
+
else:
|
|
431
|
+
first_layer = None
|
|
432
|
+
rotary = None
|
|
433
|
+
if first_layer is not None:
|
|
434
|
+
rotary = getattr(
|
|
435
|
+
getattr(first_layer, "self_attn", None), "rotary_emb", None
|
|
436
|
+
)
|
|
437
|
+
return rotary is not None
|
|
438
|
+
|
|
439
|
+
if name in {"causal_masking", "causal"}:
|
|
440
|
+
config = getattr(model, "config", None)
|
|
441
|
+
if config and getattr(config, "is_decoder", False):
|
|
442
|
+
return True
|
|
443
|
+
model_type = getattr(config, "model_type", "") if config else ""
|
|
444
|
+
return any(
|
|
445
|
+
keyword in model_type
|
|
446
|
+
for keyword in ("gpt", "llama", "mistral", "opt", "phi")
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
return True
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def check_adapter_aware_invariants(
|
|
453
|
+
model: Any, verbose: bool = False
|
|
454
|
+
) -> tuple[bool, dict[str, Any]]:
|
|
455
|
+
"""
|
|
456
|
+
Check model invariants with adapter awareness.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
model: Model to check
|
|
460
|
+
verbose: Whether to print detailed information
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
(all_passed, results) tuple
|
|
464
|
+
"""
|
|
465
|
+
results: dict[str, Any] = {"adapter_type": "none", "checks": {}, "violations": []}
|
|
466
|
+
all_passed = True
|
|
467
|
+
# Standard model checks only
|
|
468
|
+
standard_checks: dict[str, dict[str, Any]] = _check_standard_invariants(model)
|
|
469
|
+
results["checks"].update(standard_checks)
|
|
470
|
+
for check_name, check_result in standard_checks.items():
|
|
471
|
+
if not check_result.get("passed", True):
|
|
472
|
+
all_passed = False
|
|
473
|
+
results["violations"].append(
|
|
474
|
+
{
|
|
475
|
+
"type": "standard_violation",
|
|
476
|
+
"check": check_name,
|
|
477
|
+
"message": check_result.get("message", "Check failed"),
|
|
478
|
+
}
|
|
479
|
+
)
|
|
480
|
+
return all_passed, results
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _detect_adapter_type(model: Any) -> str:
|
|
484
|
+
"""Detect adapter type (disabled). Always returns 'none'."""
|
|
485
|
+
return "none"
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _check_standard_invariants(model: Any) -> dict[str, dict[str, Any]]:
|
|
489
|
+
"""Check standard model invariants."""
|
|
490
|
+
checks: dict[str, dict[str, Any]] = {}
|
|
491
|
+
|
|
492
|
+
# Check parameter count is reasonable
|
|
493
|
+
try:
|
|
494
|
+
param_count = sum(p.numel() for p in model.parameters())
|
|
495
|
+
checks["parameter_count"] = {
|
|
496
|
+
"passed": param_count > 0,
|
|
497
|
+
"count": param_count,
|
|
498
|
+
"message": f"Parameter count: {param_count}",
|
|
499
|
+
}
|
|
500
|
+
except Exception as e:
|
|
501
|
+
checks["parameter_count"] = {
|
|
502
|
+
"passed": False,
|
|
503
|
+
"message": f"Could not count parameters: {e}",
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
# Check for NaN parameters
|
|
507
|
+
try:
|
|
508
|
+
has_nan = False
|
|
509
|
+
for param in model.parameters():
|
|
510
|
+
if hasattr(param, "isnan") and param.isnan().any():
|
|
511
|
+
has_nan = True
|
|
512
|
+
break
|
|
513
|
+
|
|
514
|
+
checks["no_nan_parameters"] = {
|
|
515
|
+
"passed": not has_nan,
|
|
516
|
+
"message": "NaN parameters detected" if has_nan else "No NaN parameters",
|
|
517
|
+
}
|
|
518
|
+
except Exception as e:
|
|
519
|
+
checks["no_nan_parameters"] = {
|
|
520
|
+
"passed": False,
|
|
521
|
+
"message": f"Could not check for NaN: {e}",
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
return checks
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def check_all_invariants(model: Any, threshold: float = 1e-6) -> GuardOutcome:
|
|
528
|
+
"""
|
|
529
|
+
Check all basic model invariants.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
model: PyTorch model to check
|
|
533
|
+
threshold: Numerical threshold for invariant checks
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
GuardOutcome: Result of invariant checking
|
|
537
|
+
"""
|
|
538
|
+
violations = []
|
|
539
|
+
|
|
540
|
+
# Basic model structure checks
|
|
541
|
+
if not hasattr(model, "named_parameters"):
|
|
542
|
+
violations.append(
|
|
543
|
+
{
|
|
544
|
+
"type": "structure_violation",
|
|
545
|
+
"message": "Model missing named_parameters method",
|
|
546
|
+
}
|
|
547
|
+
)
|
|
548
|
+
return GuardOutcome(
|
|
549
|
+
name="check_all_invariants",
|
|
550
|
+
passed=False,
|
|
551
|
+
action="reject",
|
|
552
|
+
violations=violations,
|
|
553
|
+
metrics={},
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# Check for NaN/Inf in parameters
|
|
557
|
+
for name, param in model.named_parameters():
|
|
558
|
+
if hasattr(param.data, "isnan") and param.data.isnan().any():
|
|
559
|
+
violations.append(
|
|
560
|
+
{
|
|
561
|
+
"type": "nan_violation",
|
|
562
|
+
"parameter": name,
|
|
563
|
+
"message": f"NaN detected in parameter {name}",
|
|
564
|
+
}
|
|
565
|
+
)
|
|
566
|
+
if hasattr(param.data, "isinf") and param.data.isinf().any():
|
|
567
|
+
violations.append(
|
|
568
|
+
{
|
|
569
|
+
"type": "inf_violation",
|
|
570
|
+
"parameter": name,
|
|
571
|
+
"message": f"Inf detected in parameter {name}",
|
|
572
|
+
}
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Check parameter ranges are reasonable
|
|
576
|
+
for name, param in model.named_parameters():
|
|
577
|
+
if hasattr(param.data, "abs") and hasattr(param.data, "max"):
|
|
578
|
+
max_val = param.data.abs().max()
|
|
579
|
+
if hasattr(max_val, "item"):
|
|
580
|
+
max_val = max_val.item()
|
|
581
|
+
|
|
582
|
+
if max_val > 1000:
|
|
583
|
+
violations.append(
|
|
584
|
+
{
|
|
585
|
+
"type": "range_violation",
|
|
586
|
+
"parameter": name,
|
|
587
|
+
"max_value": max_val,
|
|
588
|
+
"message": f"Parameter {name} has unusually large values (max: {max_val})",
|
|
589
|
+
}
|
|
590
|
+
)
|
|
591
|
+
if max_val < threshold:
|
|
592
|
+
violations.append(
|
|
593
|
+
{
|
|
594
|
+
"type": "range_violation",
|
|
595
|
+
"parameter": name,
|
|
596
|
+
"max_value": max_val,
|
|
597
|
+
"message": f"Parameter {name} has unusually small values (max: {max_val})",
|
|
598
|
+
}
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
passed = len(violations) == 0
|
|
602
|
+
action = "continue" if passed else "reject"
|
|
603
|
+
|
|
604
|
+
return GuardOutcome(
|
|
605
|
+
name="check_all_invariants",
|
|
606
|
+
passed=passed,
|
|
607
|
+
action=action,
|
|
608
|
+
violations=violations,
|
|
609
|
+
metrics={
|
|
610
|
+
"parameters_checked": sum(1 for _ in model.named_parameters()),
|
|
611
|
+
"violations_found": len(violations),
|
|
612
|
+
},
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def assert_invariants(model: Any, threshold: float = 1e-6) -> None:
|
|
617
|
+
"""
|
|
618
|
+
Assert that all model invariants hold, raising exception if not.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
model: PyTorch model to check
|
|
622
|
+
threshold: Numerical threshold for invariant checks
|
|
623
|
+
|
|
624
|
+
Raises:
|
|
625
|
+
AssertionError: If any invariants are violated
|
|
626
|
+
"""
|
|
627
|
+
result = check_all_invariants(model, threshold)
|
|
628
|
+
if not result.passed:
|
|
629
|
+
violation_messages = [v.get("message", str(v)) for v in result.violations or []]
|
|
630
|
+
raise AssertionError(
|
|
631
|
+
f"Model invariants violated: {'; '.join(violation_messages)}"
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
|
|
635
|
+
__all__ = [
|
|
636
|
+
"InvariantsGuard",
|
|
637
|
+
"check_adapter_aware_invariants",
|
|
638
|
+
"check_all_invariants",
|
|
639
|
+
"assert_invariants",
|
|
640
|
+
]
|