invarlock 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,6 +117,18 @@ def _validate_pairing(certificate: dict[str, Any]) -> list[str]:
117
117
 
118
118
  match_fraction = stats.get("window_match_fraction")
119
119
  overlap_fraction = stats.get("window_overlap_fraction")
120
+ pairing_reason = stats.get("window_pairing_reason")
121
+ paired_windows = _coerce_int(stats.get("paired_windows"))
122
+
123
+ if pairing_reason is not None:
124
+ errors.append(
125
+ "window_pairing_reason must be null/None for paired certificates "
126
+ f"(found {pairing_reason!r})."
127
+ )
128
+ if paired_windows is None:
129
+ errors.append("Certificate missing paired_windows metric.")
130
+ elif paired_windows == 0:
131
+ errors.append("paired_windows must be > 0 for paired certificates (found 0).")
120
132
 
121
133
  if match_fraction is None:
122
134
  errors.append("Certificate missing window_match_fraction metric.")
invarlock/core/runner.py CHANGED
@@ -1528,7 +1528,7 @@ class CoreRunner:
1528
1528
  pairing_reason = "duplicate_windows"
1529
1529
  elif count_mismatch:
1530
1530
  pairing_reason = "count_mismatch"
1531
- elif not pairing_context:
1531
+ else:
1532
1532
  pairing_reason = preview_pair_stats.get(
1533
1533
  "reason"
1534
1534
  ) or final_pair_stats.get("reason")
@@ -2079,24 +2079,49 @@ class CoreRunner:
2079
2079
  # Perform rollback if checkpoint available
2080
2080
  if self.checkpoint_manager and "initial_checkpoint" in report.meta:
2081
2081
  checkpoint_id = report.meta["initial_checkpoint"]
2082
- self.checkpoint_manager.restore_checkpoint(
2083
- model, adapter, checkpoint_id
2084
- )
2085
- # Match test expectation: only include checkpoint and reason
2086
- self._log_event(
2087
- "finalize",
2088
- "rollback",
2089
- LogLevel.WARNING,
2090
- {
2091
- "checkpoint": checkpoint_id,
2092
- "reason": rollback_reason,
2093
- },
2094
- )
2082
+ restored = False
2083
+ restore_error: str | None = None
2084
+ try:
2085
+ restored = bool(
2086
+ self.checkpoint_manager.restore_checkpoint(
2087
+ model, adapter, checkpoint_id
2088
+ )
2089
+ )
2090
+ except Exception as exc:
2091
+ restored = False
2092
+ restore_error = str(exc)
2093
+
2094
+ if restored:
2095
+ # Match test expectation: only include checkpoint and reason
2096
+ self._log_event(
2097
+ "finalize",
2098
+ "rollback",
2099
+ LogLevel.WARNING,
2100
+ {
2101
+ "checkpoint": checkpoint_id,
2102
+ "reason": rollback_reason,
2103
+ },
2104
+ )
2105
+ else:
2106
+ self._log_event(
2107
+ "finalize",
2108
+ "rollback_failed",
2109
+ LogLevel.CRITICAL,
2110
+ {
2111
+ "mode": "finalize",
2112
+ "checkpoint": checkpoint_id,
2113
+ "reason": rollback_reason,
2114
+ "error": restore_error or "restore_failed",
2115
+ },
2116
+ )
2095
2117
 
2096
2118
  # Store rollback metadata in report
2097
2119
  report.meta["rollback_reason"] = rollback_reason
2098
2120
  report.meta["rollback_checkpoint"] = checkpoint_id
2099
- report.meta["guard_recovered"] = True
2121
+ report.meta["guard_recovered"] = bool(restored)
2122
+ report.meta["rollback_failed"] = not bool(restored)
2123
+ if not restored:
2124
+ report.meta["rollback_error"] = restore_error or "restore_failed"
2100
2125
 
2101
2126
  else:
2102
2127
  # Match test expectation: log without additional data payload
@@ -7,6 +7,6 @@ from __future__ import annotations
7
7
  # matching entry to `CHANGELOG.md`.
8
8
 
9
9
  BENCH_GOLDEN_ID = "bench-golden-2025-12-13"
10
- BENCH_GOLDEN_SHA256 = "0d9ff3274d29dad16ad580b4a0cf37b4f89e4f7c2e4345ce3d30a39f146ff5a7"
10
+ BENCH_GOLDEN_SHA256 = "2627b8872cd6bfc37bda31fbc11b78ed814751cbf2a9ad1396e173f1f4e5383a"
11
11
 
12
12
  __all__ = ["BENCH_GOLDEN_ID", "BENCH_GOLDEN_SHA256"]
invarlock/eval/data.py CHANGED
@@ -7,7 +7,6 @@ Pluggable data loading system with deterministic windowing for reproducible eval
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- import atexit
11
10
  import hashlib
12
11
  import json
13
12
  import math
@@ -51,7 +50,6 @@ except ImportError:
51
50
 
52
51
  try:
53
52
  import torch
54
- import torch.nn.functional as F
55
53
 
56
54
  HAS_TORCH = True
57
55
  except ImportError:
@@ -160,9 +158,9 @@ class WikiText2Provider:
160
158
  """
161
159
 
162
160
  name = "wikitext2"
163
- _MODEL_CACHE: Any | None | bool = None
164
- _MODEL_DEVICE: Any | None = None
165
- _CLEANUP_REGISTERED: bool = False
161
+ _BYTE_NGRAM_ORDER = 4
162
+ _BYTE_NGRAM_PAD = 256
163
+ _BYTE_NGRAM_ALPHA = 1.0
166
164
 
167
165
  def __init__(
168
166
  self,
@@ -178,13 +176,9 @@ class WikiText2Provider:
178
176
  """
179
177
  self.cache_dir = cache_dir
180
178
  self._validate_dependencies()
181
- self._register_cleanup()
182
- self._difficulty_model = self.__class__._MODEL_CACHE
183
- self._difficulty_device = self.__class__._MODEL_DEVICE
184
179
  self._last_stratification_stats: dict[str, Any] | None = None
185
180
  self._last_batch_size_used: int = 0
186
181
  self._last_scorer_profile: dict[str, Any] | None = None
187
- self._scorer_warmed: bool = False
188
182
  # In-process cache for loaded/filtered texts to avoid repeated
189
183
  # load_dataset() calls across stratification retries.
190
184
  self._texts_cache: dict[str, list[str]] = {}
@@ -192,43 +186,6 @@ class WikiText2Provider:
192
186
  normalized_hint = (device_hint or "").strip().lower()
193
187
  self._device_hint: str | None = normalized_hint or None
194
188
 
195
- @classmethod
196
- def _register_cleanup(cls) -> None:
197
- """Register an atexit hook once per process to release cached models."""
198
- if cls._CLEANUP_REGISTERED or not HAS_TORCH:
199
- return
200
-
201
- def _cleanup() -> None:
202
- cls._cleanup_model_cache()
203
-
204
- atexit.register(_cleanup)
205
- cls._CLEANUP_REGISTERED = True
206
-
207
- @classmethod
208
- def _cleanup_model_cache(cls) -> None:
209
- """Release cached models to avoid leaking multiprocessing semaphores."""
210
- cache = cls._MODEL_CACHE
211
- if cache is not None and cache is not False and HAS_TORCH:
212
- try:
213
- cache.to("cpu")
214
- except Exception:
215
- pass
216
- cls._MODEL_CACHE = None
217
- cls._MODEL_DEVICE = None
218
-
219
- @staticmethod
220
- def _pick_default_scorer_device() -> torch.device:
221
- """
222
- Choose a default device for the difficulty scorer model.
223
-
224
- Prefers CUDA → MPS → CPU when available.
225
- """
226
- if torch.cuda.is_available():
227
- return torch.device("cuda")
228
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
229
- return torch.device("mps")
230
- return torch.device("cpu")
231
-
232
189
  def _validate_dependencies(self) -> None:
233
190
  """Check that required dependencies are available."""
234
191
  if not HAS_DATASETS:
@@ -513,9 +470,11 @@ class WikiText2Provider:
513
470
  candidates.append(
514
471
  {
515
472
  "dataset_index": idx,
473
+ "text": texts[idx],
516
474
  "input_ids": input_ids_list,
517
475
  "attention_mask": attention_mask_list,
518
476
  "token_count": real_tokens,
477
+ "seq_len": len(input_ids_list),
519
478
  }
520
479
  )
521
480
 
@@ -531,32 +490,7 @@ class WikiText2Provider:
531
490
  details={"needed": int(total_required), "got": int(len(candidates))},
532
491
  )
533
492
 
534
- if not self._score_candidates_with_model(candidates):
535
- token_counter: Counter[int] = Counter()
536
- for candidate in candidates:
537
- for token_id, mask in zip(
538
- candidate["input_ids"], candidate["attention_mask"], strict=False
539
- ):
540
- if mask:
541
- token_counter[int(token_id)] += 1
542
-
543
- total_tokens = sum(token_counter.values()) or 1
544
- vocab_size = max(len(token_counter), 1)
545
-
546
- for candidate in candidates:
547
- difficulty = 0.0
548
- real_tokens = 0
549
- for token_id, mask in zip(
550
- candidate["input_ids"], candidate["attention_mask"], strict=False
551
- ):
552
- if not mask:
553
- continue
554
- freq = (token_counter[int(token_id)] + 1.0) / (
555
- total_tokens + vocab_size
556
- )
557
- difficulty -= math.log(freq)
558
- real_tokens += 1
559
- candidate["difficulty"] = difficulty / max(real_tokens, 1)
493
+ self._score_candidates_byte_ngram(candidates)
560
494
 
561
495
  sorted_candidates = sorted(
562
496
  candidates, key=lambda item: (item["difficulty"], item["dataset_index"])
@@ -843,182 +777,63 @@ class WikiText2Provider:
843
777
 
844
778
  return results
845
779
 
846
- def _score_candidates_with_model(self, candidates: list[dict[str, Any]]) -> bool:
847
- """Score candidate windows using a pretrained GPT-2 model if available."""
848
- if not HAS_TORCH:
849
- return False
850
-
851
- if self._difficulty_model is False:
852
- return False
853
-
854
- try:
855
- eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
856
- device_hint = getattr(self, "_device_hint", None)
857
-
858
- def _is_device_usable(device: torch.device) -> bool:
859
- try:
860
- _ = torch.zeros((1, 1), dtype=torch.long, device=device)
861
- return True
862
- except Exception:
863
- return False
864
-
865
- if self._difficulty_model is None:
866
- from transformers import GPT2LMHeadModel
867
-
868
- model = GPT2LMHeadModel.from_pretrained("gpt2")
869
- model.eval()
870
- # Decide initial scorer device: env override → provider hint → heuristic
871
- if eval_device_override:
872
- try:
873
- device = torch.device(eval_device_override)
874
- except Exception:
875
- device = self._pick_default_scorer_device()
876
- elif device_hint and device_hint != "auto":
877
- try:
878
- device = torch.device(device_hint)
879
- except Exception:
880
- device = self._pick_default_scorer_device()
881
- else:
882
- device = self._pick_default_scorer_device()
883
-
884
- if device.type != "cpu" and not _is_device_usable(device):
885
- warnings.warn(
886
- f"Difficulty scorer device {device} unavailable; falling back to CPU",
887
- stacklevel=2,
888
- )
889
- device = torch.device("cpu")
890
-
891
- model.to(device)
892
- self._difficulty_model = model
893
- self._difficulty_device = device
894
- self.__class__._MODEL_CACHE = model
895
- self.__class__._MODEL_DEVICE = device
896
-
897
- assert self._difficulty_model is not None
898
- model = self._difficulty_model
899
- device = self._difficulty_device or torch.device("cpu")
900
-
901
- # If a new override/hint is provided, move the cached model if needed.
902
- desired_device = device
903
- if eval_device_override:
904
- try:
905
- desired_device = torch.device(eval_device_override)
906
- except Exception:
907
- desired_device = device
908
- elif device_hint and device_hint != "auto":
909
- try:
910
- desired_device = torch.device(device_hint)
911
- except Exception:
912
- desired_device = device
913
-
914
- if desired_device != device:
915
- if desired_device.type != "cpu" and not _is_device_usable(
916
- desired_device
917
- ):
918
- warnings.warn(
919
- f"Difficulty scorer device {desired_device} unavailable; keeping {device}",
920
- stacklevel=2,
921
- )
922
- else:
923
- try:
924
- model.to(desired_device)
925
- device = desired_device
926
- self._difficulty_device = desired_device
927
- self.__class__._MODEL_DEVICE = desired_device
928
- except Exception as exc:
929
- warnings.warn(
930
- f"Failed to move GPT-2 difficulty scorer to {desired_device}: {exc}",
931
- stacklevel=2,
932
- )
933
-
934
- if not self._scorer_warmed:
935
- with torch.no_grad():
936
- dummy_input = torch.zeros((1, 8), dtype=torch.long, device=device)
937
- dummy_attention = torch.ones_like(dummy_input)
938
- model(dummy_input, attention_mask=dummy_attention)
939
- self._scorer_warmed = True
940
-
941
- batch_override = os.environ.get("INVARLOCK_SCORES_BATCH_SIZE")
942
- override_size = None
943
- if batch_override:
944
- try:
945
- override_size = max(1, int(batch_override))
946
- except ValueError:
947
- override_size = None
948
-
949
- batch_size = min(32, max(4, len(candidates)))
950
- if override_size is not None:
951
- batch_size = max(1, min(override_size, len(candidates)))
952
-
953
- input_batch: list[list[int]] = []
954
- attention_batch: list[list[int]] = []
955
- candidate_batch: list[dict[str, Any]] = []
956
- total_tokens = 0
957
- start_time = time.perf_counter()
958
-
959
- with torch.no_grad():
960
- for candidate in candidates:
961
- input_batch.append(candidate["input_ids"])
962
- attention_batch.append(candidate["attention_mask"])
963
- candidate_batch.append(candidate)
964
-
965
- if len(input_batch) == batch_size or candidate is candidates[-1]:
966
- input_tensor = torch.tensor(
967
- input_batch, dtype=torch.long, device=device
968
- )
969
- attention_tensor = torch.tensor(
970
- attention_batch, dtype=torch.long, device=device
971
- )
972
-
973
- outputs = model(input_tensor, attention_mask=attention_tensor)
974
- shift_logits = outputs.logits[:, :-1, :].contiguous()
975
- shift_labels = input_tensor[:, 1:].contiguous()
976
- shift_mask = attention_tensor[:, 1:].contiguous()
977
- shift_labels = shift_labels.masked_fill(shift_mask == 0, 0)
978
-
979
- vocab_size = shift_logits.size(-1)
980
- losses = F.cross_entropy(
981
- shift_logits.view(-1, vocab_size),
982
- shift_labels.view(-1),
983
- reduction="none",
984
- )
985
- losses = losses.view(shift_labels.size()) * shift_mask
986
- token_counts = shift_mask.sum(dim=1).clamp(min=1)
987
- loss_per_example = (
988
- (losses.sum(dim=1) / token_counts).cpu().tolist()
989
- )
990
-
991
- for cand_obj, loss_value in zip(
992
- candidate_batch, loss_per_example, strict=False
993
- ):
994
- cand_obj["difficulty"] = float(loss_value)
995
- total_tokens += int(token_counts.sum().item())
996
-
997
- input_batch.clear()
998
- attention_batch.clear()
999
- candidate_batch.clear()
1000
- self._last_batch_size_used = batch_size
1001
- elapsed = max(time.perf_counter() - start_time, 1e-9)
1002
- tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
1003
- self._last_scorer_profile = {
1004
- "batch_size": batch_size,
1005
- "tokens_processed": total_tokens,
1006
- "elapsed_seconds": elapsed,
1007
- "tokens_per_second": tokens_per_sec,
1008
- }
1009
- return True
1010
- except Exception as exc: # pragma: no cover - defensive
1011
- warnings.warn(
1012
- f"Failed to compute GPT-2 difficulty scores: {exc}", stacklevel=2
1013
- )
1014
- self._difficulty_model = False
1015
- self._difficulty_device = None
1016
- self.__class__._MODEL_CACHE = False
1017
- self.__class__._MODEL_DEVICE = None
780
+ def _score_candidates_byte_ngram(self, candidates: list[dict[str, Any]]) -> bool:
781
+ if not candidates:
1018
782
  self._last_batch_size_used = 0
1019
783
  self._last_scorer_profile = None
1020
784
  return False
1021
785
 
786
+ order = max(1, int(self._BYTE_NGRAM_ORDER))
787
+ pad_token = int(self._BYTE_NGRAM_PAD)
788
+ alpha = float(self._BYTE_NGRAM_ALPHA)
789
+ vocab_size = pad_token + 1
790
+
791
+ context_counts: Counter[tuple[int, ...]] = Counter()
792
+ ngram_counts: Counter[tuple[int, ...]] = Counter()
793
+ sequences: list[list[int]] = []
794
+ start_time = time.perf_counter()
795
+
796
+ for candidate in candidates:
797
+ text = candidate.get("text")
798
+ if not isinstance(text, str):
799
+ text = ""
800
+ byte_values = list(text.encode("utf-8", errors="replace"))
801
+ tokens = ([pad_token] * (order - 1)) + byte_values
802
+ sequences.append(tokens)
803
+ for idx in range(order - 1, len(tokens)):
804
+ context = tuple(tokens[idx - order + 1 : idx])
805
+ ngram = context + (tokens[idx],)
806
+ context_counts[context] += 1
807
+ ngram_counts[ngram] += 1
808
+
809
+ total_tokens = 0
810
+ for candidate, tokens in zip(candidates, sequences, strict=False):
811
+ loss_sum = 0.0
812
+ token_count = 0
813
+ for idx in range(order - 1, len(tokens)):
814
+ context = tuple(tokens[idx - order + 1 : idx])
815
+ ngram = context + (tokens[idx],)
816
+ context_count = context_counts.get(context, 0)
817
+ ngram_count = ngram_counts.get(ngram, 0)
818
+ prob = (ngram_count + alpha) / (context_count + alpha * vocab_size)
819
+ loss_sum += -math.log(prob)
820
+ token_count += 1
821
+ candidate["difficulty"] = loss_sum / max(token_count, 1)
822
+ total_tokens += token_count
823
+
824
+ self._last_batch_size_used = len(candidates)
825
+ elapsed = max(time.perf_counter() - start_time, 1e-9)
826
+ tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
827
+ self._last_scorer_profile = {
828
+ "mode": "byte_ngram",
829
+ "order": order,
830
+ "vocab_size": vocab_size,
831
+ "tokens_processed": total_tokens,
832
+ "elapsed_seconds": elapsed,
833
+ "tokens_per_second": tokens_per_sec,
834
+ }
835
+ return True
836
+
1022
837
  def _tokenize_samples(
1023
838
  self,
1024
839
  texts: list[str],
invarlock/eval/metrics.py CHANGED
@@ -1379,6 +1379,88 @@ def _resolve_eval_device(
1379
1379
  return resolved
1380
1380
 
1381
1381
 
1382
+ def _infer_model_vocab_size(model: nn.Module) -> int | None:
1383
+ """Best-effort vocab size for guarding against invalid token IDs.
1384
+
1385
+ Prefer the actual embedding size (more reliable than config.vocab_size when
1386
+ tokenizers have added tokens), and fall back to config when embeddings are
1387
+ unavailable (e.g., stub models in tests).
1388
+ """
1389
+ try:
1390
+ get_emb = getattr(model, "get_input_embeddings", None)
1391
+ if callable(get_emb):
1392
+ emb = get_emb()
1393
+ weight = getattr(emb, "weight", None)
1394
+ if weight is not None and hasattr(weight, "shape"):
1395
+ size = int(weight.shape[0])
1396
+ if size > 0:
1397
+ return size
1398
+ except Exception:
1399
+ pass
1400
+
1401
+ # Fallback for lightweight/stub models: pick the largest nn.Embedding module.
1402
+ # This is not guaranteed to be the token embedding, but is a good heuristic
1403
+ # when get_input_embeddings/config.vocab_size are unavailable.
1404
+ try:
1405
+ max_embeddings = 0
1406
+ for module in model.modules():
1407
+ if isinstance(module, nn.Embedding):
1408
+ max_embeddings = max(max_embeddings, int(module.num_embeddings))
1409
+ if max_embeddings > 0:
1410
+ return max_embeddings
1411
+ except Exception:
1412
+ pass
1413
+
1414
+ config = getattr(model, "config", None)
1415
+ vocab_size = getattr(config, "vocab_size", None)
1416
+ if isinstance(vocab_size, int) and vocab_size > 0:
1417
+ return vocab_size
1418
+ return None
1419
+
1420
+
1421
+ def _resolve_pad_token_id(model: nn.Module, vocab_size: int | None) -> int:
1422
+ """Pick a safe pad token id for sanitizing invalid token IDs."""
1423
+ config = getattr(model, "config", None)
1424
+ pad_token_id = getattr(config, "pad_token_id", None)
1425
+ if isinstance(pad_token_id, int) and pad_token_id >= 0:
1426
+ if vocab_size is None or pad_token_id < vocab_size:
1427
+ return pad_token_id
1428
+ return 0
1429
+
1430
+
1431
+ def _sanitize_token_ids_for_model(
1432
+ input_ids: torch.Tensor,
1433
+ attention_mask: torch.Tensor | None,
1434
+ labels: torch.Tensor | None,
1435
+ *,
1436
+ vocab_size: int,
1437
+ pad_token_id: int,
1438
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
1439
+ """Prevent device-side asserts from out-of-range token IDs.
1440
+
1441
+ Out-of-range token IDs can trigger CUDA device-side asserts in embedding and
1442
+ gather kernels, poisoning the CUDA context for the entire process. Instead,
1443
+ mask them out as padding and ignore them in labels.
1444
+ """
1445
+ if vocab_size <= 0:
1446
+ return input_ids, attention_mask, labels
1447
+
1448
+ invalid_inputs = (input_ids < 0) | (input_ids >= vocab_size)
1449
+ if invalid_inputs.any():
1450
+ input_ids = input_ids.masked_fill(invalid_inputs, pad_token_id)
1451
+ if attention_mask is not None:
1452
+ attention_mask = attention_mask.masked_fill(invalid_inputs, 0)
1453
+ if labels is not None:
1454
+ labels = labels.masked_fill(invalid_inputs, -100)
1455
+
1456
+ if labels is not None:
1457
+ invalid_labels = (labels != -100) & ((labels < 0) | (labels >= vocab_size))
1458
+ if invalid_labels.any():
1459
+ labels = labels.masked_fill(invalid_labels, -100)
1460
+
1461
+ return input_ids, attention_mask, labels
1462
+
1463
+
1382
1464
  # ── Perplexity calculation ─────────────────────────────────────────────────
1383
1465
  @torch.no_grad()
1384
1466
  def calculate_perplexity(
@@ -1415,6 +1497,8 @@ def compute_perplexity_strict(
1415
1497
  device = _resolve_eval_device(model, device)
1416
1498
 
1417
1499
  model.eval()
1500
+ model_vocab_size = _infer_model_vocab_size(model)
1501
+ pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
1418
1502
  nll_sum = 0.0
1419
1503
  tok_count = 0
1420
1504
 
@@ -1453,6 +1537,15 @@ def compute_perplexity_strict(
1453
1537
  else:
1454
1538
  labels = labels.to(device)
1455
1539
 
1540
+ if model_vocab_size is not None:
1541
+ input_ids, attn, labels = _sanitize_token_ids_for_model(
1542
+ input_ids,
1543
+ attn,
1544
+ labels,
1545
+ vocab_size=model_vocab_size,
1546
+ pad_token_id=pad_token_id,
1547
+ )
1548
+
1456
1549
  # Skip if sequence too short
1457
1550
  if input_ids.size(1) < 2:
1458
1551
  continue
@@ -1507,7 +1600,11 @@ def compute_perplexity_strict(
1507
1600
  continue
1508
1601
 
1509
1602
  log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1510
- tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1603
+ vocab_size = int(shift_logits.size(-1))
1604
+ valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
1605
+ if not valid.any():
1606
+ continue
1607
+ tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
1511
1608
  nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
1512
1609
 
1513
1610
  nll_sum += nll[valid].sum().item()
@@ -1552,6 +1649,8 @@ def compute_perplexity(
1552
1649
  device = _resolve_eval_device(model, device)
1553
1650
 
1554
1651
  model.eval()
1652
+ model_vocab_size = _infer_model_vocab_size(model)
1653
+ pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
1555
1654
  nll_sum = 0.0
1556
1655
  tok_count = 0
1557
1656
  batch_count = 0
@@ -1589,6 +1688,15 @@ def compute_perplexity(
1589
1688
  else:
1590
1689
  labels = labels.to(device)
1591
1690
 
1691
+ if model_vocab_size is not None:
1692
+ input_ids, attn, labels = _sanitize_token_ids_for_model(
1693
+ input_ids,
1694
+ attn,
1695
+ labels,
1696
+ vocab_size=model_vocab_size,
1697
+ pad_token_id=pad_token_id,
1698
+ )
1699
+
1592
1700
  # Skip if sequence too short
1593
1701
  if input_ids.size(1) < 2:
1594
1702
  continue
@@ -1620,7 +1728,11 @@ def compute_perplexity(
1620
1728
 
1621
1729
  # Compute negative log-likelihood
1622
1730
  log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1623
- tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1731
+ vocab_size = int(shift_logits.size(-1))
1732
+ valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
1733
+ if not valid.any():
1734
+ continue
1735
+ tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
1624
1736
 
1625
1737
  # MPS workaround: gather operation can fail on MPS, use CPU fallback
1626
1738
  if str(device).startswith("mps"):
@@ -1694,6 +1806,8 @@ def compute_ppl(
1694
1806
  device = _resolve_eval_device(model, device)
1695
1807
 
1696
1808
  model.eval()
1809
+ model_vocab_size = _infer_model_vocab_size(model)
1810
+ pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
1697
1811
  nll_sum = 0.0
1698
1812
  tok_count = 0
1699
1813
 
@@ -1712,6 +1826,15 @@ def compute_ppl(
1712
1826
  torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
1713
1827
  )
1714
1828
 
1829
+ if model_vocab_size is not None:
1830
+ input_ids_tensor, attention_mask_tensor, _ = _sanitize_token_ids_for_model(
1831
+ input_ids_tensor,
1832
+ attention_mask_tensor,
1833
+ labels=None,
1834
+ vocab_size=model_vocab_size,
1835
+ pad_token_id=pad_token_id,
1836
+ )
1837
+
1715
1838
  # Skip sequences that are too short
1716
1839
  if input_ids_tensor.size(1) < 2:
1717
1840
  continue
@@ -1747,7 +1870,11 @@ def compute_ppl(
1747
1870
 
1748
1871
  # Compute negative log-likelihood
1749
1872
  log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
1750
- tgt = shift_labels.clamp_min(0).unsqueeze(-1) # [B,T-1,1]
1873
+ vocab_size = int(shift_logits.size(-1))
1874
+ valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
1875
+ if not valid.any():
1876
+ continue
1877
+ tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
1751
1878
 
1752
1879
  # Handle MPS device issues with gather
1753
1880
  if str(device).startswith("mps"):