invarlock 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invarlock/__init__.py CHANGED
@@ -12,7 +12,7 @@ For torch-dependent functionality, see subpackages under `invarlock.*`:
12
12
  - `invarlock.eval`: Metrics, guard-overhead checks, and certification
13
13
  """
14
14
 
15
- __version__ = "0.2.0"
15
+ __version__ = "0.3.1"
16
16
 
17
17
  # Core exports - torch-independent
18
18
  from .config import CFG, Defaults, get_default_config
@@ -11,5 +11,10 @@ dataset:
11
11
  final_n: 120
12
12
  stride: 512
13
13
 
14
+ primary_metric:
15
+ acceptance_range:
16
+ min: 0.95
17
+ max: 1.15
18
+
14
19
  context:
15
20
  telemetry_profile: "ci_cpu"
@@ -20,6 +20,13 @@ from .base import (
20
20
  from .base import (
21
21
  PerformanceMetrics as BasePerformanceMetrics,
22
22
  )
23
+ from .capabilities import (
24
+ ModelCapabilities,
25
+ QuantizationConfig,
26
+ QuantizationMethod,
27
+ detect_capabilities_from_model,
28
+ detect_quantization_from_config,
29
+ )
23
30
 
24
31
  _LAZY_MAP = {
25
32
  "HF_BERT_Adapter": ".hf_bert",
@@ -99,4 +106,10 @@ __all__ = [
99
106
  "quality_label",
100
107
  "_RemovedComponent",
101
108
  "INVARLOCK_CORE_ABI",
109
+ # Capabilities
110
+ "ModelCapabilities",
111
+ "QuantizationConfig",
112
+ "QuantizationMethod",
113
+ "detect_capabilities_from_model",
114
+ "detect_quantization_from_config",
102
115
  ]
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import importlib as _importlib
4
+ import json
5
+ from pathlib import Path
4
6
  from typing import Any
5
7
 
6
8
  from invarlock.core.api import ModelAdapter
@@ -8,51 +10,176 @@ from invarlock.core.api import ModelAdapter
8
10
  from ..cli.adapter_auto import resolve_auto_adapter
9
11
 
10
12
 
13
+ def _detect_quantization_from_path(model_id: str) -> str | None:
14
+ """
15
+ Detect quantization method from a local checkpoint path.
16
+
17
+ Returns:
18
+ Quantization adapter name ("hf_bnb", "hf_awq", "hf_gptq") or None.
19
+ """
20
+ path = Path(model_id)
21
+ if not path.exists():
22
+ return None
23
+
24
+ config_path = path / "config.json"
25
+ if not config_path.exists():
26
+ return None
27
+
28
+ try:
29
+ config_data = json.loads(config_path.read_text())
30
+ quant_cfg = config_data.get("quantization_config", {})
31
+
32
+ if not quant_cfg:
33
+ return None
34
+
35
+ quant_method = quant_cfg.get("quant_method", "").lower()
36
+
37
+ if quant_method == "awq":
38
+ return "hf_awq"
39
+ elif quant_method == "gptq":
40
+ return "hf_gptq"
41
+ elif (
42
+ quant_method == "bitsandbytes"
43
+ or quant_cfg.get("load_in_8bit")
44
+ or quant_cfg.get("load_in_4bit")
45
+ ):
46
+ return "hf_bnb"
47
+
48
+ except Exception:
49
+ pass
50
+
51
+ return None
52
+
53
+
54
+ def _detect_quantization_from_model(model: Any) -> str | None:
55
+ """
56
+ Detect quantization method from a loaded model instance.
57
+
58
+ Returns:
59
+ Quantization adapter name ("hf_bnb", "hf_awq", "hf_gptq") or None.
60
+ """
61
+ config = getattr(model, "config", None)
62
+ if config is None:
63
+ return None
64
+
65
+ quant_cfg = getattr(config, "quantization_config", None)
66
+ if quant_cfg is None:
67
+ # Check for BNB attributes on the model itself
68
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(
69
+ model, "is_loaded_in_4bit", False
70
+ ):
71
+ return "hf_bnb"
72
+ return None
73
+
74
+ # Handle dict-style config
75
+ if isinstance(quant_cfg, dict):
76
+ quant_method = quant_cfg.get("quant_method", "").lower()
77
+ if quant_method == "awq":
78
+ return "hf_awq"
79
+ elif quant_method == "gptq":
80
+ return "hf_gptq"
81
+ elif (
82
+ quant_method == "bitsandbytes"
83
+ or quant_cfg.get("load_in_8bit")
84
+ or quant_cfg.get("load_in_4bit")
85
+ ):
86
+ return "hf_bnb"
87
+ else:
88
+ # Object-style config
89
+ cfg_class = quant_cfg.__class__.__name__
90
+ if cfg_class in ("AWQConfig",):
91
+ return "hf_awq"
92
+ elif cfg_class in ("GPTQConfig",):
93
+ return "hf_gptq"
94
+ elif cfg_class in ("BitsAndBytesConfig", "BnbConfig"):
95
+ return "hf_bnb"
96
+ # Check attributes
97
+ if getattr(quant_cfg, "load_in_8bit", False) or getattr(
98
+ quant_cfg, "load_in_4bit", False
99
+ ):
100
+ return "hf_bnb"
101
+
102
+ return None
103
+
104
+
11
105
  class _DelegatingAdapter(ModelAdapter):
12
106
  name = "auto_adapter"
13
107
 
14
108
  def __init__(self) -> None:
15
109
  self._delegate: ModelAdapter | None = None
16
110
 
17
- def _ensure_delegate_from_id(self, model_id: str) -> ModelAdapter:
18
- if self._delegate is not None:
19
- return self._delegate
20
- resolved = resolve_auto_adapter(model_id)
21
- if resolved == "hf_llama":
111
+ def _load_adapter(self, adapter_name: str) -> ModelAdapter:
112
+ """Load an adapter by name."""
113
+ if adapter_name == "hf_llama":
22
114
  HF_LLaMA_Adapter = _importlib.import_module(
23
115
  ".hf_llama", __package__
24
116
  ).HF_LLaMA_Adapter
25
- self._delegate = HF_LLaMA_Adapter()
26
- elif resolved == "hf_bert":
117
+ return HF_LLaMA_Adapter()
118
+ elif adapter_name == "hf_bert":
27
119
  HF_BERT_Adapter = _importlib.import_module(
28
120
  ".hf_bert", __package__
29
121
  ).HF_BERT_Adapter
30
- self._delegate = HF_BERT_Adapter()
122
+ return HF_BERT_Adapter()
123
+ elif adapter_name == "hf_gpt2":
124
+ HF_GPT2_Adapter = _importlib.import_module(
125
+ ".hf_gpt2", __package__
126
+ ).HF_GPT2_Adapter
127
+ return HF_GPT2_Adapter()
128
+ elif adapter_name == "hf_bnb":
129
+ HF_BNB_Adapter = _importlib.import_module(
130
+ "invarlock.plugins.hf_bnb_adapter"
131
+ ).HF_BNB_Adapter
132
+ return HF_BNB_Adapter()
133
+ elif adapter_name == "hf_awq":
134
+ HF_AWQ_Adapter = _importlib.import_module(
135
+ "invarlock.plugins.hf_awq_adapter"
136
+ ).HF_AWQ_Adapter
137
+ return HF_AWQ_Adapter()
138
+ elif adapter_name == "hf_gptq":
139
+ HF_GPTQ_Adapter = _importlib.import_module(
140
+ "invarlock.plugins.hf_gptq_adapter"
141
+ ).HF_GPTQ_Adapter
142
+ return HF_GPTQ_Adapter()
31
143
  else:
144
+ # Default to GPT2 adapter
32
145
  HF_GPT2_Adapter = _importlib.import_module(
33
146
  ".hf_gpt2", __package__
34
147
  ).HF_GPT2_Adapter
35
- self._delegate = HF_GPT2_Adapter()
148
+ return HF_GPT2_Adapter()
149
+
150
+ def _ensure_delegate_from_id(self, model_id: str) -> ModelAdapter:
151
+ if self._delegate is not None:
152
+ return self._delegate
153
+
154
+ # First check for quantization in local checkpoint
155
+ quant_adapter = _detect_quantization_from_path(model_id)
156
+ if quant_adapter:
157
+ self._delegate = self._load_adapter(quant_adapter)
158
+ return self._delegate
159
+
160
+ # Fall back to architecture-based resolution
161
+ resolved = resolve_auto_adapter(model_id)
162
+ self._delegate = self._load_adapter(resolved)
36
163
  return self._delegate
37
164
 
38
165
  def _ensure_delegate_from_model(self, model: Any) -> ModelAdapter:
39
- # Best-effort: inspect class name
166
+ if self._delegate is not None:
167
+ return self._delegate
168
+
169
+ # First check for quantization on the loaded model
170
+ quant_adapter = _detect_quantization_from_model(model)
171
+ if quant_adapter:
172
+ self._delegate = self._load_adapter(quant_adapter)
173
+ return self._delegate
174
+
175
+ # Fall back to class name inspection
40
176
  cls_name = getattr(model, "__class__", type(model)).__name__.lower()
41
177
  if any(k in cls_name for k in ["llama", "mistral", "qwen", "yi"]):
42
- HF_LLaMA_Adapter = _importlib.import_module(
43
- ".hf_llama", __package__
44
- ).HF_LLaMA_Adapter
45
- self._delegate = HF_LLaMA_Adapter()
178
+ self._delegate = self._load_adapter("hf_llama")
46
179
  elif any(k in cls_name for k in ["bert", "roberta", "albert", "deberta"]):
47
- HF_BERT_Adapter = _importlib.import_module(
48
- ".hf_bert", __package__
49
- ).HF_BERT_Adapter
50
- self._delegate = HF_BERT_Adapter()
180
+ self._delegate = self._load_adapter("hf_bert")
51
181
  else:
52
- HF_GPT2_Adapter = _importlib.import_module(
53
- ".hf_gpt2", __package__
54
- ).HF_GPT2_Adapter
55
- self._delegate = HF_GPT2_Adapter()
182
+ self._delegate = self._load_adapter("hf_gpt2")
56
183
  return self._delegate
57
184
 
58
185
  def can_handle(self, model: Any) -> bool: # pragma: no cover - trivial
@@ -0,0 +1,421 @@
1
+ """
2
+ Model Capabilities
3
+ ==================
4
+
5
+ Dataclasses for declaring model capabilities and quantization configuration.
6
+ Used by adapters to advertise model properties that affect device handling,
7
+ snapshot/restore behavior, and evaluation strategies.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from enum import Enum
14
+ from typing import Any
15
+
16
+
17
+ class QuantizationMethod(Enum):
18
+ """Supported quantization methods."""
19
+
20
+ NONE = "none"
21
+ BNB_8BIT = "bnb_8bit"
22
+ BNB_4BIT = "bnb_4bit"
23
+ AWQ = "awq"
24
+ GPTQ = "gptq"
25
+ ONNX = "onnx"
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class QuantizationConfig:
30
+ """
31
+ Quantization configuration for a loaded model.
32
+
33
+ Attributes:
34
+ method: The quantization method used.
35
+ bits: Bit-width of the quantization (e.g., 4, 8, 16).
36
+ group_size: Group size for grouped quantization (AWQ/GPTQ).
37
+ from_checkpoint: True if model was loaded from pre-quantized checkpoint.
38
+ double_quant: Whether double quantization is enabled (BNB 4-bit).
39
+ compute_dtype: Data type for computation (e.g., "float16", "bfloat16").
40
+ """
41
+
42
+ method: QuantizationMethod = QuantizationMethod.NONE
43
+ bits: int = 16
44
+ group_size: int | None = None
45
+ from_checkpoint: bool = False
46
+ double_quant: bool = False
47
+ compute_dtype: str | None = None
48
+
49
+ def is_quantized(self) -> bool:
50
+ """Return True if the model is quantized."""
51
+ return self.method != QuantizationMethod.NONE
52
+
53
+ def is_bnb(self) -> bool:
54
+ """Return True if using BitsAndBytes quantization."""
55
+ return self.method in (QuantizationMethod.BNB_8BIT, QuantizationMethod.BNB_4BIT)
56
+
57
+
58
+ @dataclass
59
+ class ModelCapabilities:
60
+ """
61
+ Declared capabilities of a loaded model.
62
+
63
+ Used to inform safe device handling, snapshot/restore strategies,
64
+ and evaluation metric selection.
65
+
66
+ Attributes:
67
+ quantization: Quantization configuration (if any).
68
+ device_movable: Whether model.to(device) is safe to call.
69
+ False for BNB models which handle device placement internally.
70
+ weight_tied: Mapping of tied parameter names to their source.
71
+ Example: {"lm_head.weight": "model.embed_tokens.weight"}
72
+ primary_metric_kind: Primary evaluation metric type.
73
+ Examples: "ppl_causal", "ppl_mlm", "accuracy", "bleu".
74
+ supports_kv_cache: Whether model supports key-value caching.
75
+ supports_flash_attention: Whether model supports Flash Attention.
76
+ max_sequence_length: Maximum supported sequence length.
77
+ supports_gradient_checkpointing: Whether model supports gradient checkpointing.
78
+ """
79
+
80
+ quantization: QuantizationConfig = field(
81
+ default_factory=lambda: QuantizationConfig()
82
+ )
83
+ device_movable: bool = True
84
+ weight_tied: dict[str, str] = field(default_factory=dict)
85
+ primary_metric_kind: str = "ppl_causal"
86
+ supports_kv_cache: bool = True
87
+ supports_flash_attention: bool = False
88
+ max_sequence_length: int | None = None
89
+ supports_gradient_checkpointing: bool = True
90
+
91
+ @classmethod
92
+ def for_fp16_model(cls) -> ModelCapabilities:
93
+ """Create capabilities for a standard FP16 model."""
94
+ return cls(
95
+ quantization=QuantizationConfig(method=QuantizationMethod.NONE, bits=16),
96
+ device_movable=True,
97
+ )
98
+
99
+ @classmethod
100
+ def for_bnb_8bit(cls, from_checkpoint: bool = False) -> ModelCapabilities:
101
+ """Create capabilities for a BitsAndBytes 8-bit model."""
102
+ return cls(
103
+ quantization=QuantizationConfig(
104
+ method=QuantizationMethod.BNB_8BIT,
105
+ bits=8,
106
+ from_checkpoint=from_checkpoint,
107
+ ),
108
+ device_movable=False, # BNB handles device placement
109
+ )
110
+
111
+ @classmethod
112
+ def for_bnb_4bit(
113
+ cls, from_checkpoint: bool = False, double_quant: bool = False
114
+ ) -> ModelCapabilities:
115
+ """Create capabilities for a BitsAndBytes 4-bit model."""
116
+ return cls(
117
+ quantization=QuantizationConfig(
118
+ method=QuantizationMethod.BNB_4BIT,
119
+ bits=4,
120
+ from_checkpoint=from_checkpoint,
121
+ double_quant=double_quant,
122
+ ),
123
+ device_movable=False, # BNB handles device placement
124
+ )
125
+
126
+ @classmethod
127
+ def for_awq(
128
+ cls, group_size: int = 128, from_checkpoint: bool = True
129
+ ) -> ModelCapabilities:
130
+ """Create capabilities for an AWQ model."""
131
+ return cls(
132
+ quantization=QuantizationConfig(
133
+ method=QuantizationMethod.AWQ,
134
+ bits=4,
135
+ group_size=group_size,
136
+ from_checkpoint=from_checkpoint,
137
+ ),
138
+ device_movable=False, # AWQ may have device constraints
139
+ )
140
+
141
+ @classmethod
142
+ def for_gptq(
143
+ cls, bits: int = 4, group_size: int = 128, from_checkpoint: bool = True
144
+ ) -> ModelCapabilities:
145
+ """Create capabilities for a GPTQ model."""
146
+ return cls(
147
+ quantization=QuantizationConfig(
148
+ method=QuantizationMethod.GPTQ,
149
+ bits=bits,
150
+ group_size=group_size,
151
+ from_checkpoint=from_checkpoint,
152
+ ),
153
+ device_movable=False, # GPTQ may have device constraints
154
+ )
155
+
156
+
157
+ def detect_quantization_from_config(config: Any) -> QuantizationConfig:
158
+ """
159
+ Detect quantization configuration from a HuggingFace model config.
160
+
161
+ Checks for quantization_config in the model's config and returns
162
+ the appropriate QuantizationConfig.
163
+
164
+ Args:
165
+ config: HuggingFace model config object
166
+
167
+ Returns:
168
+ QuantizationConfig describing the model's quantization state
169
+ """
170
+ if config is None:
171
+ return QuantizationConfig()
172
+
173
+ # Check for quantization_config attribute (BNB, AWQ, GPTQ)
174
+ quant_cfg = getattr(config, "quantization_config", None)
175
+ if quant_cfg is None:
176
+ return QuantizationConfig()
177
+
178
+ # Handle dict-style config (common in saved checkpoints)
179
+ if isinstance(quant_cfg, dict):
180
+ quant_method = quant_cfg.get("quant_method", "").lower()
181
+ load_in_8bit = quant_cfg.get("load_in_8bit", False)
182
+ load_in_4bit = quant_cfg.get("load_in_4bit", False)
183
+ bits = quant_cfg.get("bits", 16)
184
+ group_size = quant_cfg.get("group_size")
185
+ double_quant = quant_cfg.get("bnb_4bit_use_double_quant", False)
186
+ compute_dtype = quant_cfg.get("bnb_4bit_compute_dtype")
187
+
188
+ if quant_method == "awq":
189
+ return QuantizationConfig(
190
+ method=QuantizationMethod.AWQ,
191
+ bits=bits,
192
+ group_size=group_size,
193
+ from_checkpoint=True,
194
+ )
195
+ elif quant_method == "gptq":
196
+ return QuantizationConfig(
197
+ method=QuantizationMethod.GPTQ,
198
+ bits=bits,
199
+ group_size=group_size,
200
+ from_checkpoint=True,
201
+ )
202
+ elif load_in_8bit or quant_method == "bitsandbytes" and bits == 8:
203
+ return QuantizationConfig(
204
+ method=QuantizationMethod.BNB_8BIT,
205
+ bits=8,
206
+ from_checkpoint=True,
207
+ )
208
+ elif load_in_4bit or quant_method == "bitsandbytes" and bits == 4:
209
+ return QuantizationConfig(
210
+ method=QuantizationMethod.BNB_4BIT,
211
+ bits=4,
212
+ from_checkpoint=True,
213
+ double_quant=double_quant,
214
+ compute_dtype=str(compute_dtype) if compute_dtype else None,
215
+ )
216
+
217
+ # Handle object-style config (e.g., BitsAndBytesConfig)
218
+ # Check by class name to avoid import dependency
219
+ cfg_class = quant_cfg.__class__.__name__
220
+
221
+ if cfg_class in ("BitsAndBytesConfig", "BnbConfig"):
222
+ load_in_8bit = getattr(quant_cfg, "load_in_8bit", False)
223
+ load_in_4bit = getattr(quant_cfg, "load_in_4bit", False)
224
+ double_quant = getattr(quant_cfg, "bnb_4bit_use_double_quant", False)
225
+ compute_dtype = getattr(quant_cfg, "bnb_4bit_compute_dtype", None)
226
+
227
+ if load_in_8bit:
228
+ return QuantizationConfig(
229
+ method=QuantizationMethod.BNB_8BIT,
230
+ bits=8,
231
+ from_checkpoint=True,
232
+ )
233
+ elif load_in_4bit:
234
+ return QuantizationConfig(
235
+ method=QuantizationMethod.BNB_4BIT,
236
+ bits=4,
237
+ from_checkpoint=True,
238
+ double_quant=double_quant,
239
+ compute_dtype=str(compute_dtype) if compute_dtype else None,
240
+ )
241
+
242
+ if cfg_class in ("AWQConfig",):
243
+ bits = getattr(quant_cfg, "bits", 4)
244
+ group_size = getattr(quant_cfg, "group_size", 128)
245
+ return QuantizationConfig(
246
+ method=QuantizationMethod.AWQ,
247
+ bits=bits,
248
+ group_size=group_size,
249
+ from_checkpoint=True,
250
+ )
251
+
252
+ if cfg_class in ("GPTQConfig",):
253
+ bits = getattr(quant_cfg, "bits", 4)
254
+ group_size = getattr(quant_cfg, "group_size", 128)
255
+ return QuantizationConfig(
256
+ method=QuantizationMethod.GPTQ,
257
+ bits=bits,
258
+ group_size=group_size,
259
+ from_checkpoint=True,
260
+ )
261
+
262
+ return QuantizationConfig()
263
+
264
+
265
+ def detect_capabilities_from_model(model: Any) -> ModelCapabilities:
266
+ """
267
+ Detect model capabilities from a loaded model instance.
268
+
269
+ Inspects the model's config, state, and structure to determine
270
+ its capabilities including quantization state.
271
+
272
+ Args:
273
+ model: Loaded model instance (typically HuggingFace PreTrainedModel)
274
+
275
+ Returns:
276
+ ModelCapabilities describing the model's capabilities
277
+ """
278
+ config = getattr(model, "config", None)
279
+ quant_config = detect_quantization_from_config(config)
280
+
281
+ # Check for BNB attributes on the model itself (may not be in config)
282
+ # Transformers sets these flags on loaded BNB models even if config.quantization_config
283
+ # doesn't reflect the quantization state (e.g., for saved BNB checkpoints)
284
+ # Note: We check `is True` explicitly to avoid MagicMock truthiness
285
+ if not quant_config.is_quantized():
286
+ is_8bit = getattr(model, "is_loaded_in_8bit", None)
287
+ is_4bit = getattr(model, "is_loaded_in_4bit", None)
288
+ if is_8bit is True:
289
+ quant_config = QuantizationConfig(
290
+ method=QuantizationMethod.BNB_8BIT,
291
+ bits=8,
292
+ from_checkpoint=True,
293
+ )
294
+ elif is_4bit is True:
295
+ quant_config = QuantizationConfig(
296
+ method=QuantizationMethod.BNB_4BIT,
297
+ bits=4,
298
+ from_checkpoint=True,
299
+ )
300
+
301
+ # Also check for quantized module types that indicate BNB usage
302
+ # Only attempt this if model has a callable modules() method (torch.nn.Module)
303
+ if not quant_config.is_quantized():
304
+ modules_method = getattr(model, "modules", None)
305
+ if callable(modules_method):
306
+ try:
307
+ for module in modules_method():
308
+ module_name = module.__class__.__name__
309
+ if module_name in ("Linear8bitLt", "Linear4bit"):
310
+ if "8bit" in module_name:
311
+ quant_config = QuantizationConfig(
312
+ method=QuantizationMethod.BNB_8BIT,
313
+ bits=8,
314
+ from_checkpoint=True,
315
+ )
316
+ else:
317
+ quant_config = QuantizationConfig(
318
+ method=QuantizationMethod.BNB_4BIT,
319
+ bits=4,
320
+ from_checkpoint=True,
321
+ )
322
+ break
323
+ except (TypeError, StopIteration):
324
+ pass
325
+
326
+ # Determine if device is movable
327
+ device_movable = not quant_config.is_bnb()
328
+
329
+ # For AWQ/GPTQ, check if model has been quantized in a way that
330
+ # prevents device movement
331
+ if quant_config.method in (QuantizationMethod.AWQ, QuantizationMethod.GPTQ):
332
+ # These are typically loaded on-device and shouldn't be moved
333
+ device_movable = False
334
+
335
+ # Detect weight tying
336
+ weight_tied = _detect_weight_tying(model)
337
+
338
+ # Detect primary metric kind
339
+ primary_metric = _detect_primary_metric(model)
340
+
341
+ # Detect other capabilities
342
+ max_seq_len = getattr(config, "max_position_embeddings", None)
343
+ supports_flash = (
344
+ getattr(config, "_attn_implementation", None) == "flash_attention_2"
345
+ )
346
+
347
+ return ModelCapabilities(
348
+ quantization=quant_config,
349
+ device_movable=device_movable,
350
+ weight_tied=weight_tied,
351
+ primary_metric_kind=primary_metric,
352
+ max_sequence_length=max_seq_len,
353
+ supports_flash_attention=supports_flash,
354
+ )
355
+
356
+
357
+ def _detect_weight_tying(model: Any) -> dict[str, str]:
358
+ """Detect weight tying relationships in the model."""
359
+ tying: dict[str, str] = {}
360
+
361
+ # Common weight tying patterns
362
+ # LLaMA/Mistral: lm_head.weight ↔ model.embed_tokens.weight
363
+ if hasattr(model, "lm_head") and hasattr(model, "model"):
364
+ inner = model.model
365
+ if hasattr(inner, "embed_tokens"):
366
+ lm_head_weight = getattr(model.lm_head, "weight", None)
367
+ embed_weight = getattr(inner.embed_tokens, "weight", None)
368
+ if lm_head_weight is not None and embed_weight is not None:
369
+ if lm_head_weight is embed_weight:
370
+ tying["lm_head.weight"] = "model.embed_tokens.weight"
371
+
372
+ # GPT-2: lm_head.weight ↔ transformer.wte.weight
373
+ if hasattr(model, "lm_head") and hasattr(model, "transformer"):
374
+ xformer = model.transformer
375
+ if hasattr(xformer, "wte"):
376
+ lm_head_weight = getattr(model.lm_head, "weight", None)
377
+ wte_weight = getattr(xformer.wte, "weight", None)
378
+ if lm_head_weight is not None and wte_weight is not None:
379
+ if lm_head_weight is wte_weight:
380
+ tying["lm_head.weight"] = "transformer.wte.weight"
381
+
382
+ return tying
383
+
384
+
385
+ def _detect_primary_metric(model: Any) -> str:
386
+ """Detect the primary evaluation metric type for this model."""
387
+ config = getattr(model, "config", None)
388
+ if config is None:
389
+ return "ppl_causal"
390
+
391
+ model_type = getattr(config, "model_type", "").lower()
392
+ architectures = getattr(config, "architectures", []) or []
393
+ arch_str = " ".join(architectures).lower()
394
+
395
+ # Encoder-only models (BERT-like)
396
+ if any(k in model_type for k in ["bert", "roberta", "albert", "deberta"]):
397
+ if "masked" in arch_str or "mlm" in arch_str:
398
+ return "ppl_mlm"
399
+ if "classification" in arch_str or "sequence" in arch_str:
400
+ return "accuracy"
401
+ return "ppl_mlm"
402
+
403
+ # Encoder-decoder models (T5-like)
404
+ if any(k in model_type for k in ["t5", "bart", "marian", "pegasus"]):
405
+ if "translation" in arch_str or "mt" in arch_str:
406
+ return "bleu"
407
+ if "summarization" in arch_str:
408
+ return "rouge"
409
+ return "ppl_seq2seq"
410
+
411
+ # Decoder-only models (GPT-like, LLaMA-like)
412
+ return "ppl_causal"
413
+
414
+
415
+ __all__ = [
416
+ "QuantizationMethod",
417
+ "QuantizationConfig",
418
+ "ModelCapabilities",
419
+ "detect_quantization_from_config",
420
+ "detect_capabilities_from_model",
421
+ ]
@@ -69,8 +69,8 @@ class HF_LLaMA_Adapter(HFAdapterMixin, ModelAdapter):
69
69
  ):
70
70
  model = AutoModelForCausalLM.from_pretrained(model_id)
71
71
 
72
- target_device = self._resolve_device(device)
73
- return model.to(target_device)
72
+ # Use safe device movement that respects quantization constraints
73
+ return self._safe_to_device(model, device)
74
74
 
75
75
  def can_handle(self, model: ModuleType | Any) -> bool:
76
76
  """