invarlock 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. invarlock/__init__.py +33 -0
  2. invarlock/__main__.py +10 -0
  3. invarlock/_data/runtime/profiles/ci_cpu.yaml +15 -0
  4. invarlock/_data/runtime/profiles/release.yaml +23 -0
  5. invarlock/_data/runtime/tiers.yaml +76 -0
  6. invarlock/adapters/__init__.py +102 -0
  7. invarlock/adapters/_capabilities.py +45 -0
  8. invarlock/adapters/auto.py +99 -0
  9. invarlock/adapters/base.py +530 -0
  10. invarlock/adapters/base_types.py +85 -0
  11. invarlock/adapters/hf_bert.py +852 -0
  12. invarlock/adapters/hf_gpt2.py +403 -0
  13. invarlock/adapters/hf_llama.py +485 -0
  14. invarlock/adapters/hf_mixin.py +383 -0
  15. invarlock/adapters/hf_onnx.py +112 -0
  16. invarlock/adapters/hf_t5.py +137 -0
  17. invarlock/adapters/py.typed +1 -0
  18. invarlock/assurance/__init__.py +43 -0
  19. invarlock/cli/__init__.py +8 -0
  20. invarlock/cli/__main__.py +8 -0
  21. invarlock/cli/_evidence.py +25 -0
  22. invarlock/cli/_json.py +75 -0
  23. invarlock/cli/adapter_auto.py +162 -0
  24. invarlock/cli/app.py +287 -0
  25. invarlock/cli/commands/__init__.py +26 -0
  26. invarlock/cli/commands/certify.py +403 -0
  27. invarlock/cli/commands/doctor.py +1358 -0
  28. invarlock/cli/commands/explain_gates.py +151 -0
  29. invarlock/cli/commands/export_html.py +100 -0
  30. invarlock/cli/commands/plugins.py +1331 -0
  31. invarlock/cli/commands/report.py +354 -0
  32. invarlock/cli/commands/run.py +4146 -0
  33. invarlock/cli/commands/verify.py +1040 -0
  34. invarlock/cli/config.py +396 -0
  35. invarlock/cli/constants.py +68 -0
  36. invarlock/cli/device.py +92 -0
  37. invarlock/cli/doctor_helpers.py +74 -0
  38. invarlock/cli/errors.py +6 -0
  39. invarlock/cli/overhead_utils.py +60 -0
  40. invarlock/cli/provenance.py +66 -0
  41. invarlock/cli/utils.py +41 -0
  42. invarlock/config.py +56 -0
  43. invarlock/core/__init__.py +62 -0
  44. invarlock/core/abi.py +15 -0
  45. invarlock/core/api.py +274 -0
  46. invarlock/core/auto_tuning.py +317 -0
  47. invarlock/core/bootstrap.py +226 -0
  48. invarlock/core/checkpoint.py +221 -0
  49. invarlock/core/contracts.py +73 -0
  50. invarlock/core/error_utils.py +64 -0
  51. invarlock/core/events.py +298 -0
  52. invarlock/core/exceptions.py +95 -0
  53. invarlock/core/registry.py +481 -0
  54. invarlock/core/retry.py +146 -0
  55. invarlock/core/runner.py +2041 -0
  56. invarlock/core/types.py +154 -0
  57. invarlock/edits/__init__.py +12 -0
  58. invarlock/edits/_edit_utils.py +249 -0
  59. invarlock/edits/_external_utils.py +268 -0
  60. invarlock/edits/noop.py +47 -0
  61. invarlock/edits/py.typed +1 -0
  62. invarlock/edits/quant_rtn.py +801 -0
  63. invarlock/edits/registry.py +166 -0
  64. invarlock/eval/__init__.py +23 -0
  65. invarlock/eval/bench.py +1207 -0
  66. invarlock/eval/bootstrap.py +50 -0
  67. invarlock/eval/data.py +2052 -0
  68. invarlock/eval/metrics.py +2167 -0
  69. invarlock/eval/primary_metric.py +767 -0
  70. invarlock/eval/probes/__init__.py +24 -0
  71. invarlock/eval/probes/fft.py +139 -0
  72. invarlock/eval/probes/mi.py +213 -0
  73. invarlock/eval/probes/post_attention.py +323 -0
  74. invarlock/eval/providers/base.py +67 -0
  75. invarlock/eval/providers/seq2seq.py +111 -0
  76. invarlock/eval/providers/text_lm.py +113 -0
  77. invarlock/eval/providers/vision_text.py +93 -0
  78. invarlock/eval/py.typed +1 -0
  79. invarlock/guards/__init__.py +18 -0
  80. invarlock/guards/_contracts.py +9 -0
  81. invarlock/guards/invariants.py +640 -0
  82. invarlock/guards/policies.py +805 -0
  83. invarlock/guards/py.typed +1 -0
  84. invarlock/guards/rmt.py +2097 -0
  85. invarlock/guards/spectral.py +1419 -0
  86. invarlock/guards/tier_config.py +354 -0
  87. invarlock/guards/variance.py +3298 -0
  88. invarlock/guards_ref/__init__.py +15 -0
  89. invarlock/guards_ref/rmt_ref.py +40 -0
  90. invarlock/guards_ref/spectral_ref.py +135 -0
  91. invarlock/guards_ref/variance_ref.py +60 -0
  92. invarlock/model_profile.py +353 -0
  93. invarlock/model_utils.py +221 -0
  94. invarlock/observability/__init__.py +10 -0
  95. invarlock/observability/alerting.py +535 -0
  96. invarlock/observability/core.py +546 -0
  97. invarlock/observability/exporters.py +565 -0
  98. invarlock/observability/health.py +588 -0
  99. invarlock/observability/metrics.py +457 -0
  100. invarlock/observability/py.typed +1 -0
  101. invarlock/observability/utils.py +553 -0
  102. invarlock/plugins/__init__.py +12 -0
  103. invarlock/plugins/hello_guard.py +33 -0
  104. invarlock/plugins/hf_awq_adapter.py +82 -0
  105. invarlock/plugins/hf_bnb_adapter.py +79 -0
  106. invarlock/plugins/hf_gptq_adapter.py +78 -0
  107. invarlock/plugins/py.typed +1 -0
  108. invarlock/py.typed +1 -0
  109. invarlock/reporting/__init__.py +7 -0
  110. invarlock/reporting/certificate.py +3221 -0
  111. invarlock/reporting/certificate_schema.py +244 -0
  112. invarlock/reporting/dataset_hashing.py +215 -0
  113. invarlock/reporting/guards_analysis.py +948 -0
  114. invarlock/reporting/html.py +32 -0
  115. invarlock/reporting/normalizer.py +235 -0
  116. invarlock/reporting/policy_utils.py +517 -0
  117. invarlock/reporting/primary_metric_utils.py +265 -0
  118. invarlock/reporting/render.py +1442 -0
  119. invarlock/reporting/report.py +903 -0
  120. invarlock/reporting/report_types.py +278 -0
  121. invarlock/reporting/utils.py +175 -0
  122. invarlock/reporting/validate.py +631 -0
  123. invarlock/security.py +176 -0
  124. invarlock/sparsity_utils.py +323 -0
  125. invarlock/utils/__init__.py +150 -0
  126. invarlock/utils/digest.py +45 -0
  127. invarlock-0.2.0.dist-info/METADATA +586 -0
  128. invarlock-0.2.0.dist-info/RECORD +132 -0
  129. invarlock-0.2.0.dist-info/WHEEL +5 -0
  130. invarlock-0.2.0.dist-info/entry_points.txt +20 -0
  131. invarlock-0.2.0.dist-info/licenses/LICENSE +201 -0
  132. invarlock-0.2.0.dist-info/top_level.txt +1 -0
invarlock/eval/data.py ADDED
@@ -0,0 +1,2052 @@
1
+ """
2
+ InvarLock Evaluation Data Loading
3
+ ============================
4
+
5
+ Pluggable data loading system with deterministic windowing for reproducible evaluation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import atexit
11
+ import hashlib
12
+ import json
13
+ import math
14
+ import os
15
+ import time
16
+ import warnings
17
+ from abc import abstractmethod
18
+ from collections import Counter
19
+ from collections.abc import Sequence
20
+ from pathlib import Path
21
+ from typing import Any, NamedTuple, Protocol
22
+
23
+ import numpy as np
24
+
25
+ from invarlock.core.exceptions import DataError as _DataErr
26
+ from invarlock.core.exceptions import DependencyError as _DepErr
27
+ from invarlock.core.exceptions import ValidationError as _ValErr
28
+
29
+ # NOTE: During the typed-only migration, avoid hybrid KeyError mixin
30
+
31
+ _LIGHT_IMPORT = os.getenv("INVARLOCK_LIGHT_IMPORT", "").strip().lower() in {
32
+ "1",
33
+ "true",
34
+ "yes",
35
+ }
36
+
37
+ try:
38
+ from datasets import load_dataset
39
+
40
+ HAS_DATASETS = True
41
+ except ImportError:
42
+ HAS_DATASETS = False
43
+
44
+ def load_dataset(*args, **kwargs): # type: ignore[no-redef]
45
+ raise _DepErr(
46
+ code="E301",
47
+ message="DEPENDENCY-MISSING: datasets library required for dataset loading",
48
+ details={"dependency": "datasets"},
49
+ )
50
+
51
+
52
+ try:
53
+ import torch
54
+ import torch.nn.functional as F
55
+
56
+ HAS_TORCH = True
57
+ except ImportError:
58
+ HAS_TORCH = False
59
+
60
+
61
+ class EvaluationWindow(NamedTuple):
62
+ """A window of tokenized samples for evaluation."""
63
+
64
+ input_ids: list[list[int]] # List of tokenized sequences
65
+ attention_masks: list[list[int]] # Attention masks (1=real token, 0=padding)
66
+ indices: list[int] # Original dataset indices
67
+
68
+ def __len__(self) -> int:
69
+ return len(self.input_ids)
70
+
71
+ def to_dict(self) -> dict[str, Any]:
72
+ """Convert to dictionary for serialization."""
73
+ return {
74
+ "input_ids": self.input_ids,
75
+ "attention_masks": self.attention_masks,
76
+ "indices": self.indices,
77
+ "length": len(self.input_ids),
78
+ }
79
+
80
+
81
+ class DatasetProvider(Protocol):
82
+ """
83
+ Protocol for pluggable dataset providers.
84
+
85
+ Enables extensible dataset support while maintaining deterministic evaluation.
86
+ """
87
+
88
+ name: str
89
+
90
+ @abstractmethod
91
+ def load(self, split: str = "validation", **kwargs) -> list[str]:
92
+ """
93
+ Load raw text samples from the dataset.
94
+
95
+ Args:
96
+ split: Dataset split to load ("validation", "test", "train")
97
+ **kwargs: Provider-specific parameters
98
+
99
+ Returns:
100
+ List of text strings
101
+ """
102
+ ...
103
+
104
+ @abstractmethod
105
+ def windows(
106
+ self,
107
+ tokenizer: Any,
108
+ *,
109
+ seq_len: int = 128,
110
+ stride: int = 64,
111
+ preview_n: int = 100,
112
+ final_n: int = 100,
113
+ seed: int = 42,
114
+ split: str = "validation",
115
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
116
+ """
117
+ Create deterministic preview and final evaluation windows.
118
+
119
+ Args:
120
+ tokenizer: Tokenizer to use for text encoding
121
+ seq_len: Maximum sequence length
122
+ stride: Stride for overlapping windows (unused in current impl)
123
+ preview_n: Number of preview samples
124
+ final_n: Number of final samples
125
+ seed: Random seed for deterministic sampling
126
+ split: Dataset split to use
127
+
128
+ Returns:
129
+ Tuple of (preview_window, final_window)
130
+ """
131
+ ...
132
+
133
+ def estimate_capacity(
134
+ self,
135
+ tokenizer: Any,
136
+ *,
137
+ seq_len: int,
138
+ stride: int,
139
+ split: str = "validation",
140
+ target_total: int | None = None,
141
+ fast_mode: bool = False,
142
+ ) -> dict[str, Any]:
143
+ """
144
+ Estimate number of non-overlapping, deduplicated windows available for evaluation.
145
+
146
+ Returns metadata describing the available capacity (total tokens, usable windows, dedupe rate).
147
+ """
148
+ ...
149
+
150
+ def info(self) -> dict[str, Any]:
151
+ """Get information about this dataset provider."""
152
+ return {"name": self.name, "type": "dataset_provider"}
153
+
154
+
155
+ class WikiText2Provider:
156
+ """
157
+ WikiText-2 dataset provider with deterministic windowing.
158
+
159
+ Implements the canonical WT-2 evaluation setup with fixed 100+100 preview/final samples.
160
+ """
161
+
162
+ name = "wikitext2"
163
+ _MODEL_CACHE: Any | None | bool = None
164
+ _MODEL_DEVICE: Any | None = None
165
+ _CLEANUP_REGISTERED: bool = False
166
+
167
+ def __init__(
168
+ self,
169
+ cache_dir: Path | None = None,
170
+ device_hint: str | None = None,
171
+ **_: Any,
172
+ ):
173
+ """
174
+ Initialize WikiText-2 provider.
175
+
176
+ Args:
177
+ cache_dir: Optional cache directory for dataset storage
178
+ """
179
+ self.cache_dir = cache_dir
180
+ self._validate_dependencies()
181
+ self._register_cleanup()
182
+ self._difficulty_model = self.__class__._MODEL_CACHE
183
+ self._difficulty_device = self.__class__._MODEL_DEVICE
184
+ self._last_stratification_stats: dict[str, Any] | None = None
185
+ self._last_batch_size_used: int = 0
186
+ self._last_scorer_profile: dict[str, Any] | None = None
187
+ self._scorer_warmed: bool = False
188
+ # In-process cache for loaded/filtered texts to avoid repeated
189
+ # load_dataset() calls across stratification retries.
190
+ self._texts_cache: dict[str, list[str]] = {}
191
+ # Optional device hint from CLI/resolved run device (e.g. "cpu", "cuda", "mps", "auto")
192
+ normalized_hint = (device_hint or "").strip().lower()
193
+ self._device_hint: str | None = normalized_hint or None
194
+
195
+ @classmethod
196
+ def _register_cleanup(cls) -> None:
197
+ """Register an atexit hook once per process to release cached models."""
198
+ if cls._CLEANUP_REGISTERED or not HAS_TORCH:
199
+ return
200
+
201
+ def _cleanup() -> None:
202
+ cls._cleanup_model_cache()
203
+
204
+ atexit.register(_cleanup)
205
+ cls._CLEANUP_REGISTERED = True
206
+
207
+ @classmethod
208
+ def _cleanup_model_cache(cls) -> None:
209
+ """Release cached models to avoid leaking multiprocessing semaphores."""
210
+ cache = cls._MODEL_CACHE
211
+ if cache is not None and cache is not False and HAS_TORCH:
212
+ try:
213
+ cache.to("cpu")
214
+ except Exception:
215
+ pass
216
+ cls._MODEL_CACHE = None
217
+ cls._MODEL_DEVICE = None
218
+
219
+ @staticmethod
220
+ def _pick_default_scorer_device() -> torch.device:
221
+ """
222
+ Choose a default device for the difficulty scorer model.
223
+
224
+ Prefers CUDA → MPS → CPU when available.
225
+ """
226
+ if torch.cuda.is_available():
227
+ return torch.device("cuda")
228
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
229
+ return torch.device("mps")
230
+ return torch.device("cpu")
231
+
232
+ def _validate_dependencies(self) -> None:
233
+ """Check that required dependencies are available."""
234
+ if not HAS_DATASETS:
235
+ if _LIGHT_IMPORT:
236
+ return
237
+ raise _DepErr(
238
+ code="E301",
239
+ message=(
240
+ "DEPENDENCY-MISSING: datasets library required for WikiText-2 loading"
241
+ ),
242
+ details={"dependency": "datasets"},
243
+ )
244
+
245
+ def estimate_capacity(
246
+ self,
247
+ tokenizer: Any,
248
+ *,
249
+ seq_len: int,
250
+ stride: int,
251
+ split: str = "validation",
252
+ target_total: int | None = None,
253
+ fast_mode: bool = False,
254
+ ) -> dict[str, Any]:
255
+ """Estimate available non-overlapping windows for evaluation."""
256
+ texts = self.load(split=split, max_samples=2000)
257
+ if not texts:
258
+ return {
259
+ "total_tokens": 0,
260
+ "available_nonoverlap": 0,
261
+ "available_unique": 0,
262
+ "dedupe_rate": 0.0,
263
+ "stride": stride,
264
+ "seq_len": seq_len,
265
+ "candidate_unique": 0,
266
+ "candidate_limit": 0,
267
+ }
268
+
269
+ env_fast = os.environ.get("INVARLOCK_CAPACITY_FAST", "")
270
+ env_fast_flag = isinstance(env_fast, str) and env_fast.strip().lower() in {
271
+ "1",
272
+ "true",
273
+ "yes",
274
+ "on",
275
+ }
276
+ use_fast = bool(fast_mode) or env_fast_flag
277
+ if use_fast:
278
+ base_available = len(texts)
279
+ target_total = int(target_total or 0)
280
+ approx_available = base_available
281
+ if target_total > 0:
282
+ approx_available = max(base_available, target_total)
283
+ total_tokens = int(max(approx_available, 0) * seq_len)
284
+ approx_available = int(max(approx_available, 0))
285
+ return {
286
+ "total_tokens": total_tokens,
287
+ "available_nonoverlap": approx_available,
288
+ "available_unique": approx_available,
289
+ "dedupe_rate": 0.0,
290
+ "stride": int(stride),
291
+ "seq_len": int(seq_len),
292
+ "candidate_unique": approx_available,
293
+ "candidate_limit": approx_available,
294
+ }
295
+
296
+ tokenized = self._collect_tokenized_samples(
297
+ texts, list(range(len(texts))), tokenizer, seq_len
298
+ )
299
+
300
+ total_tokens = sum(item[3] for item in tokenized)
301
+ available_nonoverlap = len(tokenized)
302
+
303
+ unique_sequences: set[tuple[int, ...]] = set()
304
+ for _, input_ids, attention_mask, _ in tokenized:
305
+ seq = tuple(
306
+ int(tok_id)
307
+ for tok_id, mask in zip(input_ids, attention_mask, strict=False)
308
+ if mask
309
+ )
310
+ unique_sequences.add(seq)
311
+
312
+ available_unique = len(unique_sequences)
313
+ dedupe_rate = (
314
+ 0.0
315
+ if available_nonoverlap == 0
316
+ else max(
317
+ 0.0,
318
+ 1.0 - (available_unique / float(max(available_nonoverlap, 1))),
319
+ )
320
+ )
321
+
322
+ candidate_unique = None
323
+ candidate_limit = None
324
+ if target_total is not None and target_total > 0:
325
+ reserve_buffer = max(int(target_total * 0.2), 64)
326
+ candidate_limit = min(len(texts), target_total + reserve_buffer)
327
+ tokenized_subset = self._collect_tokenized_samples(
328
+ texts, list(range(candidate_limit)), tokenizer, seq_len
329
+ )
330
+ subset_signatures = {
331
+ tuple(
332
+ int(tok)
333
+ for tok, mask in zip(entry[1], entry[2], strict=False)
334
+ if mask
335
+ )
336
+ for entry in tokenized_subset
337
+ }
338
+ candidate_unique = len(subset_signatures)
339
+
340
+ result = {
341
+ "total_tokens": int(total_tokens),
342
+ "available_nonoverlap": int(available_nonoverlap),
343
+ "available_unique": int(available_unique),
344
+ "dedupe_rate": float(dedupe_rate),
345
+ "stride": int(stride),
346
+ "seq_len": int(seq_len),
347
+ }
348
+ if candidate_unique is not None:
349
+ result["candidate_unique"] = int(candidate_unique)
350
+ result["candidate_limit"] = int(candidate_limit or 0)
351
+ return result
352
+
353
+ def load(
354
+ self, split: str = "validation", max_samples: int = 2000, **kwargs
355
+ ) -> list[str]:
356
+ """
357
+ Load WikiText-2 text samples.
358
+
359
+ Args:
360
+ split: Dataset split ("validation", "test", "train")
361
+ max_samples: Maximum samples to load
362
+ **kwargs: Additional parameters (ignored)
363
+
364
+ Returns:
365
+ List of filtered text strings
366
+ """
367
+ print(f"📚 Loading WikiText-2 {split} split...")
368
+
369
+ # Serve from cache when possible (load the largest slice once)
370
+ cached = self._texts_cache.get(split)
371
+ if cached is not None and len(cached) >= max_samples:
372
+ return cached[:max_samples]
373
+
374
+ if not HAS_DATASETS and _LIGHT_IMPORT:
375
+ texts = ["hello world", "invarlock synthetic text"] * max(
376
+ 1, max_samples // 2
377
+ )
378
+ self._texts_cache[split] = texts
379
+ return texts[:max_samples]
380
+
381
+ # Load dataset with size limit for efficiency
382
+ dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
383
+ dataset = load_dataset(
384
+ "wikitext",
385
+ "wikitext-2-raw-v1",
386
+ split=dataset_slice,
387
+ cache_dir=str(self.cache_dir) if self.cache_dir else None,
388
+ )
389
+
390
+ # Filter out empty/short texts
391
+ valid_texts: list[str] = []
392
+ for item in dataset:
393
+ text = str(item.get("text", "")).strip()
394
+ # Keep texts with at least 20 characters and some alphabetic content
395
+ if len(text) >= 20 and any(c.isalpha() for c in text):
396
+ valid_texts.append(text)
397
+
398
+ # Optional exact-text dedupe to reduce duplicate-token windows
399
+ # Enable via INVARLOCK_DEDUP_TEXTS=1 (keeps first occurrence, preserves order)
400
+ import os as _os
401
+
402
+ if str(_os.environ.get("INVARLOCK_DEDUP_TEXTS", "")).strip().lower() in {
403
+ "1",
404
+ "true",
405
+ "yes",
406
+ "on",
407
+ }:
408
+ seen: set[str] = set()
409
+ deduped: list[str] = []
410
+ for t in valid_texts:
411
+ if t not in seen:
412
+ seen.add(t)
413
+ deduped.append(t)
414
+ valid_texts = deduped
415
+
416
+ # Cache the largest slice we’ve seen for this split
417
+ prev = self._texts_cache.get(split)
418
+ if prev is None or len(valid_texts) > len(prev):
419
+ self._texts_cache[split] = list(valid_texts)
420
+
421
+ print(f" ✓ Loaded {len(valid_texts)} valid samples from {len(dataset)} total")
422
+ return valid_texts
423
+
424
+ def windows(
425
+ self,
426
+ tokenizer: Any,
427
+ *,
428
+ seq_len: int = 128,
429
+ stride: int = 64,
430
+ preview_n: int = 100,
431
+ final_n: int = 100,
432
+ seed: int = 42,
433
+ split: str = "validation",
434
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
435
+ """
436
+ Create deterministic preview and final evaluation windows.
437
+
438
+ This implements the core deterministic evaluation requirement:
439
+ - Fixed seed ensures reproducible sample selection
440
+ - Non-overlapping preview and final samples
441
+ - Consistent tokenization parameters
442
+
443
+ Args:
444
+ tokenizer: HuggingFace tokenizer for text encoding
445
+ seq_len: Maximum sequence length for tokenization
446
+ stride: Stride parameter (reserved for future use)
447
+ preview_n: Number of preview samples (default: 100)
448
+ final_n: Number of final samples (default: 100)
449
+ seed: Random seed for reproducible sampling
450
+ split: Dataset split to use
451
+
452
+ Returns:
453
+ Tuple of (preview_window, final_window) with deterministic samples
454
+ """
455
+ total_required = preview_n + final_n
456
+ if total_required <= 0:
457
+ raise _ValErr(
458
+ code="E302", message="VALIDATION-FAILED: preview/final must be positive"
459
+ )
460
+
461
+ # Load text data with additional buffer to ensure enough valid samples for release windows.
462
+ extra_pool = max(500, int(0.5 * total_required))
463
+ max_samples = max(total_required + extra_pool, 2000)
464
+ texts = self.load(split=split, max_samples=max_samples)
465
+
466
+ rng = np.random.RandomState(seed)
467
+ shuffled_indices = rng.permutation(len(texts)).tolist()
468
+
469
+ reserve = max(16, int(0.1 * total_required))
470
+ target_pool = min(len(texts), total_required + reserve * 2)
471
+
472
+ if target_pool < total_required:
473
+ raise _DataErr(
474
+ code="E303",
475
+ message=(
476
+ "CAPACITY-INSUFFICIENT: not enough valid samples for requested preview/final"
477
+ ),
478
+ details={
479
+ "have": int(len(texts)),
480
+ "preview": int(preview_n),
481
+ "final": int(final_n),
482
+ },
483
+ )
484
+
485
+ candidates: list[dict[str, Any]] = []
486
+ used_indices: set[int] = set()
487
+ cursor = 0
488
+ chunk_size = max(64, min(256, target_pool))
489
+
490
+ print(" 📊 Creating evaluation windows:")
491
+ print(f" Requested preview/final: {preview_n}/{final_n}")
492
+ print(f" Sampling pool target: {target_pool} (reserve {reserve})")
493
+
494
+ while len(candidates) < total_required + reserve and cursor < len(
495
+ shuffled_indices
496
+ ):
497
+ batch = shuffled_indices[cursor : cursor + chunk_size]
498
+ cursor += chunk_size
499
+
500
+ tokenized_batch = self._collect_tokenized_samples(
501
+ texts, batch, tokenizer, seq_len
502
+ )
503
+
504
+ for (
505
+ idx,
506
+ input_ids_list,
507
+ attention_mask_list,
508
+ real_tokens,
509
+ ) in tokenized_batch:
510
+ if idx in used_indices:
511
+ continue
512
+ used_indices.add(idx)
513
+ candidates.append(
514
+ {
515
+ "dataset_index": idx,
516
+ "input_ids": input_ids_list,
517
+ "attention_mask": attention_mask_list,
518
+ "token_count": real_tokens,
519
+ }
520
+ )
521
+
522
+ if cursor >= len(shuffled_indices) and len(candidates) < total_required:
523
+ break
524
+
525
+ if len(candidates) < total_required:
526
+ raise _DataErr(
527
+ code="E304",
528
+ message=(
529
+ "TOKENIZE-INSUFFICIENT: failed to gather enough tokenized samples"
530
+ ),
531
+ details={"needed": int(total_required), "got": int(len(candidates))},
532
+ )
533
+
534
+ if not self._score_candidates_with_model(candidates):
535
+ token_counter: Counter[int] = Counter()
536
+ for candidate in candidates:
537
+ for token_id, mask in zip(
538
+ candidate["input_ids"], candidate["attention_mask"], strict=False
539
+ ):
540
+ if mask:
541
+ token_counter[int(token_id)] += 1
542
+
543
+ total_tokens = sum(token_counter.values()) or 1
544
+ vocab_size = max(len(token_counter), 1)
545
+
546
+ for candidate in candidates:
547
+ difficulty = 0.0
548
+ real_tokens = 0
549
+ for token_id, mask in zip(
550
+ candidate["input_ids"], candidate["attention_mask"], strict=False
551
+ ):
552
+ if not mask:
553
+ continue
554
+ freq = (token_counter[int(token_id)] + 1.0) / (
555
+ total_tokens + vocab_size
556
+ )
557
+ difficulty -= math.log(freq)
558
+ real_tokens += 1
559
+ candidate["difficulty"] = difficulty / max(real_tokens, 1)
560
+
561
+ sorted_candidates = sorted(
562
+ candidates, key=lambda item: (item["difficulty"], item["dataset_index"])
563
+ )
564
+
565
+ total_candidates = len(sorted_candidates)
566
+ selection_count = total_required
567
+ selected_positions: list[int] = []
568
+ used_positions: set[int] = set()
569
+
570
+ for k in range(selection_count):
571
+ target_position = (k + 0.5) * total_candidates / selection_count
572
+ base_idx = int(round(target_position))
573
+ offset = 0
574
+ chosen: int | None = None
575
+
576
+ while offset < total_candidates:
577
+ for candidate_idx in (base_idx + offset, base_idx - offset):
578
+ if (
579
+ 0 <= candidate_idx < total_candidates
580
+ and candidate_idx not in used_positions
581
+ ):
582
+ chosen = candidate_idx
583
+ break
584
+ if chosen is not None:
585
+ break
586
+ offset += 1
587
+
588
+ if chosen is not None:
589
+ used_positions.add(chosen)
590
+ selected_positions.append(chosen)
591
+
592
+ if len(selected_positions) < selection_count:
593
+ for candidate_idx in range(total_candidates):
594
+ if candidate_idx not in used_positions:
595
+ used_positions.add(candidate_idx)
596
+ selected_positions.append(candidate_idx)
597
+ if len(selected_positions) == selection_count:
598
+ break
599
+
600
+ if len(selected_positions) < selection_count:
601
+ raise _DataErr(
602
+ code="E305", message="STRATIFY-FAILED: candidate pool insufficient"
603
+ )
604
+
605
+ selected_candidates = [sorted_candidates[idx] for idx in selected_positions]
606
+ selected_candidates.sort(
607
+ key=lambda item: (item["difficulty"], item["dataset_index"])
608
+ )
609
+
610
+ preview_candidates: list[dict[str, Any]] = []
611
+ final_candidates: list[dict[str, Any]] = []
612
+
613
+ def assign_candidate(
614
+ candidate: dict[str, Any],
615
+ primary: list[dict[str, Any]],
616
+ secondary: list[dict[str, Any]],
617
+ primary_capacity: int,
618
+ secondary_capacity: int,
619
+ ) -> None:
620
+ if len(primary) < primary_capacity:
621
+ primary.append(candidate)
622
+ elif len(secondary) < secondary_capacity:
623
+ secondary.append(candidate)
624
+
625
+ for pair_start in range(0, len(selected_candidates), 2):
626
+ pair = selected_candidates[pair_start : pair_start + 2]
627
+ if not pair:
628
+ continue
629
+ if len(pair) == 2:
630
+ easy, hard = pair
631
+ pair_index = pair_start // 2
632
+ if pair_index % 2 == 0:
633
+ assign_candidate(
634
+ easy, preview_candidates, final_candidates, preview_n, final_n
635
+ )
636
+ assign_candidate(
637
+ hard, final_candidates, preview_candidates, final_n, preview_n
638
+ )
639
+ else:
640
+ assign_candidate(
641
+ easy, final_candidates, preview_candidates, final_n, preview_n
642
+ )
643
+ assign_candidate(
644
+ hard, preview_candidates, final_candidates, preview_n, final_n
645
+ )
646
+ else:
647
+ lone_candidate = pair[0]
648
+ assign_candidate(
649
+ lone_candidate,
650
+ preview_candidates,
651
+ final_candidates,
652
+ preview_n,
653
+ final_n,
654
+ )
655
+
656
+ assigned_ids = {
657
+ id(candidate) for candidate in preview_candidates + final_candidates
658
+ }
659
+ remaining = [
660
+ candidate
661
+ for candidate in selected_candidates
662
+ if id(candidate) not in assigned_ids
663
+ ]
664
+ for candidate in remaining:
665
+ if len(preview_candidates) < preview_n:
666
+ preview_candidates.append(candidate)
667
+ elif len(final_candidates) < final_n:
668
+ final_candidates.append(candidate)
669
+
670
+ def _mean_difficulty(candidates: list[dict[str, Any]]) -> float:
671
+ if not candidates:
672
+ return 0.0
673
+ return float(
674
+ sum(candidate["difficulty"] for candidate in candidates)
675
+ / len(candidates)
676
+ )
677
+
678
+ for _ in range(100):
679
+ if not preview_candidates or not final_candidates:
680
+ break
681
+ diff = _mean_difficulty(preview_candidates) - _mean_difficulty(
682
+ final_candidates
683
+ )
684
+ if abs(diff) <= 1e-4:
685
+ break
686
+ if diff < 0:
687
+ preview_candidate = min(
688
+ preview_candidates, key=lambda c: c["difficulty"]
689
+ )
690
+ final_candidate = max(final_candidates, key=lambda c: c["difficulty"])
691
+ else:
692
+ preview_candidate = max(
693
+ preview_candidates, key=lambda c: c["difficulty"]
694
+ )
695
+ final_candidate = min(final_candidates, key=lambda c: c["difficulty"])
696
+
697
+ if preview_candidate is final_candidate:
698
+ break
699
+
700
+ preview_candidates.remove(preview_candidate)
701
+ final_candidates.remove(final_candidate)
702
+ preview_candidates.append(final_candidate)
703
+ final_candidates.append(preview_candidate)
704
+
705
+ new_diff = _mean_difficulty(preview_candidates) - _mean_difficulty(
706
+ final_candidates
707
+ )
708
+ if abs(new_diff) >= abs(diff) - 1e-6:
709
+ # swap did not improve; revert and stop
710
+ preview_candidates.remove(final_candidate)
711
+ final_candidates.remove(preview_candidate)
712
+ preview_candidates.append(preview_candidate)
713
+ final_candidates.append(final_candidate)
714
+ break
715
+
716
+ if len(preview_candidates) != preview_n or len(final_candidates) != final_n:
717
+ raise _DataErr(
718
+ code="E305",
719
+ message=(
720
+ "STRATIFY-FAILED: failed to allocate preview/final windows with equal counts"
721
+ ),
722
+ details={
723
+ "preview_target": int(preview_n),
724
+ "final_target": int(final_n),
725
+ "preview_got": int(len(preview_candidates)),
726
+ "final_got": int(len(final_candidates)),
727
+ },
728
+ )
729
+
730
+ preview_candidates.sort(
731
+ key=lambda item: (item["difficulty"], item["dataset_index"])
732
+ )
733
+ final_candidates.sort(
734
+ key=lambda item: (item["difficulty"], item["dataset_index"])
735
+ )
736
+
737
+ preview_window = EvaluationWindow(
738
+ input_ids=[c["input_ids"] for c in preview_candidates],
739
+ attention_masks=[c["attention_mask"] for c in preview_candidates],
740
+ indices=[c["dataset_index"] for c in preview_candidates],
741
+ )
742
+
743
+ final_window = EvaluationWindow(
744
+ input_ids=[c["input_ids"] for c in final_candidates],
745
+ attention_masks=[c["attention_mask"] for c in final_candidates],
746
+ indices=[c["dataset_index"] for c in final_candidates],
747
+ )
748
+
749
+ if len(preview_window) != preview_n or len(final_window) != final_n:
750
+ raise _DataErr(
751
+ code="E305",
752
+ message="STRATIFY-FAILED: window stratification mismatch",
753
+ details={
754
+ "preview_target": int(preview_n),
755
+ "final_target": int(final_n),
756
+ "preview_got": int(len(preview_window)),
757
+ "final_got": int(len(final_window)),
758
+ },
759
+ )
760
+
761
+ preview_difficulties = [c["difficulty"] for c in preview_candidates]
762
+ final_difficulties = [c["difficulty"] for c in final_candidates]
763
+ self._last_stratification_stats = {
764
+ "pool_size": len(selected_candidates),
765
+ "reserve": reserve,
766
+ "batch_size_used": int(self._last_batch_size_used),
767
+ "preview_mean_difficulty": float(np.mean(preview_difficulties))
768
+ if preview_difficulties
769
+ else 0.0,
770
+ "final_mean_difficulty": float(np.mean(final_difficulties))
771
+ if final_difficulties
772
+ else 0.0,
773
+ "preview_std_difficulty": float(np.std(preview_difficulties))
774
+ if preview_difficulties
775
+ else 0.0,
776
+ "final_std_difficulty": float(np.std(final_difficulties))
777
+ if final_difficulties
778
+ else 0.0,
779
+ "difficulty_gap": float(
780
+ (np.mean(final_difficulties) - np.mean(preview_difficulties))
781
+ if (preview_difficulties and final_difficulties)
782
+ else 0.0
783
+ ),
784
+ }
785
+
786
+ print(f" Seed: {seed}, Seq length: {seq_len}")
787
+ print(f" Preview: {len(preview_window)} samples")
788
+ print(f" Final: {len(final_window)} samples")
789
+
790
+ return preview_window, final_window
791
+
792
+ def _collect_tokenized_samples(
793
+ self,
794
+ texts: Sequence[str],
795
+ indices: Sequence[int],
796
+ tokenizer: Any,
797
+ seq_len: int,
798
+ ) -> list[tuple[int, list[int], list[int], int]]:
799
+ """Tokenize samples and return raw sequences without logging."""
800
+ results: list[tuple[int, list[int], list[int], int]] = []
801
+ for idx in indices:
802
+ if idx >= len(texts):
803
+ continue
804
+
805
+ text = texts[idx]
806
+
807
+ try:
808
+ tokens = tokenizer(
809
+ text,
810
+ truncation=True,
811
+ padding="max_length",
812
+ max_length=seq_len,
813
+ return_tensors="pt" if HAS_TORCH else None,
814
+ )
815
+
816
+ if HAS_TORCH and hasattr(tokens["input_ids"], "squeeze"):
817
+ input_ids = tokens["input_ids"].squeeze(0).tolist()
818
+ attention_mask = (
819
+ tokens.get(
820
+ "attention_mask", torch.ones_like(tokens["input_ids"])
821
+ )
822
+ .squeeze(0)
823
+ .tolist()
824
+ )
825
+ else:
826
+ input_ids = tokens["input_ids"]
827
+ attention_mask = tokens.get("attention_mask", [1] * len(input_ids))
828
+
829
+ real_tokens = int(sum(attention_mask))
830
+ if real_tokens > 1:
831
+ results.append(
832
+ (
833
+ idx,
834
+ [int(token) for token in input_ids],
835
+ [int(mask) for mask in attention_mask],
836
+ real_tokens,
837
+ )
838
+ )
839
+
840
+ except Exception as e:
841
+ warnings.warn(f"Failed to tokenize sample {idx}: {e}", stacklevel=2)
842
+ continue
843
+
844
+ return results
845
+
846
+ def _score_candidates_with_model(self, candidates: list[dict[str, Any]]) -> bool:
847
+ """Score candidate windows using a pretrained GPT-2 model if available."""
848
+ if not HAS_TORCH:
849
+ return False
850
+
851
+ if self._difficulty_model is False:
852
+ return False
853
+
854
+ try:
855
+ eval_device_override = os.environ.get("INVARLOCK_EVAL_DEVICE")
856
+ device_hint = getattr(self, "_device_hint", None)
857
+
858
+ if self._difficulty_model is None:
859
+ from transformers import GPT2LMHeadModel
860
+
861
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
862
+ model.eval()
863
+ # Decide initial scorer device: env override → provider hint → heuristic
864
+ if eval_device_override:
865
+ try:
866
+ device = torch.device(eval_device_override)
867
+ except Exception:
868
+ device = self._pick_default_scorer_device()
869
+ elif device_hint and device_hint != "auto":
870
+ try:
871
+ device = torch.device(device_hint)
872
+ except Exception:
873
+ device = self._pick_default_scorer_device()
874
+ else:
875
+ device = self._pick_default_scorer_device()
876
+
877
+ model.to(device)
878
+ self._difficulty_model = model
879
+ self._difficulty_device = device
880
+ self.__class__._MODEL_CACHE = model
881
+ self.__class__._MODEL_DEVICE = device
882
+
883
+ assert self._difficulty_model is not None
884
+ model = self._difficulty_model
885
+ device = self._difficulty_device or torch.device("cpu")
886
+
887
+ # If a new override/hint is provided, move the cached model if needed.
888
+ desired_device = device
889
+ if eval_device_override:
890
+ try:
891
+ desired_device = torch.device(eval_device_override)
892
+ except Exception:
893
+ desired_device = device
894
+ elif device_hint and device_hint != "auto":
895
+ try:
896
+ desired_device = torch.device(device_hint)
897
+ except Exception:
898
+ desired_device = device
899
+
900
+ if desired_device != device:
901
+ try:
902
+ model.to(desired_device)
903
+ device = desired_device
904
+ self._difficulty_device = desired_device
905
+ self.__class__._MODEL_DEVICE = desired_device
906
+ except Exception as exc:
907
+ warnings.warn(
908
+ f"Failed to move GPT-2 difficulty scorer to {desired_device}: {exc}",
909
+ stacklevel=2,
910
+ )
911
+
912
+ if not self._scorer_warmed:
913
+ with torch.no_grad():
914
+ dummy_input = torch.zeros((1, 8), dtype=torch.long, device=device)
915
+ dummy_attention = torch.ones_like(dummy_input)
916
+ model(dummy_input, attention_mask=dummy_attention)
917
+ self._scorer_warmed = True
918
+
919
+ batch_override = os.environ.get("INVARLOCK_SCORES_BATCH_SIZE")
920
+ override_size = None
921
+ if batch_override:
922
+ try:
923
+ override_size = max(1, int(batch_override))
924
+ except ValueError:
925
+ override_size = None
926
+
927
+ batch_size = min(32, max(4, len(candidates)))
928
+ if override_size is not None:
929
+ batch_size = max(1, min(override_size, len(candidates)))
930
+
931
+ input_batch: list[list[int]] = []
932
+ attention_batch: list[list[int]] = []
933
+ candidate_batch: list[dict[str, Any]] = []
934
+ total_tokens = 0
935
+ start_time = time.perf_counter()
936
+
937
+ with torch.no_grad():
938
+ for candidate in candidates:
939
+ input_batch.append(candidate["input_ids"])
940
+ attention_batch.append(candidate["attention_mask"])
941
+ candidate_batch.append(candidate)
942
+
943
+ if len(input_batch) == batch_size or candidate is candidates[-1]:
944
+ input_tensor = torch.tensor(
945
+ input_batch, dtype=torch.long, device=device
946
+ )
947
+ attention_tensor = torch.tensor(
948
+ attention_batch, dtype=torch.long, device=device
949
+ )
950
+
951
+ outputs = model(input_tensor, attention_mask=attention_tensor)
952
+ shift_logits = outputs.logits[:, :-1, :].contiguous()
953
+ shift_labels = input_tensor[:, 1:].contiguous()
954
+ shift_mask = attention_tensor[:, 1:].contiguous()
955
+ shift_labels = shift_labels.masked_fill(shift_mask == 0, 0)
956
+
957
+ vocab_size = shift_logits.size(-1)
958
+ losses = F.cross_entropy(
959
+ shift_logits.view(-1, vocab_size),
960
+ shift_labels.view(-1),
961
+ reduction="none",
962
+ )
963
+ losses = losses.view(shift_labels.size()) * shift_mask
964
+ token_counts = shift_mask.sum(dim=1).clamp(min=1)
965
+ loss_per_example = (
966
+ (losses.sum(dim=1) / token_counts).cpu().tolist()
967
+ )
968
+
969
+ for cand_obj, loss_value in zip(
970
+ candidate_batch, loss_per_example, strict=False
971
+ ):
972
+ cand_obj["difficulty"] = float(loss_value)
973
+ total_tokens += int(token_counts.sum().item())
974
+
975
+ input_batch.clear()
976
+ attention_batch.clear()
977
+ candidate_batch.clear()
978
+ self._last_batch_size_used = batch_size
979
+ elapsed = max(time.perf_counter() - start_time, 1e-9)
980
+ tokens_per_sec = total_tokens / elapsed if total_tokens else 0.0
981
+ self._last_scorer_profile = {
982
+ "batch_size": batch_size,
983
+ "tokens_processed": total_tokens,
984
+ "elapsed_seconds": elapsed,
985
+ "tokens_per_second": tokens_per_sec,
986
+ }
987
+ return True
988
+ except Exception as exc: # pragma: no cover - defensive
989
+ warnings.warn(
990
+ f"Failed to compute GPT-2 difficulty scores: {exc}", stacklevel=2
991
+ )
992
+ self._difficulty_model = False
993
+ self._difficulty_device = None
994
+ self.__class__._MODEL_CACHE = False
995
+ self.__class__._MODEL_DEVICE = None
996
+ self._last_batch_size_used = 0
997
+ self._last_scorer_profile = None
998
+ return False
999
+
1000
+ def _tokenize_samples(
1001
+ self,
1002
+ texts: list[str],
1003
+ indices: list[int],
1004
+ tokenizer: Any,
1005
+ seq_len: int,
1006
+ window_name: str,
1007
+ ) -> EvaluationWindow:
1008
+ """Tokenize a set of text samples with consistent parameters."""
1009
+ collected = self._collect_tokenized_samples(texts, indices, tokenizer, seq_len)
1010
+
1011
+ input_ids_list = [entry[1] for entry in collected]
1012
+ attention_masks_list = [entry[2] for entry in collected]
1013
+ valid_indices = [entry[0] for entry in collected]
1014
+
1015
+ print(
1016
+ f" ✓ {window_name}: {len(valid_indices)}/{len(indices)} samples tokenized successfully"
1017
+ )
1018
+
1019
+ return EvaluationWindow(
1020
+ input_ids=input_ids_list,
1021
+ attention_masks=attention_masks_list,
1022
+ indices=valid_indices,
1023
+ )
1024
+
1025
+ @property
1026
+ def stratification_stats(self) -> dict[str, Any] | None:
1027
+ """Return summary statistics for the most recent stratified split."""
1028
+ return self._last_stratification_stats
1029
+
1030
+ @property
1031
+ def scorer_profile(self) -> dict[str, Any] | None:
1032
+ """Return performance statistics for the most recent scorer run."""
1033
+ return self._last_scorer_profile
1034
+
1035
+ def info(self) -> dict[str, Any]:
1036
+ """Get information about WikiText-2 provider."""
1037
+ return {
1038
+ "name": self.name,
1039
+ "type": "dataset_provider",
1040
+ "dataset": "wikitext-2-raw-v1",
1041
+ "source": "huggingface/datasets",
1042
+ "deterministic": True,
1043
+ "default_split": "validation",
1044
+ "requires": ["datasets"],
1045
+ }
1046
+
1047
+
1048
+ class SyntheticProvider:
1049
+ """
1050
+ Synthetic text provider for testing and development.
1051
+
1052
+ Generates coherent text samples when WikiText-2 is not available.
1053
+ """
1054
+
1055
+ name = "synthetic"
1056
+
1057
+ def __init__(self, base_samples: list[str] | None = None):
1058
+ """Initialize with optional base text samples."""
1059
+ self.base_samples = base_samples or self._default_samples()
1060
+
1061
+ def _default_samples(self) -> list[str]:
1062
+ """Generate default synthetic text samples."""
1063
+ return [
1064
+ "The weather today is quite pleasant with clear skies and gentle winds.",
1065
+ "Scientists have discovered a new species in the Amazon rainforest region.",
1066
+ "The stock market showed significant gains during this quarter's trading.",
1067
+ "Technology companies are investing heavily in artificial intelligence research.",
1068
+ "The new restaurant downtown serves excellent Mediterranean cuisine daily.",
1069
+ "Climate change continues to affect global weather patterns significantly.",
1070
+ "The university announced new programs in data science and engineering.",
1071
+ "Renewable energy sources are becoming more cost-effective than fossil fuels.",
1072
+ "The museum exhibition features artwork from the Renaissance period.",
1073
+ "Public transportation systems are being upgraded in major cities worldwide.",
1074
+ "Medical researchers published breakthrough findings about genetic therapy.",
1075
+ "The concert hall will host a performance by the symphony orchestra.",
1076
+ "Local farmers are adopting sustainable agricultural practices this season.",
1077
+ "The new software update includes enhanced security features and performance.",
1078
+ "International trade agreements are being renegotiated between countries.",
1079
+ ]
1080
+
1081
+ def estimate_capacity(
1082
+ self,
1083
+ tokenizer: Any,
1084
+ *,
1085
+ seq_len: int,
1086
+ stride: int,
1087
+ split: str = "validation",
1088
+ target_total: int | None = None,
1089
+ fast_mode: bool = False,
1090
+ ) -> dict[str, Any]:
1091
+ """Synthetic provider offers deterministic capacity based on base samples."""
1092
+ total_tokens = len(self.base_samples) * seq_len
1093
+ available = len(self.base_samples)
1094
+ return {
1095
+ "total_tokens": int(total_tokens),
1096
+ "available_nonoverlap": int(available),
1097
+ "available_unique": int(available),
1098
+ "dedupe_rate": 0.0,
1099
+ "stride": int(stride),
1100
+ "seq_len": int(seq_len),
1101
+ "candidate_unique": int(available),
1102
+ "candidate_limit": int(available),
1103
+ }
1104
+
1105
+ def load(
1106
+ self, split: str = "validation", max_samples: int = 500, **kwargs
1107
+ ) -> list[str]:
1108
+ """Generate synthetic text samples."""
1109
+ # Expand base samples to meet requirement
1110
+ expanded_samples: list[str] = []
1111
+ variations = [
1112
+ lambda s: s,
1113
+ lambda s: f"Recently, {s.lower()}",
1114
+ lambda s: f"According to reports, {s.lower()}",
1115
+ lambda s: f"It is notable that {s.lower()}",
1116
+ lambda s: f"Furthermore, {s.lower()}",
1117
+ lambda s: f"In addition, {s.lower()}",
1118
+ ]
1119
+
1120
+ # Use a deterministic approach based on max_samples
1121
+ rng = np.random.RandomState(42) # Fixed seed for reproducibility
1122
+
1123
+ while len(expanded_samples) < max_samples:
1124
+ for base_text in self.base_samples:
1125
+ if len(expanded_samples) >= max_samples:
1126
+ break
1127
+ variation = rng.choice(variations)
1128
+ expanded_samples.append(variation(base_text))
1129
+
1130
+ return expanded_samples[:max_samples]
1131
+
1132
+ def windows(
1133
+ self,
1134
+ tokenizer: Any,
1135
+ *,
1136
+ seq_len: int = 128,
1137
+ stride: int = 64,
1138
+ preview_n: int = 100,
1139
+ final_n: int = 100,
1140
+ seed: int = 42,
1141
+ split: str = "validation",
1142
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
1143
+ """Create synthetic evaluation windows."""
1144
+ texts = self.load(split=split, max_samples=preview_n + final_n)
1145
+
1146
+ # Deterministic split
1147
+ preview_texts = texts[:preview_n]
1148
+ final_texts = texts[preview_n : preview_n + final_n]
1149
+
1150
+ # Create windows (simplified tokenization)
1151
+ preview_window = self._simple_tokenize(
1152
+ preview_texts, tokenizer, seq_len, list(range(preview_n))
1153
+ )
1154
+ final_window = self._simple_tokenize(
1155
+ final_texts, tokenizer, seq_len, list(range(preview_n, preview_n + final_n))
1156
+ )
1157
+
1158
+ return preview_window, final_window
1159
+
1160
+ def _simple_tokenize(
1161
+ self, texts: list[str], tokenizer: Any, seq_len: int, indices: list[int]
1162
+ ) -> EvaluationWindow:
1163
+ """Simple tokenization for synthetic samples."""
1164
+ input_ids_list = []
1165
+ attention_masks_list = []
1166
+
1167
+ for text in texts:
1168
+ # Simple tokenization fallback
1169
+ if hasattr(tokenizer, "encode"):
1170
+ input_ids = tokenizer.encode(
1171
+ text, max_length=seq_len, truncation=True, padding="max_length"
1172
+ )
1173
+ attention_mask = (
1174
+ [
1175
+ 1 if token_id != tokenizer.pad_token_id else 0
1176
+ for token_id in input_ids
1177
+ ]
1178
+ if hasattr(tokenizer, "pad_token_id")
1179
+ else [1] * len(input_ids)
1180
+ )
1181
+ else:
1182
+ # Fallback for test scenarios
1183
+ input_ids = list(range(1, min(seq_len + 1, 50))) + [0] * max(
1184
+ 0, seq_len - 49
1185
+ )
1186
+ attention_mask = [1] * min(seq_len, 49) + [0] * max(0, seq_len - 49)
1187
+
1188
+ input_ids_list.append(input_ids)
1189
+ attention_masks_list.append(attention_mask)
1190
+
1191
+ return EvaluationWindow(
1192
+ input_ids=input_ids_list,
1193
+ attention_masks=attention_masks_list,
1194
+ indices=indices,
1195
+ )
1196
+
1197
+ def info(self) -> dict[str, Any]:
1198
+ """Get information about synthetic provider."""
1199
+ return {
1200
+ "name": self.name,
1201
+ "type": "dataset_provider",
1202
+ "dataset": "synthetic",
1203
+ "source": "generated",
1204
+ "deterministic": True,
1205
+ "base_samples": len(self.base_samples),
1206
+ }
1207
+
1208
+
1209
+ class HFTextProvider:
1210
+ """
1211
+ Generic HuggingFace datasets text provider.
1212
+
1213
+ Loads a text dataset by name/config and extracts a specified text field.
1214
+ Provides simple deterministic windowing suitable for CI/demo usage.
1215
+ """
1216
+
1217
+ name = "hf_text"
1218
+
1219
+ def __init__(
1220
+ self,
1221
+ dataset_name: str | None = None,
1222
+ config_name: str | None = None,
1223
+ text_field: str = "text",
1224
+ cache_dir: str | None = None,
1225
+ max_samples: int = 2000,
1226
+ ):
1227
+ if not HAS_DATASETS:
1228
+ if not _LIGHT_IMPORT:
1229
+ raise _DepErr(
1230
+ code="E301",
1231
+ message=(
1232
+ "DEPENDENCY-MISSING: datasets library required for hf_text provider"
1233
+ ),
1234
+ details={"dependency": "datasets"},
1235
+ )
1236
+ self.dataset_name = dataset_name or "wikitext"
1237
+ self.config_name = config_name or None
1238
+ self.text_field = text_field
1239
+ self.cache_dir = cache_dir
1240
+ self.max_samples = int(max_samples)
1241
+
1242
+ def load(self, split: str = "validation", **kwargs) -> list[str]:
1243
+ if not HAS_DATASETS and _LIGHT_IMPORT:
1244
+ return ["synthetic dataset text"] * int(self.max_samples or 1)
1245
+
1246
+ ds = load_dataset(
1247
+ path=self.dataset_name,
1248
+ name=self.config_name,
1249
+ split=split,
1250
+ cache_dir=self.cache_dir,
1251
+ )
1252
+ texts: list[str] = []
1253
+ # Limit to max_samples for CI friendliness
1254
+ count = 0
1255
+ for row in ds:
1256
+ if self.text_field not in row:
1257
+ continue
1258
+ val = row[self.text_field]
1259
+ if isinstance(val, str) and val.strip():
1260
+ texts.append(val)
1261
+ count += 1
1262
+ if count >= self.max_samples:
1263
+ break
1264
+ return texts
1265
+
1266
+ def _simple_tokenize(
1267
+ self, texts: list[str], tokenizer: Any, seq_len: int, indices: list[int]
1268
+ ) -> EvaluationWindow:
1269
+ input_ids_list: list[list[int]] = []
1270
+ attention_masks_list: list[list[int]] = []
1271
+ for text in texts:
1272
+ try:
1273
+ if hasattr(tokenizer, "encode"):
1274
+ input_ids = tokenizer.encode(
1275
+ text, truncation=True, max_length=seq_len
1276
+ )
1277
+ else:
1278
+ encoded = tokenizer(text, truncation=True, max_length=seq_len)
1279
+ input_ids = encoded["input_ids"]
1280
+ # Pad if needed
1281
+ pad_id = getattr(tokenizer, "pad_token_id", 0)
1282
+ input_ids = (input_ids + [pad_id] * (seq_len - len(input_ids)))[
1283
+ :seq_len
1284
+ ]
1285
+ attn = [1 if tid != pad_id else 0 for tid in input_ids]
1286
+ input_ids_list.append(input_ids)
1287
+ attention_masks_list.append(attn)
1288
+ except Exception:
1289
+ # Skip bad rows
1290
+ continue
1291
+ return EvaluationWindow(
1292
+ input_ids_list, attention_masks_list, indices[: len(input_ids_list)]
1293
+ )
1294
+
1295
+ def windows(
1296
+ self,
1297
+ tokenizer: Any,
1298
+ *,
1299
+ seq_len: int = 128,
1300
+ stride: int = 64,
1301
+ preview_n: int = 100,
1302
+ final_n: int = 100,
1303
+ seed: int = 42,
1304
+ split: str = "validation",
1305
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
1306
+ texts = self.load(split=split)
1307
+ total = len(texts)
1308
+ if total == 0:
1309
+ # Typed-only: no-samples is a DataError for consistency
1310
+ raise _DataErr(
1311
+ code="E306",
1312
+ message=(
1313
+ "NO-SAMPLES: hf_text produced no samples; check dataset_name/config_name/text_field"
1314
+ ),
1315
+ )
1316
+ # Deterministic selection: first N for preview, next N for final
1317
+ preview_texts = texts[:preview_n]
1318
+ final_texts = texts[preview_n : preview_n + final_n]
1319
+ preview_window = self._simple_tokenize(
1320
+ preview_texts, tokenizer, seq_len, list(range(preview_n))
1321
+ )
1322
+ final_window = self._simple_tokenize(
1323
+ final_texts, tokenizer, seq_len, list(range(preview_n, preview_n + final_n))
1324
+ )
1325
+ return preview_window, final_window
1326
+
1327
+ def estimate_capacity(
1328
+ self,
1329
+ tokenizer: Any,
1330
+ *,
1331
+ seq_len: int,
1332
+ stride: int,
1333
+ split: str = "validation",
1334
+ target_total: int | None = None,
1335
+ fast_mode: bool = False,
1336
+ ) -> dict[str, Any]:
1337
+ texts = self.load(split=split)
1338
+ return {
1339
+ "total_tokens": 0,
1340
+ "available_nonoverlap": len(texts),
1341
+ "available_unique": len(texts),
1342
+ "dedupe_rate": 0.0,
1343
+ "stride": stride,
1344
+ "seq_len": seq_len,
1345
+ "candidate_unique": len(texts),
1346
+ "candidate_limit": min(len(texts), self.max_samples),
1347
+ }
1348
+
1349
+
1350
+ class HFSeq2SeqProvider:
1351
+ """HuggingFace seq2seq provider with paired source/target fields.
1352
+
1353
+ Loads a dataset with text pairs and exposes encoder input_ids/attention_masks.
1354
+ Decoder target token ids are exposed via last_preview_labels / last_final_labels
1355
+ for the runner to attach as labels.
1356
+ """
1357
+
1358
+ name = "hf_seq2seq"
1359
+
1360
+ def __init__(
1361
+ self,
1362
+ dataset_name: str,
1363
+ config_name: str | None = None,
1364
+ src_field: str = "source",
1365
+ tgt_field: str = "target",
1366
+ cache_dir: str | None = None,
1367
+ max_samples: int = 2000,
1368
+ ) -> None:
1369
+ if not HAS_DATASETS:
1370
+ if not _LIGHT_IMPORT:
1371
+ raise _DepErr(
1372
+ code="E301",
1373
+ message=(
1374
+ "DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
1375
+ ),
1376
+ details={"dependency": "datasets"},
1377
+ )
1378
+ self.dataset_name = dataset_name
1379
+ self.config_name = config_name
1380
+ self.src_field = src_field
1381
+ self.tgt_field = tgt_field
1382
+ self.cache_dir = cache_dir
1383
+ self.max_samples = int(max_samples)
1384
+ self.last_preview_labels: list[list[int]] | None = None
1385
+ self.last_final_labels: list[list[int]] | None = None
1386
+
1387
+ def _load_pairs(self, split: str) -> list[tuple[str, str]]:
1388
+ ds = load_dataset(
1389
+ path=self.dataset_name,
1390
+ name=self.config_name,
1391
+ split=split,
1392
+ cache_dir=self.cache_dir,
1393
+ )
1394
+ out: list[tuple[str, str]] = []
1395
+ count = 0
1396
+ for row in ds:
1397
+ src = row.get(self.src_field)
1398
+ tgt = row.get(self.tgt_field)
1399
+ if (
1400
+ isinstance(src, str)
1401
+ and src.strip()
1402
+ and isinstance(tgt, str)
1403
+ and tgt.strip()
1404
+ ):
1405
+ out.append((src, tgt))
1406
+ count += 1
1407
+ if count >= self.max_samples:
1408
+ break
1409
+ return out
1410
+
1411
+ def windows(
1412
+ self,
1413
+ tokenizer: Any,
1414
+ *,
1415
+ seq_len: int = 128,
1416
+ stride: int = 64,
1417
+ preview_n: int = 100,
1418
+ final_n: int = 100,
1419
+ seed: int = 42,
1420
+ split: str = "validation",
1421
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
1422
+ pairs = self._load_pairs(split)
1423
+ if not pairs:
1424
+ raise _DataErr(
1425
+ code="E307",
1426
+ message=(
1427
+ "NO-PAIRS: hf_seq2seq produced no pairs; check src_field/tgt_field"
1428
+ ),
1429
+ )
1430
+ # Deterministic slicing
1431
+ prev_pairs = pairs[:preview_n]
1432
+ fin_pairs = pairs[preview_n : preview_n + final_n]
1433
+
1434
+ def _tok_src(src: str) -> list[int]:
1435
+ ids = (
1436
+ tokenizer.encode(src, truncation=True, max_length=seq_len)
1437
+ if hasattr(tokenizer, "encode")
1438
+ else tokenizer(src, truncation=True, max_length=seq_len)["input_ids"]
1439
+ )
1440
+ pad_id = getattr(tokenizer, "pad_token_id", 0)
1441
+ return (ids + [pad_id] * (seq_len - len(ids)))[:seq_len]
1442
+
1443
+ def _tok_tgt(tgt: str) -> list[int]:
1444
+ ids = (
1445
+ tokenizer.encode(tgt, truncation=True, max_length=seq_len)
1446
+ if hasattr(tokenizer, "encode")
1447
+ else tokenizer(tgt, truncation=True, max_length=seq_len)["input_ids"]
1448
+ )
1449
+ # Use -100 for ignored positions to align with HF loss expectations
1450
+ return (ids + [-100] * (seq_len - len(ids)))[:seq_len]
1451
+
1452
+ prev_ids = [_tok_src(s) for s, _ in prev_pairs]
1453
+ prev_masks = [
1454
+ [1 if t != getattr(tokenizer, "pad_token_id", 0) else 0 for t in seq]
1455
+ for seq in prev_ids
1456
+ ]
1457
+ fin_ids = [_tok_src(s) for s, _ in fin_pairs]
1458
+ fin_masks = [
1459
+ [1 if t != getattr(tokenizer, "pad_token_id", 0) else 0 for t in seq]
1460
+ for seq in fin_ids
1461
+ ]
1462
+
1463
+ # Prepare labels
1464
+ self.last_preview_labels = [_tok_tgt(t) for _, t in prev_pairs]
1465
+ self.last_final_labels = [_tok_tgt(t) for _, t in fin_pairs]
1466
+
1467
+ preview_window = EvaluationWindow(
1468
+ prev_ids, prev_masks, list(range(len(prev_ids)))
1469
+ )
1470
+ final_window = EvaluationWindow(
1471
+ fin_ids, fin_masks, list(range(preview_n, preview_n + len(fin_ids)))
1472
+ )
1473
+ return preview_window, final_window
1474
+
1475
+ def estimate_capacity(
1476
+ self,
1477
+ tokenizer: Any,
1478
+ *,
1479
+ seq_len: int,
1480
+ stride: int,
1481
+ split: str = "validation",
1482
+ target_total: int | None = None,
1483
+ fast_mode: bool = False,
1484
+ ) -> dict[str, Any]:
1485
+ pairs = self._load_pairs(split)
1486
+ n = len(pairs)
1487
+ return {
1488
+ "total_tokens": int(n * seq_len),
1489
+ "available_nonoverlap": n,
1490
+ "available_unique": n,
1491
+ "dedupe_rate": 0.0,
1492
+ "stride": stride,
1493
+ "seq_len": seq_len,
1494
+ "candidate_unique": n,
1495
+ "candidate_limit": n,
1496
+ "tokens_available": int(n * seq_len),
1497
+ "examples_available": n,
1498
+ }
1499
+
1500
+
1501
+ class LocalJSONLProvider:
1502
+ """
1503
+ Local JSONL provider for BYOD text datasets.
1504
+
1505
+ Accepts a single `file`, a `path` (file or directory), or `data_files`
1506
+ (glob or list of paths). Extracts a `text_field` (defaults to "text").
1507
+ """
1508
+
1509
+ name = "local_jsonl"
1510
+
1511
+ def __init__(
1512
+ self,
1513
+ file: str | None = None,
1514
+ path: str | None = None,
1515
+ data_files: str | list[str] | None = None,
1516
+ text_field: str = "text",
1517
+ max_samples: int = 2000,
1518
+ ) -> None:
1519
+ self.file = file
1520
+ self.path = path
1521
+ self.data_files = data_files
1522
+ self.text_field = text_field or "text"
1523
+ self.max_samples = int(max_samples)
1524
+
1525
+ def _resolve_files(self) -> list[Path]:
1526
+ files: list[Path] = []
1527
+ # Explicit file
1528
+ if isinstance(self.file, str) and self.file:
1529
+ p = Path(self.file)
1530
+ if p.exists() and p.is_file():
1531
+ files.append(p)
1532
+ # Path can be file or directory
1533
+ if isinstance(self.path, str) and self.path:
1534
+ p = Path(self.path)
1535
+ if p.is_file():
1536
+ files.append(p)
1537
+ elif p.is_dir():
1538
+ files.extend(sorted(p.glob("*.jsonl")))
1539
+ # data_files may be a glob or list
1540
+ if isinstance(self.data_files, str) and self.data_files:
1541
+ from glob import glob as _glob
1542
+
1543
+ files.extend(Path(p) for p in _glob(self.data_files))
1544
+ elif isinstance(self.data_files, list):
1545
+ for item in self.data_files:
1546
+ try:
1547
+ pp = Path(str(item))
1548
+ if pp.exists() and pp.is_file():
1549
+ files.append(pp)
1550
+ except Exception:
1551
+ continue
1552
+ # Deduplicate while preserving order
1553
+ seen: set[str] = set()
1554
+ uniq: list[Path] = []
1555
+ for f in files:
1556
+ fp = f.resolve().as_posix()
1557
+ if fp not in seen:
1558
+ seen.add(fp)
1559
+ uniq.append(f)
1560
+ return uniq
1561
+
1562
+ def load(self, split: str = "validation", **kwargs) -> list[str]:
1563
+ texts: list[str] = []
1564
+ count = 0
1565
+ for fp in self._resolve_files():
1566
+ try:
1567
+ with fp.open("r", encoding="utf-8") as handle:
1568
+ for line in handle:
1569
+ line = line.strip()
1570
+ if not line:
1571
+ continue
1572
+ try:
1573
+ obj = json.loads(line)
1574
+ except Exception:
1575
+ continue
1576
+ val = obj.get(self.text_field)
1577
+ if isinstance(val, str) and val.strip():
1578
+ texts.append(val)
1579
+ count += 1
1580
+ if count >= self.max_samples:
1581
+ return texts
1582
+ except Exception:
1583
+ continue
1584
+ return texts
1585
+
1586
+ def _simple_tokenize(
1587
+ self, texts: list[str], tokenizer: Any, seq_len: int, indices: list[int]
1588
+ ) -> EvaluationWindow:
1589
+ input_ids_list: list[list[int]] = []
1590
+ attention_masks_list: list[list[int]] = []
1591
+ for text in texts:
1592
+ try:
1593
+ if hasattr(tokenizer, "encode"):
1594
+ input_ids = tokenizer.encode(
1595
+ text, truncation=True, max_length=seq_len
1596
+ )
1597
+ else:
1598
+ encoded = tokenizer(text, truncation=True, max_length=seq_len)
1599
+ input_ids = encoded["input_ids"]
1600
+ pad_id = getattr(tokenizer, "pad_token_id", 0)
1601
+ input_ids = (input_ids + [pad_id] * (seq_len - len(input_ids)))[
1602
+ :seq_len
1603
+ ]
1604
+ attn = [1 if tid != pad_id else 0 for tid in input_ids]
1605
+ input_ids_list.append(input_ids)
1606
+ attention_masks_list.append(attn)
1607
+ except Exception:
1608
+ continue
1609
+ return EvaluationWindow(
1610
+ input_ids_list, attention_masks_list, indices[: len(input_ids_list)]
1611
+ )
1612
+
1613
+ def windows(
1614
+ self,
1615
+ tokenizer: Any,
1616
+ *,
1617
+ seq_len: int = 128,
1618
+ stride: int = 64,
1619
+ preview_n: int = 100,
1620
+ final_n: int = 100,
1621
+ seed: int = 42,
1622
+ split: str = "validation",
1623
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
1624
+ texts = self.load(split=split)
1625
+ if not texts:
1626
+ raise _DataErr(
1627
+ code="E306",
1628
+ message=(
1629
+ "NO-SAMPLES: local_jsonl produced no samples; check file/path/data_files"
1630
+ ),
1631
+ )
1632
+ preview_texts = texts[:preview_n]
1633
+ final_texts = texts[preview_n : preview_n + final_n]
1634
+ preview_window = self._simple_tokenize(
1635
+ preview_texts, tokenizer, seq_len, list(range(preview_n))
1636
+ )
1637
+ final_window = self._simple_tokenize(
1638
+ final_texts,
1639
+ tokenizer,
1640
+ seq_len,
1641
+ list(range(preview_n, preview_n + final_n)),
1642
+ )
1643
+ return preview_window, final_window
1644
+
1645
+ def estimate_capacity(
1646
+ self,
1647
+ tokenizer: Any,
1648
+ *,
1649
+ seq_len: int,
1650
+ stride: int,
1651
+ split: str = "validation",
1652
+ target_total: int | None = None,
1653
+ fast_mode: bool = False,
1654
+ ) -> dict[str, Any]:
1655
+ texts = self.load(split=split)
1656
+ return {
1657
+ "total_tokens": 0,
1658
+ "available_nonoverlap": len(texts),
1659
+ "available_unique": len(texts),
1660
+ "dedupe_rate": 0.0,
1661
+ "stride": stride,
1662
+ "seq_len": seq_len,
1663
+ "candidate_unique": len(texts),
1664
+ "candidate_limit": len(texts),
1665
+ }
1666
+
1667
+
1668
+ class LocalJSONLPairsProvider:
1669
+ """Local JSONL pairs provider with source/target fields.
1670
+
1671
+ Accepts a single `file`, a `path` (file or directory), or `data_files`
1672
+ (glob or list of paths). Extracts paired strings from `src_field`/`tgt_field`.
1673
+ """
1674
+
1675
+ name = "local_jsonl_pairs"
1676
+
1677
+ def __init__(
1678
+ self,
1679
+ file: str | None = None,
1680
+ path: str | None = None,
1681
+ data_files: str | list[str] | None = None,
1682
+ src_field: str = "source",
1683
+ tgt_field: str = "target",
1684
+ max_samples: int = 2000,
1685
+ ) -> None:
1686
+ self.file = file
1687
+ self.path = path
1688
+ self.data_files = data_files
1689
+ self.src_field = src_field or "source"
1690
+ self.tgt_field = tgt_field or "target"
1691
+ self.max_samples = int(max_samples)
1692
+ self.last_preview_labels: list[list[int]] | None = None
1693
+ self.last_final_labels: list[list[int]] | None = None
1694
+
1695
+ def _resolve_files(self) -> list[Path]:
1696
+ files: list[Path] = []
1697
+ if isinstance(self.file, str) and self.file:
1698
+ p = Path(self.file)
1699
+ if p.exists() and p.is_file():
1700
+ files.append(p)
1701
+ if isinstance(self.path, str) and self.path:
1702
+ p = Path(self.path)
1703
+ if p.is_file():
1704
+ files.append(p)
1705
+ elif p.is_dir():
1706
+ files.extend(sorted(p.glob("*.jsonl")))
1707
+ if isinstance(self.data_files, str) and self.data_files:
1708
+ from glob import glob as _glob
1709
+
1710
+ files.extend(Path(p) for p in _glob(self.data_files))
1711
+ elif isinstance(self.data_files, list):
1712
+ for item in self.data_files:
1713
+ try:
1714
+ pp = Path(str(item))
1715
+ if pp.exists() and pp.is_file():
1716
+ files.append(pp)
1717
+ except Exception:
1718
+ continue
1719
+ # Deduplicate
1720
+ seen: set[str] = set()
1721
+ uniq: list[Path] = []
1722
+ for f in files:
1723
+ fp = f.resolve().as_posix()
1724
+ if fp not in seen:
1725
+ seen.add(fp)
1726
+ uniq.append(f)
1727
+ return uniq
1728
+
1729
+ def _load_pairs(self) -> list[tuple[str, str]]:
1730
+ pairs: list[tuple[str, str]] = []
1731
+ count = 0
1732
+ for fp in self._resolve_files():
1733
+ try:
1734
+ with fp.open("r", encoding="utf-8") as handle:
1735
+ for line in handle:
1736
+ line = line.strip()
1737
+ if not line:
1738
+ continue
1739
+ try:
1740
+ obj = json.loads(line)
1741
+ except Exception:
1742
+ continue
1743
+ src = obj.get(self.src_field)
1744
+ tgt = obj.get(self.tgt_field)
1745
+ if (
1746
+ isinstance(src, str)
1747
+ and src.strip()
1748
+ and isinstance(tgt, str)
1749
+ and tgt.strip()
1750
+ ):
1751
+ pairs.append((src, tgt))
1752
+ count += 1
1753
+ if count >= self.max_samples:
1754
+ return pairs
1755
+ except Exception:
1756
+ continue
1757
+ return pairs
1758
+
1759
+ def windows(
1760
+ self,
1761
+ tokenizer: Any,
1762
+ *,
1763
+ seq_len: int = 128,
1764
+ stride: int = 64,
1765
+ preview_n: int = 100,
1766
+ final_n: int = 100,
1767
+ seed: int = 42,
1768
+ split: str = "validation",
1769
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
1770
+ pairs = self._load_pairs()
1771
+ if not pairs:
1772
+ raise ValueError(
1773
+ "local_jsonl_pairs produced no pairs; check src_field/tgt_field and files"
1774
+ )
1775
+ prev_pairs = pairs[:preview_n]
1776
+ fin_pairs = pairs[preview_n : preview_n + final_n]
1777
+
1778
+ pad_id = getattr(tokenizer, "pad_token_id", 0)
1779
+
1780
+ def _tok_src(src: str) -> list[int]:
1781
+ ids = (
1782
+ tokenizer.encode(src, truncation=True, max_length=seq_len)
1783
+ if hasattr(tokenizer, "encode")
1784
+ else tokenizer(src, truncation=True, max_length=seq_len)["input_ids"]
1785
+ )
1786
+ return (ids + [pad_id] * (seq_len - len(ids)))[:seq_len]
1787
+
1788
+ def _tok_tgt(tgt: str) -> list[int]:
1789
+ ids = (
1790
+ tokenizer.encode(tgt, truncation=True, max_length=seq_len)
1791
+ if hasattr(tokenizer, "encode")
1792
+ else tokenizer(tgt, truncation=True, max_length=seq_len)["input_ids"]
1793
+ )
1794
+ return (ids + [-100] * (seq_len - len(ids)))[:seq_len]
1795
+
1796
+ prev_ids = [_tok_src(s) for s, _ in prev_pairs]
1797
+ fin_ids = [_tok_src(s) for s, _ in fin_pairs]
1798
+ prev_masks = [[1 if t != pad_id else 0 for t in seq] for seq in prev_ids]
1799
+ fin_masks = [[1 if t != pad_id else 0 for t in seq] for seq in fin_ids]
1800
+ self.last_preview_labels = [_tok_tgt(t) for _, t in prev_pairs]
1801
+ self.last_final_labels = [_tok_tgt(t) for _, t in fin_pairs]
1802
+
1803
+ preview_window = EvaluationWindow(
1804
+ prev_ids, prev_masks, list(range(len(prev_ids)))
1805
+ )
1806
+ final_window = EvaluationWindow(
1807
+ fin_ids, fin_masks, list(range(preview_n, preview_n + len(fin_ids)))
1808
+ )
1809
+ return preview_window, final_window
1810
+
1811
+ def estimate_capacity(
1812
+ self,
1813
+ tokenizer: Any,
1814
+ *,
1815
+ seq_len: int,
1816
+ stride: int,
1817
+ split: str = "validation",
1818
+ target_total: int | None = None,
1819
+ fast_mode: bool = False,
1820
+ ) -> dict[str, Any]:
1821
+ pairs = self._load_pairs()
1822
+ n = len(pairs)
1823
+ return {
1824
+ "total_tokens": int(n * seq_len),
1825
+ "available_nonoverlap": n,
1826
+ "available_unique": n,
1827
+ "dedupe_rate": 0.0,
1828
+ "stride": stride,
1829
+ "seq_len": seq_len,
1830
+ "candidate_unique": n,
1831
+ "candidate_limit": n,
1832
+ "tokens_available": int(n * seq_len),
1833
+ "examples_available": n,
1834
+ }
1835
+
1836
+ # (text-only helpers removed; LocalJSONLProvider implements text tokenization)
1837
+
1838
+
1839
+ class Seq2SeqDataProvider:
1840
+ """Synthetic seq2seq provider wrapper to fit DatasetProvider interface.
1841
+
1842
+ Bridges invarlock.eval.providers.seq2seq.Seq2SeqProvider to the windowing
1843
+ protocol used by the CLI runner. Generates encoder input_ids from src_ids,
1844
+ attention_masks from src_mask, and allows the runner to derive labels.
1845
+ """
1846
+
1847
+ name = "seq2seq"
1848
+
1849
+ def __init__(self, **kwargs: Any) -> None:
1850
+ # Pass through kwargs to underlying provider (n, src_len, tgt_len, pad_id, bos_id, eos_id)
1851
+ from invarlock.eval.providers.seq2seq import Seq2SeqProvider as _S2S
1852
+
1853
+ self._inner = _S2S(**kwargs)
1854
+ self.last_preview_labels: list[list[int]] | None = None
1855
+ self.last_final_labels: list[list[int]] | None = None
1856
+
1857
+ def load(
1858
+ self, split: str = "validation", **kwargs
1859
+ ) -> list[str]: # pragma: no cover - not used
1860
+ return []
1861
+
1862
+ def windows(
1863
+ self,
1864
+ tokenizer: Any,
1865
+ *,
1866
+ seq_len: int = 128,
1867
+ stride: int = 64,
1868
+ preview_n: int = 100,
1869
+ final_n: int = 100,
1870
+ seed: int = 42,
1871
+ split: str = "validation",
1872
+ ) -> tuple[EvaluationWindow, EvaluationWindow]:
1873
+ # Generate exactly preview_n + final_n examples deterministically
1874
+ total = max(0, int(preview_n) + int(final_n))
1875
+ if total <= 0:
1876
+ total = 1
1877
+ # Build batches of size total
1878
+ # Ensure the inner generator produces at least `total` examples
1879
+ try:
1880
+ # Prefer reconfiguring 'n' if attribute present
1881
+ if getattr(self._inner, "_n", 0) < total:
1882
+ self._inner._n = int(total)
1883
+ except Exception:
1884
+ pass
1885
+ batches = list(self._inner.batches(seed=seed, batch_size=total))
1886
+ if not batches:
1887
+ raise ValueError("seq2seq provider produced no examples")
1888
+ batch = batches[0]
1889
+ # Extract source tokens/masks and target ids for labels
1890
+ src_ids_list = [list(x) for x in batch.get("src_ids", [])][:total]
1891
+ src_mask_list = [list(x) for x in batch.get("src_mask", [])][:total]
1892
+ tgt_ids_list = [list(x) for x in batch.get("tgt_ids", [])][:total]
1893
+ # Right-pad/truncate to seq_len for runner compatibility
1894
+ pad_id = getattr(tokenizer, "pad_token_id", 0)
1895
+
1896
+ def _pad(seq: list[int]) -> list[int]:
1897
+ if len(seq) < seq_len:
1898
+ return (seq + [pad_id] * (seq_len - len(seq)))[:seq_len]
1899
+ return seq[:seq_len]
1900
+
1901
+ input_ids = [_pad(s) for s in src_ids_list]
1902
+ attention_masks = []
1903
+ for i, s in enumerate(input_ids):
1904
+ # Prefer src_mask if lengths align; otherwise infer from pad_id
1905
+ if i < len(src_mask_list) and len(src_mask_list[i]) == len(src_ids_list[i]):
1906
+ # Adjust length to seq_len
1907
+ m = src_mask_list[i]
1908
+ if len(m) < seq_len:
1909
+ m = m + [0] * (seq_len - len(m))
1910
+ attention_masks.append([int(v) for v in m[:seq_len]])
1911
+ else:
1912
+ attention_masks.append([1 if t != pad_id else 0 for t in s])
1913
+
1914
+ # Split into preview/final windows
1915
+ prev_ids = input_ids[:preview_n]
1916
+ prev_mask = attention_masks[:preview_n]
1917
+ fin_ids = input_ids[preview_n : preview_n + final_n]
1918
+ fin_mask = attention_masks[preview_n : preview_n + final_n]
1919
+
1920
+ # Prepare label sequences (decoder targets) padded to seq_len
1921
+ def _pad_label(seq: list[int]) -> list[int]:
1922
+ if len(seq) < seq_len:
1923
+ return (seq + [-100] * (seq_len - len(seq)))[:seq_len]
1924
+ return seq[:seq_len]
1925
+
1926
+ prev_labels = [_pad_label(s) for s in tgt_ids_list[:preview_n]]
1927
+ fin_labels = [
1928
+ _pad_label(s) for s in tgt_ids_list[preview_n : preview_n + final_n]
1929
+ ]
1930
+ # Save for runner to attach
1931
+ self.last_preview_labels = prev_labels
1932
+ self.last_final_labels = fin_labels
1933
+
1934
+ preview_window = EvaluationWindow(prev_ids, prev_mask, list(range(preview_n)))
1935
+ final_window = EvaluationWindow(
1936
+ fin_ids, fin_mask, list(range(preview_n, preview_n + final_n))
1937
+ )
1938
+ return preview_window, final_window
1939
+
1940
+ def estimate_capacity(
1941
+ self,
1942
+ tokenizer: Any,
1943
+ *,
1944
+ seq_len: int,
1945
+ stride: int,
1946
+ split: str = "validation",
1947
+ target_total: int | None = None,
1948
+ fast_mode: bool = False,
1949
+ ) -> dict[str, Any]:
1950
+ # Deterministic bounded synthetic examples; assume large enough for CI/release smokes
1951
+ n = int(target_total or 800)
1952
+ return {
1953
+ "total_tokens": int(n * seq_len),
1954
+ "available_nonoverlap": n,
1955
+ "available_unique": n,
1956
+ "dedupe_rate": 0.0,
1957
+ "stride": stride,
1958
+ "seq_len": seq_len,
1959
+ "candidate_unique": n,
1960
+ "candidate_limit": n,
1961
+ "tokens_available": int(n * seq_len),
1962
+ "examples_available": n,
1963
+ }
1964
+
1965
+ def info(self) -> dict[str, Any]: # pragma: no cover - trivial
1966
+ return {"name": self.name, "type": "dataset_provider", "dataset": "seq2seq"}
1967
+
1968
+
1969
+ # Registry for dataset providers
1970
+ _PROVIDERS: dict[str, type] = {
1971
+ "wikitext2": WikiText2Provider,
1972
+ "synthetic": SyntheticProvider,
1973
+ "hf_text": HFTextProvider,
1974
+ "local_jsonl": LocalJSONLProvider,
1975
+ "seq2seq": Seq2SeqDataProvider,
1976
+ "hf_seq2seq": HFSeq2SeqProvider,
1977
+ "local_jsonl_pairs": LocalJSONLPairsProvider,
1978
+ }
1979
+
1980
+
1981
+ def get_provider(name: str, **kwargs) -> DatasetProvider:
1982
+ """
1983
+ Get a dataset provider by name.
1984
+
1985
+ Args:
1986
+ name: Provider name ("wikitext2", "synthetic")
1987
+ **kwargs: Provider-specific initialization parameters
1988
+
1989
+ Returns:
1990
+ Initialized dataset provider
1991
+
1992
+ Raises:
1993
+ ValidationError(E308): If provider name is not registered
1994
+ """
1995
+ if name not in _PROVIDERS:
1996
+ available = ", ".join(_PROVIDERS.keys())
1997
+ # Typed-only error for provider lookup
1998
+ raise _ValErr(
1999
+ code="E308",
2000
+ message="PROVIDER-NOT-FOUND: unknown dataset provider",
2001
+ details={"provider": name, "available": available},
2002
+ )
2003
+
2004
+ provider_class = _PROVIDERS[name]
2005
+ return provider_class(**kwargs)
2006
+
2007
+
2008
+ def list_providers() -> list[str]:
2009
+ """List available dataset provider names."""
2010
+ return list(_PROVIDERS.keys())
2011
+
2012
+
2013
+ def compute_window_hash(window: EvaluationWindow, include_data: bool = False) -> str:
2014
+ """
2015
+ Compute a deterministic hash of an evaluation window.
2016
+
2017
+ Args:
2018
+ window: EvaluationWindow to hash
2019
+ include_data: Whether to include actual token data in hash
2020
+
2021
+ Returns:
2022
+ Hex digest string of the window hash
2023
+ """
2024
+ hasher = hashlib.sha256()
2025
+
2026
+ # Always include structural information
2027
+ hasher.update(str(len(window)).encode())
2028
+ hasher.update(str(sorted(window.indices)).encode())
2029
+
2030
+ if include_data:
2031
+ # Include actual token sequences for data integrity checking
2032
+ for input_ids, attention_mask in zip(
2033
+ window.input_ids, window.attention_masks, strict=False
2034
+ ):
2035
+ hasher.update(str(input_ids).encode())
2036
+ hasher.update(str(attention_mask).encode())
2037
+
2038
+ return hasher.hexdigest()
2039
+
2040
+
2041
+ # Export public API
2042
+ __all__ = [
2043
+ "DatasetProvider",
2044
+ "EvaluationWindow",
2045
+ "WikiText2Provider",
2046
+ "SyntheticProvider",
2047
+ "HFTextProvider",
2048
+ "LocalJSONLProvider",
2049
+ "get_provider",
2050
+ "list_providers",
2051
+ "compute_window_hash",
2052
+ ]