invarlock 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/profiles/ci_cpu.yaml +5 -0
- invarlock/adapters/__init__.py +13 -0
- invarlock/adapters/auto.py +149 -22
- invarlock/adapters/capabilities.py +421 -0
- invarlock/adapters/hf_llama.py +2 -2
- invarlock/adapters/hf_mixin.py +122 -1
- invarlock/cli/commands/doctor.py +7 -1
- invarlock/cli/commands/run.py +148 -2
- invarlock/core/registry.py +34 -6
- invarlock/guards/variance.py +41 -6
- invarlock/plugins/hf_awq_adapter.py +22 -1
- invarlock/plugins/hf_bnb_adapter.py +117 -22
- invarlock/plugins/hf_gptq_adapter.py +24 -1
- invarlock/reporting/certificate.py +155 -15
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/METADATA +2 -2
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/RECORD +21 -20
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/WHEEL +0 -0
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/entry_points.txt +0 -0
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/top_level.txt +0 -0
invarlock/__init__.py
CHANGED
|
@@ -12,7 +12,7 @@ For torch-dependent functionality, see subpackages under `invarlock.*`:
|
|
|
12
12
|
- `invarlock.eval`: Metrics, guard-overhead checks, and certification
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
__version__ = "0.
|
|
15
|
+
__version__ = "0.3.1"
|
|
16
16
|
|
|
17
17
|
# Core exports - torch-independent
|
|
18
18
|
from .config import CFG, Defaults, get_default_config
|
invarlock/adapters/__init__.py
CHANGED
|
@@ -20,6 +20,13 @@ from .base import (
|
|
|
20
20
|
from .base import (
|
|
21
21
|
PerformanceMetrics as BasePerformanceMetrics,
|
|
22
22
|
)
|
|
23
|
+
from .capabilities import (
|
|
24
|
+
ModelCapabilities,
|
|
25
|
+
QuantizationConfig,
|
|
26
|
+
QuantizationMethod,
|
|
27
|
+
detect_capabilities_from_model,
|
|
28
|
+
detect_quantization_from_config,
|
|
29
|
+
)
|
|
23
30
|
|
|
24
31
|
_LAZY_MAP = {
|
|
25
32
|
"HF_BERT_Adapter": ".hf_bert",
|
|
@@ -99,4 +106,10 @@ __all__ = [
|
|
|
99
106
|
"quality_label",
|
|
100
107
|
"_RemovedComponent",
|
|
101
108
|
"INVARLOCK_CORE_ABI",
|
|
109
|
+
# Capabilities
|
|
110
|
+
"ModelCapabilities",
|
|
111
|
+
"QuantizationConfig",
|
|
112
|
+
"QuantizationMethod",
|
|
113
|
+
"detect_capabilities_from_model",
|
|
114
|
+
"detect_quantization_from_config",
|
|
102
115
|
]
|
invarlock/adapters/auto.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import importlib as _importlib
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
4
6
|
from typing import Any
|
|
5
7
|
|
|
6
8
|
from invarlock.core.api import ModelAdapter
|
|
@@ -8,51 +10,176 @@ from invarlock.core.api import ModelAdapter
|
|
|
8
10
|
from ..cli.adapter_auto import resolve_auto_adapter
|
|
9
11
|
|
|
10
12
|
|
|
13
|
+
def _detect_quantization_from_path(model_id: str) -> str | None:
|
|
14
|
+
"""
|
|
15
|
+
Detect quantization method from a local checkpoint path.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Quantization adapter name ("hf_bnb", "hf_awq", "hf_gptq") or None.
|
|
19
|
+
"""
|
|
20
|
+
path = Path(model_id)
|
|
21
|
+
if not path.exists():
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
config_path = path / "config.json"
|
|
25
|
+
if not config_path.exists():
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
config_data = json.loads(config_path.read_text())
|
|
30
|
+
quant_cfg = config_data.get("quantization_config", {})
|
|
31
|
+
|
|
32
|
+
if not quant_cfg:
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
|
36
|
+
|
|
37
|
+
if quant_method == "awq":
|
|
38
|
+
return "hf_awq"
|
|
39
|
+
elif quant_method == "gptq":
|
|
40
|
+
return "hf_gptq"
|
|
41
|
+
elif (
|
|
42
|
+
quant_method == "bitsandbytes"
|
|
43
|
+
or quant_cfg.get("load_in_8bit")
|
|
44
|
+
or quant_cfg.get("load_in_4bit")
|
|
45
|
+
):
|
|
46
|
+
return "hf_bnb"
|
|
47
|
+
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _detect_quantization_from_model(model: Any) -> str | None:
|
|
55
|
+
"""
|
|
56
|
+
Detect quantization method from a loaded model instance.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Quantization adapter name ("hf_bnb", "hf_awq", "hf_gptq") or None.
|
|
60
|
+
"""
|
|
61
|
+
config = getattr(model, "config", None)
|
|
62
|
+
if config is None:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
quant_cfg = getattr(config, "quantization_config", None)
|
|
66
|
+
if quant_cfg is None:
|
|
67
|
+
# Check for BNB attributes on the model itself
|
|
68
|
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(
|
|
69
|
+
model, "is_loaded_in_4bit", False
|
|
70
|
+
):
|
|
71
|
+
return "hf_bnb"
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
# Handle dict-style config
|
|
75
|
+
if isinstance(quant_cfg, dict):
|
|
76
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
|
77
|
+
if quant_method == "awq":
|
|
78
|
+
return "hf_awq"
|
|
79
|
+
elif quant_method == "gptq":
|
|
80
|
+
return "hf_gptq"
|
|
81
|
+
elif (
|
|
82
|
+
quant_method == "bitsandbytes"
|
|
83
|
+
or quant_cfg.get("load_in_8bit")
|
|
84
|
+
or quant_cfg.get("load_in_4bit")
|
|
85
|
+
):
|
|
86
|
+
return "hf_bnb"
|
|
87
|
+
else:
|
|
88
|
+
# Object-style config
|
|
89
|
+
cfg_class = quant_cfg.__class__.__name__
|
|
90
|
+
if cfg_class in ("AWQConfig",):
|
|
91
|
+
return "hf_awq"
|
|
92
|
+
elif cfg_class in ("GPTQConfig",):
|
|
93
|
+
return "hf_gptq"
|
|
94
|
+
elif cfg_class in ("BitsAndBytesConfig", "BnbConfig"):
|
|
95
|
+
return "hf_bnb"
|
|
96
|
+
# Check attributes
|
|
97
|
+
if getattr(quant_cfg, "load_in_8bit", False) or getattr(
|
|
98
|
+
quant_cfg, "load_in_4bit", False
|
|
99
|
+
):
|
|
100
|
+
return "hf_bnb"
|
|
101
|
+
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
|
|
11
105
|
class _DelegatingAdapter(ModelAdapter):
|
|
12
106
|
name = "auto_adapter"
|
|
13
107
|
|
|
14
108
|
def __init__(self) -> None:
|
|
15
109
|
self._delegate: ModelAdapter | None = None
|
|
16
110
|
|
|
17
|
-
def
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
resolved = resolve_auto_adapter(model_id)
|
|
21
|
-
if resolved == "hf_llama":
|
|
111
|
+
def _load_adapter(self, adapter_name: str) -> ModelAdapter:
|
|
112
|
+
"""Load an adapter by name."""
|
|
113
|
+
if adapter_name == "hf_llama":
|
|
22
114
|
HF_LLaMA_Adapter = _importlib.import_module(
|
|
23
115
|
".hf_llama", __package__
|
|
24
116
|
).HF_LLaMA_Adapter
|
|
25
|
-
|
|
26
|
-
elif
|
|
117
|
+
return HF_LLaMA_Adapter()
|
|
118
|
+
elif adapter_name == "hf_bert":
|
|
27
119
|
HF_BERT_Adapter = _importlib.import_module(
|
|
28
120
|
".hf_bert", __package__
|
|
29
121
|
).HF_BERT_Adapter
|
|
30
|
-
|
|
122
|
+
return HF_BERT_Adapter()
|
|
123
|
+
elif adapter_name == "hf_gpt2":
|
|
124
|
+
HF_GPT2_Adapter = _importlib.import_module(
|
|
125
|
+
".hf_gpt2", __package__
|
|
126
|
+
).HF_GPT2_Adapter
|
|
127
|
+
return HF_GPT2_Adapter()
|
|
128
|
+
elif adapter_name == "hf_bnb":
|
|
129
|
+
HF_BNB_Adapter = _importlib.import_module(
|
|
130
|
+
"invarlock.plugins.hf_bnb_adapter"
|
|
131
|
+
).HF_BNB_Adapter
|
|
132
|
+
return HF_BNB_Adapter()
|
|
133
|
+
elif adapter_name == "hf_awq":
|
|
134
|
+
HF_AWQ_Adapter = _importlib.import_module(
|
|
135
|
+
"invarlock.plugins.hf_awq_adapter"
|
|
136
|
+
).HF_AWQ_Adapter
|
|
137
|
+
return HF_AWQ_Adapter()
|
|
138
|
+
elif adapter_name == "hf_gptq":
|
|
139
|
+
HF_GPTQ_Adapter = _importlib.import_module(
|
|
140
|
+
"invarlock.plugins.hf_gptq_adapter"
|
|
141
|
+
).HF_GPTQ_Adapter
|
|
142
|
+
return HF_GPTQ_Adapter()
|
|
31
143
|
else:
|
|
144
|
+
# Default to GPT2 adapter
|
|
32
145
|
HF_GPT2_Adapter = _importlib.import_module(
|
|
33
146
|
".hf_gpt2", __package__
|
|
34
147
|
).HF_GPT2_Adapter
|
|
35
|
-
|
|
148
|
+
return HF_GPT2_Adapter()
|
|
149
|
+
|
|
150
|
+
def _ensure_delegate_from_id(self, model_id: str) -> ModelAdapter:
|
|
151
|
+
if self._delegate is not None:
|
|
152
|
+
return self._delegate
|
|
153
|
+
|
|
154
|
+
# First check for quantization in local checkpoint
|
|
155
|
+
quant_adapter = _detect_quantization_from_path(model_id)
|
|
156
|
+
if quant_adapter:
|
|
157
|
+
self._delegate = self._load_adapter(quant_adapter)
|
|
158
|
+
return self._delegate
|
|
159
|
+
|
|
160
|
+
# Fall back to architecture-based resolution
|
|
161
|
+
resolved = resolve_auto_adapter(model_id)
|
|
162
|
+
self._delegate = self._load_adapter(resolved)
|
|
36
163
|
return self._delegate
|
|
37
164
|
|
|
38
165
|
def _ensure_delegate_from_model(self, model: Any) -> ModelAdapter:
|
|
39
|
-
|
|
166
|
+
if self._delegate is not None:
|
|
167
|
+
return self._delegate
|
|
168
|
+
|
|
169
|
+
# First check for quantization on the loaded model
|
|
170
|
+
quant_adapter = _detect_quantization_from_model(model)
|
|
171
|
+
if quant_adapter:
|
|
172
|
+
self._delegate = self._load_adapter(quant_adapter)
|
|
173
|
+
return self._delegate
|
|
174
|
+
|
|
175
|
+
# Fall back to class name inspection
|
|
40
176
|
cls_name = getattr(model, "__class__", type(model)).__name__.lower()
|
|
41
177
|
if any(k in cls_name for k in ["llama", "mistral", "qwen", "yi"]):
|
|
42
|
-
|
|
43
|
-
".hf_llama", __package__
|
|
44
|
-
).HF_LLaMA_Adapter
|
|
45
|
-
self._delegate = HF_LLaMA_Adapter()
|
|
178
|
+
self._delegate = self._load_adapter("hf_llama")
|
|
46
179
|
elif any(k in cls_name for k in ["bert", "roberta", "albert", "deberta"]):
|
|
47
|
-
|
|
48
|
-
".hf_bert", __package__
|
|
49
|
-
).HF_BERT_Adapter
|
|
50
|
-
self._delegate = HF_BERT_Adapter()
|
|
180
|
+
self._delegate = self._load_adapter("hf_bert")
|
|
51
181
|
else:
|
|
52
|
-
|
|
53
|
-
".hf_gpt2", __package__
|
|
54
|
-
).HF_GPT2_Adapter
|
|
55
|
-
self._delegate = HF_GPT2_Adapter()
|
|
182
|
+
self._delegate = self._load_adapter("hf_gpt2")
|
|
56
183
|
return self._delegate
|
|
57
184
|
|
|
58
185
|
def can_handle(self, model: Any) -> bool: # pragma: no cover - trivial
|
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Capabilities
|
|
3
|
+
==================
|
|
4
|
+
|
|
5
|
+
Dataclasses for declaring model capabilities and quantization configuration.
|
|
6
|
+
Used by adapters to advertise model properties that affect device handling,
|
|
7
|
+
snapshot/restore behavior, and evaluation strategies.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from enum import Enum
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class QuantizationMethod(Enum):
|
|
18
|
+
"""Supported quantization methods."""
|
|
19
|
+
|
|
20
|
+
NONE = "none"
|
|
21
|
+
BNB_8BIT = "bnb_8bit"
|
|
22
|
+
BNB_4BIT = "bnb_4bit"
|
|
23
|
+
AWQ = "awq"
|
|
24
|
+
GPTQ = "gptq"
|
|
25
|
+
ONNX = "onnx"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class QuantizationConfig:
|
|
30
|
+
"""
|
|
31
|
+
Quantization configuration for a loaded model.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
method: The quantization method used.
|
|
35
|
+
bits: Bit-width of the quantization (e.g., 4, 8, 16).
|
|
36
|
+
group_size: Group size for grouped quantization (AWQ/GPTQ).
|
|
37
|
+
from_checkpoint: True if model was loaded from pre-quantized checkpoint.
|
|
38
|
+
double_quant: Whether double quantization is enabled (BNB 4-bit).
|
|
39
|
+
compute_dtype: Data type for computation (e.g., "float16", "bfloat16").
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
method: QuantizationMethod = QuantizationMethod.NONE
|
|
43
|
+
bits: int = 16
|
|
44
|
+
group_size: int | None = None
|
|
45
|
+
from_checkpoint: bool = False
|
|
46
|
+
double_quant: bool = False
|
|
47
|
+
compute_dtype: str | None = None
|
|
48
|
+
|
|
49
|
+
def is_quantized(self) -> bool:
|
|
50
|
+
"""Return True if the model is quantized."""
|
|
51
|
+
return self.method != QuantizationMethod.NONE
|
|
52
|
+
|
|
53
|
+
def is_bnb(self) -> bool:
|
|
54
|
+
"""Return True if using BitsAndBytes quantization."""
|
|
55
|
+
return self.method in (QuantizationMethod.BNB_8BIT, QuantizationMethod.BNB_4BIT)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ModelCapabilities:
|
|
60
|
+
"""
|
|
61
|
+
Declared capabilities of a loaded model.
|
|
62
|
+
|
|
63
|
+
Used to inform safe device handling, snapshot/restore strategies,
|
|
64
|
+
and evaluation metric selection.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
quantization: Quantization configuration (if any).
|
|
68
|
+
device_movable: Whether model.to(device) is safe to call.
|
|
69
|
+
False for BNB models which handle device placement internally.
|
|
70
|
+
weight_tied: Mapping of tied parameter names to their source.
|
|
71
|
+
Example: {"lm_head.weight": "model.embed_tokens.weight"}
|
|
72
|
+
primary_metric_kind: Primary evaluation metric type.
|
|
73
|
+
Examples: "ppl_causal", "ppl_mlm", "accuracy", "bleu".
|
|
74
|
+
supports_kv_cache: Whether model supports key-value caching.
|
|
75
|
+
supports_flash_attention: Whether model supports Flash Attention.
|
|
76
|
+
max_sequence_length: Maximum supported sequence length.
|
|
77
|
+
supports_gradient_checkpointing: Whether model supports gradient checkpointing.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
quantization: QuantizationConfig = field(
|
|
81
|
+
default_factory=lambda: QuantizationConfig()
|
|
82
|
+
)
|
|
83
|
+
device_movable: bool = True
|
|
84
|
+
weight_tied: dict[str, str] = field(default_factory=dict)
|
|
85
|
+
primary_metric_kind: str = "ppl_causal"
|
|
86
|
+
supports_kv_cache: bool = True
|
|
87
|
+
supports_flash_attention: bool = False
|
|
88
|
+
max_sequence_length: int | None = None
|
|
89
|
+
supports_gradient_checkpointing: bool = True
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def for_fp16_model(cls) -> ModelCapabilities:
|
|
93
|
+
"""Create capabilities for a standard FP16 model."""
|
|
94
|
+
return cls(
|
|
95
|
+
quantization=QuantizationConfig(method=QuantizationMethod.NONE, bits=16),
|
|
96
|
+
device_movable=True,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def for_bnb_8bit(cls, from_checkpoint: bool = False) -> ModelCapabilities:
|
|
101
|
+
"""Create capabilities for a BitsAndBytes 8-bit model."""
|
|
102
|
+
return cls(
|
|
103
|
+
quantization=QuantizationConfig(
|
|
104
|
+
method=QuantizationMethod.BNB_8BIT,
|
|
105
|
+
bits=8,
|
|
106
|
+
from_checkpoint=from_checkpoint,
|
|
107
|
+
),
|
|
108
|
+
device_movable=False, # BNB handles device placement
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def for_bnb_4bit(
|
|
113
|
+
cls, from_checkpoint: bool = False, double_quant: bool = False
|
|
114
|
+
) -> ModelCapabilities:
|
|
115
|
+
"""Create capabilities for a BitsAndBytes 4-bit model."""
|
|
116
|
+
return cls(
|
|
117
|
+
quantization=QuantizationConfig(
|
|
118
|
+
method=QuantizationMethod.BNB_4BIT,
|
|
119
|
+
bits=4,
|
|
120
|
+
from_checkpoint=from_checkpoint,
|
|
121
|
+
double_quant=double_quant,
|
|
122
|
+
),
|
|
123
|
+
device_movable=False, # BNB handles device placement
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def for_awq(
|
|
128
|
+
cls, group_size: int = 128, from_checkpoint: bool = True
|
|
129
|
+
) -> ModelCapabilities:
|
|
130
|
+
"""Create capabilities for an AWQ model."""
|
|
131
|
+
return cls(
|
|
132
|
+
quantization=QuantizationConfig(
|
|
133
|
+
method=QuantizationMethod.AWQ,
|
|
134
|
+
bits=4,
|
|
135
|
+
group_size=group_size,
|
|
136
|
+
from_checkpoint=from_checkpoint,
|
|
137
|
+
),
|
|
138
|
+
device_movable=False, # AWQ may have device constraints
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
def for_gptq(
|
|
143
|
+
cls, bits: int = 4, group_size: int = 128, from_checkpoint: bool = True
|
|
144
|
+
) -> ModelCapabilities:
|
|
145
|
+
"""Create capabilities for a GPTQ model."""
|
|
146
|
+
return cls(
|
|
147
|
+
quantization=QuantizationConfig(
|
|
148
|
+
method=QuantizationMethod.GPTQ,
|
|
149
|
+
bits=bits,
|
|
150
|
+
group_size=group_size,
|
|
151
|
+
from_checkpoint=from_checkpoint,
|
|
152
|
+
),
|
|
153
|
+
device_movable=False, # GPTQ may have device constraints
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def detect_quantization_from_config(config: Any) -> QuantizationConfig:
|
|
158
|
+
"""
|
|
159
|
+
Detect quantization configuration from a HuggingFace model config.
|
|
160
|
+
|
|
161
|
+
Checks for quantization_config in the model's config and returns
|
|
162
|
+
the appropriate QuantizationConfig.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
config: HuggingFace model config object
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
QuantizationConfig describing the model's quantization state
|
|
169
|
+
"""
|
|
170
|
+
if config is None:
|
|
171
|
+
return QuantizationConfig()
|
|
172
|
+
|
|
173
|
+
# Check for quantization_config attribute (BNB, AWQ, GPTQ)
|
|
174
|
+
quant_cfg = getattr(config, "quantization_config", None)
|
|
175
|
+
if quant_cfg is None:
|
|
176
|
+
return QuantizationConfig()
|
|
177
|
+
|
|
178
|
+
# Handle dict-style config (common in saved checkpoints)
|
|
179
|
+
if isinstance(quant_cfg, dict):
|
|
180
|
+
quant_method = quant_cfg.get("quant_method", "").lower()
|
|
181
|
+
load_in_8bit = quant_cfg.get("load_in_8bit", False)
|
|
182
|
+
load_in_4bit = quant_cfg.get("load_in_4bit", False)
|
|
183
|
+
bits = quant_cfg.get("bits", 16)
|
|
184
|
+
group_size = quant_cfg.get("group_size")
|
|
185
|
+
double_quant = quant_cfg.get("bnb_4bit_use_double_quant", False)
|
|
186
|
+
compute_dtype = quant_cfg.get("bnb_4bit_compute_dtype")
|
|
187
|
+
|
|
188
|
+
if quant_method == "awq":
|
|
189
|
+
return QuantizationConfig(
|
|
190
|
+
method=QuantizationMethod.AWQ,
|
|
191
|
+
bits=bits,
|
|
192
|
+
group_size=group_size,
|
|
193
|
+
from_checkpoint=True,
|
|
194
|
+
)
|
|
195
|
+
elif quant_method == "gptq":
|
|
196
|
+
return QuantizationConfig(
|
|
197
|
+
method=QuantizationMethod.GPTQ,
|
|
198
|
+
bits=bits,
|
|
199
|
+
group_size=group_size,
|
|
200
|
+
from_checkpoint=True,
|
|
201
|
+
)
|
|
202
|
+
elif load_in_8bit or quant_method == "bitsandbytes" and bits == 8:
|
|
203
|
+
return QuantizationConfig(
|
|
204
|
+
method=QuantizationMethod.BNB_8BIT,
|
|
205
|
+
bits=8,
|
|
206
|
+
from_checkpoint=True,
|
|
207
|
+
)
|
|
208
|
+
elif load_in_4bit or quant_method == "bitsandbytes" and bits == 4:
|
|
209
|
+
return QuantizationConfig(
|
|
210
|
+
method=QuantizationMethod.BNB_4BIT,
|
|
211
|
+
bits=4,
|
|
212
|
+
from_checkpoint=True,
|
|
213
|
+
double_quant=double_quant,
|
|
214
|
+
compute_dtype=str(compute_dtype) if compute_dtype else None,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Handle object-style config (e.g., BitsAndBytesConfig)
|
|
218
|
+
# Check by class name to avoid import dependency
|
|
219
|
+
cfg_class = quant_cfg.__class__.__name__
|
|
220
|
+
|
|
221
|
+
if cfg_class in ("BitsAndBytesConfig", "BnbConfig"):
|
|
222
|
+
load_in_8bit = getattr(quant_cfg, "load_in_8bit", False)
|
|
223
|
+
load_in_4bit = getattr(quant_cfg, "load_in_4bit", False)
|
|
224
|
+
double_quant = getattr(quant_cfg, "bnb_4bit_use_double_quant", False)
|
|
225
|
+
compute_dtype = getattr(quant_cfg, "bnb_4bit_compute_dtype", None)
|
|
226
|
+
|
|
227
|
+
if load_in_8bit:
|
|
228
|
+
return QuantizationConfig(
|
|
229
|
+
method=QuantizationMethod.BNB_8BIT,
|
|
230
|
+
bits=8,
|
|
231
|
+
from_checkpoint=True,
|
|
232
|
+
)
|
|
233
|
+
elif load_in_4bit:
|
|
234
|
+
return QuantizationConfig(
|
|
235
|
+
method=QuantizationMethod.BNB_4BIT,
|
|
236
|
+
bits=4,
|
|
237
|
+
from_checkpoint=True,
|
|
238
|
+
double_quant=double_quant,
|
|
239
|
+
compute_dtype=str(compute_dtype) if compute_dtype else None,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if cfg_class in ("AWQConfig",):
|
|
243
|
+
bits = getattr(quant_cfg, "bits", 4)
|
|
244
|
+
group_size = getattr(quant_cfg, "group_size", 128)
|
|
245
|
+
return QuantizationConfig(
|
|
246
|
+
method=QuantizationMethod.AWQ,
|
|
247
|
+
bits=bits,
|
|
248
|
+
group_size=group_size,
|
|
249
|
+
from_checkpoint=True,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if cfg_class in ("GPTQConfig",):
|
|
253
|
+
bits = getattr(quant_cfg, "bits", 4)
|
|
254
|
+
group_size = getattr(quant_cfg, "group_size", 128)
|
|
255
|
+
return QuantizationConfig(
|
|
256
|
+
method=QuantizationMethod.GPTQ,
|
|
257
|
+
bits=bits,
|
|
258
|
+
group_size=group_size,
|
|
259
|
+
from_checkpoint=True,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
return QuantizationConfig()
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def detect_capabilities_from_model(model: Any) -> ModelCapabilities:
|
|
266
|
+
"""
|
|
267
|
+
Detect model capabilities from a loaded model instance.
|
|
268
|
+
|
|
269
|
+
Inspects the model's config, state, and structure to determine
|
|
270
|
+
its capabilities including quantization state.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
model: Loaded model instance (typically HuggingFace PreTrainedModel)
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
ModelCapabilities describing the model's capabilities
|
|
277
|
+
"""
|
|
278
|
+
config = getattr(model, "config", None)
|
|
279
|
+
quant_config = detect_quantization_from_config(config)
|
|
280
|
+
|
|
281
|
+
# Check for BNB attributes on the model itself (may not be in config)
|
|
282
|
+
# Transformers sets these flags on loaded BNB models even if config.quantization_config
|
|
283
|
+
# doesn't reflect the quantization state (e.g., for saved BNB checkpoints)
|
|
284
|
+
# Note: We check `is True` explicitly to avoid MagicMock truthiness
|
|
285
|
+
if not quant_config.is_quantized():
|
|
286
|
+
is_8bit = getattr(model, "is_loaded_in_8bit", None)
|
|
287
|
+
is_4bit = getattr(model, "is_loaded_in_4bit", None)
|
|
288
|
+
if is_8bit is True:
|
|
289
|
+
quant_config = QuantizationConfig(
|
|
290
|
+
method=QuantizationMethod.BNB_8BIT,
|
|
291
|
+
bits=8,
|
|
292
|
+
from_checkpoint=True,
|
|
293
|
+
)
|
|
294
|
+
elif is_4bit is True:
|
|
295
|
+
quant_config = QuantizationConfig(
|
|
296
|
+
method=QuantizationMethod.BNB_4BIT,
|
|
297
|
+
bits=4,
|
|
298
|
+
from_checkpoint=True,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Also check for quantized module types that indicate BNB usage
|
|
302
|
+
# Only attempt this if model has a callable modules() method (torch.nn.Module)
|
|
303
|
+
if not quant_config.is_quantized():
|
|
304
|
+
modules_method = getattr(model, "modules", None)
|
|
305
|
+
if callable(modules_method):
|
|
306
|
+
try:
|
|
307
|
+
for module in modules_method():
|
|
308
|
+
module_name = module.__class__.__name__
|
|
309
|
+
if module_name in ("Linear8bitLt", "Linear4bit"):
|
|
310
|
+
if "8bit" in module_name:
|
|
311
|
+
quant_config = QuantizationConfig(
|
|
312
|
+
method=QuantizationMethod.BNB_8BIT,
|
|
313
|
+
bits=8,
|
|
314
|
+
from_checkpoint=True,
|
|
315
|
+
)
|
|
316
|
+
else:
|
|
317
|
+
quant_config = QuantizationConfig(
|
|
318
|
+
method=QuantizationMethod.BNB_4BIT,
|
|
319
|
+
bits=4,
|
|
320
|
+
from_checkpoint=True,
|
|
321
|
+
)
|
|
322
|
+
break
|
|
323
|
+
except (TypeError, StopIteration):
|
|
324
|
+
pass
|
|
325
|
+
|
|
326
|
+
# Determine if device is movable
|
|
327
|
+
device_movable = not quant_config.is_bnb()
|
|
328
|
+
|
|
329
|
+
# For AWQ/GPTQ, check if model has been quantized in a way that
|
|
330
|
+
# prevents device movement
|
|
331
|
+
if quant_config.method in (QuantizationMethod.AWQ, QuantizationMethod.GPTQ):
|
|
332
|
+
# These are typically loaded on-device and shouldn't be moved
|
|
333
|
+
device_movable = False
|
|
334
|
+
|
|
335
|
+
# Detect weight tying
|
|
336
|
+
weight_tied = _detect_weight_tying(model)
|
|
337
|
+
|
|
338
|
+
# Detect primary metric kind
|
|
339
|
+
primary_metric = _detect_primary_metric(model)
|
|
340
|
+
|
|
341
|
+
# Detect other capabilities
|
|
342
|
+
max_seq_len = getattr(config, "max_position_embeddings", None)
|
|
343
|
+
supports_flash = (
|
|
344
|
+
getattr(config, "_attn_implementation", None) == "flash_attention_2"
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return ModelCapabilities(
|
|
348
|
+
quantization=quant_config,
|
|
349
|
+
device_movable=device_movable,
|
|
350
|
+
weight_tied=weight_tied,
|
|
351
|
+
primary_metric_kind=primary_metric,
|
|
352
|
+
max_sequence_length=max_seq_len,
|
|
353
|
+
supports_flash_attention=supports_flash,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def _detect_weight_tying(model: Any) -> dict[str, str]:
|
|
358
|
+
"""Detect weight tying relationships in the model."""
|
|
359
|
+
tying: dict[str, str] = {}
|
|
360
|
+
|
|
361
|
+
# Common weight tying patterns
|
|
362
|
+
# LLaMA/Mistral: lm_head.weight ↔ model.embed_tokens.weight
|
|
363
|
+
if hasattr(model, "lm_head") and hasattr(model, "model"):
|
|
364
|
+
inner = model.model
|
|
365
|
+
if hasattr(inner, "embed_tokens"):
|
|
366
|
+
lm_head_weight = getattr(model.lm_head, "weight", None)
|
|
367
|
+
embed_weight = getattr(inner.embed_tokens, "weight", None)
|
|
368
|
+
if lm_head_weight is not None and embed_weight is not None:
|
|
369
|
+
if lm_head_weight is embed_weight:
|
|
370
|
+
tying["lm_head.weight"] = "model.embed_tokens.weight"
|
|
371
|
+
|
|
372
|
+
# GPT-2: lm_head.weight ↔ transformer.wte.weight
|
|
373
|
+
if hasattr(model, "lm_head") and hasattr(model, "transformer"):
|
|
374
|
+
xformer = model.transformer
|
|
375
|
+
if hasattr(xformer, "wte"):
|
|
376
|
+
lm_head_weight = getattr(model.lm_head, "weight", None)
|
|
377
|
+
wte_weight = getattr(xformer.wte, "weight", None)
|
|
378
|
+
if lm_head_weight is not None and wte_weight is not None:
|
|
379
|
+
if lm_head_weight is wte_weight:
|
|
380
|
+
tying["lm_head.weight"] = "transformer.wte.weight"
|
|
381
|
+
|
|
382
|
+
return tying
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def _detect_primary_metric(model: Any) -> str:
|
|
386
|
+
"""Detect the primary evaluation metric type for this model."""
|
|
387
|
+
config = getattr(model, "config", None)
|
|
388
|
+
if config is None:
|
|
389
|
+
return "ppl_causal"
|
|
390
|
+
|
|
391
|
+
model_type = getattr(config, "model_type", "").lower()
|
|
392
|
+
architectures = getattr(config, "architectures", []) or []
|
|
393
|
+
arch_str = " ".join(architectures).lower()
|
|
394
|
+
|
|
395
|
+
# Encoder-only models (BERT-like)
|
|
396
|
+
if any(k in model_type for k in ["bert", "roberta", "albert", "deberta"]):
|
|
397
|
+
if "masked" in arch_str or "mlm" in arch_str:
|
|
398
|
+
return "ppl_mlm"
|
|
399
|
+
if "classification" in arch_str or "sequence" in arch_str:
|
|
400
|
+
return "accuracy"
|
|
401
|
+
return "ppl_mlm"
|
|
402
|
+
|
|
403
|
+
# Encoder-decoder models (T5-like)
|
|
404
|
+
if any(k in model_type for k in ["t5", "bart", "marian", "pegasus"]):
|
|
405
|
+
if "translation" in arch_str or "mt" in arch_str:
|
|
406
|
+
return "bleu"
|
|
407
|
+
if "summarization" in arch_str:
|
|
408
|
+
return "rouge"
|
|
409
|
+
return "ppl_seq2seq"
|
|
410
|
+
|
|
411
|
+
# Decoder-only models (GPT-like, LLaMA-like)
|
|
412
|
+
return "ppl_causal"
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
__all__ = [
|
|
416
|
+
"QuantizationMethod",
|
|
417
|
+
"QuantizationConfig",
|
|
418
|
+
"ModelCapabilities",
|
|
419
|
+
"detect_quantization_from_config",
|
|
420
|
+
"detect_capabilities_from_model",
|
|
421
|
+
]
|
invarlock/adapters/hf_llama.py
CHANGED
|
@@ -69,8 +69,8 @@ class HF_LLaMA_Adapter(HFAdapterMixin, ModelAdapter):
|
|
|
69
69
|
):
|
|
70
70
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
return
|
|
72
|
+
# Use safe device movement that respects quantization constraints
|
|
73
|
+
return self._safe_to_device(model, device)
|
|
74
74
|
|
|
75
75
|
def can_handle(self, model: ModuleType | Any) -> bool:
|
|
76
76
|
"""
|