invarlock 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +1 -1
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +1 -1
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +58 -39
- invarlock/cli/commands/doctor.py +3 -1
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/report.py +1 -1
- invarlock/cli/commands/run.py +159 -61
- invarlock/cli/commands/verify.py +78 -4
- invarlock/cli/config.py +21 -5
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +2 -2
- invarlock/core/runner.py +314 -50
- invarlock/eval/bench.py +0 -13
- invarlock/eval/data.py +73 -283
- invarlock/eval/metrics.py +134 -4
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +625 -544
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +5 -29
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +42 -15
- invarlock/reporting/certificate.py +225 -46
- invarlock/reporting/certificate_schema.py +2 -1
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +197 -274
- invarlock/reporting/normalizer.py +6 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +61 -0
- invarlock/reporting/report.py +1 -1
- invarlock/reporting/report_types.py +5 -2
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/METADATA +6 -6
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/RECORD +48 -46
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/WHEEL +0 -0
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/entry_points.txt +0 -0
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.4.dist-info → invarlock-0.3.6.dist-info}/top_level.txt +0 -0
invarlock/eval/data.py
CHANGED
|
@@ -7,7 +7,6 @@ Pluggable data loading system with deterministic windowing for reproducible eval
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
import atexit
|
|
11
10
|
import hashlib
|
|
12
11
|
import json
|
|
13
12
|
import math
|
|
@@ -51,7 +50,6 @@ except ImportError:
|
|
|
51
50
|
|
|
52
51
|
try:
|
|
53
52
|
import torch
|
|
54
|
-
import torch.nn.functional as F
|
|
55
53
|
|
|
56
54
|
HAS_TORCH = True
|
|
57
55
|
except ImportError:
|
|
@@ -160,9 +158,9 @@ class WikiText2Provider:
|
|
|
160
158
|
"""
|
|
161
159
|
|
|
162
160
|
name = "wikitext2"
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
161
|
+
_BYTE_NGRAM_ORDER = 4
|
|
162
|
+
_BYTE_NGRAM_PAD = 256
|
|
163
|
+
_BYTE_NGRAM_ALPHA = 1.0
|
|
166
164
|
|
|
167
165
|
def __init__(
|
|
168
166
|
self,
|
|
@@ -178,13 +176,9 @@ class WikiText2Provider:
|
|
|
178
176
|
"""
|
|
179
177
|
self.cache_dir = cache_dir
|
|
180
178
|
self._validate_dependencies()
|
|
181
|
-
self._register_cleanup()
|
|
182
|
-
self._difficulty_model = self.__class__._MODEL_CACHE
|
|
183
|
-
self._difficulty_device = self.__class__._MODEL_DEVICE
|
|
184
179
|
self._last_stratification_stats: dict[str, Any] | None = None
|
|
185
180
|
self._last_batch_size_used: int = 0
|
|
186
181
|
self._last_scorer_profile: dict[str, Any] | None = None
|
|
187
|
-
self._scorer_warmed: bool = False
|
|
188
182
|
# In-process cache for loaded/filtered texts to avoid repeated
|
|
189
183
|
# load_dataset() calls across stratification retries.
|
|
190
184
|
self._texts_cache: dict[str, list[str]] = {}
|
|
@@ -192,48 +186,9 @@ class WikiText2Provider:
|
|
|
192
186
|
normalized_hint = (device_hint or "").strip().lower()
|
|
193
187
|
self._device_hint: str | None = normalized_hint or None
|
|
194
188
|
|
|
195
|
-
@classmethod
|
|
196
|
-
def _register_cleanup(cls) -> None:
|
|
197
|
-
"""Register an atexit hook once per process to release cached models."""
|
|
198
|
-
if cls._CLEANUP_REGISTERED or not HAS_TORCH:
|
|
199
|
-
return
|
|
200
|
-
|
|
201
|
-
def _cleanup() -> None:
|
|
202
|
-
cls._cleanup_model_cache()
|
|
203
|
-
|
|
204
|
-
atexit.register(_cleanup)
|
|
205
|
-
cls._CLEANUP_REGISTERED = True
|
|
206
|
-
|
|
207
|
-
@classmethod
|
|
208
|
-
def _cleanup_model_cache(cls) -> None:
|
|
209
|
-
"""Release cached models to avoid leaking multiprocessing semaphores."""
|
|
210
|
-
cache = cls._MODEL_CACHE
|
|
211
|
-
if cache is not None and cache is not False and HAS_TORCH:
|
|
212
|
-
try:
|
|
213
|
-
cache.to("cpu")
|
|
214
|
-
except Exception:
|
|
215
|
-
pass
|
|
216
|
-
cls._MODEL_CACHE = None
|
|
217
|
-
cls._MODEL_DEVICE = None
|
|
218
|
-
|
|
219
|
-
@staticmethod
|
|
220
|
-
def _pick_default_scorer_device() -> torch.device:
|
|
221
|
-
"""
|
|
222
|
-
Choose a default device for the difficulty scorer model.
|
|
223
|
-
|
|
224
|
-
Prefers CUDA → MPS → CPU when available.
|
|
225
|
-
"""
|
|
226
|
-
if torch.cuda.is_available():
|
|
227
|
-
return torch.device("cuda")
|
|
228
|
-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
229
|
-
return torch.device("mps")
|
|
230
|
-
return torch.device("cpu")
|
|
231
|
-
|
|
232
189
|
def _validate_dependencies(self) -> None:
|
|
233
190
|
"""Check that required dependencies are available."""
|
|
234
191
|
if not HAS_DATASETS:
|
|
235
|
-
if _LIGHT_IMPORT:
|
|
236
|
-
return
|
|
237
192
|
raise _DepErr(
|
|
238
193
|
code="E301",
|
|
239
194
|
message=(
|
|
@@ -371,13 +326,6 @@ class WikiText2Provider:
|
|
|
371
326
|
if cached is not None and len(cached) >= max_samples:
|
|
372
327
|
return cached[:max_samples]
|
|
373
328
|
|
|
374
|
-
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
375
|
-
texts = ["hello world", "invarlock synthetic text"] * max(
|
|
376
|
-
1, max_samples // 2
|
|
377
|
-
)
|
|
378
|
-
self._texts_cache[split] = texts
|
|
379
|
-
return texts[:max_samples]
|
|
380
|
-
|
|
381
329
|
# Load dataset with size limit for efficiency
|
|
382
330
|
dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
|
|
383
331
|
dataset = load_dataset(
|
|
@@ -513,9 +461,11 @@ class WikiText2Provider:
|
|
|
513
461
|
candidates.append(
|
|
514
462
|
{
|
|
515
463
|
"dataset_index": idx,
|
|
464
|
+
"text": texts[idx],
|
|
516
465
|
"input_ids": input_ids_list,
|
|
517
466
|
"attention_mask": attention_mask_list,
|
|
518
467
|
"token_count": real_tokens,
|
|
468
|
+
"seq_len": len(input_ids_list),
|
|
519
469
|
}
|
|
520
470
|
)
|
|
521
471
|
|
|
@@ -531,32 +481,7 @@ class WikiText2Provider:
|
|
|
531
481
|
details={"needed": int(total_required), "got": int(len(candidates))},
|
|
532
482
|
)
|
|
533
483
|
|
|
534
|
-
|
|
535
|
-
token_counter: Counter[int] = Counter()
|
|
536
|
-
for candidate in candidates:
|
|
537
|
-
for token_id, mask in zip(
|
|
538
|
-
candidate["input_ids"], candidate["attention_mask"], strict=False
|
|
539
|
-
):
|
|
540
|
-
if mask:
|
|
541
|
-
token_counter[int(token_id)] += 1
|
|
542
|
-
|
|
543
|
-
total_tokens = sum(token_counter.values()) or 1
|
|
544
|
-
vocab_size = max(len(token_counter), 1)
|
|
545
|
-
|
|
546
|
-
for candidate in candidates:
|
|
547
|
-
difficulty = 0.0
|
|
548
|
-
real_tokens = 0
|
|
549
|
-
for token_id, mask in zip(
|
|
550
|
-
candidate["input_ids"], candidate["attention_mask"], strict=False
|
|
551
|
-
):
|
|
552
|
-
if not mask:
|
|
553
|
-
continue
|
|
554
|
-
freq = (token_counter[int(token_id)] + 1.0) / (
|
|
555
|
-
total_tokens + vocab_size
|
|
556
|
-
)
|
|
557
|
-
difficulty -= math.log(freq)
|
|
558
|
-
real_tokens += 1
|
|
559
|
-
candidate["difficulty"] = difficulty / max(real_tokens, 1)
|
|
484
|
+
self._score_candidates_byte_ngram(candidates)
|
|
560
485
|
|
|
561
486
|
sorted_candidates = sorted(
|
|
562
487
|
candidates, key=lambda item: (item["difficulty"], item["dataset_index"])
|
|
@@ -843,193 +768,63 @@ class WikiText2Provider:
|
|
|
843
768
|
|
|
844
769
|
return results
|
|
845
770
|
|
|
846
|
-
def
|
|
847
|
-
|
|
848
|
-
if not HAS_TORCH:
|
|
849
|
-
return False
|
|
850
|
-
|
|
851
|
-
if self._difficulty_model is False:
|
|
852
|
-
return False
|
|
853
|
-
|
|
854
|
-
try:
|
|
855
|
-
eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
|
|
856
|
-
device_hint = getattr(self, "_device_hint", None)
|
|
857
|
-
|
|
858
|
-
def _is_device_usable(device: torch.device) -> bool:
|
|
859
|
-
try:
|
|
860
|
-
_ = torch.zeros((1, 1), dtype=torch.long, device=device)
|
|
861
|
-
return True
|
|
862
|
-
except Exception:
|
|
863
|
-
return False
|
|
864
|
-
|
|
865
|
-
if self._difficulty_model is None:
|
|
866
|
-
from transformers import GPT2LMHeadModel
|
|
867
|
-
|
|
868
|
-
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
|
869
|
-
model.eval()
|
|
870
|
-
# Decide initial scorer device: env override → provider hint → heuristic
|
|
871
|
-
if eval_device_override:
|
|
872
|
-
try:
|
|
873
|
-
device = torch.device(eval_device_override)
|
|
874
|
-
except Exception:
|
|
875
|
-
device = self._pick_default_scorer_device()
|
|
876
|
-
elif device_hint and device_hint != "auto":
|
|
877
|
-
try:
|
|
878
|
-
device = torch.device(device_hint)
|
|
879
|
-
except Exception:
|
|
880
|
-
device = self._pick_default_scorer_device()
|
|
881
|
-
else:
|
|
882
|
-
device = self._pick_default_scorer_device()
|
|
883
|
-
|
|
884
|
-
if device.type != "cpu" and not _is_device_usable(device):
|
|
885
|
-
warnings.warn(
|
|
886
|
-
f"Difficulty scorer device {device} unavailable; falling back to CPU",
|
|
887
|
-
stacklevel=2,
|
|
888
|
-
)
|
|
889
|
-
device = torch.device("cpu")
|
|
890
|
-
|
|
891
|
-
model.to(device)
|
|
892
|
-
self._difficulty_model = model
|
|
893
|
-
self._difficulty_device = device
|
|
894
|
-
self.__class__._MODEL_CACHE = model
|
|
895
|
-
self.__class__._MODEL_DEVICE = device
|
|
896
|
-
|
|
897
|
-
assert self._difficulty_model is not None
|
|
898
|
-
model = self._difficulty_model
|
|
899
|
-
device = self._difficulty_device or torch.device("cpu")
|
|
900
|
-
|
|
901
|
-
# If a new override/hint is provided, move the cached model if needed.
|
|
902
|
-
desired_device = device
|
|
903
|
-
if eval_device_override:
|
|
904
|
-
try:
|
|
905
|
-
desired_device = torch.device(eval_device_override)
|
|
906
|
-
except Exception:
|
|
907
|
-
desired_device = device
|
|
908
|
-
elif device_hint and device_hint != "auto":
|
|
909
|
-
try:
|
|
910
|
-
desired_device = torch.device(device_hint)
|
|
911
|
-
except Exception:
|
|
912
|
-
desired_device = device
|
|
913
|
-
|
|
914
|
-
if desired_device != device:
|
|
915
|
-
if desired_device.type != "cpu" and not _is_device_usable(
|
|
916
|
-
desired_device
|
|
917
|
-
):
|
|
918
|
-
warnings.warn(
|
|
919
|
-
f"Difficulty scorer device {desired_device} unavailable; keeping {device}",
|
|
920
|
-
stacklevel=2,
|
|
921
|
-
)
|
|
922
|
-
else:
|
|
923
|
-
try:
|
|
924
|
-
model.to(desired_device)
|
|
925
|
-
device = desired_device
|
|
926
|
-
self._difficulty_device = desired_device
|
|
927
|
-
self.__class__._MODEL_DEVICE = desired_device
|
|
928
|
-
except Exception as exc:
|
|
929
|
-
warnings.warn(
|
|
930
|
-
f"Failed to move GPT-2 difficulty scorer to {desired_device}: {exc}",
|
|
931
|
-
stacklevel=2,
|
|
932
|
-
)
|
|
933
|
-
|
|
934
|
-
if not self._scorer_warmed:
|
|
935
|
-
with torch.no_grad():
|
|
936
|
-
dummy_input = torch.zeros((1, 8), dtype=torch.long, device=device)
|
|
937
|
-
dummy_attention = torch.ones_like(dummy_input)
|
|
938
|
-
model(dummy_input, attention_mask=dummy_attention)
|
|
939
|
-
self._scorer_warmed = True
|
|
940
|
-
|
|
941
|
-
batch_override = os.environ.get("INVARLOCK_SCORES_BATCH_SIZE")
|
|
942
|
-
override_size = None
|
|
943
|
-
if batch_override:
|
|
944
|
-
try:
|
|
945
|
-
override_size = max(1, int(batch_override))
|
|
946
|
-
except ValueError:
|
|
947
|
-
override_size = None
|
|
948
|
-
|
|
949
|
-
batch_size = min(32, max(4, len(candidates)))
|
|
950
|
-
if override_size is not None:
|
|
951
|
-
batch_size = max(1, min(override_size, len(candidates)))
|
|
952
|
-
|
|
953
|
-
config = getattr(model, "config", None)
|
|
954
|
-
scorer_vocab_size = getattr(config, "vocab_size", None)
|
|
955
|
-
|
|
956
|
-
input_batch: list[list[int]] = []
|
|
957
|
-
attention_batch: list[list[int]] = []
|
|
958
|
-
candidate_batch: list[dict[str, Any]] = []
|
|
959
|
-
total_tokens = 0
|
|
960
|
-
start_time = time.perf_counter()
|
|
961
|
-
|
|
962
|
-
with torch.no_grad():
|
|
963
|
-
for candidate in candidates:
|
|
964
|
-
input_batch.append(candidate["input_ids"])
|
|
965
|
-
attention_batch.append(candidate["attention_mask"])
|
|
966
|
-
candidate_batch.append(candidate)
|
|
967
|
-
|
|
968
|
-
if len(input_batch) == batch_size or candidate is candidates[-1]:
|
|
969
|
-
input_tensor = torch.tensor(
|
|
970
|
-
input_batch, dtype=torch.long, device=device
|
|
971
|
-
)
|
|
972
|
-
attention_tensor = torch.tensor(
|
|
973
|
-
attention_batch, dtype=torch.long, device=device
|
|
974
|
-
)
|
|
975
|
-
|
|
976
|
-
# Guard against out-of-range token IDs when scoring with GPT-2.
|
|
977
|
-
# Some model tokenizers emit IDs beyond GPT-2 vocab, which can
|
|
978
|
-
# trigger device-side asserts in embedding/gather kernels.
|
|
979
|
-
if scorer_vocab_size and scorer_vocab_size > 0:
|
|
980
|
-
input_tensor = input_tensor.clamp(
|
|
981
|
-
min=0, max=scorer_vocab_size - 1
|
|
982
|
-
)
|
|
983
|
-
|
|
984
|
-
outputs = model(input_tensor, attention_mask=attention_tensor)
|
|
985
|
-
shift_logits = outputs.logits[:, :-1, :].contiguous()
|
|
986
|
-
shift_labels = input_tensor[:, 1:].contiguous()
|
|
987
|
-
shift_mask = attention_tensor[:, 1:].contiguous()
|
|
988
|
-
shift_labels = shift_labels.masked_fill(shift_mask == 0, 0)
|
|
989
|
-
|
|
990
|
-
vocab_size = shift_logits.size(-1)
|
|
991
|
-
losses = F.cross_entropy(
|
|
992
|
-
shift_logits.view(-1, vocab_size),
|
|
993
|
-
shift_labels.view(-1),
|
|
994
|
-
reduction="none",
|
|
995
|
-
)
|
|
996
|
-
losses = losses.view(shift_labels.size()) * shift_mask
|
|
997
|
-
token_counts = shift_mask.sum(dim=1).clamp(min=1)
|
|
998
|
-
loss_per_example = (
|
|
999
|
-
(losses.sum(dim=1) / token_counts).cpu().tolist()
|
|
1000
|
-
)
|
|
1001
|
-
|
|
1002
|
-
for cand_obj, loss_value in zip(
|
|
1003
|
-
candidate_batch, loss_per_example, strict=False
|
|
1004
|
-
):
|
|
1005
|
-
cand_obj["difficulty"] = float(loss_value)
|
|
1006
|
-
total_tokens += int(token_counts.sum().item())
|
|
1007
|
-
|
|
1008
|
-
input_batch.clear()
|
|
1009
|
-
attention_batch.clear()
|
|
1010
|
-
candidate_batch.clear()
|
|
1011
|
-
self._last_batch_size_used = batch_size
|
|
1012
|
-
elapsed = max(time.perf_counter() - start_time, 1e-9)
|
|
1013
|
-
tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
|
|
1014
|
-
self._last_scorer_profile = {
|
|
1015
|
-
"batch_size": batch_size,
|
|
1016
|
-
"tokens_processed": total_tokens,
|
|
1017
|
-
"elapsed_seconds": elapsed,
|
|
1018
|
-
"tokens_per_second": tokens_per_sec,
|
|
1019
|
-
}
|
|
1020
|
-
return True
|
|
1021
|
-
except Exception as exc: # pragma: no cover - defensive
|
|
1022
|
-
warnings.warn(
|
|
1023
|
-
f"Failed to compute GPT-2 difficulty scores: {exc}", stacklevel=2
|
|
1024
|
-
)
|
|
1025
|
-
self._difficulty_model = False
|
|
1026
|
-
self._difficulty_device = None
|
|
1027
|
-
self.__class__._MODEL_CACHE = False
|
|
1028
|
-
self.__class__._MODEL_DEVICE = None
|
|
771
|
+
def _score_candidates_byte_ngram(self, candidates: list[dict[str, Any]]) -> bool:
|
|
772
|
+
if not candidates:
|
|
1029
773
|
self._last_batch_size_used = 0
|
|
1030
774
|
self._last_scorer_profile = None
|
|
1031
775
|
return False
|
|
1032
776
|
|
|
777
|
+
order = max(1, int(self._BYTE_NGRAM_ORDER))
|
|
778
|
+
pad_token = int(self._BYTE_NGRAM_PAD)
|
|
779
|
+
alpha = float(self._BYTE_NGRAM_ALPHA)
|
|
780
|
+
vocab_size = pad_token + 1
|
|
781
|
+
|
|
782
|
+
context_counts: Counter[tuple[int, ...]] = Counter()
|
|
783
|
+
ngram_counts: Counter[tuple[int, ...]] = Counter()
|
|
784
|
+
sequences: list[list[int]] = []
|
|
785
|
+
start_time = time.perf_counter()
|
|
786
|
+
|
|
787
|
+
for candidate in candidates:
|
|
788
|
+
text = candidate.get("text")
|
|
789
|
+
if not isinstance(text, str):
|
|
790
|
+
text = ""
|
|
791
|
+
byte_values = list(text.encode("utf-8", errors="replace"))
|
|
792
|
+
tokens = ([pad_token] * (order - 1)) + byte_values
|
|
793
|
+
sequences.append(tokens)
|
|
794
|
+
for idx in range(order - 1, len(tokens)):
|
|
795
|
+
context = tuple(tokens[idx - order + 1 : idx])
|
|
796
|
+
ngram = context + (tokens[idx],)
|
|
797
|
+
context_counts[context] += 1
|
|
798
|
+
ngram_counts[ngram] += 1
|
|
799
|
+
|
|
800
|
+
total_tokens = 0
|
|
801
|
+
for candidate, tokens in zip(candidates, sequences, strict=False):
|
|
802
|
+
loss_sum = 0.0
|
|
803
|
+
token_count = 0
|
|
804
|
+
for idx in range(order - 1, len(tokens)):
|
|
805
|
+
context = tuple(tokens[idx - order + 1 : idx])
|
|
806
|
+
ngram = context + (tokens[idx],)
|
|
807
|
+
context_count = context_counts.get(context, 0)
|
|
808
|
+
ngram_count = ngram_counts.get(ngram, 0)
|
|
809
|
+
prob = (ngram_count + alpha) / (context_count + alpha * vocab_size)
|
|
810
|
+
loss_sum += -math.log(prob)
|
|
811
|
+
token_count += 1
|
|
812
|
+
candidate["difficulty"] = loss_sum / max(token_count, 1)
|
|
813
|
+
total_tokens += token_count
|
|
814
|
+
|
|
815
|
+
self._last_batch_size_used = len(candidates)
|
|
816
|
+
elapsed = max(time.perf_counter() - start_time, 1e-9)
|
|
817
|
+
tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
|
|
818
|
+
self._last_scorer_profile = {
|
|
819
|
+
"mode": "byte_ngram",
|
|
820
|
+
"order": order,
|
|
821
|
+
"vocab_size": vocab_size,
|
|
822
|
+
"tokens_processed": total_tokens,
|
|
823
|
+
"elapsed_seconds": elapsed,
|
|
824
|
+
"tokens_per_second": tokens_per_sec,
|
|
825
|
+
}
|
|
826
|
+
return True
|
|
827
|
+
|
|
1033
828
|
def _tokenize_samples(
|
|
1034
829
|
self,
|
|
1035
830
|
texts: list[str],
|
|
@@ -1258,14 +1053,13 @@ class HFTextProvider:
|
|
|
1258
1053
|
max_samples: int = 2000,
|
|
1259
1054
|
):
|
|
1260
1055
|
if not HAS_DATASETS:
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
)
|
|
1056
|
+
raise _DepErr(
|
|
1057
|
+
code="E301",
|
|
1058
|
+
message=(
|
|
1059
|
+
"DEPENDENCY-MISSING: datasets library required for hf_text provider"
|
|
1060
|
+
),
|
|
1061
|
+
details={"dependency": "datasets"},
|
|
1062
|
+
)
|
|
1269
1063
|
self.dataset_name = dataset_name or "wikitext"
|
|
1270
1064
|
self.config_name = config_name or None
|
|
1271
1065
|
self.text_field = text_field
|
|
@@ -1273,9 +1067,6 @@ class HFTextProvider:
|
|
|
1273
1067
|
self.max_samples = int(max_samples)
|
|
1274
1068
|
|
|
1275
1069
|
def load(self, split: str = "validation", **kwargs) -> list[str]:
|
|
1276
|
-
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
1277
|
-
return ["synthetic dataset text"] * int(self.max_samples or 1)
|
|
1278
|
-
|
|
1279
1070
|
ds = load_dataset(
|
|
1280
1071
|
path=self.dataset_name,
|
|
1281
1072
|
name=self.config_name,
|
|
@@ -1400,14 +1191,13 @@ class HFSeq2SeqProvider:
|
|
|
1400
1191
|
max_samples: int = 2000,
|
|
1401
1192
|
) -> None:
|
|
1402
1193
|
if not HAS_DATASETS:
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
)
|
|
1194
|
+
raise _DepErr(
|
|
1195
|
+
code="E301",
|
|
1196
|
+
message=(
|
|
1197
|
+
"DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
|
|
1198
|
+
),
|
|
1199
|
+
details={"dependency": "datasets"},
|
|
1200
|
+
)
|
|
1411
1201
|
self.dataset_name = dataset_name
|
|
1412
1202
|
self.config_name = config_name
|
|
1413
1203
|
self.src_field = src_field
|
invarlock/eval/metrics.py
CHANGED
|
@@ -723,7 +723,10 @@ def calculate_lens_metrics_for_model(
|
|
|
723
723
|
except Exception as e:
|
|
724
724
|
logger.error(f"Metrics calculation failed: {e}")
|
|
725
725
|
if config.strict_validation:
|
|
726
|
-
raise MetricsError(
|
|
726
|
+
raise MetricsError(
|
|
727
|
+
code="E401",
|
|
728
|
+
message=f"METRICS-COMPUTE-FAILED: {e}",
|
|
729
|
+
) from e
|
|
727
730
|
|
|
728
731
|
finally:
|
|
729
732
|
resource_manager.cleanup()
|
|
@@ -1379,6 +1382,88 @@ def _resolve_eval_device(
|
|
|
1379
1382
|
return resolved
|
|
1380
1383
|
|
|
1381
1384
|
|
|
1385
|
+
def _infer_model_vocab_size(model: nn.Module) -> int | None:
|
|
1386
|
+
"""Best-effort vocab size for guarding against invalid token IDs.
|
|
1387
|
+
|
|
1388
|
+
Prefer the actual embedding size (more reliable than config.vocab_size when
|
|
1389
|
+
tokenizers have added tokens), and fall back to config when embeddings are
|
|
1390
|
+
unavailable (e.g., stub models in tests).
|
|
1391
|
+
"""
|
|
1392
|
+
try:
|
|
1393
|
+
get_emb = getattr(model, "get_input_embeddings", None)
|
|
1394
|
+
if callable(get_emb):
|
|
1395
|
+
emb = get_emb()
|
|
1396
|
+
weight = getattr(emb, "weight", None)
|
|
1397
|
+
if weight is not None and hasattr(weight, "shape"):
|
|
1398
|
+
size = int(weight.shape[0])
|
|
1399
|
+
if size > 0:
|
|
1400
|
+
return size
|
|
1401
|
+
except Exception:
|
|
1402
|
+
pass
|
|
1403
|
+
|
|
1404
|
+
# Fallback for lightweight/stub models: pick the largest nn.Embedding module.
|
|
1405
|
+
# This is not guaranteed to be the token embedding, but is a good heuristic
|
|
1406
|
+
# when get_input_embeddings/config.vocab_size are unavailable.
|
|
1407
|
+
try:
|
|
1408
|
+
max_embeddings = 0
|
|
1409
|
+
for module in model.modules():
|
|
1410
|
+
if isinstance(module, nn.Embedding):
|
|
1411
|
+
max_embeddings = max(max_embeddings, int(module.num_embeddings))
|
|
1412
|
+
if max_embeddings > 0:
|
|
1413
|
+
return max_embeddings
|
|
1414
|
+
except Exception:
|
|
1415
|
+
pass
|
|
1416
|
+
|
|
1417
|
+
config = getattr(model, "config", None)
|
|
1418
|
+
vocab_size = getattr(config, "vocab_size", None)
|
|
1419
|
+
if isinstance(vocab_size, int) and vocab_size > 0:
|
|
1420
|
+
return vocab_size
|
|
1421
|
+
return None
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
def _resolve_pad_token_id(model: nn.Module, vocab_size: int | None) -> int:
|
|
1425
|
+
"""Pick a safe pad token id for sanitizing invalid token IDs."""
|
|
1426
|
+
config = getattr(model, "config", None)
|
|
1427
|
+
pad_token_id = getattr(config, "pad_token_id", None)
|
|
1428
|
+
if isinstance(pad_token_id, int) and pad_token_id >= 0:
|
|
1429
|
+
if vocab_size is None or pad_token_id < vocab_size:
|
|
1430
|
+
return pad_token_id
|
|
1431
|
+
return 0
|
|
1432
|
+
|
|
1433
|
+
|
|
1434
|
+
def _sanitize_token_ids_for_model(
|
|
1435
|
+
input_ids: torch.Tensor,
|
|
1436
|
+
attention_mask: torch.Tensor | None,
|
|
1437
|
+
labels: torch.Tensor | None,
|
|
1438
|
+
*,
|
|
1439
|
+
vocab_size: int,
|
|
1440
|
+
pad_token_id: int,
|
|
1441
|
+
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
|
1442
|
+
"""Prevent device-side asserts from out-of-range token IDs.
|
|
1443
|
+
|
|
1444
|
+
Out-of-range token IDs can trigger CUDA device-side asserts in embedding and
|
|
1445
|
+
gather kernels, poisoning the CUDA context for the entire process. Instead,
|
|
1446
|
+
mask them out as padding and ignore them in labels.
|
|
1447
|
+
"""
|
|
1448
|
+
if vocab_size <= 0:
|
|
1449
|
+
return input_ids, attention_mask, labels
|
|
1450
|
+
|
|
1451
|
+
invalid_inputs = (input_ids < 0) | (input_ids >= vocab_size)
|
|
1452
|
+
if invalid_inputs.any():
|
|
1453
|
+
input_ids = input_ids.masked_fill(invalid_inputs, pad_token_id)
|
|
1454
|
+
if attention_mask is not None:
|
|
1455
|
+
attention_mask = attention_mask.masked_fill(invalid_inputs, 0)
|
|
1456
|
+
if labels is not None:
|
|
1457
|
+
labels = labels.masked_fill(invalid_inputs, -100)
|
|
1458
|
+
|
|
1459
|
+
if labels is not None:
|
|
1460
|
+
invalid_labels = (labels != -100) & ((labels < 0) | (labels >= vocab_size))
|
|
1461
|
+
if invalid_labels.any():
|
|
1462
|
+
labels = labels.masked_fill(invalid_labels, -100)
|
|
1463
|
+
|
|
1464
|
+
return input_ids, attention_mask, labels
|
|
1465
|
+
|
|
1466
|
+
|
|
1382
1467
|
# ── Perplexity calculation ─────────────────────────────────────────────────
|
|
1383
1468
|
@torch.no_grad()
|
|
1384
1469
|
def calculate_perplexity(
|
|
@@ -1415,6 +1500,8 @@ def compute_perplexity_strict(
|
|
|
1415
1500
|
device = _resolve_eval_device(model, device)
|
|
1416
1501
|
|
|
1417
1502
|
model.eval()
|
|
1503
|
+
model_vocab_size = _infer_model_vocab_size(model)
|
|
1504
|
+
pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
|
|
1418
1505
|
nll_sum = 0.0
|
|
1419
1506
|
tok_count = 0
|
|
1420
1507
|
|
|
@@ -1453,6 +1540,15 @@ def compute_perplexity_strict(
|
|
|
1453
1540
|
else:
|
|
1454
1541
|
labels = labels.to(device)
|
|
1455
1542
|
|
|
1543
|
+
if model_vocab_size is not None:
|
|
1544
|
+
input_ids, attn, labels = _sanitize_token_ids_for_model(
|
|
1545
|
+
input_ids,
|
|
1546
|
+
attn,
|
|
1547
|
+
labels,
|
|
1548
|
+
vocab_size=model_vocab_size,
|
|
1549
|
+
pad_token_id=pad_token_id,
|
|
1550
|
+
)
|
|
1551
|
+
|
|
1456
1552
|
# Skip if sequence too short
|
|
1457
1553
|
if input_ids.size(1) < 2:
|
|
1458
1554
|
continue
|
|
@@ -1507,7 +1603,11 @@ def compute_perplexity_strict(
|
|
|
1507
1603
|
continue
|
|
1508
1604
|
|
|
1509
1605
|
log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
|
|
1510
|
-
|
|
1606
|
+
vocab_size = int(shift_logits.size(-1))
|
|
1607
|
+
valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
|
|
1608
|
+
if not valid.any():
|
|
1609
|
+
continue
|
|
1610
|
+
tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
|
|
1511
1611
|
nll = -log_probs.gather(-1, tgt).squeeze(-1) # [B,T-1]
|
|
1512
1612
|
|
|
1513
1613
|
nll_sum += nll[valid].sum().item()
|
|
@@ -1552,6 +1652,8 @@ def compute_perplexity(
|
|
|
1552
1652
|
device = _resolve_eval_device(model, device)
|
|
1553
1653
|
|
|
1554
1654
|
model.eval()
|
|
1655
|
+
model_vocab_size = _infer_model_vocab_size(model)
|
|
1656
|
+
pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
|
|
1555
1657
|
nll_sum = 0.0
|
|
1556
1658
|
tok_count = 0
|
|
1557
1659
|
batch_count = 0
|
|
@@ -1589,6 +1691,15 @@ def compute_perplexity(
|
|
|
1589
1691
|
else:
|
|
1590
1692
|
labels = labels.to(device)
|
|
1591
1693
|
|
|
1694
|
+
if model_vocab_size is not None:
|
|
1695
|
+
input_ids, attn, labels = _sanitize_token_ids_for_model(
|
|
1696
|
+
input_ids,
|
|
1697
|
+
attn,
|
|
1698
|
+
labels,
|
|
1699
|
+
vocab_size=model_vocab_size,
|
|
1700
|
+
pad_token_id=pad_token_id,
|
|
1701
|
+
)
|
|
1702
|
+
|
|
1592
1703
|
# Skip if sequence too short
|
|
1593
1704
|
if input_ids.size(1) < 2:
|
|
1594
1705
|
continue
|
|
@@ -1620,7 +1731,11 @@ def compute_perplexity(
|
|
|
1620
1731
|
|
|
1621
1732
|
# Compute negative log-likelihood
|
|
1622
1733
|
log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
|
|
1623
|
-
|
|
1734
|
+
vocab_size = int(shift_logits.size(-1))
|
|
1735
|
+
valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
|
|
1736
|
+
if not valid.any():
|
|
1737
|
+
continue
|
|
1738
|
+
tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
|
|
1624
1739
|
|
|
1625
1740
|
# MPS workaround: gather operation can fail on MPS, use CPU fallback
|
|
1626
1741
|
if str(device).startswith("mps"):
|
|
@@ -1694,6 +1809,8 @@ def compute_ppl(
|
|
|
1694
1809
|
device = _resolve_eval_device(model, device)
|
|
1695
1810
|
|
|
1696
1811
|
model.eval()
|
|
1812
|
+
model_vocab_size = _infer_model_vocab_size(model)
|
|
1813
|
+
pad_token_id = _resolve_pad_token_id(model, model_vocab_size)
|
|
1697
1814
|
nll_sum = 0.0
|
|
1698
1815
|
tok_count = 0
|
|
1699
1816
|
|
|
@@ -1712,6 +1829,15 @@ def compute_ppl(
|
|
|
1712
1829
|
torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0).to(device)
|
|
1713
1830
|
)
|
|
1714
1831
|
|
|
1832
|
+
if model_vocab_size is not None:
|
|
1833
|
+
input_ids_tensor, attention_mask_tensor, _ = _sanitize_token_ids_for_model(
|
|
1834
|
+
input_ids_tensor,
|
|
1835
|
+
attention_mask_tensor,
|
|
1836
|
+
labels=None,
|
|
1837
|
+
vocab_size=model_vocab_size,
|
|
1838
|
+
pad_token_id=pad_token_id,
|
|
1839
|
+
)
|
|
1840
|
+
|
|
1715
1841
|
# Skip sequences that are too short
|
|
1716
1842
|
if input_ids_tensor.size(1) < 2:
|
|
1717
1843
|
continue
|
|
@@ -1747,7 +1873,11 @@ def compute_ppl(
|
|
|
1747
1873
|
|
|
1748
1874
|
# Compute negative log-likelihood
|
|
1749
1875
|
log_probs = shift_logits.log_softmax(dim=-1) # [B,T-1,V]
|
|
1750
|
-
|
|
1876
|
+
vocab_size = int(shift_logits.size(-1))
|
|
1877
|
+
valid = valid & (shift_labels >= 0) & (shift_labels < vocab_size)
|
|
1878
|
+
if not valid.any():
|
|
1879
|
+
continue
|
|
1880
|
+
tgt = shift_labels.clamp(min=0, max=vocab_size - 1).unsqueeze(-1) # [B,T-1,1]
|
|
1751
1881
|
|
|
1752
1882
|
# Handle MPS device issues with gather
|
|
1753
1883
|
if str(device).startswith("mps"):
|