invarlock 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +33 -0
- invarlock/__main__.py +10 -0
- invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
- invarlock/_data/runtime/profiles/release.yaml +23 -0
- invarlock/_data/runtime/tiers.yaml +76 -0
- invarlock/adapters/__init__.py +102 -0
- invarlock/adapters/_capabilities.py +45 -0
- invarlock/adapters/auto.py +99 -0
- invarlock/adapters/base.py +530 -0
- invarlock/adapters/base_types.py +85 -0
- invarlock/adapters/hf_bert.py +852 -0
- invarlock/adapters/hf_gpt2.py +403 -0
- invarlock/adapters/hf_llama.py +485 -0
- invarlock/adapters/hf_mixin.py +383 -0
- invarlock/adapters/hf_onnx.py +112 -0
- invarlock/adapters/hf_t5.py +137 -0
- invarlock/adapters/py.typed +1 -0
- invarlock/assurance/__init__.py +43 -0
- invarlock/cli/__init__.py +8 -0
- invarlock/cli/__main__.py +8 -0
- invarlock/cli/_evidence.py +25 -0
- invarlock/cli/_json.py +75 -0
- invarlock/cli/adapter_auto.py +162 -0
- invarlock/cli/app.py +287 -0
- invarlock/cli/commands/__init__.py +26 -0
- invarlock/cli/commands/certify.py +403 -0
- invarlock/cli/commands/doctor.py +1358 -0
- invarlock/cli/commands/explain_gates.py +151 -0
- invarlock/cli/commands/export_html.py +100 -0
- invarlock/cli/commands/plugins.py +1331 -0
- invarlock/cli/commands/report.py +354 -0
- invarlock/cli/commands/run.py +4146 -0
- invarlock/cli/commands/verify.py +1040 -0
- invarlock/cli/config.py +396 -0
- invarlock/cli/constants.py +68 -0
- invarlock/cli/device.py +92 -0
- invarlock/cli/doctor_helpers.py +74 -0
- invarlock/cli/errors.py +6 -0
- invarlock/cli/overhead_utils.py +60 -0
- invarlock/cli/provenance.py +66 -0
- invarlock/cli/utils.py +41 -0
- invarlock/config.py +56 -0
- invarlock/core/__init__.py +62 -0
- invarlock/core/abi.py +15 -0
- invarlock/core/api.py +274 -0
- invarlock/core/auto_tuning.py +317 -0
- invarlock/core/bootstrap.py +226 -0
- invarlock/core/checkpoint.py +221 -0
- invarlock/core/contracts.py +73 -0
- invarlock/core/error_utils.py +64 -0
- invarlock/core/events.py +298 -0
- invarlock/core/exceptions.py +95 -0
- invarlock/core/registry.py +481 -0
- invarlock/core/retry.py +146 -0
- invarlock/core/runner.py +2041 -0
- invarlock/core/types.py +154 -0
- invarlock/edits/__init__.py +12 -0
- invarlock/edits/_edit_utils.py +249 -0
- invarlock/edits/_external_utils.py +268 -0
- invarlock/edits/noop.py +47 -0
- invarlock/edits/py.typed +1 -0
- invarlock/edits/quant_rtn.py +801 -0
- invarlock/edits/registry.py +166 -0
- invarlock/eval/__init__.py +23 -0
- invarlock/eval/bench.py +1207 -0
- invarlock/eval/bootstrap.py +50 -0
- invarlock/eval/data.py +2052 -0
- invarlock/eval/metrics.py +2167 -0
- invarlock/eval/primary_metric.py +767 -0
- invarlock/eval/probes/__init__.py +24 -0
- invarlock/eval/probes/fft.py +139 -0
- invarlock/eval/probes/mi.py +213 -0
- invarlock/eval/probes/post_attention.py +323 -0
- invarlock/eval/providers/base.py +67 -0
- invarlock/eval/providers/seq2seq.py +111 -0
- invarlock/eval/providers/text_lm.py +113 -0
- invarlock/eval/providers/vision_text.py +93 -0
- invarlock/eval/py.typed +1 -0
- invarlock/guards/__init__.py +18 -0
- invarlock/guards/_contracts.py +9 -0
- invarlock/guards/invariants.py +640 -0
- invarlock/guards/policies.py +805 -0
- invarlock/guards/py.typed +1 -0
- invarlock/guards/rmt.py +2097 -0
- invarlock/guards/spectral.py +1419 -0
- invarlock/guards/tier_config.py +354 -0
- invarlock/guards/variance.py +3298 -0
- invarlock/guards_ref/__init__.py +15 -0
- invarlock/guards_ref/rmt_ref.py +40 -0
- invarlock/guards_ref/spectral_ref.py +135 -0
- invarlock/guards_ref/variance_ref.py +60 -0
- invarlock/model_profile.py +353 -0
- invarlock/model_utils.py +221 -0
- invarlock/observability/__init__.py +10 -0
- invarlock/observability/alerting.py +535 -0
- invarlock/observability/core.py +546 -0
- invarlock/observability/exporters.py +565 -0
- invarlock/observability/health.py +588 -0
- invarlock/observability/metrics.py +457 -0
- invarlock/observability/py.typed +1 -0
- invarlock/observability/utils.py +553 -0
- invarlock/plugins/__init__.py +12 -0
- invarlock/plugins/hello_guard.py +33 -0
- invarlock/plugins/hf_awq_adapter.py +82 -0
- invarlock/plugins/hf_bnb_adapter.py +79 -0
- invarlock/plugins/hf_gptq_adapter.py +78 -0
- invarlock/plugins/py.typed +1 -0
- invarlock/py.typed +1 -0
- invarlock/reporting/__init__.py +7 -0
- invarlock/reporting/certificate.py +3221 -0
- invarlock/reporting/certificate_schema.py +244 -0
- invarlock/reporting/dataset_hashing.py +215 -0
- invarlock/reporting/guards_analysis.py +948 -0
- invarlock/reporting/html.py +32 -0
- invarlock/reporting/normalizer.py +235 -0
- invarlock/reporting/policy_utils.py +517 -0
- invarlock/reporting/primary_metric_utils.py +265 -0
- invarlock/reporting/render.py +1442 -0
- invarlock/reporting/report.py +903 -0
- invarlock/reporting/report_types.py +278 -0
- invarlock/reporting/utils.py +175 -0
- invarlock/reporting/validate.py +631 -0
- invarlock/security.py +176 -0
- invarlock/sparsity_utils.py +323 -0
- invarlock/utils/__init__.py +150 -0
- invarlock/utils/digest.py +45 -0
- invarlock-0.2.0.dist-info/METADATA +586 -0
- invarlock-0.2.0.dist-info/RECORD +132 -0
- invarlock-0.2.0.dist-info/WHEEL +5 -0
- invarlock-0.2.0.dist-info/entry_points.txt +20 -0
- invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
- invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/eval/data.py
ADDED
|
@@ -0,0 +1,2052 @@
|
|
|
1
|
+
"""
|
|
2
|
+
InvarLock Evaluation Data Loading
|
|
3
|
+
============================
|
|
4
|
+
|
|
5
|
+
Pluggable data loading system with deterministic windowing for reproducible evaluation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import atexit
|
|
11
|
+
import hashlib
|
|
12
|
+
import json
|
|
13
|
+
import math
|
|
14
|
+
import os
|
|
15
|
+
import time
|
|
16
|
+
import warnings
|
|
17
|
+
from abc import abstractmethod
|
|
18
|
+
from collections import Counter
|
|
19
|
+
from collections.abc import Sequence
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any, NamedTuple, Protocol
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from invarlock.core.exceptions import DataError as _DataErr
|
|
26
|
+
from invarlock.core.exceptions import DependencyError as _DepErr
|
|
27
|
+
from invarlock.core.exceptions import ValidationError as _ValErr
|
|
28
|
+
|
|
29
|
+
# NOTE: During the typed-only migration, avoid hybrid KeyError mixin
|
|
30
|
+
|
|
31
|
+
_LIGHT_IMPORT = os.getenv("INVARLOCK_LIGHT_IMPORT", "").strip().lower() in {
|
|
32
|
+
"1",
|
|
33
|
+
"true",
|
|
34
|
+
"yes",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
from datasets import load_dataset
|
|
39
|
+
|
|
40
|
+
HAS_DATASETS = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
HAS_DATASETS = False
|
|
43
|
+
|
|
44
|
+
def load_dataset(*args, **kwargs): # type: ignore[no-redef]
|
|
45
|
+
raise _DepErr(
|
|
46
|
+
code="E301",
|
|
47
|
+
message="DEPENDENCY-MISSING: datasets library required for dataset loading",
|
|
48
|
+
details={"dependency": "datasets"},
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
import torch
|
|
54
|
+
import torch.nn.functional as F
|
|
55
|
+
|
|
56
|
+
HAS_TORCH = True
|
|
57
|
+
except ImportError:
|
|
58
|
+
HAS_TORCH = False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class EvaluationWindow(NamedTuple):
|
|
62
|
+
"""A window of tokenized samples for evaluation."""
|
|
63
|
+
|
|
64
|
+
input_ids: list[list[int]] # List of tokenized sequences
|
|
65
|
+
attention_masks: list[list[int]] # Attention masks (1=real token, 0=padding)
|
|
66
|
+
indices: list[int] # Original dataset indices
|
|
67
|
+
|
|
68
|
+
def __len__(self) -> int:
|
|
69
|
+
return len(self.input_ids)
|
|
70
|
+
|
|
71
|
+
def to_dict(self) -> dict[str, Any]:
|
|
72
|
+
"""Convert to dictionary for serialization."""
|
|
73
|
+
return {
|
|
74
|
+
"input_ids": self.input_ids,
|
|
75
|
+
"attention_masks": self.attention_masks,
|
|
76
|
+
"indices": self.indices,
|
|
77
|
+
"length": len(self.input_ids),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class DatasetProvider(Protocol):
|
|
82
|
+
"""
|
|
83
|
+
Protocol for pluggable dataset providers.
|
|
84
|
+
|
|
85
|
+
Enables extensible dataset support while maintaining deterministic evaluation.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
name: str
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def load(self, split: str = "validation", **kwargs) -> list[str]:
|
|
92
|
+
"""
|
|
93
|
+
Load raw text samples from the dataset.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
split: Dataset split to load ("validation", "test", "train")
|
|
97
|
+
**kwargs: Provider-specific parameters
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
List of text strings
|
|
101
|
+
"""
|
|
102
|
+
...
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def windows(
|
|
106
|
+
self,
|
|
107
|
+
tokenizer: Any,
|
|
108
|
+
*,
|
|
109
|
+
seq_len: int = 128,
|
|
110
|
+
stride: int = 64,
|
|
111
|
+
preview_n: int = 100,
|
|
112
|
+
final_n: int = 100,
|
|
113
|
+
seed: int = 42,
|
|
114
|
+
split: str = "validation",
|
|
115
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
116
|
+
"""
|
|
117
|
+
Create deterministic preview and final evaluation windows.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
tokenizer: Tokenizer to use for text encoding
|
|
121
|
+
seq_len: Maximum sequence length
|
|
122
|
+
stride: Stride for overlapping windows (unused in current impl)
|
|
123
|
+
preview_n: Number of preview samples
|
|
124
|
+
final_n: Number of final samples
|
|
125
|
+
seed: Random seed for deterministic sampling
|
|
126
|
+
split: Dataset split to use
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Tuple of (preview_window, final_window)
|
|
130
|
+
"""
|
|
131
|
+
...
|
|
132
|
+
|
|
133
|
+
def estimate_capacity(
|
|
134
|
+
self,
|
|
135
|
+
tokenizer: Any,
|
|
136
|
+
*,
|
|
137
|
+
seq_len: int,
|
|
138
|
+
stride: int,
|
|
139
|
+
split: str = "validation",
|
|
140
|
+
target_total: int | None = None,
|
|
141
|
+
fast_mode: bool = False,
|
|
142
|
+
) -> dict[str, Any]:
|
|
143
|
+
"""
|
|
144
|
+
Estimate number of non-overlapping, deduplicated windows available for evaluation.
|
|
145
|
+
|
|
146
|
+
Returns metadata describing the available capacity (total tokens, usable windows, dedupe rate).
|
|
147
|
+
"""
|
|
148
|
+
...
|
|
149
|
+
|
|
150
|
+
def info(self) -> dict[str, Any]:
|
|
151
|
+
"""Get information about this dataset provider."""
|
|
152
|
+
return {"name": self.name, "type": "dataset_provider"}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class WikiText2Provider:
|
|
156
|
+
"""
|
|
157
|
+
WikiText-2 dataset provider with deterministic windowing.
|
|
158
|
+
|
|
159
|
+
Implements the canonical WT-2 evaluation setup with fixed 100+100 preview/final samples.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
name = "wikitext2"
|
|
163
|
+
_MODEL_CACHE: Any | None | bool = None
|
|
164
|
+
_MODEL_DEVICE: Any | None = None
|
|
165
|
+
_CLEANUP_REGISTERED: bool = False
|
|
166
|
+
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
cache_dir: Path | None = None,
|
|
170
|
+
device_hint: str | None = None,
|
|
171
|
+
**_: Any,
|
|
172
|
+
):
|
|
173
|
+
"""
|
|
174
|
+
Initialize WikiText-2 provider.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
cache_dir: Optional cache directory for dataset storage
|
|
178
|
+
"""
|
|
179
|
+
self.cache_dir = cache_dir
|
|
180
|
+
self._validate_dependencies()
|
|
181
|
+
self._register_cleanup()
|
|
182
|
+
self._difficulty_model = self.__class__._MODEL_CACHE
|
|
183
|
+
self._difficulty_device = self.__class__._MODEL_DEVICE
|
|
184
|
+
self._last_stratification_stats: dict[str, Any] | None = None
|
|
185
|
+
self._last_batch_size_used: int = 0
|
|
186
|
+
self._last_scorer_profile: dict[str, Any] | None = None
|
|
187
|
+
self._scorer_warmed: bool = False
|
|
188
|
+
# In-process cache for loaded/filtered texts to avoid repeated
|
|
189
|
+
# load_dataset() calls across stratification retries.
|
|
190
|
+
self._texts_cache: dict[str, list[str]] = {}
|
|
191
|
+
# Optional device hint from CLI/resolved run device (e.g. "cpu", "cuda", "mps", "auto")
|
|
192
|
+
normalized_hint = (device_hint or "").strip().lower()
|
|
193
|
+
self._device_hint: str | None = normalized_hint or None
|
|
194
|
+
|
|
195
|
+
@classmethod
|
|
196
|
+
def _register_cleanup(cls) -> None:
|
|
197
|
+
"""Register an atexit hook once per process to release cached models."""
|
|
198
|
+
if cls._CLEANUP_REGISTERED or not HAS_TORCH:
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
def _cleanup() -> None:
|
|
202
|
+
cls._cleanup_model_cache()
|
|
203
|
+
|
|
204
|
+
atexit.register(_cleanup)
|
|
205
|
+
cls._CLEANUP_REGISTERED = True
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def _cleanup_model_cache(cls) -> None:
|
|
209
|
+
"""Release cached models to avoid leaking multiprocessing semaphores."""
|
|
210
|
+
cache = cls._MODEL_CACHE
|
|
211
|
+
if cache is not None and cache is not False and HAS_TORCH:
|
|
212
|
+
try:
|
|
213
|
+
cache.to("cpu")
|
|
214
|
+
except Exception:
|
|
215
|
+
pass
|
|
216
|
+
cls._MODEL_CACHE = None
|
|
217
|
+
cls._MODEL_DEVICE = None
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def _pick_default_scorer_device() -> torch.device:
|
|
221
|
+
"""
|
|
222
|
+
Choose a default device for the difficulty scorer model.
|
|
223
|
+
|
|
224
|
+
Prefers CUDA → MPS → CPU when available.
|
|
225
|
+
"""
|
|
226
|
+
if torch.cuda.is_available():
|
|
227
|
+
return torch.device("cuda")
|
|
228
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
229
|
+
return torch.device("mps")
|
|
230
|
+
return torch.device("cpu")
|
|
231
|
+
|
|
232
|
+
def _validate_dependencies(self) -> None:
|
|
233
|
+
"""Check that required dependencies are available."""
|
|
234
|
+
if not HAS_DATASETS:
|
|
235
|
+
if _LIGHT_IMPORT:
|
|
236
|
+
return
|
|
237
|
+
raise _DepErr(
|
|
238
|
+
code="E301",
|
|
239
|
+
message=(
|
|
240
|
+
"DEPENDENCY-MISSING: datasets library required for WikiText-2 loading"
|
|
241
|
+
),
|
|
242
|
+
details={"dependency": "datasets"},
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def estimate_capacity(
|
|
246
|
+
self,
|
|
247
|
+
tokenizer: Any,
|
|
248
|
+
*,
|
|
249
|
+
seq_len: int,
|
|
250
|
+
stride: int,
|
|
251
|
+
split: str = "validation",
|
|
252
|
+
target_total: int | None = None,
|
|
253
|
+
fast_mode: bool = False,
|
|
254
|
+
) -> dict[str, Any]:
|
|
255
|
+
"""Estimate available non-overlapping windows for evaluation."""
|
|
256
|
+
texts = self.load(split=split, max_samples=2000)
|
|
257
|
+
if not texts:
|
|
258
|
+
return {
|
|
259
|
+
"total_tokens": 0,
|
|
260
|
+
"available_nonoverlap": 0,
|
|
261
|
+
"available_unique": 0,
|
|
262
|
+
"dedupe_rate": 0.0,
|
|
263
|
+
"stride": stride,
|
|
264
|
+
"seq_len": seq_len,
|
|
265
|
+
"candidate_unique": 0,
|
|
266
|
+
"candidate_limit": 0,
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
env_fast = os.environ.get("INVARLOCK_CAPACITY_FAST", "")
|
|
270
|
+
env_fast_flag = isinstance(env_fast, str) and env_fast.strip().lower() in {
|
|
271
|
+
"1",
|
|
272
|
+
"true",
|
|
273
|
+
"yes",
|
|
274
|
+
"on",
|
|
275
|
+
}
|
|
276
|
+
use_fast = bool(fast_mode) or env_fast_flag
|
|
277
|
+
if use_fast:
|
|
278
|
+
base_available = len(texts)
|
|
279
|
+
target_total = int(target_total or 0)
|
|
280
|
+
approx_available = base_available
|
|
281
|
+
if target_total > 0:
|
|
282
|
+
approx_available = max(base_available, target_total)
|
|
283
|
+
total_tokens = int(max(approx_available, 0) * seq_len)
|
|
284
|
+
approx_available = int(max(approx_available, 0))
|
|
285
|
+
return {
|
|
286
|
+
"total_tokens": total_tokens,
|
|
287
|
+
"available_nonoverlap": approx_available,
|
|
288
|
+
"available_unique": approx_available,
|
|
289
|
+
"dedupe_rate": 0.0,
|
|
290
|
+
"stride": int(stride),
|
|
291
|
+
"seq_len": int(seq_len),
|
|
292
|
+
"candidate_unique": approx_available,
|
|
293
|
+
"candidate_limit": approx_available,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
tokenized = self._collect_tokenized_samples(
|
|
297
|
+
texts, list(range(len(texts))), tokenizer, seq_len
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
total_tokens = sum(item[3] for item in tokenized)
|
|
301
|
+
available_nonoverlap = len(tokenized)
|
|
302
|
+
|
|
303
|
+
unique_sequences: set[tuple[int, ...]] = set()
|
|
304
|
+
for _, input_ids, attention_mask, _ in tokenized:
|
|
305
|
+
seq = tuple(
|
|
306
|
+
int(tok_id)
|
|
307
|
+
for tok_id, mask in zip(input_ids, attention_mask, strict=False)
|
|
308
|
+
if mask
|
|
309
|
+
)
|
|
310
|
+
unique_sequences.add(seq)
|
|
311
|
+
|
|
312
|
+
available_unique = len(unique_sequences)
|
|
313
|
+
dedupe_rate = (
|
|
314
|
+
0.0
|
|
315
|
+
if available_nonoverlap == 0
|
|
316
|
+
else max(
|
|
317
|
+
0.0,
|
|
318
|
+
1.0 - (available_unique / float(max(available_nonoverlap, 1))),
|
|
319
|
+
)
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
candidate_unique = None
|
|
323
|
+
candidate_limit = None
|
|
324
|
+
if target_total is not None and target_total > 0:
|
|
325
|
+
reserve_buffer = max(int(target_total * 0.2), 64)
|
|
326
|
+
candidate_limit = min(len(texts), target_total + reserve_buffer)
|
|
327
|
+
tokenized_subset = self._collect_tokenized_samples(
|
|
328
|
+
texts, list(range(candidate_limit)), tokenizer, seq_len
|
|
329
|
+
)
|
|
330
|
+
subset_signatures = {
|
|
331
|
+
tuple(
|
|
332
|
+
int(tok)
|
|
333
|
+
for tok, mask in zip(entry[1], entry[2], strict=False)
|
|
334
|
+
if mask
|
|
335
|
+
)
|
|
336
|
+
for entry in tokenized_subset
|
|
337
|
+
}
|
|
338
|
+
candidate_unique = len(subset_signatures)
|
|
339
|
+
|
|
340
|
+
result = {
|
|
341
|
+
"total_tokens": int(total_tokens),
|
|
342
|
+
"available_nonoverlap": int(available_nonoverlap),
|
|
343
|
+
"available_unique": int(available_unique),
|
|
344
|
+
"dedupe_rate": float(dedupe_rate),
|
|
345
|
+
"stride": int(stride),
|
|
346
|
+
"seq_len": int(seq_len),
|
|
347
|
+
}
|
|
348
|
+
if candidate_unique is not None:
|
|
349
|
+
result["candidate_unique"] = int(candidate_unique)
|
|
350
|
+
result["candidate_limit"] = int(candidate_limit or 0)
|
|
351
|
+
return result
|
|
352
|
+
|
|
353
|
+
def load(
|
|
354
|
+
self, split: str = "validation", max_samples: int = 2000, **kwargs
|
|
355
|
+
) -> list[str]:
|
|
356
|
+
"""
|
|
357
|
+
Load WikiText-2 text samples.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
split: Dataset split ("validation", "test", "train")
|
|
361
|
+
max_samples: Maximum samples to load
|
|
362
|
+
**kwargs: Additional parameters (ignored)
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
List of filtered text strings
|
|
366
|
+
"""
|
|
367
|
+
print(f"📚 Loading WikiText-2 {split} split...")
|
|
368
|
+
|
|
369
|
+
# Serve from cache when possible (load the largest slice once)
|
|
370
|
+
cached = self._texts_cache.get(split)
|
|
371
|
+
if cached is not None and len(cached) >= max_samples:
|
|
372
|
+
return cached[:max_samples]
|
|
373
|
+
|
|
374
|
+
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
375
|
+
texts = ["hello world", "invarlock synthetic text"] * max(
|
|
376
|
+
1, max_samples // 2
|
|
377
|
+
)
|
|
378
|
+
self._texts_cache[split] = texts
|
|
379
|
+
return texts[:max_samples]
|
|
380
|
+
|
|
381
|
+
# Load dataset with size limit for efficiency
|
|
382
|
+
dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
|
|
383
|
+
dataset = load_dataset(
|
|
384
|
+
"wikitext",
|
|
385
|
+
"wikitext-2-raw-v1",
|
|
386
|
+
split=dataset_slice,
|
|
387
|
+
cache_dir=str(self.cache_dir) if self.cache_dir else None,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Filter out empty/short texts
|
|
391
|
+
valid_texts: list[str] = []
|
|
392
|
+
for item in dataset:
|
|
393
|
+
text = str(item.get("text", "")).strip()
|
|
394
|
+
# Keep texts with at least 20 characters and some alphabetic content
|
|
395
|
+
if len(text) >= 20 and any(c.isalpha() for c in text):
|
|
396
|
+
valid_texts.append(text)
|
|
397
|
+
|
|
398
|
+
# Optional exact-text dedupe to reduce duplicate-token windows
|
|
399
|
+
# Enable via INVARLOCK_DEDUP_TEXTS=1 (keeps first occurrence, preserves order)
|
|
400
|
+
import os as _os
|
|
401
|
+
|
|
402
|
+
if str(_os.environ.get("INVARLOCK_DEDUP_TEXTS", "")).strip().lower() in {
|
|
403
|
+
"1",
|
|
404
|
+
"true",
|
|
405
|
+
"yes",
|
|
406
|
+
"on",
|
|
407
|
+
}:
|
|
408
|
+
seen: set[str] = set()
|
|
409
|
+
deduped: list[str] = []
|
|
410
|
+
for t in valid_texts:
|
|
411
|
+
if t not in seen:
|
|
412
|
+
seen.add(t)
|
|
413
|
+
deduped.append(t)
|
|
414
|
+
valid_texts = deduped
|
|
415
|
+
|
|
416
|
+
# Cache the largest slice we’ve seen for this split
|
|
417
|
+
prev = self._texts_cache.get(split)
|
|
418
|
+
if prev is None or len(valid_texts) > len(prev):
|
|
419
|
+
self._texts_cache[split] = list(valid_texts)
|
|
420
|
+
|
|
421
|
+
print(f" ✓ Loaded {len(valid_texts)} valid samples from {len(dataset)} total")
|
|
422
|
+
return valid_texts
|
|
423
|
+
|
|
424
|
+
def windows(
|
|
425
|
+
self,
|
|
426
|
+
tokenizer: Any,
|
|
427
|
+
*,
|
|
428
|
+
seq_len: int = 128,
|
|
429
|
+
stride: int = 64,
|
|
430
|
+
preview_n: int = 100,
|
|
431
|
+
final_n: int = 100,
|
|
432
|
+
seed: int = 42,
|
|
433
|
+
split: str = "validation",
|
|
434
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
435
|
+
"""
|
|
436
|
+
Create deterministic preview and final evaluation windows.
|
|
437
|
+
|
|
438
|
+
This implements the core deterministic evaluation requirement:
|
|
439
|
+
- Fixed seed ensures reproducible sample selection
|
|
440
|
+
- Non-overlapping preview and final samples
|
|
441
|
+
- Consistent tokenization parameters
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
tokenizer: HuggingFace tokenizer for text encoding
|
|
445
|
+
seq_len: Maximum sequence length for tokenization
|
|
446
|
+
stride: Stride parameter (reserved for future use)
|
|
447
|
+
preview_n: Number of preview samples (default: 100)
|
|
448
|
+
final_n: Number of final samples (default: 100)
|
|
449
|
+
seed: Random seed for reproducible sampling
|
|
450
|
+
split: Dataset split to use
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
Tuple of (preview_window, final_window) with deterministic samples
|
|
454
|
+
"""
|
|
455
|
+
total_required = preview_n + final_n
|
|
456
|
+
if total_required <= 0:
|
|
457
|
+
raise _ValErr(
|
|
458
|
+
code="E302", message="VALIDATION-FAILED: preview/final must be positive"
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Load text data with additional buffer to ensure enough valid samples for release windows.
|
|
462
|
+
extra_pool = max(500, int(0.5 * total_required))
|
|
463
|
+
max_samples = max(total_required + extra_pool, 2000)
|
|
464
|
+
texts = self.load(split=split, max_samples=max_samples)
|
|
465
|
+
|
|
466
|
+
rng = np.random.RandomState(seed)
|
|
467
|
+
shuffled_indices = rng.permutation(len(texts)).tolist()
|
|
468
|
+
|
|
469
|
+
reserve = max(16, int(0.1 * total_required))
|
|
470
|
+
target_pool = min(len(texts), total_required + reserve * 2)
|
|
471
|
+
|
|
472
|
+
if target_pool < total_required:
|
|
473
|
+
raise _DataErr(
|
|
474
|
+
code="E303",
|
|
475
|
+
message=(
|
|
476
|
+
"CAPACITY-INSUFFICIENT: not enough valid samples for requested preview/final"
|
|
477
|
+
),
|
|
478
|
+
details={
|
|
479
|
+
"have": int(len(texts)),
|
|
480
|
+
"preview": int(preview_n),
|
|
481
|
+
"final": int(final_n),
|
|
482
|
+
},
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
candidates: list[dict[str, Any]] = []
|
|
486
|
+
used_indices: set[int] = set()
|
|
487
|
+
cursor = 0
|
|
488
|
+
chunk_size = max(64, min(256, target_pool))
|
|
489
|
+
|
|
490
|
+
print(" 📊 Creating evaluation windows:")
|
|
491
|
+
print(f" Requested preview/final: {preview_n}/{final_n}")
|
|
492
|
+
print(f" Sampling pool target: {target_pool} (reserve {reserve})")
|
|
493
|
+
|
|
494
|
+
while len(candidates) < total_required + reserve and cursor < len(
|
|
495
|
+
shuffled_indices
|
|
496
|
+
):
|
|
497
|
+
batch = shuffled_indices[cursor : cursor + chunk_size]
|
|
498
|
+
cursor += chunk_size
|
|
499
|
+
|
|
500
|
+
tokenized_batch = self._collect_tokenized_samples(
|
|
501
|
+
texts, batch, tokenizer, seq_len
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
for (
|
|
505
|
+
idx,
|
|
506
|
+
input_ids_list,
|
|
507
|
+
attention_mask_list,
|
|
508
|
+
real_tokens,
|
|
509
|
+
) in tokenized_batch:
|
|
510
|
+
if idx in used_indices:
|
|
511
|
+
continue
|
|
512
|
+
used_indices.add(idx)
|
|
513
|
+
candidates.append(
|
|
514
|
+
{
|
|
515
|
+
"dataset_index": idx,
|
|
516
|
+
"input_ids": input_ids_list,
|
|
517
|
+
"attention_mask": attention_mask_list,
|
|
518
|
+
"token_count": real_tokens,
|
|
519
|
+
}
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
if cursor >= len(shuffled_indices) and len(candidates) < total_required:
|
|
523
|
+
break
|
|
524
|
+
|
|
525
|
+
if len(candidates) < total_required:
|
|
526
|
+
raise _DataErr(
|
|
527
|
+
code="E304",
|
|
528
|
+
message=(
|
|
529
|
+
"TOKENIZE-INSUFFICIENT: failed to gather enough tokenized samples"
|
|
530
|
+
),
|
|
531
|
+
details={"needed": int(total_required), "got": int(len(candidates))},
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if not self._score_candidates_with_model(candidates):
|
|
535
|
+
token_counter: Counter[int] = Counter()
|
|
536
|
+
for candidate in candidates:
|
|
537
|
+
for token_id, mask in zip(
|
|
538
|
+
candidate["input_ids"], candidate["attention_mask"], strict=False
|
|
539
|
+
):
|
|
540
|
+
if mask:
|
|
541
|
+
token_counter[int(token_id)] += 1
|
|
542
|
+
|
|
543
|
+
total_tokens = sum(token_counter.values()) or 1
|
|
544
|
+
vocab_size = max(len(token_counter), 1)
|
|
545
|
+
|
|
546
|
+
for candidate in candidates:
|
|
547
|
+
difficulty = 0.0
|
|
548
|
+
real_tokens = 0
|
|
549
|
+
for token_id, mask in zip(
|
|
550
|
+
candidate["input_ids"], candidate["attention_mask"], strict=False
|
|
551
|
+
):
|
|
552
|
+
if not mask:
|
|
553
|
+
continue
|
|
554
|
+
freq = (token_counter[int(token_id)] + 1.0) / (
|
|
555
|
+
total_tokens + vocab_size
|
|
556
|
+
)
|
|
557
|
+
difficulty -= math.log(freq)
|
|
558
|
+
real_tokens += 1
|
|
559
|
+
candidate["difficulty"] = difficulty / max(real_tokens, 1)
|
|
560
|
+
|
|
561
|
+
sorted_candidates = sorted(
|
|
562
|
+
candidates, key=lambda item: (item["difficulty"], item["dataset_index"])
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
total_candidates = len(sorted_candidates)
|
|
566
|
+
selection_count = total_required
|
|
567
|
+
selected_positions: list[int] = []
|
|
568
|
+
used_positions: set[int] = set()
|
|
569
|
+
|
|
570
|
+
for k in range(selection_count):
|
|
571
|
+
target_position = (k + 0.5) * total_candidates / selection_count
|
|
572
|
+
base_idx = int(round(target_position))
|
|
573
|
+
offset = 0
|
|
574
|
+
chosen: int | None = None
|
|
575
|
+
|
|
576
|
+
while offset < total_candidates:
|
|
577
|
+
for candidate_idx in (base_idx + offset, base_idx - offset):
|
|
578
|
+
if (
|
|
579
|
+
0 <= candidate_idx < total_candidates
|
|
580
|
+
and candidate_idx not in used_positions
|
|
581
|
+
):
|
|
582
|
+
chosen = candidate_idx
|
|
583
|
+
break
|
|
584
|
+
if chosen is not None:
|
|
585
|
+
break
|
|
586
|
+
offset += 1
|
|
587
|
+
|
|
588
|
+
if chosen is not None:
|
|
589
|
+
used_positions.add(chosen)
|
|
590
|
+
selected_positions.append(chosen)
|
|
591
|
+
|
|
592
|
+
if len(selected_positions) < selection_count:
|
|
593
|
+
for candidate_idx in range(total_candidates):
|
|
594
|
+
if candidate_idx not in used_positions:
|
|
595
|
+
used_positions.add(candidate_idx)
|
|
596
|
+
selected_positions.append(candidate_idx)
|
|
597
|
+
if len(selected_positions) == selection_count:
|
|
598
|
+
break
|
|
599
|
+
|
|
600
|
+
if len(selected_positions) < selection_count:
|
|
601
|
+
raise _DataErr(
|
|
602
|
+
code="E305", message="STRATIFY-FAILED: candidate pool insufficient"
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
selected_candidates = [sorted_candidates[idx] for idx in selected_positions]
|
|
606
|
+
selected_candidates.sort(
|
|
607
|
+
key=lambda item: (item["difficulty"], item["dataset_index"])
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
preview_candidates: list[dict[str, Any]] = []
|
|
611
|
+
final_candidates: list[dict[str, Any]] = []
|
|
612
|
+
|
|
613
|
+
def assign_candidate(
|
|
614
|
+
candidate: dict[str, Any],
|
|
615
|
+
primary: list[dict[str, Any]],
|
|
616
|
+
secondary: list[dict[str, Any]],
|
|
617
|
+
primary_capacity: int,
|
|
618
|
+
secondary_capacity: int,
|
|
619
|
+
) -> None:
|
|
620
|
+
if len(primary) < primary_capacity:
|
|
621
|
+
primary.append(candidate)
|
|
622
|
+
elif len(secondary) < secondary_capacity:
|
|
623
|
+
secondary.append(candidate)
|
|
624
|
+
|
|
625
|
+
for pair_start in range(0, len(selected_candidates), 2):
|
|
626
|
+
pair = selected_candidates[pair_start : pair_start + 2]
|
|
627
|
+
if not pair:
|
|
628
|
+
continue
|
|
629
|
+
if len(pair) == 2:
|
|
630
|
+
easy, hard = pair
|
|
631
|
+
pair_index = pair_start // 2
|
|
632
|
+
if pair_index % 2 == 0:
|
|
633
|
+
assign_candidate(
|
|
634
|
+
easy, preview_candidates, final_candidates, preview_n, final_n
|
|
635
|
+
)
|
|
636
|
+
assign_candidate(
|
|
637
|
+
hard, final_candidates, preview_candidates, final_n, preview_n
|
|
638
|
+
)
|
|
639
|
+
else:
|
|
640
|
+
assign_candidate(
|
|
641
|
+
easy, final_candidates, preview_candidates, final_n, preview_n
|
|
642
|
+
)
|
|
643
|
+
assign_candidate(
|
|
644
|
+
hard, preview_candidates, final_candidates, preview_n, final_n
|
|
645
|
+
)
|
|
646
|
+
else:
|
|
647
|
+
lone_candidate = pair[0]
|
|
648
|
+
assign_candidate(
|
|
649
|
+
lone_candidate,
|
|
650
|
+
preview_candidates,
|
|
651
|
+
final_candidates,
|
|
652
|
+
preview_n,
|
|
653
|
+
final_n,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
assigned_ids = {
|
|
657
|
+
id(candidate) for candidate in preview_candidates + final_candidates
|
|
658
|
+
}
|
|
659
|
+
remaining = [
|
|
660
|
+
candidate
|
|
661
|
+
for candidate in selected_candidates
|
|
662
|
+
if id(candidate) not in assigned_ids
|
|
663
|
+
]
|
|
664
|
+
for candidate in remaining:
|
|
665
|
+
if len(preview_candidates) < preview_n:
|
|
666
|
+
preview_candidates.append(candidate)
|
|
667
|
+
elif len(final_candidates) < final_n:
|
|
668
|
+
final_candidates.append(candidate)
|
|
669
|
+
|
|
670
|
+
def _mean_difficulty(candidates: list[dict[str, Any]]) -> float:
|
|
671
|
+
if not candidates:
|
|
672
|
+
return 0.0
|
|
673
|
+
return float(
|
|
674
|
+
sum(candidate["difficulty"] for candidate in candidates)
|
|
675
|
+
/ len(candidates)
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
for _ in range(100):
|
|
679
|
+
if not preview_candidates or not final_candidates:
|
|
680
|
+
break
|
|
681
|
+
diff = _mean_difficulty(preview_candidates) - _mean_difficulty(
|
|
682
|
+
final_candidates
|
|
683
|
+
)
|
|
684
|
+
if abs(diff) <= 1e-4:
|
|
685
|
+
break
|
|
686
|
+
if diff < 0:
|
|
687
|
+
preview_candidate = min(
|
|
688
|
+
preview_candidates, key=lambda c: c["difficulty"]
|
|
689
|
+
)
|
|
690
|
+
final_candidate = max(final_candidates, key=lambda c: c["difficulty"])
|
|
691
|
+
else:
|
|
692
|
+
preview_candidate = max(
|
|
693
|
+
preview_candidates, key=lambda c: c["difficulty"]
|
|
694
|
+
)
|
|
695
|
+
final_candidate = min(final_candidates, key=lambda c: c["difficulty"])
|
|
696
|
+
|
|
697
|
+
if preview_candidate is final_candidate:
|
|
698
|
+
break
|
|
699
|
+
|
|
700
|
+
preview_candidates.remove(preview_candidate)
|
|
701
|
+
final_candidates.remove(final_candidate)
|
|
702
|
+
preview_candidates.append(final_candidate)
|
|
703
|
+
final_candidates.append(preview_candidate)
|
|
704
|
+
|
|
705
|
+
new_diff = _mean_difficulty(preview_candidates) - _mean_difficulty(
|
|
706
|
+
final_candidates
|
|
707
|
+
)
|
|
708
|
+
if abs(new_diff) >= abs(diff) - 1e-6:
|
|
709
|
+
# swap did not improve; revert and stop
|
|
710
|
+
preview_candidates.remove(final_candidate)
|
|
711
|
+
final_candidates.remove(preview_candidate)
|
|
712
|
+
preview_candidates.append(preview_candidate)
|
|
713
|
+
final_candidates.append(final_candidate)
|
|
714
|
+
break
|
|
715
|
+
|
|
716
|
+
if len(preview_candidates) != preview_n or len(final_candidates) != final_n:
|
|
717
|
+
raise _DataErr(
|
|
718
|
+
code="E305",
|
|
719
|
+
message=(
|
|
720
|
+
"STRATIFY-FAILED: failed to allocate preview/final windows with equal counts"
|
|
721
|
+
),
|
|
722
|
+
details={
|
|
723
|
+
"preview_target": int(preview_n),
|
|
724
|
+
"final_target": int(final_n),
|
|
725
|
+
"preview_got": int(len(preview_candidates)),
|
|
726
|
+
"final_got": int(len(final_candidates)),
|
|
727
|
+
},
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
preview_candidates.sort(
|
|
731
|
+
key=lambda item: (item["difficulty"], item["dataset_index"])
|
|
732
|
+
)
|
|
733
|
+
final_candidates.sort(
|
|
734
|
+
key=lambda item: (item["difficulty"], item["dataset_index"])
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
preview_window = EvaluationWindow(
|
|
738
|
+
input_ids=[c["input_ids"] for c in preview_candidates],
|
|
739
|
+
attention_masks=[c["attention_mask"] for c in preview_candidates],
|
|
740
|
+
indices=[c["dataset_index"] for c in preview_candidates],
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
final_window = EvaluationWindow(
|
|
744
|
+
input_ids=[c["input_ids"] for c in final_candidates],
|
|
745
|
+
attention_masks=[c["attention_mask"] for c in final_candidates],
|
|
746
|
+
indices=[c["dataset_index"] for c in final_candidates],
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
if len(preview_window) != preview_n or len(final_window) != final_n:
|
|
750
|
+
raise _DataErr(
|
|
751
|
+
code="E305",
|
|
752
|
+
message="STRATIFY-FAILED: window stratification mismatch",
|
|
753
|
+
details={
|
|
754
|
+
"preview_target": int(preview_n),
|
|
755
|
+
"final_target": int(final_n),
|
|
756
|
+
"preview_got": int(len(preview_window)),
|
|
757
|
+
"final_got": int(len(final_window)),
|
|
758
|
+
},
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
preview_difficulties = [c["difficulty"] for c in preview_candidates]
|
|
762
|
+
final_difficulties = [c["difficulty"] for c in final_candidates]
|
|
763
|
+
self._last_stratification_stats = {
|
|
764
|
+
"pool_size": len(selected_candidates),
|
|
765
|
+
"reserve": reserve,
|
|
766
|
+
"batch_size_used": int(self._last_batch_size_used),
|
|
767
|
+
"preview_mean_difficulty": float(np.mean(preview_difficulties))
|
|
768
|
+
if preview_difficulties
|
|
769
|
+
else 0.0,
|
|
770
|
+
"final_mean_difficulty": float(np.mean(final_difficulties))
|
|
771
|
+
if final_difficulties
|
|
772
|
+
else 0.0,
|
|
773
|
+
"preview_std_difficulty": float(np.std(preview_difficulties))
|
|
774
|
+
if preview_difficulties
|
|
775
|
+
else 0.0,
|
|
776
|
+
"final_std_difficulty": float(np.std(final_difficulties))
|
|
777
|
+
if final_difficulties
|
|
778
|
+
else 0.0,
|
|
779
|
+
"difficulty_gap": float(
|
|
780
|
+
(np.mean(final_difficulties) - np.mean(preview_difficulties))
|
|
781
|
+
if (preview_difficulties and final_difficulties)
|
|
782
|
+
else 0.0
|
|
783
|
+
),
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
print(f" Seed: {seed}, Seq length: {seq_len}")
|
|
787
|
+
print(f" Preview: {len(preview_window)} samples")
|
|
788
|
+
print(f" Final: {len(final_window)} samples")
|
|
789
|
+
|
|
790
|
+
return preview_window, final_window
|
|
791
|
+
|
|
792
|
+
def _collect_tokenized_samples(
|
|
793
|
+
self,
|
|
794
|
+
texts: Sequence[str],
|
|
795
|
+
indices: Sequence[int],
|
|
796
|
+
tokenizer: Any,
|
|
797
|
+
seq_len: int,
|
|
798
|
+
) -> list[tuple[int, list[int], list[int], int]]:
|
|
799
|
+
"""Tokenize samples and return raw sequences without logging."""
|
|
800
|
+
results: list[tuple[int, list[int], list[int], int]] = []
|
|
801
|
+
for idx in indices:
|
|
802
|
+
if idx >= len(texts):
|
|
803
|
+
continue
|
|
804
|
+
|
|
805
|
+
text = texts[idx]
|
|
806
|
+
|
|
807
|
+
try:
|
|
808
|
+
tokens = tokenizer(
|
|
809
|
+
text,
|
|
810
|
+
truncation=True,
|
|
811
|
+
padding="max_length",
|
|
812
|
+
max_length=seq_len,
|
|
813
|
+
return_tensors="pt" if HAS_TORCH else None,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
if HAS_TORCH and hasattr(tokens["input_ids"], "squeeze"):
|
|
817
|
+
input_ids = tokens["input_ids"].squeeze(0).tolist()
|
|
818
|
+
attention_mask = (
|
|
819
|
+
tokens.get(
|
|
820
|
+
"attention_mask", torch.ones_like(tokens["input_ids"])
|
|
821
|
+
)
|
|
822
|
+
.squeeze(0)
|
|
823
|
+
.tolist()
|
|
824
|
+
)
|
|
825
|
+
else:
|
|
826
|
+
input_ids = tokens["input_ids"]
|
|
827
|
+
attention_mask = tokens.get("attention_mask", [1] * len(input_ids))
|
|
828
|
+
|
|
829
|
+
real_tokens = int(sum(attention_mask))
|
|
830
|
+
if real_tokens > 1:
|
|
831
|
+
results.append(
|
|
832
|
+
(
|
|
833
|
+
idx,
|
|
834
|
+
[int(token) for token in input_ids],
|
|
835
|
+
[int(mask) for mask in attention_mask],
|
|
836
|
+
real_tokens,
|
|
837
|
+
)
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
except Exception as e:
|
|
841
|
+
warnings.warn(f"Failed to tokenize sample {idx}: {e}", stacklevel=2)
|
|
842
|
+
continue
|
|
843
|
+
|
|
844
|
+
return results
|
|
845
|
+
|
|
846
|
+
def _score_candidates_with_model(self, candidates: list[dict[str, Any]]) -> bool:
|
|
847
|
+
"""Score candidate windows using a pretrained GPT-2 model if available."""
|
|
848
|
+
if not HAS_TORCH:
|
|
849
|
+
return False
|
|
850
|
+
|
|
851
|
+
if self._difficulty_model is False:
|
|
852
|
+
return False
|
|
853
|
+
|
|
854
|
+
try:
|
|
855
|
+
eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
|
|
856
|
+
device_hint = getattr(self, "_device_hint", None)
|
|
857
|
+
|
|
858
|
+
if self._difficulty_model is None:
|
|
859
|
+
from transformers import GPT2LMHeadModel
|
|
860
|
+
|
|
861
|
+
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
|
862
|
+
model.eval()
|
|
863
|
+
# Decide initial scorer device: env override → provider hint → heuristic
|
|
864
|
+
if eval_device_override:
|
|
865
|
+
try:
|
|
866
|
+
device = torch.device(eval_device_override)
|
|
867
|
+
except Exception:
|
|
868
|
+
device = self._pick_default_scorer_device()
|
|
869
|
+
elif device_hint and device_hint != "auto":
|
|
870
|
+
try:
|
|
871
|
+
device = torch.device(device_hint)
|
|
872
|
+
except Exception:
|
|
873
|
+
device = self._pick_default_scorer_device()
|
|
874
|
+
else:
|
|
875
|
+
device = self._pick_default_scorer_device()
|
|
876
|
+
|
|
877
|
+
model.to(device)
|
|
878
|
+
self._difficulty_model = model
|
|
879
|
+
self._difficulty_device = device
|
|
880
|
+
self.__class__._MODEL_CACHE = model
|
|
881
|
+
self.__class__._MODEL_DEVICE = device
|
|
882
|
+
|
|
883
|
+
assert self._difficulty_model is not None
|
|
884
|
+
model = self._difficulty_model
|
|
885
|
+
device = self._difficulty_device or torch.device("cpu")
|
|
886
|
+
|
|
887
|
+
# If a new override/hint is provided, move the cached model if needed.
|
|
888
|
+
desired_device = device
|
|
889
|
+
if eval_device_override:
|
|
890
|
+
try:
|
|
891
|
+
desired_device = torch.device(eval_device_override)
|
|
892
|
+
except Exception:
|
|
893
|
+
desired_device = device
|
|
894
|
+
elif device_hint and device_hint != "auto":
|
|
895
|
+
try:
|
|
896
|
+
desired_device = torch.device(device_hint)
|
|
897
|
+
except Exception:
|
|
898
|
+
desired_device = device
|
|
899
|
+
|
|
900
|
+
if desired_device != device:
|
|
901
|
+
try:
|
|
902
|
+
model.to(desired_device)
|
|
903
|
+
device = desired_device
|
|
904
|
+
self._difficulty_device = desired_device
|
|
905
|
+
self.__class__._MODEL_DEVICE = desired_device
|
|
906
|
+
except Exception as exc:
|
|
907
|
+
warnings.warn(
|
|
908
|
+
f"Failed to move GPT-2 difficulty scorer to {desired_device}: {exc}",
|
|
909
|
+
stacklevel=2,
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
if not self._scorer_warmed:
|
|
913
|
+
with torch.no_grad():
|
|
914
|
+
dummy_input = torch.zeros((1, 8), dtype=torch.long, device=device)
|
|
915
|
+
dummy_attention = torch.ones_like(dummy_input)
|
|
916
|
+
model(dummy_input, attention_mask=dummy_attention)
|
|
917
|
+
self._scorer_warmed = True
|
|
918
|
+
|
|
919
|
+
batch_override = os.environ.get("INVARLOCK_SCORES_BATCH_SIZE")
|
|
920
|
+
override_size = None
|
|
921
|
+
if batch_override:
|
|
922
|
+
try:
|
|
923
|
+
override_size = max(1, int(batch_override))
|
|
924
|
+
except ValueError:
|
|
925
|
+
override_size = None
|
|
926
|
+
|
|
927
|
+
batch_size = min(32, max(4, len(candidates)))
|
|
928
|
+
if override_size is not None:
|
|
929
|
+
batch_size = max(1, min(override_size, len(candidates)))
|
|
930
|
+
|
|
931
|
+
input_batch: list[list[int]] = []
|
|
932
|
+
attention_batch: list[list[int]] = []
|
|
933
|
+
candidate_batch: list[dict[str, Any]] = []
|
|
934
|
+
total_tokens = 0
|
|
935
|
+
start_time = time.perf_counter()
|
|
936
|
+
|
|
937
|
+
with torch.no_grad():
|
|
938
|
+
for candidate in candidates:
|
|
939
|
+
input_batch.append(candidate["input_ids"])
|
|
940
|
+
attention_batch.append(candidate["attention_mask"])
|
|
941
|
+
candidate_batch.append(candidate)
|
|
942
|
+
|
|
943
|
+
if len(input_batch) == batch_size or candidate is candidates[-1]:
|
|
944
|
+
input_tensor = torch.tensor(
|
|
945
|
+
input_batch, dtype=torch.long, device=device
|
|
946
|
+
)
|
|
947
|
+
attention_tensor = torch.tensor(
|
|
948
|
+
attention_batch, dtype=torch.long, device=device
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
outputs = model(input_tensor, attention_mask=attention_tensor)
|
|
952
|
+
shift_logits = outputs.logits[:, :-1, :].contiguous()
|
|
953
|
+
shift_labels = input_tensor[:, 1:].contiguous()
|
|
954
|
+
shift_mask = attention_tensor[:, 1:].contiguous()
|
|
955
|
+
shift_labels = shift_labels.masked_fill(shift_mask == 0, 0)
|
|
956
|
+
|
|
957
|
+
vocab_size = shift_logits.size(-1)
|
|
958
|
+
losses = F.cross_entropy(
|
|
959
|
+
shift_logits.view(-1, vocab_size),
|
|
960
|
+
shift_labels.view(-1),
|
|
961
|
+
reduction="none",
|
|
962
|
+
)
|
|
963
|
+
losses = losses.view(shift_labels.size()) * shift_mask
|
|
964
|
+
token_counts = shift_mask.sum(dim=1).clamp(min=1)
|
|
965
|
+
loss_per_example = (
|
|
966
|
+
(losses.sum(dim=1) / token_counts).cpu().tolist()
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
for cand_obj, loss_value in zip(
|
|
970
|
+
candidate_batch, loss_per_example, strict=False
|
|
971
|
+
):
|
|
972
|
+
cand_obj["difficulty"] = float(loss_value)
|
|
973
|
+
total_tokens += int(token_counts.sum().item())
|
|
974
|
+
|
|
975
|
+
input_batch.clear()
|
|
976
|
+
attention_batch.clear()
|
|
977
|
+
candidate_batch.clear()
|
|
978
|
+
self._last_batch_size_used = batch_size
|
|
979
|
+
elapsed = max(time.perf_counter() - start_time, 1e-9)
|
|
980
|
+
tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
|
|
981
|
+
self._last_scorer_profile = {
|
|
982
|
+
"batch_size": batch_size,
|
|
983
|
+
"tokens_processed": total_tokens,
|
|
984
|
+
"elapsed_seconds": elapsed,
|
|
985
|
+
"tokens_per_second": tokens_per_sec,
|
|
986
|
+
}
|
|
987
|
+
return True
|
|
988
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
989
|
+
warnings.warn(
|
|
990
|
+
f"Failed to compute GPT-2 difficulty scores: {exc}", stacklevel=2
|
|
991
|
+
)
|
|
992
|
+
self._difficulty_model = False
|
|
993
|
+
self._difficulty_device = None
|
|
994
|
+
self.__class__._MODEL_CACHE = False
|
|
995
|
+
self.__class__._MODEL_DEVICE = None
|
|
996
|
+
self._last_batch_size_used = 0
|
|
997
|
+
self._last_scorer_profile = None
|
|
998
|
+
return False
|
|
999
|
+
|
|
1000
|
+
def _tokenize_samples(
|
|
1001
|
+
self,
|
|
1002
|
+
texts: list[str],
|
|
1003
|
+
indices: list[int],
|
|
1004
|
+
tokenizer: Any,
|
|
1005
|
+
seq_len: int,
|
|
1006
|
+
window_name: str,
|
|
1007
|
+
) -> EvaluationWindow:
|
|
1008
|
+
"""Tokenize a set of text samples with consistent parameters."""
|
|
1009
|
+
collected = self._collect_tokenized_samples(texts, indices, tokenizer, seq_len)
|
|
1010
|
+
|
|
1011
|
+
input_ids_list = [entry[1] for entry in collected]
|
|
1012
|
+
attention_masks_list = [entry[2] for entry in collected]
|
|
1013
|
+
valid_indices = [entry[0] for entry in collected]
|
|
1014
|
+
|
|
1015
|
+
print(
|
|
1016
|
+
f" ✓ {window_name}: {len(valid_indices)}/{len(indices)} samples tokenized successfully"
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
return EvaluationWindow(
|
|
1020
|
+
input_ids=input_ids_list,
|
|
1021
|
+
attention_masks=attention_masks_list,
|
|
1022
|
+
indices=valid_indices,
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
@property
|
|
1026
|
+
def stratification_stats(self) -> dict[str, Any] | None:
|
|
1027
|
+
"""Return summary statistics for the most recent stratified split."""
|
|
1028
|
+
return self._last_stratification_stats
|
|
1029
|
+
|
|
1030
|
+
@property
|
|
1031
|
+
def scorer_profile(self) -> dict[str, Any] | None:
|
|
1032
|
+
"""Return performance statistics for the most recent scorer run."""
|
|
1033
|
+
return self._last_scorer_profile
|
|
1034
|
+
|
|
1035
|
+
def info(self) -> dict[str, Any]:
|
|
1036
|
+
"""Get information about WikiText-2 provider."""
|
|
1037
|
+
return {
|
|
1038
|
+
"name": self.name,
|
|
1039
|
+
"type": "dataset_provider",
|
|
1040
|
+
"dataset": "wikitext-2-raw-v1",
|
|
1041
|
+
"source": "huggingface/datasets",
|
|
1042
|
+
"deterministic": True,
|
|
1043
|
+
"default_split": "validation",
|
|
1044
|
+
"requires": ["datasets"],
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
class SyntheticProvider:
|
|
1049
|
+
"""
|
|
1050
|
+
Synthetic text provider for testing and development.
|
|
1051
|
+
|
|
1052
|
+
Generates coherent text samples when WikiText-2 is not available.
|
|
1053
|
+
"""
|
|
1054
|
+
|
|
1055
|
+
name = "synthetic"
|
|
1056
|
+
|
|
1057
|
+
def __init__(self, base_samples: list[str] | None = None):
|
|
1058
|
+
"""Initialize with optional base text samples."""
|
|
1059
|
+
self.base_samples = base_samples or self._default_samples()
|
|
1060
|
+
|
|
1061
|
+
def _default_samples(self) -> list[str]:
|
|
1062
|
+
"""Generate default synthetic text samples."""
|
|
1063
|
+
return [
|
|
1064
|
+
"The weather today is quite pleasant with clear skies and gentle winds.",
|
|
1065
|
+
"Scientists have discovered a new species in the Amazon rainforest region.",
|
|
1066
|
+
"The stock market showed significant gains during this quarter's trading.",
|
|
1067
|
+
"Technology companies are investing heavily in artificial intelligence research.",
|
|
1068
|
+
"The new restaurant downtown serves excellent Mediterranean cuisine daily.",
|
|
1069
|
+
"Climate change continues to affect global weather patterns significantly.",
|
|
1070
|
+
"The university announced new programs in data science and engineering.",
|
|
1071
|
+
"Renewable energy sources are becoming more cost-effective than fossil fuels.",
|
|
1072
|
+
"The museum exhibition features artwork from the Renaissance period.",
|
|
1073
|
+
"Public transportation systems are being upgraded in major cities worldwide.",
|
|
1074
|
+
"Medical researchers published breakthrough findings about genetic therapy.",
|
|
1075
|
+
"The concert hall will host a performance by the symphony orchestra.",
|
|
1076
|
+
"Local farmers are adopting sustainable agricultural practices this season.",
|
|
1077
|
+
"The new software update includes enhanced security features and performance.",
|
|
1078
|
+
"International trade agreements are being renegotiated between countries.",
|
|
1079
|
+
]
|
|
1080
|
+
|
|
1081
|
+
def estimate_capacity(
|
|
1082
|
+
self,
|
|
1083
|
+
tokenizer: Any,
|
|
1084
|
+
*,
|
|
1085
|
+
seq_len: int,
|
|
1086
|
+
stride: int,
|
|
1087
|
+
split: str = "validation",
|
|
1088
|
+
target_total: int | None = None,
|
|
1089
|
+
fast_mode: bool = False,
|
|
1090
|
+
) -> dict[str, Any]:
|
|
1091
|
+
"""Synthetic provider offers deterministic capacity based on base samples."""
|
|
1092
|
+
total_tokens = len(self.base_samples) * seq_len
|
|
1093
|
+
available = len(self.base_samples)
|
|
1094
|
+
return {
|
|
1095
|
+
"total_tokens": int(total_tokens),
|
|
1096
|
+
"available_nonoverlap": int(available),
|
|
1097
|
+
"available_unique": int(available),
|
|
1098
|
+
"dedupe_rate": 0.0,
|
|
1099
|
+
"stride": int(stride),
|
|
1100
|
+
"seq_len": int(seq_len),
|
|
1101
|
+
"candidate_unique": int(available),
|
|
1102
|
+
"candidate_limit": int(available),
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
def load(
|
|
1106
|
+
self, split: str = "validation", max_samples: int = 500, **kwargs
|
|
1107
|
+
) -> list[str]:
|
|
1108
|
+
"""Generate synthetic text samples."""
|
|
1109
|
+
# Expand base samples to meet requirement
|
|
1110
|
+
expanded_samples: list[str] = []
|
|
1111
|
+
variations = [
|
|
1112
|
+
lambda s: s,
|
|
1113
|
+
lambda s: f"Recently, {s.lower()}",
|
|
1114
|
+
lambda s: f"According to reports, {s.lower()}",
|
|
1115
|
+
lambda s: f"It is notable that {s.lower()}",
|
|
1116
|
+
lambda s: f"Furthermore, {s.lower()}",
|
|
1117
|
+
lambda s: f"In addition, {s.lower()}",
|
|
1118
|
+
]
|
|
1119
|
+
|
|
1120
|
+
# Use a deterministic approach based on max_samples
|
|
1121
|
+
rng = np.random.RandomState(42) # Fixed seed for reproducibility
|
|
1122
|
+
|
|
1123
|
+
while len(expanded_samples) < max_samples:
|
|
1124
|
+
for base_text in self.base_samples:
|
|
1125
|
+
if len(expanded_samples) >= max_samples:
|
|
1126
|
+
break
|
|
1127
|
+
variation = rng.choice(variations)
|
|
1128
|
+
expanded_samples.append(variation(base_text))
|
|
1129
|
+
|
|
1130
|
+
return expanded_samples[:max_samples]
|
|
1131
|
+
|
|
1132
|
+
def windows(
|
|
1133
|
+
self,
|
|
1134
|
+
tokenizer: Any,
|
|
1135
|
+
*,
|
|
1136
|
+
seq_len: int = 128,
|
|
1137
|
+
stride: int = 64,
|
|
1138
|
+
preview_n: int = 100,
|
|
1139
|
+
final_n: int = 100,
|
|
1140
|
+
seed: int = 42,
|
|
1141
|
+
split: str = "validation",
|
|
1142
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
1143
|
+
"""Create synthetic evaluation windows."""
|
|
1144
|
+
texts = self.load(split=split, max_samples=preview_n + final_n)
|
|
1145
|
+
|
|
1146
|
+
# Deterministic split
|
|
1147
|
+
preview_texts = texts[:preview_n]
|
|
1148
|
+
final_texts = texts[preview_n : preview_n + final_n]
|
|
1149
|
+
|
|
1150
|
+
# Create windows (simplified tokenization)
|
|
1151
|
+
preview_window = self._simple_tokenize(
|
|
1152
|
+
preview_texts, tokenizer, seq_len, list(range(preview_n))
|
|
1153
|
+
)
|
|
1154
|
+
final_window = self._simple_tokenize(
|
|
1155
|
+
final_texts, tokenizer, seq_len, list(range(preview_n, preview_n + final_n))
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
return preview_window, final_window
|
|
1159
|
+
|
|
1160
|
+
def _simple_tokenize(
|
|
1161
|
+
self, texts: list[str], tokenizer: Any, seq_len: int, indices: list[int]
|
|
1162
|
+
) -> EvaluationWindow:
|
|
1163
|
+
"""Simple tokenization for synthetic samples."""
|
|
1164
|
+
input_ids_list = []
|
|
1165
|
+
attention_masks_list = []
|
|
1166
|
+
|
|
1167
|
+
for text in texts:
|
|
1168
|
+
# Simple tokenization fallback
|
|
1169
|
+
if hasattr(tokenizer, "encode"):
|
|
1170
|
+
input_ids = tokenizer.encode(
|
|
1171
|
+
text, max_length=seq_len, truncation=True, padding="max_length"
|
|
1172
|
+
)
|
|
1173
|
+
attention_mask = (
|
|
1174
|
+
[
|
|
1175
|
+
1 if token_id != tokenizer.pad_token_id else 0
|
|
1176
|
+
for token_id in input_ids
|
|
1177
|
+
]
|
|
1178
|
+
if hasattr(tokenizer, "pad_token_id")
|
|
1179
|
+
else [1] * len(input_ids)
|
|
1180
|
+
)
|
|
1181
|
+
else:
|
|
1182
|
+
# Fallback for test scenarios
|
|
1183
|
+
input_ids = list(range(1, min(seq_len + 1, 50))) + [0] * max(
|
|
1184
|
+
0, seq_len - 49
|
|
1185
|
+
)
|
|
1186
|
+
attention_mask = [1] * min(seq_len, 49) + [0] * max(0, seq_len - 49)
|
|
1187
|
+
|
|
1188
|
+
input_ids_list.append(input_ids)
|
|
1189
|
+
attention_masks_list.append(attention_mask)
|
|
1190
|
+
|
|
1191
|
+
return EvaluationWindow(
|
|
1192
|
+
input_ids=input_ids_list,
|
|
1193
|
+
attention_masks=attention_masks_list,
|
|
1194
|
+
indices=indices,
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
def info(self) -> dict[str, Any]:
|
|
1198
|
+
"""Get information about synthetic provider."""
|
|
1199
|
+
return {
|
|
1200
|
+
"name": self.name,
|
|
1201
|
+
"type": "dataset_provider",
|
|
1202
|
+
"dataset": "synthetic",
|
|
1203
|
+
"source": "generated",
|
|
1204
|
+
"deterministic": True,
|
|
1205
|
+
"base_samples": len(self.base_samples),
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
class HFTextProvider:
|
|
1210
|
+
"""
|
|
1211
|
+
Generic HuggingFace datasets text provider.
|
|
1212
|
+
|
|
1213
|
+
Loads a text dataset by name/config and extracts a specified text field.
|
|
1214
|
+
Provides simple deterministic windowing suitable for CI/demo usage.
|
|
1215
|
+
"""
|
|
1216
|
+
|
|
1217
|
+
name = "hf_text"
|
|
1218
|
+
|
|
1219
|
+
def __init__(
|
|
1220
|
+
self,
|
|
1221
|
+
dataset_name: str | None = None,
|
|
1222
|
+
config_name: str | None = None,
|
|
1223
|
+
text_field: str = "text",
|
|
1224
|
+
cache_dir: str | None = None,
|
|
1225
|
+
max_samples: int = 2000,
|
|
1226
|
+
):
|
|
1227
|
+
if not HAS_DATASETS:
|
|
1228
|
+
if not _LIGHT_IMPORT:
|
|
1229
|
+
raise _DepErr(
|
|
1230
|
+
code="E301",
|
|
1231
|
+
message=(
|
|
1232
|
+
"DEPENDENCY-MISSING: datasets library required for hf_text provider"
|
|
1233
|
+
),
|
|
1234
|
+
details={"dependency": "datasets"},
|
|
1235
|
+
)
|
|
1236
|
+
self.dataset_name = dataset_name or "wikitext"
|
|
1237
|
+
self.config_name = config_name or None
|
|
1238
|
+
self.text_field = text_field
|
|
1239
|
+
self.cache_dir = cache_dir
|
|
1240
|
+
self.max_samples = int(max_samples)
|
|
1241
|
+
|
|
1242
|
+
def load(self, split: str = "validation", **kwargs) -> list[str]:
|
|
1243
|
+
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
1244
|
+
return ["synthetic dataset text"] * int(self.max_samples or 1)
|
|
1245
|
+
|
|
1246
|
+
ds = load_dataset(
|
|
1247
|
+
path=self.dataset_name,
|
|
1248
|
+
name=self.config_name,
|
|
1249
|
+
split=split,
|
|
1250
|
+
cache_dir=self.cache_dir,
|
|
1251
|
+
)
|
|
1252
|
+
texts: list[str] = []
|
|
1253
|
+
# Limit to max_samples for CI friendliness
|
|
1254
|
+
count = 0
|
|
1255
|
+
for row in ds:
|
|
1256
|
+
if self.text_field not in row:
|
|
1257
|
+
continue
|
|
1258
|
+
val = row[self.text_field]
|
|
1259
|
+
if isinstance(val, str) and val.strip():
|
|
1260
|
+
texts.append(val)
|
|
1261
|
+
count += 1
|
|
1262
|
+
if count >= self.max_samples:
|
|
1263
|
+
break
|
|
1264
|
+
return texts
|
|
1265
|
+
|
|
1266
|
+
def _simple_tokenize(
|
|
1267
|
+
self, texts: list[str], tokenizer: Any, seq_len: int, indices: list[int]
|
|
1268
|
+
) -> EvaluationWindow:
|
|
1269
|
+
input_ids_list: list[list[int]] = []
|
|
1270
|
+
attention_masks_list: list[list[int]] = []
|
|
1271
|
+
for text in texts:
|
|
1272
|
+
try:
|
|
1273
|
+
if hasattr(tokenizer, "encode"):
|
|
1274
|
+
input_ids = tokenizer.encode(
|
|
1275
|
+
text, truncation=True, max_length=seq_len
|
|
1276
|
+
)
|
|
1277
|
+
else:
|
|
1278
|
+
encoded = tokenizer(text, truncation=True, max_length=seq_len)
|
|
1279
|
+
input_ids = encoded["input_ids"]
|
|
1280
|
+
# Pad if needed
|
|
1281
|
+
pad_id = getattr(tokenizer, "pad_token_id", 0)
|
|
1282
|
+
input_ids = (input_ids + [pad_id] * (seq_len - len(input_ids)))[
|
|
1283
|
+
:seq_len
|
|
1284
|
+
]
|
|
1285
|
+
attn = [1 if tid != pad_id else 0 for tid in input_ids]
|
|
1286
|
+
input_ids_list.append(input_ids)
|
|
1287
|
+
attention_masks_list.append(attn)
|
|
1288
|
+
except Exception:
|
|
1289
|
+
# Skip bad rows
|
|
1290
|
+
continue
|
|
1291
|
+
return EvaluationWindow(
|
|
1292
|
+
input_ids_list, attention_masks_list, indices[: len(input_ids_list)]
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
def windows(
|
|
1296
|
+
self,
|
|
1297
|
+
tokenizer: Any,
|
|
1298
|
+
*,
|
|
1299
|
+
seq_len: int = 128,
|
|
1300
|
+
stride: int = 64,
|
|
1301
|
+
preview_n: int = 100,
|
|
1302
|
+
final_n: int = 100,
|
|
1303
|
+
seed: int = 42,
|
|
1304
|
+
split: str = "validation",
|
|
1305
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
1306
|
+
texts = self.load(split=split)
|
|
1307
|
+
total = len(texts)
|
|
1308
|
+
if total == 0:
|
|
1309
|
+
# Typed-only: no-samples is a DataError for consistency
|
|
1310
|
+
raise _DataErr(
|
|
1311
|
+
code="E306",
|
|
1312
|
+
message=(
|
|
1313
|
+
"NO-SAMPLES: hf_text produced no samples; check dataset_name/config_name/text_field"
|
|
1314
|
+
),
|
|
1315
|
+
)
|
|
1316
|
+
# Deterministic selection: first N for preview, next N for final
|
|
1317
|
+
preview_texts = texts[:preview_n]
|
|
1318
|
+
final_texts = texts[preview_n : preview_n + final_n]
|
|
1319
|
+
preview_window = self._simple_tokenize(
|
|
1320
|
+
preview_texts, tokenizer, seq_len, list(range(preview_n))
|
|
1321
|
+
)
|
|
1322
|
+
final_window = self._simple_tokenize(
|
|
1323
|
+
final_texts, tokenizer, seq_len, list(range(preview_n, preview_n + final_n))
|
|
1324
|
+
)
|
|
1325
|
+
return preview_window, final_window
|
|
1326
|
+
|
|
1327
|
+
def estimate_capacity(
|
|
1328
|
+
self,
|
|
1329
|
+
tokenizer: Any,
|
|
1330
|
+
*,
|
|
1331
|
+
seq_len: int,
|
|
1332
|
+
stride: int,
|
|
1333
|
+
split: str = "validation",
|
|
1334
|
+
target_total: int | None = None,
|
|
1335
|
+
fast_mode: bool = False,
|
|
1336
|
+
) -> dict[str, Any]:
|
|
1337
|
+
texts = self.load(split=split)
|
|
1338
|
+
return {
|
|
1339
|
+
"total_tokens": 0,
|
|
1340
|
+
"available_nonoverlap": len(texts),
|
|
1341
|
+
"available_unique": len(texts),
|
|
1342
|
+
"dedupe_rate": 0.0,
|
|
1343
|
+
"stride": stride,
|
|
1344
|
+
"seq_len": seq_len,
|
|
1345
|
+
"candidate_unique": len(texts),
|
|
1346
|
+
"candidate_limit": min(len(texts), self.max_samples),
|
|
1347
|
+
}
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
class HFSeq2SeqProvider:
|
|
1351
|
+
"""HuggingFace seq2seq provider with paired source/target fields.
|
|
1352
|
+
|
|
1353
|
+
Loads a dataset with text pairs and exposes encoder input_ids/attention_masks.
|
|
1354
|
+
Decoder target token ids are exposed via last_preview_labels / last_final_labels
|
|
1355
|
+
for the runner to attach as labels.
|
|
1356
|
+
"""
|
|
1357
|
+
|
|
1358
|
+
name = "hf_seq2seq"
|
|
1359
|
+
|
|
1360
|
+
def __init__(
|
|
1361
|
+
self,
|
|
1362
|
+
dataset_name: str,
|
|
1363
|
+
config_name: str | None = None,
|
|
1364
|
+
src_field: str = "source",
|
|
1365
|
+
tgt_field: str = "target",
|
|
1366
|
+
cache_dir: str | None = None,
|
|
1367
|
+
max_samples: int = 2000,
|
|
1368
|
+
) -> None:
|
|
1369
|
+
if not HAS_DATASETS:
|
|
1370
|
+
if not _LIGHT_IMPORT:
|
|
1371
|
+
raise _DepErr(
|
|
1372
|
+
code="E301",
|
|
1373
|
+
message=(
|
|
1374
|
+
"DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
|
|
1375
|
+
),
|
|
1376
|
+
details={"dependency": "datasets"},
|
|
1377
|
+
)
|
|
1378
|
+
self.dataset_name = dataset_name
|
|
1379
|
+
self.config_name = config_name
|
|
1380
|
+
self.src_field = src_field
|
|
1381
|
+
self.tgt_field = tgt_field
|
|
1382
|
+
self.cache_dir = cache_dir
|
|
1383
|
+
self.max_samples = int(max_samples)
|
|
1384
|
+
self.last_preview_labels: list[list[int]] | None = None
|
|
1385
|
+
self.last_final_labels: list[list[int]] | None = None
|
|
1386
|
+
|
|
1387
|
+
def _load_pairs(self, split: str) -> list[tuple[str, str]]:
|
|
1388
|
+
ds = load_dataset(
|
|
1389
|
+
path=self.dataset_name,
|
|
1390
|
+
name=self.config_name,
|
|
1391
|
+
split=split,
|
|
1392
|
+
cache_dir=self.cache_dir,
|
|
1393
|
+
)
|
|
1394
|
+
out: list[tuple[str, str]] = []
|
|
1395
|
+
count = 0
|
|
1396
|
+
for row in ds:
|
|
1397
|
+
src = row.get(self.src_field)
|
|
1398
|
+
tgt = row.get(self.tgt_field)
|
|
1399
|
+
if (
|
|
1400
|
+
isinstance(src, str)
|
|
1401
|
+
and src.strip()
|
|
1402
|
+
and isinstance(tgt, str)
|
|
1403
|
+
and tgt.strip()
|
|
1404
|
+
):
|
|
1405
|
+
out.append((src, tgt))
|
|
1406
|
+
count += 1
|
|
1407
|
+
if count >= self.max_samples:
|
|
1408
|
+
break
|
|
1409
|
+
return out
|
|
1410
|
+
|
|
1411
|
+
def windows(
|
|
1412
|
+
self,
|
|
1413
|
+
tokenizer: Any,
|
|
1414
|
+
*,
|
|
1415
|
+
seq_len: int = 128,
|
|
1416
|
+
stride: int = 64,
|
|
1417
|
+
preview_n: int = 100,
|
|
1418
|
+
final_n: int = 100,
|
|
1419
|
+
seed: int = 42,
|
|
1420
|
+
split: str = "validation",
|
|
1421
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
1422
|
+
pairs = self._load_pairs(split)
|
|
1423
|
+
if not pairs:
|
|
1424
|
+
raise _DataErr(
|
|
1425
|
+
code="E307",
|
|
1426
|
+
message=(
|
|
1427
|
+
"NO-PAIRS: hf_seq2seq produced no pairs; check src_field/tgt_field"
|
|
1428
|
+
),
|
|
1429
|
+
)
|
|
1430
|
+
# Deterministic slicing
|
|
1431
|
+
prev_pairs = pairs[:preview_n]
|
|
1432
|
+
fin_pairs = pairs[preview_n : preview_n + final_n]
|
|
1433
|
+
|
|
1434
|
+
def _tok_src(src: str) -> list[int]:
|
|
1435
|
+
ids = (
|
|
1436
|
+
tokenizer.encode(src, truncation=True, max_length=seq_len)
|
|
1437
|
+
if hasattr(tokenizer, "encode")
|
|
1438
|
+
else tokenizer(src, truncation=True, max_length=seq_len)["input_ids"]
|
|
1439
|
+
)
|
|
1440
|
+
pad_id = getattr(tokenizer, "pad_token_id", 0)
|
|
1441
|
+
return (ids + [pad_id] * (seq_len - len(ids)))[:seq_len]
|
|
1442
|
+
|
|
1443
|
+
def _tok_tgt(tgt: str) -> list[int]:
|
|
1444
|
+
ids = (
|
|
1445
|
+
tokenizer.encode(tgt, truncation=True, max_length=seq_len)
|
|
1446
|
+
if hasattr(tokenizer, "encode")
|
|
1447
|
+
else tokenizer(tgt, truncation=True, max_length=seq_len)["input_ids"]
|
|
1448
|
+
)
|
|
1449
|
+
# Use -100 for ignored positions to align with HF loss expectations
|
|
1450
|
+
return (ids + [-100] * (seq_len - len(ids)))[:seq_len]
|
|
1451
|
+
|
|
1452
|
+
prev_ids = [_tok_src(s) for s, _ in prev_pairs]
|
|
1453
|
+
prev_masks = [
|
|
1454
|
+
[1 if t != getattr(tokenizer, "pad_token_id", 0) else 0 for t in seq]
|
|
1455
|
+
for seq in prev_ids
|
|
1456
|
+
]
|
|
1457
|
+
fin_ids = [_tok_src(s) for s, _ in fin_pairs]
|
|
1458
|
+
fin_masks = [
|
|
1459
|
+
[1 if t != getattr(tokenizer, "pad_token_id", 0) else 0 for t in seq]
|
|
1460
|
+
for seq in fin_ids
|
|
1461
|
+
]
|
|
1462
|
+
|
|
1463
|
+
# Prepare labels
|
|
1464
|
+
self.last_preview_labels = [_tok_tgt(t) for _, t in prev_pairs]
|
|
1465
|
+
self.last_final_labels = [_tok_tgt(t) for _, t in fin_pairs]
|
|
1466
|
+
|
|
1467
|
+
preview_window = EvaluationWindow(
|
|
1468
|
+
prev_ids, prev_masks, list(range(len(prev_ids)))
|
|
1469
|
+
)
|
|
1470
|
+
final_window = EvaluationWindow(
|
|
1471
|
+
fin_ids, fin_masks, list(range(preview_n, preview_n + len(fin_ids)))
|
|
1472
|
+
)
|
|
1473
|
+
return preview_window, final_window
|
|
1474
|
+
|
|
1475
|
+
def estimate_capacity(
|
|
1476
|
+
self,
|
|
1477
|
+
tokenizer: Any,
|
|
1478
|
+
*,
|
|
1479
|
+
seq_len: int,
|
|
1480
|
+
stride: int,
|
|
1481
|
+
split: str = "validation",
|
|
1482
|
+
target_total: int | None = None,
|
|
1483
|
+
fast_mode: bool = False,
|
|
1484
|
+
) -> dict[str, Any]:
|
|
1485
|
+
pairs = self._load_pairs(split)
|
|
1486
|
+
n = len(pairs)
|
|
1487
|
+
return {
|
|
1488
|
+
"total_tokens": int(n * seq_len),
|
|
1489
|
+
"available_nonoverlap": n,
|
|
1490
|
+
"available_unique": n,
|
|
1491
|
+
"dedupe_rate": 0.0,
|
|
1492
|
+
"stride": stride,
|
|
1493
|
+
"seq_len": seq_len,
|
|
1494
|
+
"candidate_unique": n,
|
|
1495
|
+
"candidate_limit": n,
|
|
1496
|
+
"tokens_available": int(n * seq_len),
|
|
1497
|
+
"examples_available": n,
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
class LocalJSONLProvider:
|
|
1502
|
+
"""
|
|
1503
|
+
Local JSONL provider for BYOD text datasets.
|
|
1504
|
+
|
|
1505
|
+
Accepts a single `file`, a `path` (file or directory), or `data_files`
|
|
1506
|
+
(glob or list of paths). Extracts a `text_field` (defaults to "text").
|
|
1507
|
+
"""
|
|
1508
|
+
|
|
1509
|
+
name = "local_jsonl"
|
|
1510
|
+
|
|
1511
|
+
def __init__(
|
|
1512
|
+
self,
|
|
1513
|
+
file: str | None = None,
|
|
1514
|
+
path: str | None = None,
|
|
1515
|
+
data_files: str | list[str] | None = None,
|
|
1516
|
+
text_field: str = "text",
|
|
1517
|
+
max_samples: int = 2000,
|
|
1518
|
+
) -> None:
|
|
1519
|
+
self.file = file
|
|
1520
|
+
self.path = path
|
|
1521
|
+
self.data_files = data_files
|
|
1522
|
+
self.text_field = text_field or "text"
|
|
1523
|
+
self.max_samples = int(max_samples)
|
|
1524
|
+
|
|
1525
|
+
def _resolve_files(self) -> list[Path]:
|
|
1526
|
+
files: list[Path] = []
|
|
1527
|
+
# Explicit file
|
|
1528
|
+
if isinstance(self.file, str) and self.file:
|
|
1529
|
+
p = Path(self.file)
|
|
1530
|
+
if p.exists() and p.is_file():
|
|
1531
|
+
files.append(p)
|
|
1532
|
+
# Path can be file or directory
|
|
1533
|
+
if isinstance(self.path, str) and self.path:
|
|
1534
|
+
p = Path(self.path)
|
|
1535
|
+
if p.is_file():
|
|
1536
|
+
files.append(p)
|
|
1537
|
+
elif p.is_dir():
|
|
1538
|
+
files.extend(sorted(p.glob("*.jsonl")))
|
|
1539
|
+
# data_files may be a glob or list
|
|
1540
|
+
if isinstance(self.data_files, str) and self.data_files:
|
|
1541
|
+
from glob import glob as _glob
|
|
1542
|
+
|
|
1543
|
+
files.extend(Path(p) for p in _glob(self.data_files))
|
|
1544
|
+
elif isinstance(self.data_files, list):
|
|
1545
|
+
for item in self.data_files:
|
|
1546
|
+
try:
|
|
1547
|
+
pp = Path(str(item))
|
|
1548
|
+
if pp.exists() and pp.is_file():
|
|
1549
|
+
files.append(pp)
|
|
1550
|
+
except Exception:
|
|
1551
|
+
continue
|
|
1552
|
+
# Deduplicate while preserving order
|
|
1553
|
+
seen: set[str] = set()
|
|
1554
|
+
uniq: list[Path] = []
|
|
1555
|
+
for f in files:
|
|
1556
|
+
fp = f.resolve().as_posix()
|
|
1557
|
+
if fp not in seen:
|
|
1558
|
+
seen.add(fp)
|
|
1559
|
+
uniq.append(f)
|
|
1560
|
+
return uniq
|
|
1561
|
+
|
|
1562
|
+
def load(self, split: str = "validation", **kwargs) -> list[str]:
|
|
1563
|
+
texts: list[str] = []
|
|
1564
|
+
count = 0
|
|
1565
|
+
for fp in self._resolve_files():
|
|
1566
|
+
try:
|
|
1567
|
+
with fp.open("r", encoding="utf-8") as handle:
|
|
1568
|
+
for line in handle:
|
|
1569
|
+
line = line.strip()
|
|
1570
|
+
if not line:
|
|
1571
|
+
continue
|
|
1572
|
+
try:
|
|
1573
|
+
obj = json.loads(line)
|
|
1574
|
+
except Exception:
|
|
1575
|
+
continue
|
|
1576
|
+
val = obj.get(self.text_field)
|
|
1577
|
+
if isinstance(val, str) and val.strip():
|
|
1578
|
+
texts.append(val)
|
|
1579
|
+
count += 1
|
|
1580
|
+
if count >= self.max_samples:
|
|
1581
|
+
return texts
|
|
1582
|
+
except Exception:
|
|
1583
|
+
continue
|
|
1584
|
+
return texts
|
|
1585
|
+
|
|
1586
|
+
def _simple_tokenize(
|
|
1587
|
+
self, texts: list[str], tokenizer: Any, seq_len: int, indices: list[int]
|
|
1588
|
+
) -> EvaluationWindow:
|
|
1589
|
+
input_ids_list: list[list[int]] = []
|
|
1590
|
+
attention_masks_list: list[list[int]] = []
|
|
1591
|
+
for text in texts:
|
|
1592
|
+
try:
|
|
1593
|
+
if hasattr(tokenizer, "encode"):
|
|
1594
|
+
input_ids = tokenizer.encode(
|
|
1595
|
+
text, truncation=True, max_length=seq_len
|
|
1596
|
+
)
|
|
1597
|
+
else:
|
|
1598
|
+
encoded = tokenizer(text, truncation=True, max_length=seq_len)
|
|
1599
|
+
input_ids = encoded["input_ids"]
|
|
1600
|
+
pad_id = getattr(tokenizer, "pad_token_id", 0)
|
|
1601
|
+
input_ids = (input_ids + [pad_id] * (seq_len - len(input_ids)))[
|
|
1602
|
+
:seq_len
|
|
1603
|
+
]
|
|
1604
|
+
attn = [1 if tid != pad_id else 0 for tid in input_ids]
|
|
1605
|
+
input_ids_list.append(input_ids)
|
|
1606
|
+
attention_masks_list.append(attn)
|
|
1607
|
+
except Exception:
|
|
1608
|
+
continue
|
|
1609
|
+
return EvaluationWindow(
|
|
1610
|
+
input_ids_list, attention_masks_list, indices[: len(input_ids_list)]
|
|
1611
|
+
)
|
|
1612
|
+
|
|
1613
|
+
def windows(
|
|
1614
|
+
self,
|
|
1615
|
+
tokenizer: Any,
|
|
1616
|
+
*,
|
|
1617
|
+
seq_len: int = 128,
|
|
1618
|
+
stride: int = 64,
|
|
1619
|
+
preview_n: int = 100,
|
|
1620
|
+
final_n: int = 100,
|
|
1621
|
+
seed: int = 42,
|
|
1622
|
+
split: str = "validation",
|
|
1623
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
1624
|
+
texts = self.load(split=split)
|
|
1625
|
+
if not texts:
|
|
1626
|
+
raise _DataErr(
|
|
1627
|
+
code="E306",
|
|
1628
|
+
message=(
|
|
1629
|
+
"NO-SAMPLES: local_jsonl produced no samples; check file/path/data_files"
|
|
1630
|
+
),
|
|
1631
|
+
)
|
|
1632
|
+
preview_texts = texts[:preview_n]
|
|
1633
|
+
final_texts = texts[preview_n : preview_n + final_n]
|
|
1634
|
+
preview_window = self._simple_tokenize(
|
|
1635
|
+
preview_texts, tokenizer, seq_len, list(range(preview_n))
|
|
1636
|
+
)
|
|
1637
|
+
final_window = self._simple_tokenize(
|
|
1638
|
+
final_texts,
|
|
1639
|
+
tokenizer,
|
|
1640
|
+
seq_len,
|
|
1641
|
+
list(range(preview_n, preview_n + final_n)),
|
|
1642
|
+
)
|
|
1643
|
+
return preview_window, final_window
|
|
1644
|
+
|
|
1645
|
+
def estimate_capacity(
|
|
1646
|
+
self,
|
|
1647
|
+
tokenizer: Any,
|
|
1648
|
+
*,
|
|
1649
|
+
seq_len: int,
|
|
1650
|
+
stride: int,
|
|
1651
|
+
split: str = "validation",
|
|
1652
|
+
target_total: int | None = None,
|
|
1653
|
+
fast_mode: bool = False,
|
|
1654
|
+
) -> dict[str, Any]:
|
|
1655
|
+
texts = self.load(split=split)
|
|
1656
|
+
return {
|
|
1657
|
+
"total_tokens": 0,
|
|
1658
|
+
"available_nonoverlap": len(texts),
|
|
1659
|
+
"available_unique": len(texts),
|
|
1660
|
+
"dedupe_rate": 0.0,
|
|
1661
|
+
"stride": stride,
|
|
1662
|
+
"seq_len": seq_len,
|
|
1663
|
+
"candidate_unique": len(texts),
|
|
1664
|
+
"candidate_limit": len(texts),
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
|
|
1668
|
+
class LocalJSONLPairsProvider:
|
|
1669
|
+
"""Local JSONL pairs provider with source/target fields.
|
|
1670
|
+
|
|
1671
|
+
Accepts a single `file`, a `path` (file or directory), or `data_files`
|
|
1672
|
+
(glob or list of paths). Extracts paired strings from `src_field`/`tgt_field`.
|
|
1673
|
+
"""
|
|
1674
|
+
|
|
1675
|
+
name = "local_jsonl_pairs"
|
|
1676
|
+
|
|
1677
|
+
def __init__(
|
|
1678
|
+
self,
|
|
1679
|
+
file: str | None = None,
|
|
1680
|
+
path: str | None = None,
|
|
1681
|
+
data_files: str | list[str] | None = None,
|
|
1682
|
+
src_field: str = "source",
|
|
1683
|
+
tgt_field: str = "target",
|
|
1684
|
+
max_samples: int = 2000,
|
|
1685
|
+
) -> None:
|
|
1686
|
+
self.file = file
|
|
1687
|
+
self.path = path
|
|
1688
|
+
self.data_files = data_files
|
|
1689
|
+
self.src_field = src_field or "source"
|
|
1690
|
+
self.tgt_field = tgt_field or "target"
|
|
1691
|
+
self.max_samples = int(max_samples)
|
|
1692
|
+
self.last_preview_labels: list[list[int]] | None = None
|
|
1693
|
+
self.last_final_labels: list[list[int]] | None = None
|
|
1694
|
+
|
|
1695
|
+
def _resolve_files(self) -> list[Path]:
|
|
1696
|
+
files: list[Path] = []
|
|
1697
|
+
if isinstance(self.file, str) and self.file:
|
|
1698
|
+
p = Path(self.file)
|
|
1699
|
+
if p.exists() and p.is_file():
|
|
1700
|
+
files.append(p)
|
|
1701
|
+
if isinstance(self.path, str) and self.path:
|
|
1702
|
+
p = Path(self.path)
|
|
1703
|
+
if p.is_file():
|
|
1704
|
+
files.append(p)
|
|
1705
|
+
elif p.is_dir():
|
|
1706
|
+
files.extend(sorted(p.glob("*.jsonl")))
|
|
1707
|
+
if isinstance(self.data_files, str) and self.data_files:
|
|
1708
|
+
from glob import glob as _glob
|
|
1709
|
+
|
|
1710
|
+
files.extend(Path(p) for p in _glob(self.data_files))
|
|
1711
|
+
elif isinstance(self.data_files, list):
|
|
1712
|
+
for item in self.data_files:
|
|
1713
|
+
try:
|
|
1714
|
+
pp = Path(str(item))
|
|
1715
|
+
if pp.exists() and pp.is_file():
|
|
1716
|
+
files.append(pp)
|
|
1717
|
+
except Exception:
|
|
1718
|
+
continue
|
|
1719
|
+
# Deduplicate
|
|
1720
|
+
seen: set[str] = set()
|
|
1721
|
+
uniq: list[Path] = []
|
|
1722
|
+
for f in files:
|
|
1723
|
+
fp = f.resolve().as_posix()
|
|
1724
|
+
if fp not in seen:
|
|
1725
|
+
seen.add(fp)
|
|
1726
|
+
uniq.append(f)
|
|
1727
|
+
return uniq
|
|
1728
|
+
|
|
1729
|
+
def _load_pairs(self) -> list[tuple[str, str]]:
|
|
1730
|
+
pairs: list[tuple[str, str]] = []
|
|
1731
|
+
count = 0
|
|
1732
|
+
for fp in self._resolve_files():
|
|
1733
|
+
try:
|
|
1734
|
+
with fp.open("r", encoding="utf-8") as handle:
|
|
1735
|
+
for line in handle:
|
|
1736
|
+
line = line.strip()
|
|
1737
|
+
if not line:
|
|
1738
|
+
continue
|
|
1739
|
+
try:
|
|
1740
|
+
obj = json.loads(line)
|
|
1741
|
+
except Exception:
|
|
1742
|
+
continue
|
|
1743
|
+
src = obj.get(self.src_field)
|
|
1744
|
+
tgt = obj.get(self.tgt_field)
|
|
1745
|
+
if (
|
|
1746
|
+
isinstance(src, str)
|
|
1747
|
+
and src.strip()
|
|
1748
|
+
and isinstance(tgt, str)
|
|
1749
|
+
and tgt.strip()
|
|
1750
|
+
):
|
|
1751
|
+
pairs.append((src, tgt))
|
|
1752
|
+
count += 1
|
|
1753
|
+
if count >= self.max_samples:
|
|
1754
|
+
return pairs
|
|
1755
|
+
except Exception:
|
|
1756
|
+
continue
|
|
1757
|
+
return pairs
|
|
1758
|
+
|
|
1759
|
+
def windows(
|
|
1760
|
+
self,
|
|
1761
|
+
tokenizer: Any,
|
|
1762
|
+
*,
|
|
1763
|
+
seq_len: int = 128,
|
|
1764
|
+
stride: int = 64,
|
|
1765
|
+
preview_n: int = 100,
|
|
1766
|
+
final_n: int = 100,
|
|
1767
|
+
seed: int = 42,
|
|
1768
|
+
split: str = "validation",
|
|
1769
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
1770
|
+
pairs = self._load_pairs()
|
|
1771
|
+
if not pairs:
|
|
1772
|
+
raise ValueError(
|
|
1773
|
+
"local_jsonl_pairs produced no pairs; check src_field/tgt_field and files"
|
|
1774
|
+
)
|
|
1775
|
+
prev_pairs = pairs[:preview_n]
|
|
1776
|
+
fin_pairs = pairs[preview_n : preview_n + final_n]
|
|
1777
|
+
|
|
1778
|
+
pad_id = getattr(tokenizer, "pad_token_id", 0)
|
|
1779
|
+
|
|
1780
|
+
def _tok_src(src: str) -> list[int]:
|
|
1781
|
+
ids = (
|
|
1782
|
+
tokenizer.encode(src, truncation=True, max_length=seq_len)
|
|
1783
|
+
if hasattr(tokenizer, "encode")
|
|
1784
|
+
else tokenizer(src, truncation=True, max_length=seq_len)["input_ids"]
|
|
1785
|
+
)
|
|
1786
|
+
return (ids + [pad_id] * (seq_len - len(ids)))[:seq_len]
|
|
1787
|
+
|
|
1788
|
+
def _tok_tgt(tgt: str) -> list[int]:
|
|
1789
|
+
ids = (
|
|
1790
|
+
tokenizer.encode(tgt, truncation=True, max_length=seq_len)
|
|
1791
|
+
if hasattr(tokenizer, "encode")
|
|
1792
|
+
else tokenizer(tgt, truncation=True, max_length=seq_len)["input_ids"]
|
|
1793
|
+
)
|
|
1794
|
+
return (ids + [-100] * (seq_len - len(ids)))[:seq_len]
|
|
1795
|
+
|
|
1796
|
+
prev_ids = [_tok_src(s) for s, _ in prev_pairs]
|
|
1797
|
+
fin_ids = [_tok_src(s) for s, _ in fin_pairs]
|
|
1798
|
+
prev_masks = [[1 if t != pad_id else 0 for t in seq] for seq in prev_ids]
|
|
1799
|
+
fin_masks = [[1 if t != pad_id else 0 for t in seq] for seq in fin_ids]
|
|
1800
|
+
self.last_preview_labels = [_tok_tgt(t) for _, t in prev_pairs]
|
|
1801
|
+
self.last_final_labels = [_tok_tgt(t) for _, t in fin_pairs]
|
|
1802
|
+
|
|
1803
|
+
preview_window = EvaluationWindow(
|
|
1804
|
+
prev_ids, prev_masks, list(range(len(prev_ids)))
|
|
1805
|
+
)
|
|
1806
|
+
final_window = EvaluationWindow(
|
|
1807
|
+
fin_ids, fin_masks, list(range(preview_n, preview_n + len(fin_ids)))
|
|
1808
|
+
)
|
|
1809
|
+
return preview_window, final_window
|
|
1810
|
+
|
|
1811
|
+
def estimate_capacity(
|
|
1812
|
+
self,
|
|
1813
|
+
tokenizer: Any,
|
|
1814
|
+
*,
|
|
1815
|
+
seq_len: int,
|
|
1816
|
+
stride: int,
|
|
1817
|
+
split: str = "validation",
|
|
1818
|
+
target_total: int | None = None,
|
|
1819
|
+
fast_mode: bool = False,
|
|
1820
|
+
) -> dict[str, Any]:
|
|
1821
|
+
pairs = self._load_pairs()
|
|
1822
|
+
n = len(pairs)
|
|
1823
|
+
return {
|
|
1824
|
+
"total_tokens": int(n * seq_len),
|
|
1825
|
+
"available_nonoverlap": n,
|
|
1826
|
+
"available_unique": n,
|
|
1827
|
+
"dedupe_rate": 0.0,
|
|
1828
|
+
"stride": stride,
|
|
1829
|
+
"seq_len": seq_len,
|
|
1830
|
+
"candidate_unique": n,
|
|
1831
|
+
"candidate_limit": n,
|
|
1832
|
+
"tokens_available": int(n * seq_len),
|
|
1833
|
+
"examples_available": n,
|
|
1834
|
+
}
|
|
1835
|
+
|
|
1836
|
+
# (text-only helpers removed; LocalJSONLProvider implements text tokenization)
|
|
1837
|
+
|
|
1838
|
+
|
|
1839
|
+
class Seq2SeqDataProvider:
|
|
1840
|
+
"""Synthetic seq2seq provider wrapper to fit DatasetProvider interface.
|
|
1841
|
+
|
|
1842
|
+
Bridges invarlock.eval.providers.seq2seq.Seq2SeqProvider to the windowing
|
|
1843
|
+
protocol used by the CLI runner. Generates encoder input_ids from src_ids,
|
|
1844
|
+
attention_masks from src_mask, and allows the runner to derive labels.
|
|
1845
|
+
"""
|
|
1846
|
+
|
|
1847
|
+
name = "seq2seq"
|
|
1848
|
+
|
|
1849
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
1850
|
+
# Pass through kwargs to underlying provider (n, src_len, tgt_len, pad_id, bos_id, eos_id)
|
|
1851
|
+
from invarlock.eval.providers.seq2seq import Seq2SeqProvider as _S2S
|
|
1852
|
+
|
|
1853
|
+
self._inner = _S2S(**kwargs)
|
|
1854
|
+
self.last_preview_labels: list[list[int]] | None = None
|
|
1855
|
+
self.last_final_labels: list[list[int]] | None = None
|
|
1856
|
+
|
|
1857
|
+
def load(
|
|
1858
|
+
self, split: str = "validation", **kwargs
|
|
1859
|
+
) -> list[str]: # pragma: no cover - not used
|
|
1860
|
+
return []
|
|
1861
|
+
|
|
1862
|
+
def windows(
|
|
1863
|
+
self,
|
|
1864
|
+
tokenizer: Any,
|
|
1865
|
+
*,
|
|
1866
|
+
seq_len: int = 128,
|
|
1867
|
+
stride: int = 64,
|
|
1868
|
+
preview_n: int = 100,
|
|
1869
|
+
final_n: int = 100,
|
|
1870
|
+
seed: int = 42,
|
|
1871
|
+
split: str = "validation",
|
|
1872
|
+
) -> tuple[EvaluationWindow, EvaluationWindow]:
|
|
1873
|
+
# Generate exactly preview_n + final_n examples deterministically
|
|
1874
|
+
total = max(0, int(preview_n) + int(final_n))
|
|
1875
|
+
if total <= 0:
|
|
1876
|
+
total = 1
|
|
1877
|
+
# Build batches of size total
|
|
1878
|
+
# Ensure the inner generator produces at least `total` examples
|
|
1879
|
+
try:
|
|
1880
|
+
# Prefer reconfiguring 'n' if attribute present
|
|
1881
|
+
if getattr(self._inner, "_n", 0) < total:
|
|
1882
|
+
self._inner._n = int(total)
|
|
1883
|
+
except Exception:
|
|
1884
|
+
pass
|
|
1885
|
+
batches = list(self._inner.batches(seed=seed, batch_size=total))
|
|
1886
|
+
if not batches:
|
|
1887
|
+
raise ValueError("seq2seq provider produced no examples")
|
|
1888
|
+
batch = batches[0]
|
|
1889
|
+
# Extract source tokens/masks and target ids for labels
|
|
1890
|
+
src_ids_list = [list(x) for x in batch.get("src_ids", [])][:total]
|
|
1891
|
+
src_mask_list = [list(x) for x in batch.get("src_mask", [])][:total]
|
|
1892
|
+
tgt_ids_list = [list(x) for x in batch.get("tgt_ids", [])][:total]
|
|
1893
|
+
# Right-pad/truncate to seq_len for runner compatibility
|
|
1894
|
+
pad_id = getattr(tokenizer, "pad_token_id", 0)
|
|
1895
|
+
|
|
1896
|
+
def _pad(seq: list[int]) -> list[int]:
|
|
1897
|
+
if len(seq) < seq_len:
|
|
1898
|
+
return (seq + [pad_id] * (seq_len - len(seq)))[:seq_len]
|
|
1899
|
+
return seq[:seq_len]
|
|
1900
|
+
|
|
1901
|
+
input_ids = [_pad(s) for s in src_ids_list]
|
|
1902
|
+
attention_masks = []
|
|
1903
|
+
for i, s in enumerate(input_ids):
|
|
1904
|
+
# Prefer src_mask if lengths align; otherwise infer from pad_id
|
|
1905
|
+
if i < len(src_mask_list) and len(src_mask_list[i]) == len(src_ids_list[i]):
|
|
1906
|
+
# Adjust length to seq_len
|
|
1907
|
+
m = src_mask_list[i]
|
|
1908
|
+
if len(m) < seq_len:
|
|
1909
|
+
m = m + [0] * (seq_len - len(m))
|
|
1910
|
+
attention_masks.append([int(v) for v in m[:seq_len]])
|
|
1911
|
+
else:
|
|
1912
|
+
attention_masks.append([1 if t != pad_id else 0 for t in s])
|
|
1913
|
+
|
|
1914
|
+
# Split into preview/final windows
|
|
1915
|
+
prev_ids = input_ids[:preview_n]
|
|
1916
|
+
prev_mask = attention_masks[:preview_n]
|
|
1917
|
+
fin_ids = input_ids[preview_n : preview_n + final_n]
|
|
1918
|
+
fin_mask = attention_masks[preview_n : preview_n + final_n]
|
|
1919
|
+
|
|
1920
|
+
# Prepare label sequences (decoder targets) padded to seq_len
|
|
1921
|
+
def _pad_label(seq: list[int]) -> list[int]:
|
|
1922
|
+
if len(seq) < seq_len:
|
|
1923
|
+
return (seq + [-100] * (seq_len - len(seq)))[:seq_len]
|
|
1924
|
+
return seq[:seq_len]
|
|
1925
|
+
|
|
1926
|
+
prev_labels = [_pad_label(s) for s in tgt_ids_list[:preview_n]]
|
|
1927
|
+
fin_labels = [
|
|
1928
|
+
_pad_label(s) for s in tgt_ids_list[preview_n : preview_n + final_n]
|
|
1929
|
+
]
|
|
1930
|
+
# Save for runner to attach
|
|
1931
|
+
self.last_preview_labels = prev_labels
|
|
1932
|
+
self.last_final_labels = fin_labels
|
|
1933
|
+
|
|
1934
|
+
preview_window = EvaluationWindow(prev_ids, prev_mask, list(range(preview_n)))
|
|
1935
|
+
final_window = EvaluationWindow(
|
|
1936
|
+
fin_ids, fin_mask, list(range(preview_n, preview_n + final_n))
|
|
1937
|
+
)
|
|
1938
|
+
return preview_window, final_window
|
|
1939
|
+
|
|
1940
|
+
def estimate_capacity(
|
|
1941
|
+
self,
|
|
1942
|
+
tokenizer: Any,
|
|
1943
|
+
*,
|
|
1944
|
+
seq_len: int,
|
|
1945
|
+
stride: int,
|
|
1946
|
+
split: str = "validation",
|
|
1947
|
+
target_total: int | None = None,
|
|
1948
|
+
fast_mode: bool = False,
|
|
1949
|
+
) -> dict[str, Any]:
|
|
1950
|
+
# Deterministic bounded synthetic examples; assume large enough for CI/release smokes
|
|
1951
|
+
n = int(target_total or 800)
|
|
1952
|
+
return {
|
|
1953
|
+
"total_tokens": int(n * seq_len),
|
|
1954
|
+
"available_nonoverlap": n,
|
|
1955
|
+
"available_unique": n,
|
|
1956
|
+
"dedupe_rate": 0.0,
|
|
1957
|
+
"stride": stride,
|
|
1958
|
+
"seq_len": seq_len,
|
|
1959
|
+
"candidate_unique": n,
|
|
1960
|
+
"candidate_limit": n,
|
|
1961
|
+
"tokens_available": int(n * seq_len),
|
|
1962
|
+
"examples_available": n,
|
|
1963
|
+
}
|
|
1964
|
+
|
|
1965
|
+
def info(self) -> dict[str, Any]: # pragma: no cover - trivial
|
|
1966
|
+
return {"name": self.name, "type": "dataset_provider", "dataset": "seq2seq"}
|
|
1967
|
+
|
|
1968
|
+
|
|
1969
|
+
# Registry for dataset providers
|
|
1970
|
+
_PROVIDERS: dict[str, type] = {
|
|
1971
|
+
"wikitext2": WikiText2Provider,
|
|
1972
|
+
"synthetic": SyntheticProvider,
|
|
1973
|
+
"hf_text": HFTextProvider,
|
|
1974
|
+
"local_jsonl": LocalJSONLProvider,
|
|
1975
|
+
"seq2seq": Seq2SeqDataProvider,
|
|
1976
|
+
"hf_seq2seq": HFSeq2SeqProvider,
|
|
1977
|
+
"local_jsonl_pairs": LocalJSONLPairsProvider,
|
|
1978
|
+
}
|
|
1979
|
+
|
|
1980
|
+
|
|
1981
|
+
def get_provider(name: str, **kwargs) -> DatasetProvider:
|
|
1982
|
+
"""
|
|
1983
|
+
Get a dataset provider by name.
|
|
1984
|
+
|
|
1985
|
+
Args:
|
|
1986
|
+
name: Provider name ("wikitext2", "synthetic")
|
|
1987
|
+
**kwargs: Provider-specific initialization parameters
|
|
1988
|
+
|
|
1989
|
+
Returns:
|
|
1990
|
+
Initialized dataset provider
|
|
1991
|
+
|
|
1992
|
+
Raises:
|
|
1993
|
+
ValidationError(E308): If provider name is not registered
|
|
1994
|
+
"""
|
|
1995
|
+
if name not in _PROVIDERS:
|
|
1996
|
+
available = ", ".join(_PROVIDERS.keys())
|
|
1997
|
+
# Typed-only error for provider lookup
|
|
1998
|
+
raise _ValErr(
|
|
1999
|
+
code="E308",
|
|
2000
|
+
message="PROVIDER-NOT-FOUND: unknown dataset provider",
|
|
2001
|
+
details={"provider": name, "available": available},
|
|
2002
|
+
)
|
|
2003
|
+
|
|
2004
|
+
provider_class = _PROVIDERS[name]
|
|
2005
|
+
return provider_class(**kwargs)
|
|
2006
|
+
|
|
2007
|
+
|
|
2008
|
+
def list_providers() -> list[str]:
|
|
2009
|
+
"""List available dataset provider names."""
|
|
2010
|
+
return list(_PROVIDERS.keys())
|
|
2011
|
+
|
|
2012
|
+
|
|
2013
|
+
def compute_window_hash(window: EvaluationWindow, include_data: bool = False) -> str:
|
|
2014
|
+
"""
|
|
2015
|
+
Compute a deterministic hash of an evaluation window.
|
|
2016
|
+
|
|
2017
|
+
Args:
|
|
2018
|
+
window: EvaluationWindow to hash
|
|
2019
|
+
include_data: Whether to include actual token data in hash
|
|
2020
|
+
|
|
2021
|
+
Returns:
|
|
2022
|
+
Hex digest string of the window hash
|
|
2023
|
+
"""
|
|
2024
|
+
hasher = hashlib.sha256()
|
|
2025
|
+
|
|
2026
|
+
# Always include structural information
|
|
2027
|
+
hasher.update(str(len(window)).encode())
|
|
2028
|
+
hasher.update(str(sorted(window.indices)).encode())
|
|
2029
|
+
|
|
2030
|
+
if include_data:
|
|
2031
|
+
# Include actual token sequences for data integrity checking
|
|
2032
|
+
for input_ids, attention_mask in zip(
|
|
2033
|
+
window.input_ids, window.attention_masks, strict=False
|
|
2034
|
+
):
|
|
2035
|
+
hasher.update(str(input_ids).encode())
|
|
2036
|
+
hasher.update(str(attention_mask).encode())
|
|
2037
|
+
|
|
2038
|
+
return hasher.hexdigest()
|
|
2039
|
+
|
|
2040
|
+
|
|
2041
|
+
# Export public API
|
|
2042
|
+
__all__ = [
|
|
2043
|
+
"DatasetProvider",
|
|
2044
|
+
"EvaluationWindow",
|
|
2045
|
+
"WikiText2Provider",
|
|
2046
|
+
"SyntheticProvider",
|
|
2047
|
+
"HFTextProvider",
|
|
2048
|
+
"LocalJSONLProvider",
|
|
2049
|
+
"get_provider",
|
|
2050
|
+
"list_providers",
|
|
2051
|
+
"compute_window_hash",
|
|
2052
|
+
]
|