invarlock 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,10 +4,12 @@ Shared HuggingFace adapter mixin.
4
4
 
5
5
  Provides reusable functionality for InvarLock's HuggingFace adapters:
6
6
  - Device resolution helpers
7
+ - Safe device movement for quantized models
7
8
  - Snapshot/restore with device awareness
8
9
  - Chunked snapshot helpers to reduce peak memory usage
9
10
  - Lightweight config serialization
10
11
  - Weight-tying detection plumbing
12
+ - Quantization detection and capabilities
11
13
  """
12
14
 
13
15
  from __future__ import annotations
@@ -17,12 +19,15 @@ import json
17
19
  import os
18
20
  import tempfile
19
21
  from pathlib import Path
20
- from typing import Any
22
+ from typing import TYPE_CHECKING, Any
21
23
 
22
24
  import torch
23
25
 
24
26
  from invarlock.security import is_secure_path
25
27
 
28
+ if TYPE_CHECKING:
29
+ from .capabilities import ModelCapabilities, QuantizationConfig
30
+
26
31
  SCALAR_TYPES = (int, float, str, bool)
27
32
 
28
33
 
@@ -91,6 +96,122 @@ class HFAdapterMixin:
91
96
 
92
97
  return torch.device(device_str)
93
98
 
99
+ def _safe_to_device(
100
+ self,
101
+ model: torch.nn.Module,
102
+ device: str | torch.device | None = "auto",
103
+ capabilities: ModelCapabilities | None = None,
104
+ ) -> torch.nn.Module:
105
+ """
106
+ Safely move model to device, respecting quantization constraints.
107
+
108
+ For quantized models (BNB, AWQ, GPTQ), device movement may be
109
+ impossible or already handled by the loading mechanism. This
110
+ method checks the model's capabilities before attempting .to().
111
+
112
+ Args:
113
+ model: The model to move.
114
+ device: Target device ("auto", "cuda", "mps", "cpu").
115
+ capabilities: Pre-computed capabilities, or None to auto-detect.
116
+
117
+ Returns:
118
+ The model (possibly on the new device, or unchanged if not movable).
119
+ """
120
+ target_device = self._resolve_device(device)
121
+
122
+ # Auto-detect capabilities if not provided
123
+ if capabilities is None:
124
+ capabilities = self._detect_capabilities(model)
125
+
126
+ # Check if model can be moved
127
+ if capabilities is not None and not capabilities.device_movable:
128
+ # Model handles its own device placement (e.g., BNB, AWQ, GPTQ)
129
+ # Log this decision for debugging but don't attempt .to()
130
+ return model
131
+
132
+ # Safe to move
133
+ return model.to(target_device)
134
+
135
+ def _detect_capabilities(self, model: torch.nn.Module) -> ModelCapabilities | None:
136
+ """
137
+ Detect model capabilities from a loaded model instance.
138
+
139
+ Args:
140
+ model: Loaded model instance.
141
+
142
+ Returns:
143
+ ModelCapabilities if detection succeeds, None otherwise.
144
+ """
145
+ try:
146
+ from .capabilities import detect_capabilities_from_model
147
+
148
+ return detect_capabilities_from_model(model)
149
+ except ImportError:
150
+ return None
151
+
152
+ def _is_quantized_model(self, model: torch.nn.Module) -> bool:
153
+ """
154
+ Check if a model is quantized (BNB, AWQ, GPTQ).
155
+
156
+ This is a quick heuristic check that doesn't require full
157
+ capability detection.
158
+
159
+ Args:
160
+ model: Model to check.
161
+
162
+ Returns:
163
+ True if the model appears to be quantized.
164
+ """
165
+ config = getattr(model, "config", None)
166
+ if config is None:
167
+ return False
168
+
169
+ # Check for quantization_config attribute
170
+ quant_cfg = getattr(config, "quantization_config", None)
171
+ if quant_cfg is not None:
172
+ return True
173
+
174
+ # Check for BNB-specific attributes on the model
175
+ if hasattr(model, "is_loaded_in_8bit") and model.is_loaded_in_8bit:
176
+ return True
177
+ if hasattr(model, "is_loaded_in_4bit") and model.is_loaded_in_4bit:
178
+ return True
179
+
180
+ # Check for quantized module types in the model
181
+ for module in model.modules():
182
+ module_name = module.__class__.__name__.lower()
183
+ if any(
184
+ q in module_name
185
+ for q in ["linear8bit", "linear4bit", "quantlinear", "awqlinear"]
186
+ ):
187
+ return True
188
+
189
+ return False
190
+
191
+ def _detect_quantization_config(
192
+ self, model: torch.nn.Module
193
+ ) -> QuantizationConfig | None:
194
+ """
195
+ Detect quantization configuration from a model.
196
+
197
+ Args:
198
+ model: Model to inspect.
199
+
200
+ Returns:
201
+ QuantizationConfig if quantization detected, None otherwise.
202
+ """
203
+ try:
204
+ from .capabilities import detect_quantization_from_config
205
+
206
+ config = getattr(model, "config", None)
207
+ if config is not None:
208
+ quant_cfg = detect_quantization_from_config(config)
209
+ if quant_cfg.is_quantized():
210
+ return quant_cfg
211
+ except ImportError:
212
+ pass
213
+ return None
214
+
94
215
  # ------------------------------------------------------------------
95
216
  # HF save/export helpers
96
217
  # ------------------------------------------------------------------
@@ -326,8 +326,14 @@ def doctor_command(
326
326
  try:
327
327
  import torch
328
328
 
329
+ torch_version = getattr(torch, "__version__", None)
329
330
  if not json_out:
330
- console.print(f"[green]✅ PyTorch {torch.__version__}[/green]")
331
+ if torch_version:
332
+ console.print(f"[green]✅ PyTorch {torch_version}[/green]")
333
+ else:
334
+ console.print(
335
+ "[yellow]⚠️ PyTorch present but version unavailable[/yellow]"
336
+ )
331
337
 
332
338
  # Device information
333
339
  from ..device import get_device_info
@@ -81,6 +81,137 @@ GUARD_OVERHEAD_THRESHOLD = 0.01
81
81
  SPLIT_ALIASES: tuple[str, ...] = ("validation", "val", "dev", "eval", "test")
82
82
 
83
83
 
84
+ def _coerce_mapping(obj: object) -> dict[str, Any]:
85
+ """Best-effort conversion of config-like objects to plain dicts."""
86
+
87
+ if isinstance(obj, dict):
88
+ return obj
89
+ try:
90
+ raw = getattr(obj, "_data", None)
91
+ if isinstance(raw, dict):
92
+ return raw
93
+ except Exception:
94
+ pass
95
+ try:
96
+ dumped = obj.model_dump() # type: ignore[attr-defined]
97
+ if isinstance(dumped, dict):
98
+ return dumped
99
+ except Exception:
100
+ pass
101
+ try:
102
+ data = vars(obj)
103
+ if isinstance(data, dict):
104
+ return data
105
+ except Exception:
106
+ pass
107
+ return {}
108
+
109
+
110
+ def _resolve_pm_acceptance_range(
111
+ cfg: InvarLockConfig | dict[str, Any] | None,
112
+ ) -> dict[str, float]:
113
+ """Resolve primary-metric acceptance bounds from config/env with safe defaults."""
114
+
115
+ base_min = 0.95
116
+ base_max = 1.10
117
+
118
+ cfg_min = None
119
+ cfg_max = None
120
+ try:
121
+ cfg_map = _coerce_mapping(cfg) if cfg is not None else {}
122
+ pm_section = cfg_map.get("primary_metric") if isinstance(cfg_map, dict) else {}
123
+ pm_map = _coerce_mapping(pm_section)
124
+ acceptance = (
125
+ pm_map.get("acceptance_range") if isinstance(pm_map, dict) else None
126
+ )
127
+ if isinstance(acceptance, dict):
128
+ if acceptance.get("min") is not None:
129
+ try:
130
+ cfg_min = float(acceptance["min"])
131
+ except (TypeError, ValueError):
132
+ cfg_min = None
133
+ if acceptance.get("max") is not None:
134
+ try:
135
+ cfg_max = float(acceptance["max"])
136
+ except (TypeError, ValueError):
137
+ cfg_max = None
138
+ except Exception:
139
+ cfg_min = None
140
+ cfg_max = None
141
+
142
+ def _parse_env(name: str) -> float | None:
143
+ try:
144
+ raw = os.environ.get(name, "")
145
+ if raw is None or str(raw).strip() == "":
146
+ return None
147
+ return float(raw)
148
+ except Exception:
149
+ return None
150
+
151
+ env_min = _parse_env("INVARLOCK_PM_ACCEPTANCE_MIN")
152
+ env_max = _parse_env("INVARLOCK_PM_ACCEPTANCE_MAX")
153
+
154
+ has_explicit = any(v is not None for v in (cfg_min, cfg_max, env_min, env_max))
155
+ if not has_explicit:
156
+ return {}
157
+
158
+ min_val = (
159
+ env_min if env_min is not None else cfg_min if cfg_min is not None else base_min
160
+ )
161
+ max_val = (
162
+ env_max if env_max is not None else cfg_max if cfg_max is not None else base_max
163
+ )
164
+
165
+ try:
166
+ if min_val is not None and min_val <= 0:
167
+ min_val = base_min
168
+ except Exception:
169
+ min_val = base_min
170
+ try:
171
+ if max_val is not None and max_val <= 0:
172
+ max_val = base_max
173
+ except Exception:
174
+ max_val = base_max
175
+
176
+ try:
177
+ if max_val is not None and min_val is not None and max_val < min_val:
178
+ max_val = min_val
179
+ except Exception:
180
+ max_val = base_max
181
+
182
+ return {"min": float(min_val), "max": float(max_val)}
183
+
184
+
185
+ def _free_model_memory(model: object | None) -> None:
186
+ """Best-effort cleanup to release GPU memory for a model object."""
187
+ if model is None:
188
+ return
189
+ try:
190
+ import gc
191
+
192
+ del model
193
+ gc.collect()
194
+ if torch is not None and torch.cuda.is_available():
195
+ torch.cuda.empty_cache()
196
+ torch.cuda.synchronize()
197
+ except Exception:
198
+ # Cleanup should never raise; fallback is to proceed without cache purge
199
+ pass
200
+
201
+
202
+ def _should_measure_overhead(profile_normalized: str) -> tuple[bool, bool]:
203
+ """Return (measure_guard_overhead, skip_overhead) derived from env/profile."""
204
+
205
+ skip_overhead_env = (
206
+ os.environ.get("INVARLOCK_SKIP_OVERHEAD_CHECK", "").strip().lower()
207
+ )
208
+ skip_overhead = skip_overhead_env in {"1", "true", "yes"}
209
+ measure_guard_overhead = (
210
+ profile_normalized in {"ci", "release"} and not skip_overhead
211
+ )
212
+ return measure_guard_overhead, skip_overhead
213
+
214
+
84
215
  def _choose_dataset_split(
85
216
  *, requested: str | None, available: list[str] | None
86
217
  ) -> tuple[str, bool]:
@@ -1671,6 +1802,7 @@ def run_command(
1671
1802
  "edit": edit_meta,
1672
1803
  "guards": guard_metadata,
1673
1804
  }
1805
+ pm_acceptance_range = _resolve_pm_acceptance_range(cfg)
1674
1806
 
1675
1807
  console.print(f"🔌 Adapter: {adapter.name}")
1676
1808
 
@@ -1746,6 +1878,10 @@ def run_command(
1746
1878
  "plugins": plugin_provenance,
1747
1879
  "run_id": run_id,
1748
1880
  }
1881
+ run_context.setdefault("primary_metric", {})["acceptance_range"] = (
1882
+ pm_acceptance_range
1883
+ )
1884
+ run_context["pm_acceptance_range"] = pm_acceptance_range
1749
1885
  run_context["model_profile"] = {
1750
1886
  "family": model_profile.family,
1751
1887
  "default_loss": model_profile.default_loss,
@@ -2756,18 +2892,26 @@ def run_command(
2756
2892
 
2757
2893
  restore_fn = _restore2
2758
2894
  else:
2759
- # reload path
2895
+ # reload path - properly free GPU memory before setting to None
2896
+ _free_model_memory(model)
2760
2897
  model = None
2761
2898
  restore_fn = None
2762
2899
  except Exception:
2763
2900
  # On any failure, fall back to reload-per-attempt path
2901
+ _free_model_memory(model)
2764
2902
  model = None
2765
2903
  restore_fn = None
2766
2904
 
2767
2905
  # RETRY LOOP - All report processing inside loop
2768
2906
  attempt = 1
2769
2907
  profile_normalized = (profile or "").lower()
2770
- measure_guard_overhead = profile_normalized in {"ci", "release"}
2908
+ measure_guard_overhead, skip_overhead = _should_measure_overhead(
2909
+ profile_normalized
2910
+ )
2911
+ if skip_overhead and profile_normalized in {"ci", "release"}:
2912
+ console.print(
2913
+ "[yellow]⚠️ Overhead check skipped via INVARLOCK_SKIP_OVERHEAD_CHECK[/yellow]"
2914
+ )
2771
2915
 
2772
2916
  while True:
2773
2917
  # Reset RNG streams each attempt to guarantee determinism across retries
@@ -2933,6 +3077,8 @@ def run_command(
2933
3077
  if env_flags:
2934
3078
  meta_payload["env_flags"] = env_flags
2935
3079
  report["meta"].update(meta_payload)
3080
+ if pm_acceptance_range:
3081
+ report["meta"]["pm_acceptance_range"] = pm_acceptance_range
2936
3082
  report["meta"]["model_profile"] = {
2937
3083
  "family": model_profile.family,
2938
3084
  "default_loss": model_profile.default_loss,
@@ -117,14 +117,24 @@ class CoreRegistry:
117
117
  module: str,
118
118
  class_name: str,
119
119
  status: str = "Available (fallback)",
120
+ required_deps: list[str] | None = None,
120
121
  ) -> None:
121
122
  if name not in registry:
123
+ # Check runtime dependencies for optional plugins
124
+ actual_available = True
125
+ actual_status = status
126
+ if required_deps:
127
+ missing = self._check_runtime_dependencies(required_deps)
128
+ if missing:
129
+ actual_available = False
130
+ actual_status = f"Needs extra: {', '.join(missing)}"
131
+
122
132
  registry[name] = PluginInfo(
123
133
  name=name,
124
134
  module=module,
125
135
  class_name=class_name,
126
- available=True,
127
- status=status,
136
+ available=actual_available,
137
+ status=actual_status,
128
138
  package="invarlock",
129
139
  version=INVARLOCK_VERSION,
130
140
  )
@@ -147,27 +157,30 @@ class CoreRegistry:
147
157
  _fallback(
148
158
  self._adapters, "hf_mlm_auto", "invarlock.adapters", "HF_MLM_Auto_Adapter"
149
159
  )
150
- # Optional plugin adapters (available when modules present)
160
+ # Optional plugin adapters (verify runtime dependencies)
151
161
  _fallback(
152
162
  self._adapters,
153
163
  "hf_gptq",
154
164
  "invarlock.plugins.hf_gptq_adapter",
155
165
  "HF_GPTQ_Adapter",
156
- status="Available (fallback plugin)",
166
+ status="Available (plugin)",
167
+ required_deps=["auto_gptq"],
157
168
  )
158
169
  _fallback(
159
170
  self._adapters,
160
171
  "hf_awq",
161
172
  "invarlock.plugins.hf_awq_adapter",
162
173
  "HF_AWQ_Adapter",
163
- status="Available (fallback plugin)",
174
+ status="Available (plugin)",
175
+ required_deps=["autoawq"],
164
176
  )
165
177
  _fallback(
166
178
  self._adapters,
167
179
  "hf_bnb",
168
180
  "invarlock.plugins.hf_bnb_adapter",
169
181
  "HF_BNB_Adapter",
170
- status="Available (fallback plugin)",
182
+ status="Available (plugin)",
183
+ required_deps=["bitsandbytes"],
171
184
  )
172
185
 
173
186
  # Register built-in edits (quant-only core) and internal no-op
@@ -181,6 +194,21 @@ class CoreRegistry:
181
194
  _fallback(self._guards, "rmt", "invarlock.guards", "RMTGuard")
182
195
  _fallback(self._guards, "hello_guard", "invarlock.plugins", "HelloGuard")
183
196
 
197
+ def _check_runtime_dependencies(self, deps: list[str]) -> list[str]:
198
+ """
199
+ Check if runtime dependencies are actually importable.
200
+
201
+ Returns:
202
+ List of missing dependency names.
203
+ """
204
+ missing = []
205
+ for dep in deps:
206
+ try:
207
+ importlib.import_module(dep)
208
+ except ImportError:
209
+ missing.append(dep)
210
+ return missing
211
+
184
212
  def _create_plugin_info(
185
213
  self, entry_point: EntryPoint, plugin_type: str
186
214
  ) -> PluginInfo:
@@ -39,6 +39,30 @@ from .policies import VariancePolicyDict
39
39
  __all__ = ["equalise_residual_variance", "equalise_branch_variance", "VarianceGuard"]
40
40
 
41
41
 
42
+ def _safe_mean(
43
+ samples: list[float] | np.ndarray, default: float | None = None
44
+ ) -> float | None:
45
+ """
46
+ Compute mean of samples, returning default if empty.
47
+
48
+ Avoids numpy RuntimeWarning "Mean of empty slice" when samples is empty
49
+ or contains no valid values.
50
+
51
+ Args:
52
+ samples: List or array of float values.
53
+ default: Value to return if samples is empty.
54
+
55
+ Returns:
56
+ Mean value or default if samples is empty.
57
+ """
58
+ if samples is None:
59
+ return default
60
+ arr = np.asarray(samples)
61
+ if arr.size == 0:
62
+ return default
63
+ return float(np.nanmean(arr))
64
+
65
+
42
66
  try: # Optional dependency: tqdm (progress bars)
43
67
  from tqdm.auto import tqdm as _tqdm
44
68
  except Exception: # pragma: no cover - exercised only when tqdm is absent
@@ -1472,7 +1496,14 @@ class VarianceGuard(Guard):
1472
1496
 
1473
1497
  if coverage >= min_coverage and not self._scales:
1474
1498
  ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
1475
- ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
1499
+ ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples)
1500
+ if ppl_no_ve_mean is None:
1501
+ # No valid samples - cannot compute mean
1502
+ self._ratio_ci = None
1503
+ predictive_state["reason"] = "no_valid_samples"
1504
+ self._predictive_gate_state = predictive_state
1505
+ self._stats["predictive_gate"] = predictive_state.copy()
1506
+ return
1476
1507
  self.set_ab_results(
1477
1508
  ppl_no_ve=ppl_no_ve_mean,
1478
1509
  ppl_with_ve=ppl_no_ve_mean,
@@ -1527,8 +1558,12 @@ class VarianceGuard(Guard):
1527
1558
  n_bootstrap=500,
1528
1559
  seed=calib_seed,
1529
1560
  )
1530
- ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
1531
- ppl_with_ve_mean = float(np.mean(ppl_with_ve_samples))
1561
+ ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples)
1562
+ ppl_with_ve_mean = _safe_mean(ppl_with_ve_samples)
1563
+ if ppl_no_ve_mean is None or ppl_with_ve_mean is None:
1564
+ # Fallback if means couldn't be computed
1565
+ ppl_no_ve_mean = ppl_no_ve_mean or 0.0
1566
+ ppl_with_ve_mean = ppl_with_ve_mean or 0.0
1532
1567
  self.set_ab_results(
1533
1568
  ppl_no_ve=ppl_no_ve_mean,
1534
1569
  ppl_with_ve=ppl_with_ve_mean,
@@ -2118,7 +2153,7 @@ class VarianceGuard(Guard):
2118
2153
 
2119
2154
  if coverage >= min_coverage and not self._scales:
2120
2155
  ppl_no_ve_samples = ppl_no_ve_samples[:coverage]
2121
- ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
2156
+ ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples, default=0.0)
2122
2157
  self.set_ab_results(
2123
2158
  ppl_no_ve=ppl_no_ve_mean,
2124
2159
  ppl_with_ve=ppl_no_ve_mean,
@@ -2158,8 +2193,8 @@ class VarianceGuard(Guard):
2158
2193
  n_bootstrap=500,
2159
2194
  seed=calib_seed,
2160
2195
  )
2161
- ppl_no_ve_mean = float(np.mean(ppl_no_ve_samples))
2162
- ppl_with_ve_mean = float(np.mean(ppl_with_ve_samples))
2196
+ ppl_no_ve_mean = _safe_mean(ppl_no_ve_samples, default=0.0)
2197
+ ppl_with_ve_mean = _safe_mean(ppl_with_ve_samples, default=0.0)
2163
2198
  self.set_ab_results(
2164
2199
  ppl_no_ve=ppl_no_ve_mean,
2165
2200
  ppl_with_ve=ppl_with_ve_mean,
@@ -4,12 +4,16 @@ HuggingFace AWQ Adapter (plugin)
4
4
 
5
5
  Optional adapter for loading AWQ-quantized causal LMs from the Hub.
6
6
  Requires the `autoawq` extra on supported platforms (typically Linux/CUDA).
7
+
8
+ AWQ models are pre-quantized and typically handle device placement internally
9
+ during loading. This adapter does NOT call .to() on the loaded model.
7
10
  """
8
11
 
9
12
  from __future__ import annotations
10
13
 
11
14
  from typing import Any
12
15
 
16
+ from invarlock.adapters.capabilities import ModelCapabilities
13
17
  from invarlock.adapters.hf_mixin import HFAdapterMixin
14
18
  from invarlock.core.api import ModelAdapter
15
19
  from invarlock.core.error_utils import wrap_errors
@@ -56,7 +60,24 @@ class HF_AWQ_Adapter(HFAdapterMixin, ModelAdapter):
56
60
  trust_remote_code=True,
57
61
  **{k: v for k, v in kwargs.items() if k != "device"},
58
62
  )
59
- return model.to(self._resolve_device(device))
63
+
64
+ # AWQ models are pre-quantized; use safe device movement
65
+ # which respects the model's device constraints
66
+ return self._safe_to_device(
67
+ model, device, capabilities=ModelCapabilities.for_awq()
68
+ )
69
+
70
+ def get_capabilities(self, model: Any) -> ModelCapabilities:
71
+ """Return capabilities for an AWQ-quantized model."""
72
+ config = getattr(model, "config", None)
73
+ group_size = 128 # Default AWQ group size
74
+ if config is not None:
75
+ quant_cfg = getattr(config, "quantization_config", None)
76
+ if isinstance(quant_cfg, dict):
77
+ group_size = quant_cfg.get("group_size", 128)
78
+ elif quant_cfg is not None:
79
+ group_size = getattr(quant_cfg, "group_size", 128)
80
+ return ModelCapabilities.for_awq(group_size=group_size)
60
81
 
61
82
  def can_handle(self, model: Any) -> bool:
62
83
  cfg = getattr(model, "config", None)