invarlock 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. invarlock/__init__.py +1 -1
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +1 -1
  4. invarlock/calibration/spectral_null.py +15 -10
  5. invarlock/calibration/variance_ve.py +0 -2
  6. invarlock/cli/commands/calibrate.py +6 -2
  7. invarlock/cli/commands/certify.py +58 -39
  8. invarlock/cli/commands/doctor.py +3 -1
  9. invarlock/cli/commands/explain_gates.py +57 -8
  10. invarlock/cli/commands/report.py +1 -1
  11. invarlock/cli/commands/run.py +159 -61
  12. invarlock/cli/commands/verify.py +78 -4
  13. invarlock/cli/config.py +21 -5
  14. invarlock/core/api.py +45 -5
  15. invarlock/core/auto_tuning.py +65 -20
  16. invarlock/core/contracts.py +7 -1
  17. invarlock/core/registry.py +2 -2
  18. invarlock/core/runner.py +314 -50
  19. invarlock/eval/bench.py +0 -13
  20. invarlock/eval/data.py +73 -283
  21. invarlock/eval/metrics.py +134 -4
  22. invarlock/eval/primary_metric.py +23 -0
  23. invarlock/eval/tail_stats.py +230 -0
  24. invarlock/guards/_estimators.py +154 -0
  25. invarlock/guards/policies.py +16 -6
  26. invarlock/guards/rmt.py +625 -544
  27. invarlock/guards/spectral.py +348 -110
  28. invarlock/guards/tier_config.py +32 -30
  29. invarlock/guards/variance.py +5 -29
  30. invarlock/guards_ref/rmt_ref.py +23 -23
  31. invarlock/model_profile.py +42 -15
  32. invarlock/reporting/certificate.py +225 -46
  33. invarlock/reporting/certificate_schema.py +2 -1
  34. invarlock/reporting/dataset_hashing.py +15 -2
  35. invarlock/reporting/guards_analysis.py +197 -274
  36. invarlock/reporting/normalizer.py +6 -0
  37. invarlock/reporting/policy_utils.py +38 -36
  38. invarlock/reporting/primary_metric_utils.py +71 -17
  39. invarlock/reporting/render.py +61 -0
  40. invarlock/reporting/report.py +1 -1
  41. invarlock/reporting/report_types.py +5 -2
  42. invarlock/reporting/validate.py +1 -18
  43. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
  44. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
  45. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
  46. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
  47. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
  48. {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/eval/data.py CHANGED
@@ -7,7 +7,6 @@ Pluggable data loading system with deterministic windowing for reproducible eval
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- import atexit
11
10
  import hashlib
12
11
  import json
13
12
  import math
@@ -51,7 +50,6 @@ except ImportError:
51
50
 
52
51
  try:
53
52
  import torch
54
- import torch.nn.functional as F
55
53
 
56
54
  HAS_TORCH = True
57
55
  except ImportError:
@@ -160,9 +158,9 @@ class WikiText2Provider:
160
158
  """
161
159
 
162
160
  name = "wikitext2"
163
- _MODEL_CACHE: Any | None | bool = None
164
- _MODEL_DEVICE: Any | None = None
165
- _CLEANUP_REGISTERED: bool = False
161
+ _BYTE_NGRAM_ORDER = 4
162
+ _BYTE_NGRAM_PAD = 256
163
+ _BYTE_NGRAM_ALPHA = 1.0
166
164
 
167
165
  def __init__(
168
166
  self,
@@ -178,13 +176,9 @@ class WikiText2Provider:
178
176
  """
179
177
  self.cache_dir = cache_dir
180
178
  self._validate_dependencies()
181
- self._register_cleanup()
182
- self._difficulty_model = self.__class__._MODEL_CACHE
183
- self._difficulty_device = self.__class__._MODEL_DEVICE
184
179
  self._last_stratification_stats: dict[str, Any] | None = None
185
180
  self._last_batch_size_used: int = 0
186
181
  self._last_scorer_profile: dict[str, Any] | None = None
187
- self._scorer_warmed: bool = False
188
182
  # In-process cache for loaded/filtered texts to avoid repeated
189
183
  # load_dataset() calls across stratification retries.
190
184
  self._texts_cache: dict[str, list[str]] = {}
@@ -192,48 +186,9 @@ class WikiText2Provider:
192
186
  normalized_hint = (device_hint or "").strip().lower()
193
187
  self._device_hint: str | None = normalized_hint or None
194
188
 
195
- @classmethod
196
- def _register_cleanup(cls) -> None:
197
- """Register an atexit hook once per process to release cached models."""
198
- if cls._CLEANUP_REGISTERED or not HAS_TORCH:
199
- return
200
-
201
- def _cleanup() -> None:
202
- cls._cleanup_model_cache()
203
-
204
- atexit.register(_cleanup)
205
- cls._CLEANUP_REGISTERED = True
206
-
207
- @classmethod
208
- def _cleanup_model_cache(cls) -> None:
209
- """Release cached models to avoid leaking multiprocessing semaphores."""
210
- cache = cls._MODEL_CACHE
211
- if cache is not None and cache is not False and HAS_TORCH:
212
- try:
213
- cache.to("cpu")
214
- except Exception:
215
- pass
216
- cls._MODEL_CACHE = None
217
- cls._MODEL_DEVICE = None
218
-
219
- @staticmethod
220
- def _pick_default_scorer_device() -> torch.device:
221
- """
222
- Choose a default device for the difficulty scorer model.
223
-
224
- Prefers CUDA → MPS → CPU when available.
225
- """
226
- if torch.cuda.is_available():
227
- return torch.device("cuda")
228
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
229
- return torch.device("mps")
230
- return torch.device("cpu")
231
-
232
189
  def _validate_dependencies(self) -> None:
233
190
  """Check that required dependencies are available."""
234
191
  if not HAS_DATASETS:
235
- if _LIGHT_IMPORT:
236
- return
237
192
  raise _DepErr(
238
193
  code="E301",
239
194
  message=(
@@ -371,13 +326,6 @@ class WikiText2Provider:
371
326
  if cached is not None and len(cached) >= max_samples:
372
327
  return cached[:max_samples]
373
328
 
374
- if not HAS_DATASETS and _LIGHT_IMPORT:
375
- texts = ["hello world", "invarlock synthetic text"] * max(
376
- 1, max_samples // 2
377
- )
378
- self._texts_cache[split] = texts
379
- return texts[:max_samples]
380
-
381
329
  # Load dataset with size limit for efficiency
382
330
  dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
383
331
  dataset = load_dataset(
@@ -513,9 +461,11 @@ class WikiText2Provider:
513
461
  candidates.append(
514
462
  {
515
463
  "dataset_index": idx,
464
+ "text": texts[idx],
516
465
  "input_ids": input_ids_list,
517
466
  "attention_mask": attention_mask_list,
518
467
  "token_count": real_tokens,
468
+ "seq_len": len(input_ids_list),
519
469
  }
520
470
  )
521
471
 
@@ -531,32 +481,7 @@ class WikiText2Provider:
531
481
  details={"needed": int(total_required), "got": int(len(candidates))},
532
482
  )
533
483
 
534
- if not self._score_candidates_with_model(candidates):
535
- token_counter: Counter[int] = Counter()
536
- for candidate in candidates:
537
- for token_id, mask in zip(
538
- candidate["input_ids"], candidate["attention_mask"], strict=False
539
- ):
540
- if mask:
541
- token_counter[int(token_id)] += 1
542
-
543
- total_tokens = sum(token_counter.values()) or 1
544
- vocab_size = max(len(token_counter), 1)
545
-
546
- for candidate in candidates:
547
- difficulty = 0.0
548
- real_tokens = 0
549
- for token_id, mask in zip(
550
- candidate["input_ids"], candidate["attention_mask"], strict=False
551
- ):
552
- if not mask:
553
- continue
554
- freq = (token_counter[int(token_id)] + 1.0) / (
555
- total_tokens + vocab_size
556
- )
557
- difficulty -= math.log(freq)
558
- real_tokens += 1
559
- candidate["difficulty"] = difficulty / max(real_tokens, 1)
484
+ self._score_candidates_byte_ngram(candidates)
560
485
 
561
486
  sorted_candidates = sorted(
562
487
  candidates, key=lambda item: (item["difficulty"], item["dataset_index"])
@@ -843,193 +768,63 @@ class WikiText2Provider:
843
768
 
844
769
  return results
845
770
 
846
- def _score_candidates_with_model(self, candidates: list[dict[str, Any]]) -> bool:
847
- """Score candidate windows using a pretrained GPT-2 model if available."""
848
- if not HAS_TORCH:
849
- return False
850
-
851
- if self._difficulty_model is False:
852
- return False
853
-
854
- try:
855
- eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
856
- device_hint = getattr(self, "_device_hint", None)
857
-
858
- def _is_device_usable(device: torch.device) -> bool:
859
- try:
860
- _ = torch.zeros((1, 1), dtype=torch.long, device=device)
861
- return True
862
- except Exception:
863
- return False
864
-
865
- if self._difficulty_model is None:
866
- from transformers import GPT2LMHeadModel
867
-
868
- model = GPT2LMHeadModel.from_pretrained("gpt2")
869
- model.eval()
870
- # Decide initial scorer device: env override → provider hint → heuristic
871
- if eval_device_override:
872
- try:
873
- device = torch.device(eval_device_override)
874
- except Exception:
875
- device = self._pick_default_scorer_device()
876
- elif device_hint and device_hint != "auto":
877
- try:
878
- device = torch.device(device_hint)
879
- except Exception:
880
- device = self._pick_default_scorer_device()
881
- else:
882
- device = self._pick_default_scorer_device()
883
-
884
- if device.type != "cpu" and not _is_device_usable(device):
885
- warnings.warn(
886
- f"Difficulty scorer device {device} unavailable; falling back to CPU",
887
- stacklevel=2,
888
- )
889
- device = torch.device("cpu")
890
-
891
- model.to(device)
892
- self._difficulty_model = model
893
- self._difficulty_device = device
894
- self.__class__._MODEL_CACHE = model
895
- self.__class__._MODEL_DEVICE = device
896
-
897
- assert self._difficulty_model is not None
898
- model = self._difficulty_model
899
- device = self._difficulty_device or torch.device("cpu")
900
-
901
- # If a new override/hint is provided, move the cached model if needed.
902
- desired_device = device
903
- if eval_device_override:
904
- try:
905
- desired_device = torch.device(eval_device_override)
906
- except Exception:
907
- desired_device = device
908
- elif device_hint and device_hint != "auto":
909
- try:
910
- desired_device = torch.device(device_hint)
911
- except Exception:
912
- desired_device = device
913
-
914
- if desired_device != device:
915
- if desired_device.type != "cpu" and not _is_device_usable(
916
- desired_device
917
- ):
918
- warnings.warn(
919
- f"Difficulty scorer device {desired_device} unavailable; keeping {device}",
920
- stacklevel=2,
921
- )
922
- else:
923
- try:
924
- model.to(desired_device)
925
- device = desired_device
926
- self._difficulty_device = desired_device
927
- self.__class__._MODEL_DEVICE = desired_device
928
- except Exception as exc:
929
- warnings.warn(
930
- f"Failed to move GPT-2 difficulty scorer to {desired_device}: {exc}",
931
- stacklevel=2,
932
- )
933
-
934
- if not self._scorer_warmed:
935
- with torch.no_grad():
936
- dummy_input = torch.zeros((1, 8), dtype=torch.long, device=device)
937
- dummy_attention = torch.ones_like(dummy_input)
938
- model(dummy_input, attention_mask=dummy_attention)
939
- self._scorer_warmed = True
940
-
941
- batch_override = os.environ.get("INVARLOCK_SCORES_BATCH_SIZE")
942
- override_size = None
943
- if batch_override:
944
- try:
945
- override_size = max(1, int(batch_override))
946
- except ValueError:
947
- override_size = None
948
-
949
- batch_size = min(32, max(4, len(candidates)))
950
- if override_size is not None:
951
- batch_size = max(1, min(override_size, len(candidates)))
952
-
953
- config = getattr(model, "config", None)
954
- scorer_vocab_size = getattr(config, "vocab_size", None)
955
-
956
- input_batch: list[list[int]] = []
957
- attention_batch: list[list[int]] = []
958
- candidate_batch: list[dict[str, Any]] = []
959
- total_tokens = 0
960
- start_time = time.perf_counter()
961
-
962
- with torch.no_grad():
963
- for candidate in candidates:
964
- input_batch.append(candidate["input_ids"])
965
- attention_batch.append(candidate["attention_mask"])
966
- candidate_batch.append(candidate)
967
-
968
- if len(input_batch) == batch_size or candidate is candidates[-1]:
969
- input_tensor = torch.tensor(
970
- input_batch, dtype=torch.long, device=device
971
- )
972
- attention_tensor = torch.tensor(
973
- attention_batch, dtype=torch.long, device=device
974
- )
975
-
976
- # Guard against out-of-range token IDs when scoring with GPT-2.
977
- # Some model tokenizers emit IDs beyond GPT-2 vocab, which can
978
- # trigger device-side asserts in embedding/gather kernels.
979
- if scorer_vocab_size and scorer_vocab_size > 0:
980
- input_tensor = input_tensor.clamp(
981
- min=0, max=scorer_vocab_size - 1
982
- )
983
-
984
- outputs = model(input_tensor, attention_mask=attention_tensor)
985
- shift_logits = outputs.logits[:, :-1, :].contiguous()
986
- shift_labels = input_tensor[:, 1:].contiguous()
987
- shift_mask = attention_tensor[:, 1:].contiguous()
988
- shift_labels = shift_labels.masked_fill(shift_mask == 0, 0)
989
-
990
- vocab_size = shift_logits.size(-1)
991
- losses = F.cross_entropy(
992
- shift_logits.view(-1, vocab_size),
993
- shift_labels.view(-1),
994
- reduction="none",
995
- )
996
- losses = losses.view(shift_labels.size()) * shift_mask
997
- token_counts = shift_mask.sum(dim=1).clamp(min=1)
998
- loss_per_example = (
999
- (losses.sum(dim=1) / token_counts).cpu().tolist()
1000
- )
1001
-
1002
- for cand_obj, loss_value in zip(
1003
- candidate_batch, loss_per_example, strict=False
1004
- ):
1005
- cand_obj["difficulty"] = float(loss_value)
1006
- total_tokens += int(token_counts.sum().item())
1007
-
1008
- input_batch.clear()
1009
- attention_batch.clear()
1010
- candidate_batch.clear()
1011
- self._last_batch_size_used = batch_size
1012
- elapsed = max(time.perf_counter() - start_time, 1e-9)
1013
- tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
1014
- self._last_scorer_profile = {
1015
- "batch_size": batch_size,
1016
- "tokens_processed": total_tokens,
1017
- "elapsed_seconds": elapsed,
1018
- "tokens_per_second": tokens_per_sec,
1019
- }
1020
- return True
1021
- except Exception as exc: # pragma: no cover - defensive
1022
- warnings.warn(
1023
- f"Failed to compute GPT-2 difficulty scores: {exc}", stacklevel=2
1024
- )
1025
- self._difficulty_model = False
1026
- self._difficulty_device = None
1027
- self.__class__._MODEL_CACHE = False
1028
- self.__class__._MODEL_DEVICE = None
771
+ def _score_candidates_byte_ngram(self, candidates: list[dict[str, Any]]) -> bool:
772
+ if not candidates:
1029
773
  self._last_batch_size_used = 0
1030
774
  self._last_scorer_profile = None
1031
775
  return False
1032
776
 
777
+ order = max(1, int(self._BYTE_NGRAM_ORDER))
778
+ pad_token = int(self._BYTE_NGRAM_PAD)
779
+ alpha = float(self._BYTE_NGRAM_ALPHA)
780
+ vocab_size = pad_token + 1
781
+
782
+ context_counts: Counter[tuple[int, ...]] = Counter()
783
+ ngram_counts: Counter[tuple[int, ...]] = Counter()
784
+ sequences: list[list[int]] = []
785
+ start_time = time.perf_counter()
786
+
787
+ for candidate in candidates:
788
+ text = candidate.get("text")
789
+ if not isinstance(text, str):
790
+ text = ""
791
+ byte_values = list(text.encode("utf-8", errors="replace"))
792
+ tokens = ([pad_token] * (order - 1)) + byte_values
793
+ sequences.append(tokens)
794
+ for idx in range(order - 1, len(tokens)):
795
+ context = tuple(tokens[idx - order + 1 : idx])
796
+ ngram = context + (tokens[idx],)
797
+ context_counts[context] += 1
798
+ ngram_counts[ngram] += 1
799
+
800
+ total_tokens = 0
801
+ for candidate, tokens in zip(candidates, sequences, strict=False):
802
+ loss_sum = 0.0
803
+ token_count = 0
804
+ for idx in range(order - 1, len(tokens)):
805
+ context = tuple(tokens[idx - order + 1 : idx])
806
+ ngram = context + (tokens[idx],)
807
+ context_count = context_counts.get(context, 0)
808
+ ngram_count = ngram_counts.get(ngram, 0)
809
+ prob = (ngram_count + alpha) / (context_count + alpha * vocab_size)
810
+ loss_sum += -math.log(prob)
811
+ token_count += 1
812
+ candidate["difficulty"] = loss_sum / max(token_count, 1)
813
+ total_tokens += token_count
814
+
815
+ self._last_batch_size_used = len(candidates)
816
+ elapsed = max(time.perf_counter() - start_time, 1e-9)
817
+ tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
818
+ self._last_scorer_profile = {
819
+ "mode": "byte_ngram",
820
+ "order": order,
821
+ "vocab_size": vocab_size,
822
+ "tokens_processed": total_tokens,
823
+ "elapsed_seconds": elapsed,
824
+ "tokens_per_second": tokens_per_sec,
825
+ }
826
+ return True
827
+
1033
828
  def _tokenize_samples(
1034
829
  self,
1035
830
  texts: list[str],
@@ -1258,14 +1053,13 @@ class HFTextProvider:
1258
1053
  max_samples: int = 2000,
1259
1054
  ):
1260
1055
  if not HAS_DATASETS:
1261
- if not _LIGHT_IMPORT:
1262
- raise _DepErr(
1263
- code="E301",
1264
- message=(
1265
- "DEPENDENCY-MISSING: datasets library required for hf_text provider"
1266
- ),
1267
- details={"dependency": "datasets"},
1268
- )
1056
+ raise _DepErr(
1057
+ code="E301",
1058
+ message=(
1059
+ "DEPENDENCY-MISSING: datasets library required for hf_text provider"
1060
+ ),
1061
+ details={"dependency": "datasets"},
1062
+ )
1269
1063
  self.dataset_name = dataset_name or "wikitext"
1270
1064
  self.config_name = config_name or None
1271
1065
  self.text_field = text_field
@@ -1273,9 +1067,6 @@ class HFTextProvider:
1273
1067
  self.max_samples = int(max_samples)
1274
1068
 
1275
1069
  def load(self, split: str = "validation", **kwargs) -> list[str]:
1276
- if not HAS_DATASETS and _LIGHT_IMPORT:
1277
- return ["synthetic dataset text"] * int(self.max_samples or 1)
1278
-
1279
1070
  ds = load_dataset(
1280
1071
  path=self.dataset_name,
1281
1072
  name=self.config_name,
@@ -1400,14 +1191,13 @@ class HFSeq2SeqProvider:
1400
1191
  max_samples: int = 2000,
1401
1192
  ) -> None:
1402
1193
  if not HAS_DATASETS:
1403
- if not _LIGHT_IMPORT:
1404
- raise _DepErr(
1405
- code="E301",
1406
- message=(
1407
- "DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
1408
- ),
1409
- details={"dependency": "datasets"},
1410
- )
1194
+ raise _DepErr(
1195
+ code="E301",
1196
+ message=(
1197
+ "DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
1198
+ ),
1199
+ details={"dependency": "datasets"},
1200
+ )
1411
1201
  self.dataset_name = dataset_name
1412
1202
  self.config_name = config_name
1413
1203
  self.src_field = src_field
invarlock/eval/metrics.py CHANGED
@@ -723,7 +723,10 @@ def calculate_lens_metrics_for_model(
723
723
  except Exception as e:
724
724
  logger.error(f"Metrics calculation failed: {e}")
725
725
  if config.strict_validation:
726
- raise MetricsError(f"Metrics calculation failed: {e}") from e
726
+ raise MetricsError(
727
+ code="E401",
728
+ message=f"METRICS-COMPUTE-FAILED: {e}",
729
+ ) from e
727
730
 
728
731
  finally:
729
732
  resource_manager.cleanup()
@@ -1379,6 +1382,88 @@ def _resolve_eval_device(
1379
1382
  return resolved
1380
1383
 
1381
1384
 
1385
+ def _infer_model_vocab_size(model: nn.Module) -> int | None:
1386
+ """Best-effort vocab size for guarding against invalid token IDs.
1387
+
1388
+ Prefer the actual embedding size (more reliable than config.vocab_size when
1389
+ tokenizers have added tokens), and fall back to config when embeddings are
1390
+ unavailable (e.g., stub models in tests).
1391
+ """
1392
+ try:
1393
+ get_emb = getattr(model, "get_input_embeddings", None)
1394
+ if callable(get_emb):
1395
+ emb = get_emb()
1396
+ weight = getattr(emb, "weight", None)
1397
+ if weight is not None and hasattr(weight, "shape"):
1398
+ size = int(weight.shape[0])
1399
+ if size > 0:
1400
+ return size
1401
+ except Exception:
1402
+ pass
1403
+
1404
+ # Fallback for lightweight/stub models: pick the largest nn.Embedding module.
1405
+ # This is not guaranteed to be the token embedding, but is a good heuristic
1406
+ # when get_input_embeddings/config.vocab_size are unavailable.
1407
+ try:
1408
+ max_embeddings = 0
1409
+ for module in model.modules():
1410
+ if isinstance(module, nn.Embedding):
1411
+ max_embeddings = max(max_embeddings, int(module.num_embeddings))
1412
+ if max_embeddings > 0:
1413
+ return max_embeddings
1414
+ except Exception:
1415
+ pass
1416
+
1417
+ config = getattr(model, "config", None)
1418
+ vocab_size = getattr(config, "vocab_size", None)
1419
+ if isinstance(vocab_size, int) and vocab_size > 0:
1420
+ return vocab_size
1421
+ return None
1422
+
1423
+
1424
+ def _resolve_pad_token_id(model: nn.Module, vocab_size: int | None) -> int:
1425
+ """Pick a safe pad token id for sanitizing invalid token IDs."""
1426
+ config = getattr(model, "config", None)
1427
+ pad_token_id = getattr(config, "pad_token_id", None)
1428
+ if isinstance(pad_token_id, int) and pad_token_id >= 0:
1429
+ if vocab_size is None or pad_token_id < vocab_size:
1430
+ return pad_token_id
1431
+ return 0
1432
+
1433
+
1434
+ def _sanitize_token_ids_for_model(
1435
+ input_ids: torch.Tensor,
1436
+ attention_mask: torch.Tensor | None,
1437
+ labels: torch.Tensor | None,
1438
+ *,
1439
+ vocab_size: int,
1440
+ pad_token_id: int,
1441
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
1442
+ """Prevent device-side asserts from out-of-range token IDs.
1443
+
1444
+ Out-of-range token IDs can trigger CUDA device-side asserts in embedding and
1445
+ gather kernels, poisoning the CUDA context for the entire process. Instead,
1446
+ mask them out as padding and ignore them in labels.
1447
+ """
1448
+ if vocab_size <= 0:
1449
+ return input_ids, attention_mask, labels
1450
+
1451
+ invalid_inputs = (input_ids < 0) | (input_ids >= vocab_size)
1452
+ if invalid_inputs.any():
1453
+ input_ids = input_ids.masked_fill(invalid_inputs, pad_token_id)
1454
+ if attention_mask is not None:
1455
+ attention_mask = attention_mask.masked_fill(invalid_inputs, 0)
1456
+ if labels is not None:
1457
+ labels = labels.masked_fill(invalid_inputs, -100)
1458
+
1459
+ if labels is not None:
1460
+ invalid_labels = (labels != -100) & ((labels < 0) | (labels >= vocab_size))
1461
+ if invalid_labels.any():
1462
+ labels = labels.masked_fill(invalid_labels, -100)
1463
+
1464
+ return input_ids, attention_mask, labels
1465
+
1466
+
1382
1467
  # ── Perplexity calculation ─────────────────────────────────────────────────
1383
1468
  @torch.no_grad()
1384
1469
  def calculate_perplexity(
@@ -1415,6 +1500,8 @@ def compute_perplexity_strict(
1415
1500
  device = _resolve_eval_device(model, device)
1416
1501
 
1417
1502
  model.eval()
1503
+ model_vocab_size = _infer_model_vocab_size(model)
1504
+ pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
1418
1505
  nll_sum = 0.0
1419
1506
  tok_count = 0
1420
1507
 
@@ -1453,6 +1540,15 @@ def compute_perplexity_strict(
1453
1540
  else:
1454
1541
  labels = labels.to(device)
1455
1542
 
1543
+ if model_vocab_size is not None:
1544
+ input_ids, attn, labels = _sanitize_token_ids_for_model(
1545
+ input_ids,
1546
+ attn,
1547
+ labels,
1548
+ vocab_size=model_vocab_size,
1549
+ pad_token_id=pad_token_id,
1550
+ )
1551
+
1456
1552
  # Skip if sequence too short
1457
1553
  if input_ids.size(1) < 2:
1458
1554
  continue
@@ -1507,7 +1603,11 @@ def compute_perplexity_strict(
1507
1603
  continue
1508
1604
 
1509
1605
  log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1510
- tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1606
+ vocab_size = int(shift_logits.size(-1))
1607
+ valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
1608
+ if not valid.any():
1609
+ continue
1610
+ tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
1511
1611
  nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
1512
1612
 
1513
1613
  nll_sum += nll[valid].sum().item()
@@ -1552,6 +1652,8 @@ def compute_perplexity(
1552
1652
  device = _resolve_eval_device(model, device)
1553
1653
 
1554
1654
  model.eval()
1655
+ model_vocab_size = _infer_model_vocab_size(model)
1656
+ pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
1555
1657
  nll_sum = 0.0
1556
1658
  tok_count = 0
1557
1659
  batch_count = 0
@@ -1589,6 +1691,15 @@ def compute_perplexity(
1589
1691
  else:
1590
1692
  labels = labels.to(device)
1591
1693
 
1694
+ if model_vocab_size is not None:
1695
+ input_ids, attn, labels = _sanitize_token_ids_for_model(
1696
+ input_ids,
1697
+ attn,
1698
+ labels,
1699
+ vocab_size=model_vocab_size,
1700
+ pad_token_id=pad_token_id,
1701
+ )
1702
+
1592
1703
  # Skip if sequence too short
1593
1704
  if input_ids.size(1) < 2:
1594
1705
  continue
@@ -1620,7 +1731,11 @@ def compute_perplexity(
1620
1731
 
1621
1732
  # Compute negative log-likelihood
1622
1733
  log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1623
- tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1734
+ vocab_size = int(shift_logits.size(-1))
1735
+ valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
1736
+ if not valid.any():
1737
+ continue
1738
+ tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
1624
1739
 
1625
1740
  # MPS workaround: gather operation can fail on MPS, use CPU fallback
1626
1741
  if str(device).startswith("mps"):
@@ -1694,6 +1809,8 @@ def compute_ppl(
1694
1809
  device = _resolve_eval_device(model, device)
1695
1810
 
1696
1811
  model.eval()
1812
+ model_vocab_size = _infer_model_vocab_size(model)
1813
+ pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
1697
1814
  nll_sum = 0.0
1698
1815
  tok_count = 0
1699
1816
 
@@ -1712,6 +1829,15 @@ def compute_ppl(
1712
1829
  torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
1713
1830
  )
1714
1831
 
1832
+ if model_vocab_size is not None:
1833
+ input_ids_tensor, attention_mask_tensor, _ = _sanitize_token_ids_for_model(
1834
+ input_ids_tensor,
1835
+ attention_mask_tensor,
1836
+ labels=None,
1837
+ vocab_size=model_vocab_size,
1838
+ pad_token_id=pad_token_id,
1839
+ )
1840
+
1715
1841
  # Skip sequences that are too short
1716
1842
  if input_ids_tensor.size(1) < 2:
1717
1843
  continue
@@ -1747,7 +1873,11 @@ def compute_ppl(
1747
1873
 
1748
1874
  # Compute negative log-likelihood
1749
1875
  log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1750
- tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1876
+ vocab_size = int(shift_logits.size(-1))
1877
+ valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
1878
+ if not valid.any():
1879
+ continue
1880
+ tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
1751
1881
 
1752
1882
  # Handle MPS device issues with gather
1753
1883
  if str(device).startswith("mps"):