invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,396 @@
1
+ """Minimal CLI config implementation for invarlock.cli.
2
+
3
+ Provides a lightweight, dict-backed configuration object plus helpers used by
4
+ the CLI commands (load_config, apply_profile, apply_edit_override, resolve_edit_kind).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import copy
10
+ import os
11
+ from dataclasses import dataclass, field
12
+ from importlib import resources as _ires
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import yaml
17
+
18
+
19
+ def _deep_merge(a: dict, b: dict) -> dict:
20
+ out = copy.deepcopy(a)
21
+ for k, v in b.items():
22
+ if isinstance(v, dict) and isinstance(out.get(k), dict):
23
+ out[k] = _deep_merge(out[k], v)
24
+ else:
25
+ out[k] = copy.deepcopy(v)
26
+ return out
27
+
28
+
29
+ class _Obj:
30
+ def __init__(self, data: Any):
31
+ self._data = data
32
+
33
+ def __getattr__(self, item):
34
+ # Only return values for existing keys; otherwise raise AttributeError
35
+ # so hasattr/getattr(..., default) behave correctly.
36
+ if item in self._data:
37
+ v = self._data[item]
38
+ if isinstance(v, dict):
39
+ return _Obj(v)
40
+ return v
41
+ raise AttributeError(item)
42
+
43
+ def __getitem__(self, key): # enable dict-like access in tests
44
+ if isinstance(self._data, dict):
45
+ return self._data[key]
46
+ raise TypeError("Object is not subscriptable")
47
+
48
+ # Provide dict-like helpers where tests use mapping semantics
49
+ def get(self, key: str, default: Any = None) -> Any:
50
+ if isinstance(self._data, dict):
51
+ return self._data.get(key, default)
52
+ return default
53
+
54
+ def items(self): # pragma: no cover - convenience for debug/tests
55
+ if isinstance(self._data, dict):
56
+ return self._data.items()
57
+ return []
58
+
59
+
60
+ @dataclass
61
+ class InvarLockConfig:
62
+ """Lightweight, dict-backed config with ergonomic attribute access.
63
+
64
+ Accepts either a single `data` mapping or keyword sections like `model=`,
65
+ `edit=`, `dataset=`, etc., and stores them internally as a dict.
66
+ """
67
+
68
+ data: dict[str, Any] = field(default_factory=dict)
69
+
70
+ def __init__(self, data: dict[str, Any] | None = None, **sections: Any) -> None:
71
+ if data is not None and sections:
72
+ merged = _deep_merge(data, sections)
73
+ self.data = merged
74
+ elif data is not None:
75
+ self.data = copy.deepcopy(data)
76
+ else:
77
+ self.data = copy.deepcopy(sections)
78
+
79
+ # Basic validation hooks for well-known edits (none required here)
80
+
81
+ def model_dump(self) -> dict[str, Any]:
82
+ return copy.deepcopy(self.data)
83
+
84
+ def __getattr__(self, item):
85
+ if item in self.data:
86
+ v = self.data[item]
87
+ if isinstance(v, dict):
88
+ return _Obj(v)
89
+ return v
90
+ raise AttributeError(item)
91
+
92
+
93
+ # Typed sub-configs used by tests (minimal validation only)
94
+ @dataclass
95
+ class OutputConfig:
96
+ dir: Path | str
97
+
98
+ def __post_init__(self) -> None:
99
+ if isinstance(self.dir, str):
100
+ self.dir = Path(self.dir)
101
+
102
+
103
+ @dataclass
104
+ class DatasetConfig:
105
+ seq_len: int = 512
106
+ stride: int = 512
107
+ provider: str | None = None
108
+ split: str = "validation"
109
+ preview_n: int | None = None
110
+ final_n: int | None = None
111
+ seed: int | None = None
112
+
113
+ def __post_init__(self) -> None:
114
+ if self.stride > self.seq_len:
115
+ raise ValueError("stride must be <= seq_len")
116
+
117
+
118
+ @dataclass
119
+ class EvalBootstrapConfig:
120
+ replicates: int = 1000
121
+ alpha: float = 0.05
122
+ ci_band: float = 0.10
123
+
124
+ def __post_init__(self) -> None:
125
+ if self.replicates <= 0:
126
+ raise ValueError("replicates must be > 0")
127
+ if not (0.0 < float(self.alpha) < 1.0):
128
+ raise ValueError("alpha must be in (0,1)")
129
+
130
+
131
+ @dataclass
132
+ class SpectralGuardConfig:
133
+ sigma_quantile: float | None = None
134
+ contraction: float | None = None
135
+ family_caps: dict[str, Any] = field(default_factory=dict)
136
+
137
+ def __post_init__(self) -> None:
138
+ # contraction is an alias for sigma_quantile
139
+ if self.contraction is not None and self.sigma_quantile is None:
140
+ self.sigma_quantile = float(self.contraction)
141
+ self.contraction = None
142
+ # normalize family_caps: scalar → {"kappa": value}
143
+ caps = {}
144
+ for k, v in (self.family_caps or {}).items():
145
+ if isinstance(v, dict):
146
+ caps[k] = {"kappa": float(v.get("kappa", 0.0))}
147
+ else:
148
+ caps[k] = {"kappa": float(v)}
149
+ self.family_caps = caps
150
+
151
+
152
+ @dataclass
153
+ class RMTGuardConfig:
154
+ epsilon: dict[str, float] | float | None = None
155
+
156
+
157
+ @dataclass
158
+ class VarianceGuardConfig:
159
+ clamp: list[float] | None = None
160
+ mode: str | None = None
161
+ deadband: float | None = None
162
+ min_gain: float | None = None
163
+ min_rel_gain: float | None = None
164
+ min_abs_adjust: float | None = None
165
+ max_scale_step: float | None = None
166
+ min_effect_lognll: float | None = None
167
+ predictive_one_sided: bool | None = None
168
+ topk_backstop: int | None = None
169
+ max_adjusted_modules: int | None = None
170
+ predictive_gate: bool | None = None
171
+ target_modules: list[str] | None = None
172
+ scope: str | None = None
173
+ calibration: dict[str, Any] = field(default_factory=dict)
174
+ absolute_floor_ppl: float | None = None
175
+
176
+ def __post_init__(self) -> None:
177
+ if self.clamp is not None:
178
+ if not (isinstance(self.clamp, list) and len(self.clamp) == 2):
179
+ raise ValueError("clamp must be [low, high]")
180
+ low, high = float(self.clamp[0]), float(self.clamp[1])
181
+ if low >= high:
182
+ raise ValueError("clamp lower bound must be < upper bound")
183
+ if self.absolute_floor_ppl is None:
184
+ # Provide conservative default when not specified
185
+ self.absolute_floor_ppl = 0.05
186
+
187
+
188
+ @dataclass
189
+ class EditConfig:
190
+ name: str
191
+ plan: dict[str, Any] = field(default_factory=dict)
192
+
193
+
194
+ @dataclass
195
+ class AutoConfig:
196
+ probes: int = 0
197
+ target_pm_ratio: float = 1.0
198
+
199
+ def __post_init__(self) -> None:
200
+ if not (0 <= int(self.probes) <= 10):
201
+ raise ValueError("probes must be between 0 and 10")
202
+ if float(self.target_pm_ratio) < 1.0:
203
+ raise ValueError("target_pm_ratio must be >= 1.0")
204
+
205
+
206
+ def _create_loader(base_dir: Path):
207
+ class Loader(yaml.SafeLoader):
208
+ pass
209
+
210
+ Loader._base_dir = Path(base_dir)
211
+
212
+ def _construct_include(loader: yaml.SafeLoader, node: yaml.Node):
213
+ rel = loader.construct_scalar(node)
214
+ path = (loader._base_dir / rel).resolve()
215
+ with path.open(encoding="utf-8") as fh:
216
+ inc_loader = _create_loader(path.parent)
217
+ return yaml.load(fh, Loader=inc_loader)
218
+
219
+ Loader.add_constructor("!include", _construct_include)
220
+ return Loader
221
+
222
+
223
+ def load_config(path: str | Path) -> InvarLockConfig:
224
+ p = Path(path)
225
+ if not p.exists():
226
+ raise FileNotFoundError(f"Configuration file not found: {p}")
227
+ loader = _create_loader(p.parent)
228
+ with p.open(encoding="utf-8") as fh:
229
+ raw = yaml.load(fh, Loader=loader)
230
+ if not isinstance(raw, dict):
231
+ raise ValueError("Top-level config must be a mapping")
232
+ defaults = raw.pop("defaults", None)
233
+ if defaults is not None and not isinstance(defaults, dict):
234
+ raise ValueError("defaults must be a mapping when present")
235
+ if isinstance(defaults, dict):
236
+ raw = _deep_merge(defaults, raw)
237
+ # Coerce known guard configs for friendlier attribute access
238
+ guards = raw.get("guards")
239
+ if isinstance(guards, dict):
240
+ var = guards.get("variance")
241
+ if isinstance(var, dict):
242
+ # Pick only recognized keys
243
+ vkw = {
244
+ k: var.get(k)
245
+ for k in [
246
+ "clamp",
247
+ "mode",
248
+ "deadband",
249
+ "min_gain",
250
+ "min_rel_gain",
251
+ "min_abs_adjust",
252
+ "max_scale_step",
253
+ "min_effect_lognll",
254
+ "predictive_one_sided",
255
+ "topk_backstop",
256
+ "max_adjusted_modules",
257
+ "predictive_gate",
258
+ "target_modules",
259
+ "scope",
260
+ "calibration",
261
+ "absolute_floor_ppl",
262
+ ]
263
+ }
264
+ if vkw.get("mode") is None:
265
+ vkw["mode"] = "ci"
266
+ guards["variance"] = VarianceGuardConfig(
267
+ **{k: v for k, v in vkw.items() if v is not None}
268
+ )
269
+ return InvarLockConfig(raw)
270
+
271
+
272
+ def _load_runtime_yaml(*rel_parts: str) -> dict[str, Any] | None:
273
+ """Load YAML from the runtime config locations.
274
+
275
+ Search order:
276
+ 1) $INVARLOCK_CONFIG_ROOT/runtime/...
277
+ 2) invarlock._data.runtime package resources
278
+ Returns mapping or None if not found.
279
+ """
280
+ # 1) Environment override
281
+ root = os.getenv("INVARLOCK_CONFIG_ROOT")
282
+ if root:
283
+ p = Path(root) / "runtime"
284
+ for part in rel_parts:
285
+ p = p / part
286
+ if p.exists():
287
+ with p.open(encoding="utf-8") as fh:
288
+ data = yaml.safe_load(fh) or {}
289
+ if not isinstance(data, dict):
290
+ raise ValueError("Runtime YAML must be a mapping")
291
+ return data
292
+
293
+ # 2) Package data
294
+ try:
295
+ base = _ires.files("invarlock._data.runtime")
296
+ res = base
297
+ for part in rel_parts:
298
+ res = res.joinpath(part)
299
+ # Traversable API: try reading if file-like
300
+ try:
301
+ if getattr(res, "is_file", None) and res.is_file(): # type: ignore[attr-defined]
302
+ text = res.read_text(encoding="utf-8") # type: ignore[assignment]
303
+ data = yaml.safe_load(text) or {}
304
+ if not isinstance(data, dict):
305
+ raise ValueError("Runtime YAML must be a mapping")
306
+ return data
307
+ except FileNotFoundError:
308
+ pass
309
+ except Exception:
310
+ # Importlib resources may not be available in certain environments
311
+ pass
312
+ return None
313
+
314
+
315
+ def load_tiers() -> dict[str, Any]:
316
+ """Load tier policies from runtime locations."""
317
+ data = _load_runtime_yaml("tiers.yaml")
318
+ if data is not None:
319
+ return data
320
+ raise FileNotFoundError(
321
+ "tiers.yaml not found in package runtime (and no INVARLOCK_CONFIG_ROOT override)"
322
+ )
323
+
324
+
325
+ def apply_profile(cfg: InvarLockConfig, profile: str) -> InvarLockConfig:
326
+ # First, try packaged/runtime profiles
327
+ overrides: dict[str, Any] | None = _load_runtime_yaml("profiles", f"{profile}.yaml")
328
+
329
+ if overrides is None:
330
+ # Provide sensible CI defaults when 'ci' profile file is absent
331
+ if profile.lower() == "ci":
332
+ try:
333
+ prev = int(os.getenv("INVARLOCK_CI_PREVIEW", "200"))
334
+ except Exception:
335
+ prev = 200
336
+ try:
337
+ fin = int(os.getenv("INVARLOCK_CI_FINAL", "200"))
338
+ except Exception:
339
+ fin = 200
340
+ overrides = {
341
+ "dataset": {"preview_n": prev, "final_n": fin},
342
+ "eval": {"bootstrap": {"replicates": 1200, "alpha": 0.05}},
343
+ }
344
+ else:
345
+ raise ValueError(f"Unknown profile: {profile}")
346
+ return InvarLockConfig(_deep_merge(cfg.model_dump(), overrides))
347
+
348
+
349
+ def resolve_edit_kind(kind: str) -> str:
350
+ kind = kind.lower().strip()
351
+ # Aliases for common edit types
352
+ mapping = {
353
+ "prune": "quant_rtn",
354
+ "quant": "quant_rtn",
355
+ "mixed": "orchestrator",
356
+ }
357
+ # Direct mapping for aliased kinds
358
+ if kind in mapping:
359
+ return mapping[kind]
360
+ # Check if the kind is a registered edit name (e.g., "noop", "quant_rtn")
361
+ try:
362
+ from invarlock.edits.registry import get_registry
363
+
364
+ registry = get_registry()
365
+ if registry.get_plugin(kind) is not None:
366
+ return kind
367
+ except ImportError:
368
+ pass
369
+ # Also allow well-known edit names directly
370
+ known_edits = {"quant_rtn", "noop"}
371
+ if kind in known_edits:
372
+ return kind
373
+ raise ValueError(f"Unknown edit kind: {kind}")
374
+
375
+
376
+ def apply_edit_override(cfg: InvarLockConfig, kind: str) -> InvarLockConfig:
377
+ cfgd = cfg.model_dump()
378
+ resolved = resolve_edit_kind(kind)
379
+ edit_section = cfgd.setdefault("edit", {})
380
+ edit_section["name"] = resolved
381
+ edit_section["kind"] = kind
382
+ return InvarLockConfig(cfgd)
383
+
384
+
385
+ # Backward-compat helper name expected by tests
386
+ def _deep_merge_dicts(a: dict, b: dict) -> dict: # pragma: no cover - trivial alias
387
+ return _deep_merge(a, b)
388
+
389
+
390
+ def create_example_config() -> InvarLockConfig: # pragma: no cover - test helper
391
+ return InvarLockConfig(
392
+ model={"id": "gpt2", "adapter": "hf_gpt2", "device": "auto"},
393
+ edit={"name": "quant_rtn", "plan": {}},
394
+ dataset={"provider": "wikitext2", "seq_len": 512, "stride": 512},
395
+ output={"dir": "runs"},
396
+ )
@@ -0,0 +1,68 @@
1
+ """CLI constants shared across commands to keep outputs consistent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ # Human-readable, versioned format identifiers for JSON outputs
6
+ # Keep in sync with tests under tests/cli/*_json_*.py
7
+ DOCTOR_FORMAT_VERSION = "doctor-v1"
8
+ PLUGINS_FORMAT_VERSION = "plugins-v1"
9
+ VERIFY_FORMAT_VERSION = "verify-v1"
10
+
11
+ PROVIDER_NOTES: dict[str, str] = {
12
+ # WikiText-2 is loaded via datasets; works offline if cached.
13
+ "wikitext2": "requires network or local cache",
14
+ # Synthetic corpus used for quick smokes and CI; fully offline.
15
+ "synthetic": "offline; deterministic",
16
+ # Hugging Face text datasets (via datasets.load_dataset)
17
+ "hf_text": "requires network",
18
+ # Local providers (offline)
19
+ "local_jsonl": "local files; offline",
20
+ "local_jsonl_pairs": "paired prompts/responses (JSONL); offline",
21
+ # Seq2Seq providers
22
+ "seq2seq": "toy seq2seq dataset; offline",
23
+ "hf_seq2seq": "requires network",
24
+ }
25
+
26
+ # Optional structured metadata for richer CLI tables
27
+ PROVIDER_PARAMS: dict[str, str] = {
28
+ "wikitext2": "-",
29
+ "synthetic": "-",
30
+ "hf_text": "dataset_name[, split, text_field]",
31
+ "hf_seq2seq": "dataset_name[, split, input_field, target_field]",
32
+ "local_jsonl": "path[, text_field]",
33
+ "local_jsonl_pairs": "path[, input_field, target_field]",
34
+ "seq2seq": "-",
35
+ }
36
+
37
+ # Stable network classification to avoid tying UI to note strings
38
+ PROVIDER_NETWORK: dict[str, str] = {
39
+ # 'no' | 'cache' | 'yes'
40
+ "wikitext2": "cache",
41
+ "synthetic": "no",
42
+ "hf_text": "yes",
43
+ "local_jsonl": "no",
44
+ "local_jsonl_pairs": "no",
45
+ "seq2seq": "no",
46
+ "hf_seq2seq": "yes",
47
+ }
48
+
49
+ # Simple kind classification for presentation
50
+ PROVIDER_KIND: dict[str, str] = {
51
+ "wikitext2": "text",
52
+ "synthetic": "text",
53
+ "hf_text": "text",
54
+ "local_jsonl": "text",
55
+ "local_jsonl_pairs": "pairs",
56
+ "seq2seq": "seq2seq",
57
+ "hf_seq2seq": "seq2seq",
58
+ }
59
+
60
+ __all__ = [
61
+ "DOCTOR_FORMAT_VERSION",
62
+ "PLUGINS_FORMAT_VERSION",
63
+ "VERIFY_FORMAT_VERSION",
64
+ "PROVIDER_NOTES",
65
+ "PROVIDER_PARAMS",
66
+ "PROVIDER_NETWORK",
67
+ "PROVIDER_KIND",
68
+ ]
@@ -0,0 +1,92 @@
1
+ """Minimal device helpers for the CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ def resolve_device(requested: str | None) -> str:
9
+ req = (requested or "auto").lower()
10
+ if req != "auto":
11
+ if not is_device_available(req):
12
+ raise RuntimeError(f"Device '{req}' is not available")
13
+ return req
14
+ # Prefer CUDA → MPS → CPU
15
+ if is_device_available("cuda"):
16
+ # Resolve to first CUDA device explicitly
17
+ return "cuda:0"
18
+ if is_device_available("mps"):
19
+ return "mps"
20
+ return "cpu"
21
+
22
+
23
+ def is_device_available(device: str) -> bool:
24
+ d = (device or "cpu").lower()
25
+ # Normalize CUDA variants like 'cuda:0' → 'cuda'
26
+ if d.startswith("cuda"):
27
+ d = "cuda"
28
+ if d == "cpu":
29
+ return True
30
+ try:
31
+ import torch # noqa: F401
32
+
33
+ if d == "cuda" and hasattr(torch, "cuda") and torch.cuda.is_available(): # type: ignore[attr-defined]
34
+ return True
35
+ if (
36
+ d == "mps"
37
+ and hasattr(torch.backends, "mps")
38
+ and torch.backends.mps.is_available()
39
+ ): # type: ignore[attr-defined]
40
+ return True
41
+ except Exception:
42
+ return False
43
+ return False
44
+
45
+
46
+ def validate_device_for_config(
47
+ device: str, config_requirements: dict[str, Any] | None = None
48
+ ) -> tuple[bool, str]:
49
+ # Simple validation stub; extend with model/profile specific checks as needed
50
+ valid = {"cpu", "cuda", "cuda:0", "mps"}
51
+ if device not in valid:
52
+ return False, f"Unsupported device '{device}'"
53
+ if config_requirements and config_requirements.get("required_device"):
54
+ req = str(config_requirements.get("required_device")).lower()
55
+ if device != req:
56
+ return (
57
+ False,
58
+ f"Configuration requires device '{req}' but '{device}' was selected",
59
+ )
60
+ return True, ""
61
+
62
+
63
+ def get_device_info() -> dict[str, dict]:
64
+ """Return a structured snapshot of device availability.
65
+
66
+ Keys: 'cpu', 'cuda', 'mps', and 'auto_selected'.
67
+ """
68
+ info: dict[str, dict] = {
69
+ "cpu": {"available": True, "info": "Available"},
70
+ "cuda": {"available": False, "info": "Not available"},
71
+ "mps": {"available": False, "info": "Not available"},
72
+ }
73
+ auto = resolve_device("auto")
74
+ try:
75
+ import torch # noqa: F401
76
+
77
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # type: ignore[attr-defined]
78
+ info["mps"]["available"] = True
79
+ info["mps"]["info"] = "Available"
80
+ if hasattr(torch, "cuda") and torch.cuda.is_available(): # type: ignore[attr-defined]
81
+ props = torch.cuda.get_device_properties(0)
82
+ name = getattr(props, "name", "CUDA")
83
+ mem = getattr(props, "total_memory", 0)
84
+ info["cuda"]["available"] = True
85
+ info["cuda"]["info"] = "Available"
86
+ info["cuda"]["device_count"] = torch.cuda.device_count()
87
+ info["cuda"]["device_name"] = name
88
+ info["cuda"]["memory_total"] = f"{mem / 1e9:.1f} GB"
89
+ except Exception:
90
+ pass
91
+ info["auto_selected"] = auto
92
+ return info
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import platform as _platform
5
+ from typing import Any
6
+
7
+
8
+ def get_adapter_rows() -> list[dict[str, Any]]:
9
+ """Build adapter rows similar to doctor output for testing.
10
+
11
+ Applies optional-extra detection for hf_onnx (optimum/onnxruntime) even if
12
+ registered as a core adapter, so missing extras are surfaced.
13
+ """
14
+ from invarlock.core.registry import get_registry
15
+
16
+ try:
17
+ import torch as _t # noqa: F401
18
+
19
+ has_cuda = bool(getattr(_t, "cuda", None) and _t.cuda.is_available())
20
+ except Exception:
21
+ has_cuda = False
22
+
23
+ registry = get_registry()
24
+ is_linux = _platform.system().lower() == "linux"
25
+
26
+ rows: list[dict[str, Any]] = []
27
+ for name in registry.list_adapters():
28
+ info = registry.get_plugin_info(name, "adapters")
29
+ module = str(info.get("module") or "")
30
+ support = (
31
+ "auto"
32
+ if module.startswith("invarlock.adapters")
33
+ and name in {"hf_causal_auto", "hf_mlm_auto"}
34
+ else ("core" if module.startswith("invarlock.adapters") else "optional")
35
+ )
36
+ backend, status, enable = None, "ready", ""
37
+
38
+ if name in {"hf_gpt2", "hf_bert", "hf_llama", "hf_causal_auto", "hf_mlm_auto"}:
39
+ backend = "transformers"
40
+ elif name == "hf_gptq":
41
+ backend = "auto-gptq"
42
+ if not is_linux:
43
+ status, enable = "unsupported", "Linux-only"
44
+ elif name == "hf_awq":
45
+ backend = "autoawq"
46
+ if not is_linux:
47
+ status, enable = "unsupported", "Linux-only"
48
+ elif name == "hf_bnb":
49
+ backend = "bitsandbytes"
50
+ if not has_cuda:
51
+ status, enable = "unsupported", "Requires CUDA"
52
+ elif name == "hf_onnx":
53
+ backend = "onnxruntime"
54
+ present = (
55
+ importlib.util.find_spec("optimum.onnxruntime") is not None
56
+ or importlib.util.find_spec("onnxruntime") is not None
57
+ )
58
+ if not present:
59
+ status = "needs_extra"
60
+ enable = "pip install 'invarlock[onnx]'"
61
+
62
+ rows.append(
63
+ {
64
+ "name": name,
65
+ "origin": "core" if support in {"core", "auto"} else "plugin",
66
+ "mode": "auto-matcher" if support == "auto" else "adapter",
67
+ "backend": backend,
68
+ "version": None,
69
+ "status": status,
70
+ "enable": enable,
71
+ }
72
+ )
73
+
74
+ return rows
@@ -0,0 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+ # Back-compat shim: re-export the core InvarlockError for CLI imports
4
+ from invarlock.core.exceptions import InvarlockError
5
+
6
+ __all__ = ["InvarlockError"]
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+
7
+ def _extract_pm_snapshot_for_overhead(
8
+ src: object, *, kind: str
9
+ ) -> dict[str, Any] | None:
10
+ """Extract or compute a primary-metric snapshot from diverse report shapes.
11
+
12
+ Accepts either:
13
+ - CoreRunner RunReport-like objects (dataclasses) with `.metrics`/`.evaluation_windows`
14
+ - Dict reports with `evaluation_windows` or `metrics.primary_metric`
15
+
16
+ Returns a dict suitable for `metrics.primary_metric` or None if unavailable.
17
+ """
18
+ # 1) Prefer existing primary_metric on object metrics
19
+ try:
20
+ metrics = getattr(src, "metrics", None)
21
+ if isinstance(metrics, dict):
22
+ pm = metrics.get("primary_metric")
23
+ if isinstance(pm, dict):
24
+ fin = pm.get("final")
25
+ if isinstance(fin, int | float) and math.isfinite(float(fin)):
26
+ return pm # already a valid snapshot
27
+ except Exception:
28
+ pass
29
+
30
+ # 2) If dict-shaped report provided, try computing from it directly
31
+ try:
32
+ if isinstance(src, dict):
33
+ from invarlock.eval.primary_metric import compute_primary_metric_from_report
34
+
35
+ pm2 = compute_primary_metric_from_report(src, kind=kind)
36
+ fin2 = pm2.get("final") if isinstance(pm2, dict) else None
37
+ if isinstance(fin2, int | float) and math.isfinite(float(fin2)):
38
+ return pm2
39
+ except Exception:
40
+ pass
41
+
42
+ # 3) Compute from evaluation_windows attribute on CoreRunner reports
43
+ try:
44
+ ew = getattr(src, "evaluation_windows", None)
45
+ if isinstance(ew, dict) and ew:
46
+ from invarlock.eval.primary_metric import compute_primary_metric_from_report
47
+
48
+ pm3 = compute_primary_metric_from_report(
49
+ {"evaluation_windows": ew}, kind=kind
50
+ )
51
+ fin3 = pm3.get("final") if isinstance(pm3, dict) else None
52
+ if isinstance(fin3, int | float) and math.isfinite(float(fin3)):
53
+ return pm3
54
+ except Exception:
55
+ pass
56
+
57
+ return None
58
+
59
+
60
+ __all__ = ["_extract_pm_snapshot_for_overhead"]