invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HuggingFace GPT-2 Model Adapter
|
|
3
|
+
===============================
|
|
4
|
+
|
|
5
|
+
ModelAdapter implementation for HuggingFace GPT-2 architecture models.
|
|
6
|
+
|
|
7
|
+
This adapter provides enhanced HuggingFace integration including:
|
|
8
|
+
- Better model detection for HF model variants
|
|
9
|
+
- Proper handling of transformers library specifics
|
|
10
|
+
- Device-aware state serialization with HF model handling
|
|
11
|
+
- Weight tying preservation (lm_head ↔ wte)
|
|
12
|
+
- Split size and layer naming convention support
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from types import SimpleNamespace
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
|
|
22
|
+
from invarlock.core.api import ModelAdapter
|
|
23
|
+
from invarlock.core.error_utils import wrap_errors
|
|
24
|
+
from invarlock.core.exceptions import AdapterError, DependencyError, ModelLoadError
|
|
25
|
+
|
|
26
|
+
from .hf_mixin import HFAdapterMixin
|
|
27
|
+
|
|
28
|
+
LIGHT_IMPORT = os.getenv("INVARLOCK_LIGHT_IMPORT", "").strip().lower() in {
|
|
29
|
+
"1",
|
|
30
|
+
"true",
|
|
31
|
+
"yes",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
TensorType = torch.Tensor
|
|
35
|
+
ModuleType = nn.Module
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class HF_GPT2_Adapter(HFAdapterMixin, ModelAdapter):
|
|
39
|
+
"""
|
|
40
|
+
HuggingFace-specific ModelAdapter implementation for GPT-2 models.
|
|
41
|
+
|
|
42
|
+
Supports HuggingFace GPT2Model and GPT2LMHeadModel variants with:
|
|
43
|
+
- Enhanced HF model detection and validation
|
|
44
|
+
- Device-aware state serialization
|
|
45
|
+
- Weight tying preservation across snapshot/restore cycles
|
|
46
|
+
- Proper handling of Conv1D layers and split_size conventions
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
name = "hf_gpt2"
|
|
50
|
+
|
|
51
|
+
def load_model(self, model_id: str, device: str = "auto") -> ModuleType | Any:
|
|
52
|
+
"""
|
|
53
|
+
Load a HuggingFace GPT-2 model.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
model_id: Model identifier (e.g. "gpt2", "gpt2-medium")
|
|
57
|
+
device: Target device ("auto", "cuda", "mps", "cpu")
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Loaded GPT-2 model
|
|
61
|
+
"""
|
|
62
|
+
# Lazy import to allow dependency mapping; in light-import mode fall back to a stub
|
|
63
|
+
try:
|
|
64
|
+
with wrap_errors(
|
|
65
|
+
DependencyError,
|
|
66
|
+
"E203",
|
|
67
|
+
"DEPENDENCY-MISSING: transformers",
|
|
68
|
+
lambda e: {"dependency": "transformers"},
|
|
69
|
+
):
|
|
70
|
+
from transformers import AutoModelForCausalLM # type: ignore
|
|
71
|
+
|
|
72
|
+
with wrap_errors(
|
|
73
|
+
ModelLoadError,
|
|
74
|
+
"E201",
|
|
75
|
+
"MODEL-LOAD-FAILED: transformers AutoModelForCausalLM",
|
|
76
|
+
lambda e: {"model_id": model_id},
|
|
77
|
+
):
|
|
78
|
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
79
|
+
|
|
80
|
+
target_device = self._resolve_device(device)
|
|
81
|
+
return model.to(target_device)
|
|
82
|
+
except DependencyError:
|
|
83
|
+
if LIGHT_IMPORT:
|
|
84
|
+
# Minimal stand-in that satisfies downstream interface requirements
|
|
85
|
+
stub = SimpleNamespace(name="hf_gpt2_stub")
|
|
86
|
+
stub.to = lambda *_a, **_k: stub # type: ignore[attr-defined]
|
|
87
|
+
return stub
|
|
88
|
+
raise
|
|
89
|
+
|
|
90
|
+
def can_handle(self, model: ModuleType | Any) -> bool:
|
|
91
|
+
"""
|
|
92
|
+
Check if this adapter can handle the given model.
|
|
93
|
+
|
|
94
|
+
Enhanced detection for HuggingFace GPT-2 models with validation
|
|
95
|
+
of expected structure and configuration.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
model: The model to check
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if this is a HuggingFace GPT-2 compatible model
|
|
102
|
+
"""
|
|
103
|
+
# Check for HuggingFace GPT-2 class names (avoid importing classes at module import time)
|
|
104
|
+
model_name = model.__class__.__name__
|
|
105
|
+
if model_name in ["GPT2Model", "GPT2LMHeadModel"]:
|
|
106
|
+
# Verify it has HF config
|
|
107
|
+
if hasattr(model, "config") and hasattr(model.config, "model_type"):
|
|
108
|
+
return model.config.model_type == "gpt2"
|
|
109
|
+
|
|
110
|
+
# Structural validation for GPT-2-like models
|
|
111
|
+
if hasattr(model, "config") and hasattr(model, "transformer"):
|
|
112
|
+
config = model.config
|
|
113
|
+
transformer = model.transformer
|
|
114
|
+
|
|
115
|
+
# Check for GPT-2 configuration attributes
|
|
116
|
+
if (
|
|
117
|
+
hasattr(config, "n_layer")
|
|
118
|
+
and hasattr(config, "n_head")
|
|
119
|
+
and hasattr(config, "hidden_size")
|
|
120
|
+
and hasattr(transformer, "h")
|
|
121
|
+
):
|
|
122
|
+
# Validate transformer structure
|
|
123
|
+
try:
|
|
124
|
+
h_layers = transformer.h
|
|
125
|
+
if hasattr(h_layers, "__len__") and len(h_layers) > 0:
|
|
126
|
+
layer = h_layers[0]
|
|
127
|
+
elif hasattr(h_layers, "__iter__"):
|
|
128
|
+
# Handle iterables without len() (like Mock objects in tests)
|
|
129
|
+
try:
|
|
130
|
+
layer = next(iter(h_layers))
|
|
131
|
+
except (StopIteration, TypeError):
|
|
132
|
+
return False
|
|
133
|
+
else:
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
# Check for GPT-2 layer structure with HF conventions
|
|
137
|
+
if (
|
|
138
|
+
hasattr(layer, "attn")
|
|
139
|
+
and hasattr(layer, "mlp")
|
|
140
|
+
and hasattr(layer.attn, "c_attn")
|
|
141
|
+
and hasattr(layer.attn, "c_proj")
|
|
142
|
+
and hasattr(layer.mlp, "c_fc")
|
|
143
|
+
and hasattr(layer.mlp, "c_proj")
|
|
144
|
+
):
|
|
145
|
+
return True
|
|
146
|
+
except (AttributeError, TypeError):
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
# Check for bare GPT2Model structure (less common but possible)
|
|
150
|
+
if hasattr(model, "h") and hasattr(model, "config"):
|
|
151
|
+
if hasattr(model.config, "n_layer") and len(model.h) > 0:
|
|
152
|
+
layer = model.h[0]
|
|
153
|
+
if (
|
|
154
|
+
hasattr(layer, "attn")
|
|
155
|
+
and hasattr(layer, "mlp")
|
|
156
|
+
and hasattr(layer.attn, "c_attn")
|
|
157
|
+
and hasattr(layer.mlp, "c_fc")
|
|
158
|
+
):
|
|
159
|
+
return True
|
|
160
|
+
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def describe(self, model: ModuleType | Any) -> dict[str, Any]:
|
|
164
|
+
"""
|
|
165
|
+
Get structural description of the HuggingFace GPT-2 model.
|
|
166
|
+
|
|
167
|
+
Returns the required format for validation gates:
|
|
168
|
+
- n_layer: int
|
|
169
|
+
- heads_per_layer: List[int]
|
|
170
|
+
- mlp_dims: List[int]
|
|
171
|
+
- tying: Dict[str, str] (weight tying map)
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
model: The HuggingFace GPT-2 model to describe
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Dictionary with model structure info in required format
|
|
178
|
+
"""
|
|
179
|
+
# Determine model structure
|
|
180
|
+
if hasattr(model, "transformer"):
|
|
181
|
+
# GPT2LMHeadModel structure
|
|
182
|
+
transformer = model.transformer
|
|
183
|
+
layers = transformer.h
|
|
184
|
+
config = model.config
|
|
185
|
+
elif hasattr(model, "h"):
|
|
186
|
+
# Direct GPT2Model structure
|
|
187
|
+
layers = model.h
|
|
188
|
+
config = model.config
|
|
189
|
+
transformer = model
|
|
190
|
+
else:
|
|
191
|
+
raise AdapterError(
|
|
192
|
+
code="E202",
|
|
193
|
+
message=(
|
|
194
|
+
"ADAPTER-STRUCTURE-INVALID: unrecognized HuggingFace GPT-2 model structure"
|
|
195
|
+
),
|
|
196
|
+
details={"model_class": model.__class__.__name__},
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Extract basic configuration
|
|
200
|
+
n_layers = len(layers)
|
|
201
|
+
n_heads = getattr(
|
|
202
|
+
config, "n_head", getattr(config, "num_attention_heads", None)
|
|
203
|
+
)
|
|
204
|
+
hidden_size = getattr(config, "hidden_size", getattr(config, "d_model", None))
|
|
205
|
+
vocab_size = getattr(config, "vocab_size", None)
|
|
206
|
+
|
|
207
|
+
if n_heads is None or hidden_size is None:
|
|
208
|
+
raise AdapterError(
|
|
209
|
+
code="E202",
|
|
210
|
+
message=(
|
|
211
|
+
"ADAPTER-STRUCTURE-INVALID: missing n_heads or hidden_size in config"
|
|
212
|
+
),
|
|
213
|
+
details={"model_class": model.__class__.__name__},
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Get device info
|
|
217
|
+
try:
|
|
218
|
+
device = next(model.parameters()).device
|
|
219
|
+
except StopIteration:
|
|
220
|
+
device = torch.device("cpu")
|
|
221
|
+
|
|
222
|
+
# Calculate total parameters
|
|
223
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
224
|
+
|
|
225
|
+
# Get MLP dimensions for each layer
|
|
226
|
+
mlp_dims = []
|
|
227
|
+
heads_per_layer = []
|
|
228
|
+
|
|
229
|
+
for layer_idx in range(n_layers):
|
|
230
|
+
layer = layers[layer_idx]
|
|
231
|
+
|
|
232
|
+
# For GPT-2, all layers have the same head count
|
|
233
|
+
heads_per_layer.append(n_heads)
|
|
234
|
+
|
|
235
|
+
# Get MLP intermediate dimension
|
|
236
|
+
# HuggingFace GPT-2 uses Conv1D layers where weight shape is (in_features, out_features)
|
|
237
|
+
if hasattr(layer.mlp.c_fc, "weight"):
|
|
238
|
+
if hasattr(layer.mlp.c_fc, "nf"): # Conv1D layer
|
|
239
|
+
mlp_dim = layer.mlp.c_fc.nf # out_features for Conv1D
|
|
240
|
+
else:
|
|
241
|
+
# Regular linear layer: (out_features, in_features)
|
|
242
|
+
mlp_dim = layer.mlp.c_fc.weight.shape[0]
|
|
243
|
+
else:
|
|
244
|
+
# Fallback to config
|
|
245
|
+
mlp_dim = getattr(config, "n_inner", hidden_size * 4)
|
|
246
|
+
|
|
247
|
+
mlp_dims.append(mlp_dim)
|
|
248
|
+
|
|
249
|
+
# Detect weight tying (lm_head ↔ wte)
|
|
250
|
+
tying_map = {}
|
|
251
|
+
if hasattr(model, "lm_head") and hasattr(transformer, "wte"):
|
|
252
|
+
# Check if the weights are the same tensor (tied)
|
|
253
|
+
if model.lm_head.weight is transformer.wte.weight:
|
|
254
|
+
tying_map["lm_head.weight"] = "transformer.wte.weight"
|
|
255
|
+
|
|
256
|
+
# Build the required description format
|
|
257
|
+
description = {
|
|
258
|
+
# Required fields for validation gates
|
|
259
|
+
"n_layer": n_layers,
|
|
260
|
+
"heads_per_layer": heads_per_layer,
|
|
261
|
+
"mlp_dims": mlp_dims,
|
|
262
|
+
"tying": tying_map, # Use 'tying' instead of 'weight_tying' as per spec
|
|
263
|
+
# Additional useful information
|
|
264
|
+
"model_type": "gpt2",
|
|
265
|
+
"model_class": model.__class__.__name__,
|
|
266
|
+
"n_heads": n_heads,
|
|
267
|
+
"hidden_size": hidden_size,
|
|
268
|
+
"vocab_size": vocab_size,
|
|
269
|
+
"total_params": total_params,
|
|
270
|
+
"device": str(device),
|
|
271
|
+
# HuggingFace specific info
|
|
272
|
+
"hf_model_type": getattr(config, "model_type", "gpt2"),
|
|
273
|
+
"hf_config_class": config.__class__.__name__
|
|
274
|
+
if hasattr(config, "__class__")
|
|
275
|
+
else "unknown",
|
|
276
|
+
# Architecture details
|
|
277
|
+
"architecture": {
|
|
278
|
+
"has_lm_head": hasattr(model, "lm_head"),
|
|
279
|
+
"has_transformer_wrapper": hasattr(model, "transformer"),
|
|
280
|
+
"layer_norm_type": "pre", # GPT-2 uses pre-layer norm
|
|
281
|
+
"activation": getattr(config, "activation_function", "gelu_new"),
|
|
282
|
+
"positional_encoding": "learned", # GPT-2 uses learned position embeddings
|
|
283
|
+
"use_bias": getattr(config, "use_bias", True),
|
|
284
|
+
"split_size": getattr(config, "split_size", None),
|
|
285
|
+
},
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
return description
|
|
289
|
+
|
|
290
|
+
def _extract_weight_tying_info(self, model: ModuleType | Any) -> dict[str, str]:
|
|
291
|
+
"""
|
|
292
|
+
Extract weight tying relationships from the model.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
model: The model to analyze
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Dictionary mapping tied parameter names to their source parameter names
|
|
299
|
+
"""
|
|
300
|
+
tying_info = {}
|
|
301
|
+
|
|
302
|
+
# Check for lm_head ↔ wte tying (most common in GPT-2)
|
|
303
|
+
if hasattr(model, "lm_head") and hasattr(model, "transformer"):
|
|
304
|
+
if hasattr(model.transformer, "wte"):
|
|
305
|
+
if model.lm_head.weight is model.transformer.wte.weight:
|
|
306
|
+
tying_info["lm_head.weight"] = "transformer.wte.weight"
|
|
307
|
+
|
|
308
|
+
# Could be extended for other tying relationships
|
|
309
|
+
return tying_info
|
|
310
|
+
|
|
311
|
+
def _restore_weight_tying(
|
|
312
|
+
self, model: ModuleType | Any, tied_param: str, source_param: str
|
|
313
|
+
) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Restore a weight tying relationship between parameters.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
model: The model to modify
|
|
319
|
+
tied_param: Name of the parameter that should be tied
|
|
320
|
+
source_param: Name of the source parameter to tie to
|
|
321
|
+
"""
|
|
322
|
+
# This is a placeholder for weight tying restoration logic
|
|
323
|
+
# In practice, this would need to handle the specific tying relationships
|
|
324
|
+
# For now, we just warn about broken tying
|
|
325
|
+
print(
|
|
326
|
+
f"Warning: Weight tying relationship {tied_param} -> {source_param} may have been broken during restore"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def validate_split_size(self, model: ModuleType | Any) -> bool:
|
|
330
|
+
"""
|
|
331
|
+
Validate that split_size handling is correct for HuggingFace models.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
model: The model to validate
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
True if split_size is handled correctly
|
|
338
|
+
"""
|
|
339
|
+
if not hasattr(model, "config"):
|
|
340
|
+
return True # No config to validate
|
|
341
|
+
|
|
342
|
+
config = model.config
|
|
343
|
+
split_size = getattr(config, "split_size", None)
|
|
344
|
+
|
|
345
|
+
if split_size is None:
|
|
346
|
+
return True # No split_size specified
|
|
347
|
+
|
|
348
|
+
# Validate that c_attn layers respect split_size
|
|
349
|
+
try:
|
|
350
|
+
desc = self.describe(model)
|
|
351
|
+
if desc["n_layer"] > 0:
|
|
352
|
+
# Check first layer as representative
|
|
353
|
+
if hasattr(model, "transformer"):
|
|
354
|
+
layer = model.transformer.h[0]
|
|
355
|
+
else:
|
|
356
|
+
layer = model.h[0]
|
|
357
|
+
|
|
358
|
+
c_attn = layer.attn.c_attn
|
|
359
|
+
if hasattr(c_attn, "weight"):
|
|
360
|
+
# For Conv1D: weight shape is (in_features, out_features)
|
|
361
|
+
# out_features should be 3 * hidden_size for combined Q,K,V
|
|
362
|
+
expected_out = 3 * desc["hidden_size"]
|
|
363
|
+
actual_out = (
|
|
364
|
+
c_attn.weight.shape[1]
|
|
365
|
+
if hasattr(c_attn, "nf")
|
|
366
|
+
else c_attn.weight.shape[0]
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
return actual_out == expected_out
|
|
370
|
+
|
|
371
|
+
return True
|
|
372
|
+
|
|
373
|
+
except Exception:
|
|
374
|
+
return False
|
|
375
|
+
|
|
376
|
+
def get_layer_modules(
|
|
377
|
+
self, model: ModuleType | Any, layer_idx: int
|
|
378
|
+
) -> dict[str, ModuleType | Any]:
|
|
379
|
+
"""
|
|
380
|
+
Get the modules for a specific layer (utility method).
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
model: The HuggingFace GPT-2 model
|
|
384
|
+
layer_idx: Index of the layer to get modules for
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
Dictionary mapping module names to modules
|
|
388
|
+
"""
|
|
389
|
+
if hasattr(model, "transformer"):
|
|
390
|
+
layer = model.transformer.h[layer_idx]
|
|
391
|
+
else:
|
|
392
|
+
layer = model.h[layer_idx]
|
|
393
|
+
|
|
394
|
+
modules = {
|
|
395
|
+
"attn.c_attn": layer.attn.c_attn, # Combined Q,K,V projection
|
|
396
|
+
"attn.c_proj": layer.attn.c_proj, # Output projection
|
|
397
|
+
"mlp.c_fc": layer.mlp.c_fc, # Feed-forward expansion
|
|
398
|
+
"mlp.c_proj": layer.mlp.c_proj, # Feed-forward projection
|
|
399
|
+
"ln_1": layer.ln_1, # Layer norm 1
|
|
400
|
+
"ln_2": layer.ln_2, # Layer norm 2
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
return modules
|