invarlock 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invarlock/__init__.py CHANGED
@@ -12,7 +12,7 @@ For torch-dependent functionality, see subpackages under `invarlock.*`:
12
12
  - `invarlock.eval`: Metrics, guard-overhead checks, and certification
13
13
  """
14
14
 
15
- __version__ = "0.3.3"
15
+ __version__ = "0.3.5"
16
16
 
17
17
  # Core exports - torch-independent
18
18
  from .config import CFG, Defaults, get_default_config
@@ -209,18 +209,18 @@ class _DelegatingAdapter(ModelAdapter):
209
209
  class HF_Causal_Auto_Adapter(_DelegatingAdapter):
210
210
  name = "hf_causal_auto"
211
211
 
212
- def load_model(self, model_id: str, device: str = "auto") -> Any:
212
+ def load_model(self, model_id: str, device: str = "auto", **kwargs: Any) -> Any:
213
213
  delegate = self._ensure_delegate_from_id(model_id)
214
- return delegate.load_model(model_id, device=device)
214
+ return delegate.load_model(model_id, device=device, **kwargs)
215
215
 
216
216
 
217
217
  class HF_MLM_Auto_Adapter(_DelegatingAdapter):
218
218
  name = "hf_mlm_auto"
219
219
 
220
- def load_model(self, model_id: str, device: str = "auto") -> Any:
220
+ def load_model(self, model_id: str, device: str = "auto", **kwargs: Any) -> Any:
221
221
  # Force BERT-like adapter for MLM families
222
222
  HF_BERT_Adapter = _importlib.import_module(
223
223
  ".hf_bert", __package__
224
224
  ).HF_BERT_Adapter
225
225
  self._delegate = HF_BERT_Adapter()
226
- return self._delegate.load_model(model_id, device=device)
226
+ return self._delegate.load_model(model_id, device=device, **kwargs)
@@ -41,7 +41,9 @@ class HF_BERT_Adapter(HFAdapterMixin, ModelAdapter):
41
41
 
42
42
  name = "hf_bert"
43
43
 
44
- def load_model(self, model_id: str, device: str = "auto") -> ModuleType | Any:
44
+ def load_model(
45
+ self, model_id: str, device: str = "auto", **kwargs: Any
46
+ ) -> ModuleType | Any:
45
47
  """
46
48
  Load a HuggingFace BERT model.
47
49
 
@@ -68,7 +70,7 @@ class HF_BERT_Adapter(HFAdapterMixin, ModelAdapter):
68
70
  "MODEL-LOAD-FAILED: transformers AutoModelForMaskedLM",
69
71
  lambda e: {"model_id": model_id},
70
72
  ):
71
- model = AutoModelForMaskedLM.from_pretrained(model_id)
73
+ model = AutoModelForMaskedLM.from_pretrained(model_id, **kwargs)
72
74
  except Exception:
73
75
  with wrap_errors(
74
76
  ModelLoadError,
@@ -76,10 +78,9 @@ class HF_BERT_Adapter(HFAdapterMixin, ModelAdapter):
76
78
  "MODEL-LOAD-FAILED: transformers AutoModel",
77
79
  lambda e: {"model_id": model_id},
78
80
  ):
79
- model = AutoModel.from_pretrained(model_id)
81
+ model = AutoModel.from_pretrained(model_id, **kwargs)
80
82
 
81
- target_device = self._resolve_device(device)
82
- return model.to(target_device)
83
+ return self._safe_to_device(model, device)
83
84
 
84
85
  def can_handle(self, model: ModuleType | Any) -> bool:
85
86
  """
@@ -48,7 +48,9 @@ class HF_GPT2_Adapter(HFAdapterMixin, ModelAdapter):
48
48
 
49
49
  name = "hf_gpt2"
50
50
 
51
- def load_model(self, model_id: str, device: str = "auto") -> ModuleType | Any:
51
+ def load_model(
52
+ self, model_id: str, device: str = "auto", **kwargs: Any
53
+ ) -> ModuleType | Any:
52
54
  """
53
55
  Load a HuggingFace GPT-2 model.
54
56
 
@@ -75,10 +77,9 @@ class HF_GPT2_Adapter(HFAdapterMixin, ModelAdapter):
75
77
  "MODEL-LOAD-FAILED: transformers AutoModelForCausalLM",
76
78
  lambda e: {"model_id": model_id},
77
79
  ):
78
- model = AutoModelForCausalLM.from_pretrained(model_id)
80
+ model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
79
81
 
80
- target_device = self._resolve_device(device)
81
- return model.to(target_device)
82
+ return self._safe_to_device(model, device)
82
83
  except DependencyError:
83
84
  if LIGHT_IMPORT:
84
85
  # Minimal stand-in that satisfies downstream interface requirements
@@ -41,7 +41,9 @@ class HF_LLaMA_Adapter(HFAdapterMixin, ModelAdapter):
41
41
 
42
42
  name = "hf_llama"
43
43
 
44
- def load_model(self, model_id: str, device: str = "auto") -> ModuleType | Any:
44
+ def load_model(
45
+ self, model_id: str, device: str = "auto", **kwargs: Any
46
+ ) -> ModuleType | Any:
45
47
  """
46
48
  Load a HuggingFace LLaMA model.
47
49
 
@@ -67,7 +69,7 @@ class HF_LLaMA_Adapter(HFAdapterMixin, ModelAdapter):
67
69
  "MODEL-LOAD-FAILED: transformers AutoModelForCausalLM",
68
70
  lambda e: {"model_id": model_id},
69
71
  ):
70
- model = AutoModelForCausalLM.from_pretrained(model_id)
72
+ model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
71
73
 
72
74
  # Use safe device movement that respects quantization constraints
73
75
  return self._safe_to_device(model, device)
@@ -119,6 +119,10 @@ class HFAdapterMixin:
119
119
  """
120
120
  target_device = self._resolve_device(device)
121
121
 
122
+ # If transformers already sharded/placed the model, skip explicit .to().
123
+ if getattr(model, "hf_device_map", None):
124
+ return model
125
+
122
126
  # Auto-detect capabilities if not provided
123
127
  if capabilities is None:
124
128
  capabilities = self._detect_capabilities(model)
@@ -334,7 +338,9 @@ class HFAdapterMixin:
334
338
  if hasattr(model, "config")
335
339
  else {},
336
340
  "params": {},
341
+ "params_meta": {},
337
342
  "buffers": {},
343
+ "buffers_meta": {},
338
344
  "device_map": {},
339
345
  "weight_tying": self._extract_weight_tying_info(model),
340
346
  }
@@ -344,6 +350,10 @@ class HFAdapterMixin:
344
350
  file_path = snapshot_dir / filename
345
351
  torch.save(param.detach().cpu(), file_path)
346
352
  manifest["params"][name] = filename
353
+ manifest["params_meta"][name] = {
354
+ "shape": [int(x) for x in param.shape],
355
+ "dtype": str(param.dtype),
356
+ }
347
357
  manifest["device_map"][name] = str(param.device)
348
358
 
349
359
  for name, buffer in model.named_buffers():
@@ -351,6 +361,10 @@ class HFAdapterMixin:
351
361
  file_path = snapshot_dir / filename
352
362
  torch.save(buffer.detach().cpu(), file_path)
353
363
  manifest["buffers"][name] = filename
364
+ manifest["buffers_meta"][name] = {
365
+ "shape": [int(x) for x in buffer.shape],
366
+ "dtype": str(buffer.dtype),
367
+ }
354
368
  manifest["device_map"][f"buffer::{name}"] = str(buffer.device)
355
369
 
356
370
  manifest_path = snapshot_dir / "manifest.json"
@@ -377,24 +391,89 @@ class HFAdapterMixin:
377
391
 
378
392
  device_map = manifest.get("device_map", {})
379
393
 
380
- for name, filename in manifest.get("params", {}).items():
394
+ params_manifest = manifest.get("params", {})
395
+ if not isinstance(params_manifest, dict):
396
+ raise TypeError("Invalid snapshot manifest: params must be a mapping")
397
+ buffers_manifest = manifest.get("buffers", {})
398
+ if not isinstance(buffers_manifest, dict):
399
+ raise TypeError("Invalid snapshot manifest: buffers must be a mapping")
400
+ params_meta = manifest.get("params_meta", {})
401
+ buffers_meta = manifest.get("buffers_meta", {})
402
+
403
+ # Preflight: ensure manifest/model agreement and tensor readability before copying.
404
+ for name, filename in params_manifest.items():
405
+ if name not in param_map:
406
+ raise KeyError(f"Snapshot parameter missing in target model: {name}")
407
+ if not isinstance(filename, str) or not filename:
408
+ raise TypeError(f"Invalid snapshot manifest filename for param: {name}")
381
409
  file_path = snapshot_dir / filename
382
- if name not in param_map or not file_path.exists():
383
- continue
410
+ if not file_path.exists():
411
+ raise FileNotFoundError(
412
+ f"Missing snapshot tensor for param: {file_path}"
413
+ )
414
+ tensor = torch.load(file_path, map_location="cpu")
415
+ if not isinstance(tensor, torch.Tensor):
416
+ raise TypeError(f"Invalid snapshot tensor payload for param: {name}")
417
+ meta = params_meta.get(name) if isinstance(params_meta, dict) else None
418
+ if isinstance(meta, dict):
419
+ expected_shape = meta.get("shape")
420
+ expected_dtype = meta.get("dtype")
421
+ if isinstance(expected_shape, list) and list(tensor.shape) != list(
422
+ expected_shape
423
+ ):
424
+ raise ValueError(
425
+ f"Snapshot tensor shape mismatch for param: {name}"
426
+ )
427
+ if isinstance(expected_dtype, str) and expected_dtype:
428
+ if str(tensor.dtype) != expected_dtype:
429
+ raise ValueError(
430
+ f"Snapshot tensor dtype mismatch for param: {name}"
431
+ )
432
+
433
+ for name, filename in buffers_manifest.items():
434
+ if name not in buffer_map:
435
+ raise KeyError(f"Snapshot buffer missing in target model: {name}")
436
+ if not isinstance(filename, str) or not filename:
437
+ raise TypeError(
438
+ f"Invalid snapshot manifest filename for buffer: {name}"
439
+ )
440
+ file_path = snapshot_dir / filename
441
+ if not file_path.exists():
442
+ raise FileNotFoundError(
443
+ f"Missing snapshot tensor for buffer: {file_path}"
444
+ )
445
+ tensor = torch.load(file_path, map_location="cpu")
446
+ if not isinstance(tensor, torch.Tensor):
447
+ raise TypeError(f"Invalid snapshot tensor payload for buffer: {name}")
448
+ meta = buffers_meta.get(name) if isinstance(buffers_meta, dict) else None
449
+ if isinstance(meta, dict):
450
+ expected_shape = meta.get("shape")
451
+ expected_dtype = meta.get("dtype")
452
+ if isinstance(expected_shape, list) and list(tensor.shape) != list(
453
+ expected_shape
454
+ ):
455
+ raise ValueError(
456
+ f"Snapshot tensor shape mismatch for buffer: {name}"
457
+ )
458
+ if isinstance(expected_dtype, str) and expected_dtype:
459
+ if str(tensor.dtype) != expected_dtype:
460
+ raise ValueError(
461
+ f"Snapshot tensor dtype mismatch for buffer: {name}"
462
+ )
463
+
464
+ # Restore parameters/buffers (second pass) after successful preflight.
465
+ for name, filename in params_manifest.items():
384
466
  target = param_map[name]
385
467
  target_device = torch.device(device_map.get(name, str(target.device)))
386
- tensor = torch.load(file_path, map_location="cpu")
468
+ tensor = torch.load(snapshot_dir / filename, map_location="cpu")
387
469
  with torch.no_grad():
388
470
  target.copy_(tensor.to(target_device))
389
471
 
390
- for name, filename in manifest.get("buffers", {}).items():
391
- file_path = snapshot_dir / filename
392
- if name not in buffer_map or not file_path.exists():
393
- continue
472
+ for name, filename in buffers_manifest.items():
394
473
  target = buffer_map[name]
395
474
  key = f"buffer::{name}"
396
475
  target_device = torch.device(device_map.get(key, str(target.device)))
397
- tensor = torch.load(file_path, map_location="cpu")
476
+ tensor = torch.load(snapshot_dir / filename, map_location="cpu")
398
477
  target.copy_(tensor.to(target_device))
399
478
 
400
479
  original_tying = manifest.get("weight_tying", {})
@@ -30,7 +30,9 @@ class HF_T5_Adapter(HFAdapterMixin, ModelAdapter):
30
30
 
31
31
  name = "hf_t5"
32
32
 
33
- def load_model(self, model_id: str, device: str = "auto") -> ModuleType | Any: # type: ignore[override]
33
+ def load_model( # type: ignore[override]
34
+ self, model_id: str, device: str = "auto", **kwargs: Any
35
+ ) -> ModuleType | Any:
34
36
  with wrap_errors(
35
37
  DependencyError,
36
38
  "E203",
@@ -45,8 +47,8 @@ class HF_T5_Adapter(HFAdapterMixin, ModelAdapter):
45
47
  "MODEL-LOAD-FAILED: transformers AutoModelForSeq2SeqLM",
46
48
  lambda e: {"model_id": model_id},
47
49
  ):
48
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
49
- return model.to(self._resolve_device(device))
50
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, **kwargs)
51
+ return self._safe_to_device(model, device)
50
52
 
51
53
  def can_handle(self, model: ModuleType | Any) -> bool: # type: ignore[override]
52
54
  cfg = getattr(model, "config", None)
@@ -378,7 +378,7 @@ def certify_command(
378
378
  f"Context: device={device}, adapter={adapter_name}, edit={edit_name}. "
379
379
  "Baseline ok; edited failed to compute ppl. "
380
380
  "Try: use an accelerator (mps/cuda), force float32, reduce max_modules, "
381
- "or lower batch size (INVARLOCK_SCORES_BATCH_SIZE)."
381
+ "or lower the evaluation batch size."
382
382
  ),
383
383
  details={
384
384
  "device": device,