invarlock 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/profiles/ci_cpu.yaml +5 -0
- invarlock/adapters/__init__.py +13 -0
- invarlock/adapters/auto.py +149 -22
- invarlock/adapters/capabilities.py +421 -0
- invarlock/adapters/hf_llama.py +2 -2
- invarlock/adapters/hf_mixin.py +122 -1
- invarlock/cli/commands/doctor.py +7 -1
- invarlock/cli/commands/run.py +148 -2
- invarlock/core/registry.py +34 -6
- invarlock/guards/variance.py +41 -6
- invarlock/plugins/hf_awq_adapter.py +22 -1
- invarlock/plugins/hf_bnb_adapter.py +117 -22
- invarlock/plugins/hf_gptq_adapter.py +24 -1
- invarlock/reporting/certificate.py +155 -15
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/METADATA +2 -2
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/RECORD +21 -20
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/WHEEL +0 -0
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/entry_points.txt +0 -0
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.2.0.dist-info → invarlock-0.3.1.dist-info}/top_level.txt +0 -0
invarlock/adapters/hf_mixin.py
CHANGED
|
@@ -4,10 +4,12 @@ Shared HuggingFace adapter mixin.
|
|
|
4
4
|
|
|
5
5
|
Provides reusable functionality for InvarLock's HuggingFace adapters:
|
|
6
6
|
- Device resolution helpers
|
|
7
|
+
- Safe device movement for quantized models
|
|
7
8
|
- Snapshot/restore with device awareness
|
|
8
9
|
- Chunked snapshot helpers to reduce peak memory usage
|
|
9
10
|
- Lightweight config serialization
|
|
10
11
|
- Weight-tying detection plumbing
|
|
12
|
+
- Quantization detection and capabilities
|
|
11
13
|
"""
|
|
12
14
|
|
|
13
15
|
from __future__ import annotations
|
|
@@ -17,12 +19,15 @@ import json
|
|
|
17
19
|
import os
|
|
18
20
|
import tempfile
|
|
19
21
|
from pathlib import Path
|
|
20
|
-
from typing import Any
|
|
22
|
+
from typing import TYPE_CHECKING, Any
|
|
21
23
|
|
|
22
24
|
import torch
|
|
23
25
|
|
|
24
26
|
from invarlock.security import is_secure_path
|
|
25
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from .capabilities import ModelCapabilities, QuantizationConfig
|
|
30
|
+
|
|
26
31
|
SCALAR_TYPES = (int, float, str, bool)
|
|
27
32
|
|
|
28
33
|
|
|
@@ -91,6 +96,122 @@ class HFAdapterMixin:
|
|
|
91
96
|
|
|
92
97
|
return torch.device(device_str)
|
|
93
98
|
|
|
99
|
+
def _safe_to_device(
|
|
100
|
+
self,
|
|
101
|
+
model: torch.nn.Module,
|
|
102
|
+
device: str | torch.device | None = "auto",
|
|
103
|
+
capabilities: ModelCapabilities | None = None,
|
|
104
|
+
) -> torch.nn.Module:
|
|
105
|
+
"""
|
|
106
|
+
Safely move model to device, respecting quantization constraints.
|
|
107
|
+
|
|
108
|
+
For quantized models (BNB, AWQ, GPTQ), device movement may be
|
|
109
|
+
impossible or already handled by the loading mechanism. This
|
|
110
|
+
method checks the model's capabilities before attempting .to().
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
model: The model to move.
|
|
114
|
+
device: Target device ("auto", "cuda", "mps", "cpu").
|
|
115
|
+
capabilities: Pre-computed capabilities, or None to auto-detect.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
The model (possibly on the new device, or unchanged if not movable).
|
|
119
|
+
"""
|
|
120
|
+
target_device = self._resolve_device(device)
|
|
121
|
+
|
|
122
|
+
# Auto-detect capabilities if not provided
|
|
123
|
+
if capabilities is None:
|
|
124
|
+
capabilities = self._detect_capabilities(model)
|
|
125
|
+
|
|
126
|
+
# Check if model can be moved
|
|
127
|
+
if capabilities is not None and not capabilities.device_movable:
|
|
128
|
+
# Model handles its own device placement (e.g., BNB, AWQ, GPTQ)
|
|
129
|
+
# Log this decision for debugging but don't attempt .to()
|
|
130
|
+
return model
|
|
131
|
+
|
|
132
|
+
# Safe to move
|
|
133
|
+
return model.to(target_device)
|
|
134
|
+
|
|
135
|
+
def _detect_capabilities(self, model: torch.nn.Module) -> ModelCapabilities | None:
|
|
136
|
+
"""
|
|
137
|
+
Detect model capabilities from a loaded model instance.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
model: Loaded model instance.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
ModelCapabilities if detection succeeds, None otherwise.
|
|
144
|
+
"""
|
|
145
|
+
try:
|
|
146
|
+
from .capabilities import detect_capabilities_from_model
|
|
147
|
+
|
|
148
|
+
return detect_capabilities_from_model(model)
|
|
149
|
+
except ImportError:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def _is_quantized_model(self, model: torch.nn.Module) -> bool:
|
|
153
|
+
"""
|
|
154
|
+
Check if a model is quantized (BNB, AWQ, GPTQ).
|
|
155
|
+
|
|
156
|
+
This is a quick heuristic check that doesn't require full
|
|
157
|
+
capability detection.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
model: Model to check.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
True if the model appears to be quantized.
|
|
164
|
+
"""
|
|
165
|
+
config = getattr(model, "config", None)
|
|
166
|
+
if config is None:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
# Check for quantization_config attribute
|
|
170
|
+
quant_cfg = getattr(config, "quantization_config", None)
|
|
171
|
+
if quant_cfg is not None:
|
|
172
|
+
return True
|
|
173
|
+
|
|
174
|
+
# Check for BNB-specific attributes on the model
|
|
175
|
+
if hasattr(model, "is_loaded_in_8bit") and model.is_loaded_in_8bit:
|
|
176
|
+
return True
|
|
177
|
+
if hasattr(model, "is_loaded_in_4bit") and model.is_loaded_in_4bit:
|
|
178
|
+
return True
|
|
179
|
+
|
|
180
|
+
# Check for quantized module types in the model
|
|
181
|
+
for module in model.modules():
|
|
182
|
+
module_name = module.__class__.__name__.lower()
|
|
183
|
+
if any(
|
|
184
|
+
q in module_name
|
|
185
|
+
for q in ["linear8bit", "linear4bit", "quantlinear", "awqlinear"]
|
|
186
|
+
):
|
|
187
|
+
return True
|
|
188
|
+
|
|
189
|
+
return False
|
|
190
|
+
|
|
191
|
+
def _detect_quantization_config(
|
|
192
|
+
self, model: torch.nn.Module
|
|
193
|
+
) -> QuantizationConfig | None:
|
|
194
|
+
"""
|
|
195
|
+
Detect quantization configuration from a model.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
model: Model to inspect.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
QuantizationConfig if quantization detected, None otherwise.
|
|
202
|
+
"""
|
|
203
|
+
try:
|
|
204
|
+
from .capabilities import detect_quantization_from_config
|
|
205
|
+
|
|
206
|
+
config = getattr(model, "config", None)
|
|
207
|
+
if config is not None:
|
|
208
|
+
quant_cfg = detect_quantization_from_config(config)
|
|
209
|
+
if quant_cfg.is_quantized():
|
|
210
|
+
return quant_cfg
|
|
211
|
+
except ImportError:
|
|
212
|
+
pass
|
|
213
|
+
return None
|
|
214
|
+
|
|
94
215
|
# ------------------------------------------------------------------
|
|
95
216
|
# HF save/export helpers
|
|
96
217
|
# ------------------------------------------------------------------
|
invarlock/cli/commands/doctor.py
CHANGED
|
@@ -326,8 +326,14 @@ def doctor_command(
|
|
|
326
326
|
try:
|
|
327
327
|
import torch
|
|
328
328
|
|
|
329
|
+
torch_version = getattr(torch, "__version__", None)
|
|
329
330
|
if not json_out:
|
|
330
|
-
|
|
331
|
+
if torch_version:
|
|
332
|
+
console.print(f"[green]✅ PyTorch {torch_version}[/green]")
|
|
333
|
+
else:
|
|
334
|
+
console.print(
|
|
335
|
+
"[yellow]⚠️ PyTorch present but version unavailable[/yellow]"
|
|
336
|
+
)
|
|
331
337
|
|
|
332
338
|
# Device information
|
|
333
339
|
from ..device import get_device_info
|
invarlock/cli/commands/run.py
CHANGED
|
@@ -81,6 +81,137 @@ GUARD_OVERHEAD_THRESHOLD = 0.01
|
|
|
81
81
|
SPLIT_ALIASES: tuple[str, ...] = ("validation", "val", "dev", "eval", "test")
|
|
82
82
|
|
|
83
83
|
|
|
84
|
+
def _coerce_mapping(obj: object) -> dict[str, Any]:
|
|
85
|
+
"""Best-effort conversion of config-like objects to plain dicts."""
|
|
86
|
+
|
|
87
|
+
if isinstance(obj, dict):
|
|
88
|
+
return obj
|
|
89
|
+
try:
|
|
90
|
+
raw = getattr(obj, "_data", None)
|
|
91
|
+
if isinstance(raw, dict):
|
|
92
|
+
return raw
|
|
93
|
+
except Exception:
|
|
94
|
+
pass
|
|
95
|
+
try:
|
|
96
|
+
dumped = obj.model_dump() # type: ignore[attr-defined]
|
|
97
|
+
if isinstance(dumped, dict):
|
|
98
|
+
return dumped
|
|
99
|
+
except Exception:
|
|
100
|
+
pass
|
|
101
|
+
try:
|
|
102
|
+
data = vars(obj)
|
|
103
|
+
if isinstance(data, dict):
|
|
104
|
+
return data
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
return {}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _resolve_pm_acceptance_range(
|
|
111
|
+
cfg: InvarLockConfig | dict[str, Any] | None,
|
|
112
|
+
) -> dict[str, float]:
|
|
113
|
+
"""Resolve primary-metric acceptance bounds from config/env with safe defaults."""
|
|
114
|
+
|
|
115
|
+
base_min = 0.95
|
|
116
|
+
base_max = 1.10
|
|
117
|
+
|
|
118
|
+
cfg_min = None
|
|
119
|
+
cfg_max = None
|
|
120
|
+
try:
|
|
121
|
+
cfg_map = _coerce_mapping(cfg) if cfg is not None else {}
|
|
122
|
+
pm_section = cfg_map.get("primary_metric") if isinstance(cfg_map, dict) else {}
|
|
123
|
+
pm_map = _coerce_mapping(pm_section)
|
|
124
|
+
acceptance = (
|
|
125
|
+
pm_map.get("acceptance_range") if isinstance(pm_map, dict) else None
|
|
126
|
+
)
|
|
127
|
+
if isinstance(acceptance, dict):
|
|
128
|
+
if acceptance.get("min") is not None:
|
|
129
|
+
try:
|
|
130
|
+
cfg_min = float(acceptance["min"])
|
|
131
|
+
except (TypeError, ValueError):
|
|
132
|
+
cfg_min = None
|
|
133
|
+
if acceptance.get("max") is not None:
|
|
134
|
+
try:
|
|
135
|
+
cfg_max = float(acceptance["max"])
|
|
136
|
+
except (TypeError, ValueError):
|
|
137
|
+
cfg_max = None
|
|
138
|
+
except Exception:
|
|
139
|
+
cfg_min = None
|
|
140
|
+
cfg_max = None
|
|
141
|
+
|
|
142
|
+
def _parse_env(name: str) -> float | None:
|
|
143
|
+
try:
|
|
144
|
+
raw = os.environ.get(name, "")
|
|
145
|
+
if raw is None or str(raw).strip() == "":
|
|
146
|
+
return None
|
|
147
|
+
return float(raw)
|
|
148
|
+
except Exception:
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
env_min = _parse_env("INVARLOCK_PM_ACCEPTANCE_MIN")
|
|
152
|
+
env_max = _parse_env("INVARLOCK_PM_ACCEPTANCE_MAX")
|
|
153
|
+
|
|
154
|
+
has_explicit = any(v is not None for v in (cfg_min, cfg_max, env_min, env_max))
|
|
155
|
+
if not has_explicit:
|
|
156
|
+
return {}
|
|
157
|
+
|
|
158
|
+
min_val = (
|
|
159
|
+
env_min if env_min is not None else cfg_min if cfg_min is not None else base_min
|
|
160
|
+
)
|
|
161
|
+
max_val = (
|
|
162
|
+
env_max if env_max is not None else cfg_max if cfg_max is not None else base_max
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
if min_val is not None and min_val <= 0:
|
|
167
|
+
min_val = base_min
|
|
168
|
+
except Exception:
|
|
169
|
+
min_val = base_min
|
|
170
|
+
try:
|
|
171
|
+
if max_val is not None and max_val <= 0:
|
|
172
|
+
max_val = base_max
|
|
173
|
+
except Exception:
|
|
174
|
+
max_val = base_max
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
if max_val is not None and min_val is not None and max_val < min_val:
|
|
178
|
+
max_val = min_val
|
|
179
|
+
except Exception:
|
|
180
|
+
max_val = base_max
|
|
181
|
+
|
|
182
|
+
return {"min": float(min_val), "max": float(max_val)}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _free_model_memory(model: object | None) -> None:
|
|
186
|
+
"""Best-effort cleanup to release GPU memory for a model object."""
|
|
187
|
+
if model is None:
|
|
188
|
+
return
|
|
189
|
+
try:
|
|
190
|
+
import gc
|
|
191
|
+
|
|
192
|
+
del model
|
|
193
|
+
gc.collect()
|
|
194
|
+
if torch is not None and torch.cuda.is_available():
|
|
195
|
+
torch.cuda.empty_cache()
|
|
196
|
+
torch.cuda.synchronize()
|
|
197
|
+
except Exception:
|
|
198
|
+
# Cleanup should never raise; fallback is to proceed without cache purge
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _should_measure_overhead(profile_normalized: str) -> tuple[bool, bool]:
|
|
203
|
+
"""Return (measure_guard_overhead, skip_overhead) derived from env/profile."""
|
|
204
|
+
|
|
205
|
+
skip_overhead_env = (
|
|
206
|
+
os.environ.get("INVARLOCK_SKIP_OVERHEAD_CHECK", "").strip().lower()
|
|
207
|
+
)
|
|
208
|
+
skip_overhead = skip_overhead_env in {"1", "true", "yes"}
|
|
209
|
+
measure_guard_overhead = (
|
|
210
|
+
profile_normalized in {"ci", "release"} and not skip_overhead
|
|
211
|
+
)
|
|
212
|
+
return measure_guard_overhead, skip_overhead
|
|
213
|
+
|
|
214
|
+
|
|
84
215
|
def _choose_dataset_split(
|
|
85
216
|
*, requested: str | None, available: list[str] | None
|
|
86
217
|
) -> tuple[str, bool]:
|
|
@@ -1671,6 +1802,7 @@ def run_command(
|
|
|
1671
1802
|
"edit": edit_meta,
|
|
1672
1803
|
"guards": guard_metadata,
|
|
1673
1804
|
}
|
|
1805
|
+
pm_acceptance_range = _resolve_pm_acceptance_range(cfg)
|
|
1674
1806
|
|
|
1675
1807
|
console.print(f"🔌 Adapter: {adapter.name}")
|
|
1676
1808
|
|
|
@@ -1746,6 +1878,10 @@ def run_command(
|
|
|
1746
1878
|
"plugins": plugin_provenance,
|
|
1747
1879
|
"run_id": run_id,
|
|
1748
1880
|
}
|
|
1881
|
+
run_context.setdefault("primary_metric", {})["acceptance_range"] = (
|
|
1882
|
+
pm_acceptance_range
|
|
1883
|
+
)
|
|
1884
|
+
run_context["pm_acceptance_range"] = pm_acceptance_range
|
|
1749
1885
|
run_context["model_profile"] = {
|
|
1750
1886
|
"family": model_profile.family,
|
|
1751
1887
|
"default_loss": model_profile.default_loss,
|
|
@@ -2756,18 +2892,26 @@ def run_command(
|
|
|
2756
2892
|
|
|
2757
2893
|
restore_fn = _restore2
|
|
2758
2894
|
else:
|
|
2759
|
-
# reload path
|
|
2895
|
+
# reload path - properly free GPU memory before setting to None
|
|
2896
|
+
_free_model_memory(model)
|
|
2760
2897
|
model = None
|
|
2761
2898
|
restore_fn = None
|
|
2762
2899
|
except Exception:
|
|
2763
2900
|
# On any failure, fall back to reload-per-attempt path
|
|
2901
|
+
_free_model_memory(model)
|
|
2764
2902
|
model = None
|
|
2765
2903
|
restore_fn = None
|
|
2766
2904
|
|
|
2767
2905
|
# RETRY LOOP - All report processing inside loop
|
|
2768
2906
|
attempt = 1
|
|
2769
2907
|
profile_normalized = (profile or "").lower()
|
|
2770
|
-
measure_guard_overhead =
|
|
2908
|
+
measure_guard_overhead, skip_overhead = _should_measure_overhead(
|
|
2909
|
+
profile_normalized
|
|
2910
|
+
)
|
|
2911
|
+
if skip_overhead and profile_normalized in {"ci", "release"}:
|
|
2912
|
+
console.print(
|
|
2913
|
+
"[yellow]⚠️ Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK[/yellow]"
|
|
2914
|
+
)
|
|
2771
2915
|
|
|
2772
2916
|
while True:
|
|
2773
2917
|
# Reset RNG streams each attempt to guarantee determinism across retries
|
|
@@ -2933,6 +3077,8 @@ def run_command(
|
|
|
2933
3077
|
if env_flags:
|
|
2934
3078
|
meta_payload["env_flags"] = env_flags
|
|
2935
3079
|
report["meta"].update(meta_payload)
|
|
3080
|
+
if pm_acceptance_range:
|
|
3081
|
+
report["meta"]["pm_acceptance_range"] = pm_acceptance_range
|
|
2936
3082
|
report["meta"]["model_profile"] = {
|
|
2937
3083
|
"family": model_profile.family,
|
|
2938
3084
|
"default_loss": model_profile.default_loss,
|
invarlock/core/registry.py
CHANGED
|
@@ -117,14 +117,24 @@ class CoreRegistry:
|
|
|
117
117
|
module: str,
|
|
118
118
|
class_name: str,
|
|
119
119
|
status: str = "Available (fallback)",
|
|
120
|
+
required_deps: list[str] | None = None,
|
|
120
121
|
) -> None:
|
|
121
122
|
if name not in registry:
|
|
123
|
+
# Check runtime dependencies for optional plugins
|
|
124
|
+
actual_available = True
|
|
125
|
+
actual_status = status
|
|
126
|
+
if required_deps:
|
|
127
|
+
missing = self._check_runtime_dependencies(required_deps)
|
|
128
|
+
if missing:
|
|
129
|
+
actual_available = False
|
|
130
|
+
actual_status = f"Needs extra: {', '.join(missing)}"
|
|
131
|
+
|
|
122
132
|
registry[name] = PluginInfo(
|
|
123
133
|
name=name,
|
|
124
134
|
module=module,
|
|
125
135
|
class_name=class_name,
|
|
126
|
-
available=
|
|
127
|
-
status=
|
|
136
|
+
available=actual_available,
|
|
137
|
+
status=actual_status,
|
|
128
138
|
package="invarlock",
|
|
129
139
|
version=INVARLOCK_VERSION,
|
|
130
140
|
)
|
|
@@ -147,27 +157,30 @@ class CoreRegistry:
|
|
|
147
157
|
_fallback(
|
|
148
158
|
self._adapters, "hf_mlm_auto", "invarlock.adapters", "HF_MLM_Auto_Adapter"
|
|
149
159
|
)
|
|
150
|
-
# Optional plugin adapters (
|
|
160
|
+
# Optional plugin adapters (verify runtime dependencies)
|
|
151
161
|
_fallback(
|
|
152
162
|
self._adapters,
|
|
153
163
|
"hf_gptq",
|
|
154
164
|
"invarlock.plugins.hf_gptq_adapter",
|
|
155
165
|
"HF_GPTQ_Adapter",
|
|
156
|
-
status="Available (
|
|
166
|
+
status="Available (plugin)",
|
|
167
|
+
required_deps=["auto_gptq"],
|
|
157
168
|
)
|
|
158
169
|
_fallback(
|
|
159
170
|
self._adapters,
|
|
160
171
|
"hf_awq",
|
|
161
172
|
"invarlock.plugins.hf_awq_adapter",
|
|
162
173
|
"HF_AWQ_Adapter",
|
|
163
|
-
status="Available (
|
|
174
|
+
status="Available (plugin)",
|
|
175
|
+
required_deps=["autoawq"],
|
|
164
176
|
)
|
|
165
177
|
_fallback(
|
|
166
178
|
self._adapters,
|
|
167
179
|
"hf_bnb",
|
|
168
180
|
"invarlock.plugins.hf_bnb_adapter",
|
|
169
181
|
"HF_BNB_Adapter",
|
|
170
|
-
status="Available (
|
|
182
|
+
status="Available (plugin)",
|
|
183
|
+
required_deps=["bitsandbytes"],
|
|
171
184
|
)
|
|
172
185
|
|
|
173
186
|
# Register built-in edits (quant-only core) and internal no-op
|
|
@@ -181,6 +194,21 @@ class CoreRegistry:
|
|
|
181
194
|
_fallback(self._guards, "rmt", "invarlock.guards", "RMTGuard")
|
|
182
195
|
_fallback(self._guards, "hello_guard", "invarlock.plugins", "HelloGuard")
|
|
183
196
|
|
|
197
|
+
def _check_runtime_dependencies(self, deps: list[str]) -> list[str]:
|
|
198
|
+
"""
|
|
199
|
+
Check if runtime dependencies are actually importable.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
List of missing dependency names.
|
|
203
|
+
"""
|
|
204
|
+
missing = []
|
|
205
|
+
for dep in deps:
|
|
206
|
+
try:
|
|
207
|
+
importlib.import_module(dep)
|
|
208
|
+
except ImportError:
|
|
209
|
+
missing.append(dep)
|
|
210
|
+
return missing
|
|
211
|
+
|
|
184
212
|
def _create_plugin_info(
|
|
185
213
|
self, entry_point: EntryPoint, plugin_type: str
|
|
186
214
|
) -> PluginInfo:
|
invarlock/guards/variance.py
CHANGED
|
@@ -39,6 +39,30 @@ from .policies import VariancePolicyDict
|
|
|
39
39
|
__all__ = ["equalise_residual_variance", "equalise_branch_variance", "VarianceGuard"]
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
def _safe_mean(
|
|
43
|
+
samples: list[float] | np.ndarray, default: float | None = None
|
|
44
|
+
) -> float | None:
|
|
45
|
+
"""
|
|
46
|
+
Compute mean of samples, returning default if empty.
|
|
47
|
+
|
|
48
|
+
Avoids numpy RuntimeWarning "Mean of empty slice" when samples is empty
|
|
49
|
+
or contains no valid values.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
samples: List or array of float values.
|
|
53
|
+
default: Value to return if samples is empty.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Mean value or default if samples is empty.
|
|
57
|
+
"""
|
|
58
|
+
if samples is None:
|
|
59
|
+
return default
|
|
60
|
+
arr = np.asarray(samples)
|
|
61
|
+
if arr.size == 0:
|
|
62
|
+
return default
|
|
63
|
+
return float(np.nanmean(arr))
|
|
64
|
+
|
|
65
|
+
|
|
42
66
|
try: # Optional dependency: tqdm (progress bars)
|
|
43
67
|
from tqdm.auto import tqdm as _tqdm
|
|
44
68
|
except Exception: # pragma: no cover - exercised only when tqdm is absent
|
|
@@ -1472,7 +1496,14 @@ class VarianceGuard(Guard):
|
|
|
1472
1496
|
|
|
1473
1497
|
if coverage >= min_coverage and not self._scales:
|
|
1474
1498
|
ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
|
|
1475
|
-
ppl_no_ve_mean =
|
|
1499
|
+
ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples)
|
|
1500
|
+
if ppl_no_ve_mean is None:
|
|
1501
|
+
# No valid samples - cannot compute mean
|
|
1502
|
+
self._ratio_ci = None
|
|
1503
|
+
predictive_state["reason"] = "no_valid_samples"
|
|
1504
|
+
self._predictive_gate_state = predictive_state
|
|
1505
|
+
self._stats["predictive_gate"] = predictive_state.copy()
|
|
1506
|
+
return
|
|
1476
1507
|
self.set_ab_results(
|
|
1477
1508
|
ppl_no_ve=ppl_no_ve_mean,
|
|
1478
1509
|
ppl_with_ve=ppl_no_ve_mean,
|
|
@@ -1527,8 +1558,12 @@ class VarianceGuard(Guard):
|
|
|
1527
1558
|
n_bootstrap=500,
|
|
1528
1559
|
seed=calib_seed,
|
|
1529
1560
|
)
|
|
1530
|
-
ppl_no_ve_mean =
|
|
1531
|
-
ppl_with_ve_mean =
|
|
1561
|
+
ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples)
|
|
1562
|
+
ppl_with_ve_mean = _safe_mean(ppl_with_ve_samples)
|
|
1563
|
+
if ppl_no_ve_mean is None or ppl_with_ve_mean is None:
|
|
1564
|
+
# Fallback if means couldn't be computed
|
|
1565
|
+
ppl_no_ve_mean = ppl_no_ve_mean or 0.0
|
|
1566
|
+
ppl_with_ve_mean = ppl_with_ve_mean or 0.0
|
|
1532
1567
|
self.set_ab_results(
|
|
1533
1568
|
ppl_no_ve=ppl_no_ve_mean,
|
|
1534
1569
|
ppl_with_ve=ppl_with_ve_mean,
|
|
@@ -2118,7 +2153,7 @@ class VarianceGuard(Guard):
|
|
|
2118
2153
|
|
|
2119
2154
|
if coverage >= min_coverage and not self._scales:
|
|
2120
2155
|
ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
|
|
2121
|
-
ppl_no_ve_mean =
|
|
2156
|
+
ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples, default=0.0)
|
|
2122
2157
|
self.set_ab_results(
|
|
2123
2158
|
ppl_no_ve=ppl_no_ve_mean,
|
|
2124
2159
|
ppl_with_ve=ppl_no_ve_mean,
|
|
@@ -2158,8 +2193,8 @@ class VarianceGuard(Guard):
|
|
|
2158
2193
|
n_bootstrap=500,
|
|
2159
2194
|
seed=calib_seed,
|
|
2160
2195
|
)
|
|
2161
|
-
ppl_no_ve_mean =
|
|
2162
|
-
ppl_with_ve_mean =
|
|
2196
|
+
ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples, default=0.0)
|
|
2197
|
+
ppl_with_ve_mean = _safe_mean(ppl_with_ve_samples, default=0.0)
|
|
2163
2198
|
self.set_ab_results(
|
|
2164
2199
|
ppl_no_ve=ppl_no_ve_mean,
|
|
2165
2200
|
ppl_with_ve=ppl_with_ve_mean,
|
|
@@ -4,12 +4,16 @@ HuggingFace AWQ Adapter (plugin)
|
|
|
4
4
|
|
|
5
5
|
Optional adapter for loading AWQ-quantized causal LMs from the Hub.
|
|
6
6
|
Requires the `autoawq` extra on supported platforms (typically Linux/CUDA).
|
|
7
|
+
|
|
8
|
+
AWQ models are pre-quantized and typically handle device placement internally
|
|
9
|
+
during loading. This adapter does NOT call .to() on the loaded model.
|
|
7
10
|
"""
|
|
8
11
|
|
|
9
12
|
from __future__ import annotations
|
|
10
13
|
|
|
11
14
|
from typing import Any
|
|
12
15
|
|
|
16
|
+
from invarlock.adapters.capabilities import ModelCapabilities
|
|
13
17
|
from invarlock.adapters.hf_mixin import HFAdapterMixin
|
|
14
18
|
from invarlock.core.api import ModelAdapter
|
|
15
19
|
from invarlock.core.error_utils import wrap_errors
|
|
@@ -56,7 +60,24 @@ class HF_AWQ_Adapter(HFAdapterMixin, ModelAdapter):
|
|
|
56
60
|
trust_remote_code=True,
|
|
57
61
|
**{k: v for k, v in kwargs.items() if k != "device"},
|
|
58
62
|
)
|
|
59
|
-
|
|
63
|
+
|
|
64
|
+
# AWQ models are pre-quantized; use safe device movement
|
|
65
|
+
# which respects the model's device constraints
|
|
66
|
+
return self._safe_to_device(
|
|
67
|
+
model, device, capabilities=ModelCapabilities.for_awq()
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def get_capabilities(self, model: Any) -> ModelCapabilities:
|
|
71
|
+
"""Return capabilities for an AWQ-quantized model."""
|
|
72
|
+
config = getattr(model, "config", None)
|
|
73
|
+
group_size = 128 # Default AWQ group size
|
|
74
|
+
if config is not None:
|
|
75
|
+
quant_cfg = getattr(config, "quantization_config", None)
|
|
76
|
+
if isinstance(quant_cfg, dict):
|
|
77
|
+
group_size = quant_cfg.get("group_size", 128)
|
|
78
|
+
elif quant_cfg is not None:
|
|
79
|
+
group_size = getattr(quant_cfg, "group_size", 128)
|
|
80
|
+
return ModelCapabilities.for_awq(group_size=group_size)
|
|
60
81
|
|
|
61
82
|
def can_handle(self, model: Any) -> bool:
|
|
62
83
|
cfg = getattr(model, "config", None)
|