invarlock 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. invarlock/__init__.py +2 -2
  2. invarlock/_data/runtime/tiers.yaml +57 -30
  3. invarlock/adapters/__init__.py +11 -15
  4. invarlock/adapters/auto.py +35 -40
  5. invarlock/adapters/capabilities.py +2 -2
  6. invarlock/adapters/hf_causal.py +418 -0
  7. invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
  8. invarlock/adapters/hf_mixin.py +25 -4
  9. invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
  10. invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
  11. invarlock/calibration/spectral_null.py +15 -10
  12. invarlock/calibration/variance_ve.py +0 -2
  13. invarlock/cli/adapter_auto.py +31 -21
  14. invarlock/cli/app.py +73 -2
  15. invarlock/cli/commands/calibrate.py +6 -2
  16. invarlock/cli/commands/certify.py +651 -91
  17. invarlock/cli/commands/doctor.py +11 -11
  18. invarlock/cli/commands/explain_gates.py +57 -8
  19. invarlock/cli/commands/plugins.py +13 -9
  20. invarlock/cli/commands/report.py +233 -69
  21. invarlock/cli/commands/run.py +1066 -244
  22. invarlock/cli/commands/verify.py +154 -15
  23. invarlock/cli/config.py +22 -6
  24. invarlock/cli/doctor_helpers.py +4 -5
  25. invarlock/cli/output.py +193 -0
  26. invarlock/cli/provenance.py +1 -1
  27. invarlock/core/api.py +45 -5
  28. invarlock/core/auto_tuning.py +65 -20
  29. invarlock/core/bootstrap.py +1 -1
  30. invarlock/core/contracts.py +7 -1
  31. invarlock/core/registry.py +11 -13
  32. invarlock/core/runner.py +425 -75
  33. invarlock/edits/quant_rtn.py +65 -37
  34. invarlock/eval/bench.py +3 -16
  35. invarlock/eval/data.py +82 -51
  36. invarlock/eval/metrics.py +63 -2
  37. invarlock/eval/primary_metric.py +23 -0
  38. invarlock/eval/tail_stats.py +230 -0
  39. invarlock/eval/tasks/__init__.py +12 -0
  40. invarlock/eval/tasks/classification.py +48 -0
  41. invarlock/eval/tasks/qa.py +36 -0
  42. invarlock/eval/tasks/text_generation.py +102 -0
  43. invarlock/guards/_estimators.py +154 -0
  44. invarlock/guards/invariants.py +19 -10
  45. invarlock/guards/policies.py +16 -6
  46. invarlock/guards/rmt.py +627 -546
  47. invarlock/guards/spectral.py +348 -110
  48. invarlock/guards/tier_config.py +32 -30
  49. invarlock/guards/variance.py +7 -31
  50. invarlock/guards_ref/rmt_ref.py +23 -23
  51. invarlock/model_profile.py +90 -42
  52. invarlock/observability/health.py +6 -6
  53. invarlock/observability/metrics.py +108 -0
  54. invarlock/reporting/certificate.py +384 -55
  55. invarlock/reporting/certificate_schema.py +3 -2
  56. invarlock/reporting/dataset_hashing.py +15 -2
  57. invarlock/reporting/guards_analysis.py +350 -277
  58. invarlock/reporting/html.py +55 -5
  59. invarlock/reporting/normalizer.py +13 -0
  60. invarlock/reporting/policy_utils.py +38 -36
  61. invarlock/reporting/primary_metric_utils.py +71 -17
  62. invarlock/reporting/render.py +852 -431
  63. invarlock/reporting/report.py +40 -4
  64. invarlock/reporting/report_types.py +11 -3
  65. invarlock/reporting/telemetry.py +86 -0
  66. invarlock/reporting/validate.py +1 -18
  67. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/METADATA +27 -13
  68. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/RECORD +72 -65
  69. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
  70. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
  71. invarlock/adapters/hf_gpt2.py +0 -404
  72. invarlock/adapters/hf_llama.py +0 -487
  73. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
  74. {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
@@ -86,6 +86,31 @@ class RTNQuantEdit(ModelEdit):
86
86
 
87
87
  # group_size is currently reserved for potential future variants; it is
88
88
  # ignored for the built-in INT8 demo edit.
89
+ self._emit_enabled = True
90
+ self._emit_console = None
91
+ self._output_style = None
92
+
93
+ def _configure_output(self, **kwargs: Any) -> None:
94
+ emit = kwargs.get("emit", True)
95
+ self._emit_enabled = bool(emit)
96
+ console = kwargs.get("console")
97
+ if console is not None and hasattr(console, "print"):
98
+ self._emit_console = console
99
+ else:
100
+ self._emit_console = None
101
+ self._output_style = kwargs.get("output_style")
102
+
103
+ def _emit(self, message: str) -> None:
104
+ if not self._emit_enabled:
105
+ return
106
+ line = f"[EDIT] {message}".rstrip()
107
+ if self._emit_console is not None:
108
+ try:
109
+ self._emit_console.print(line, markup=False)
110
+ except TypeError:
111
+ self._emit_console.print(line)
112
+ else:
113
+ print(line)
89
114
 
90
115
  def can_edit(self, model_desc: dict[str, Any]) -> bool:
91
116
  """Check if RTN quantization can be applied to this model."""
@@ -233,15 +258,18 @@ class RTNQuantEdit(ModelEdit):
233
258
  scope = kwargs.get("scope", self.scope)
234
259
  seed = kwargs.get("seed", self.seed)
235
260
 
261
+ self._configure_output(**kwargs)
262
+
236
263
  # Diagnostic reporting
237
- print("🔧 RTN Quantization Configuration:")
238
- print(
239
- f" Bitwidth: {bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
264
+ self._emit("RTN Quantization Configuration:")
265
+ self._emit(
266
+ "Bitwidth: "
267
+ f"{bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
240
268
  )
241
- print(f" Scope: {scope}")
242
- print(f" Group size: {group_size}")
243
- print(f" Clamp ratio: {clamp_ratio}")
244
- print(f" Seed: {seed}")
269
+ self._emit(f"Scope: {scope}")
270
+ self._emit(f"Group size: {group_size}")
271
+ self._emit(f"Clamp ratio: {clamp_ratio}")
272
+ self._emit(f"Seed: {seed}")
245
273
 
246
274
  # Persist configuration overrides for downstream helpers
247
275
  self.bitwidth = bitwidth
@@ -256,22 +284,22 @@ class RTNQuantEdit(ModelEdit):
256
284
  np.random.seed(seed)
257
285
 
258
286
  # Identify target modules and get weight tying map
259
- print(f"🎯 Identifying target modules for scope '{scope}'...")
287
+ self._emit(f"Identifying target modules for scope '{scope}'...")
260
288
  target_modules = self._identify_target_modules(model)
261
289
  total_identified = len(target_modules)
262
290
 
263
291
  max_modules = kwargs.get("max_modules")
264
292
  if isinstance(max_modules, int) and max_modules > 0:
265
293
  if max_modules < total_identified:
266
- print(
267
- f" Limiting quantization to first {max_modules} modules "
294
+ self._emit(
295
+ f"Limiting quantization to first {max_modules} modules "
268
296
  f"(of {total_identified}) based on plan.max_modules"
269
297
  )
270
298
  target_modules = target_modules[:max_modules]
271
299
  self.max_modules = max_modules
272
300
  else:
273
- print(
274
- f" max_modules={max_modules} >= available modules "
301
+ self._emit(
302
+ f"max_modules={max_modules} >= available modules "
275
303
  f"({total_identified}); using all targets"
276
304
  )
277
305
  self.max_modules = None
@@ -280,33 +308,35 @@ class RTNQuantEdit(ModelEdit):
280
308
 
281
309
  tying_map = self._get_weight_tying_map(model)
282
310
 
283
- print(f" Found {len(target_modules)} target modules:")
311
+ self._emit(f"Found {len(target_modules)} target modules:")
284
312
  for i, (name, module) in enumerate(target_modules):
285
313
  weight_shape = module.weight.shape
286
314
  param_count = module.weight.numel()
287
- print(f" [{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
315
+ self._emit(f"[{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
288
316
 
289
317
  if len(target_modules) == 0:
290
- print("❌ WARNING: No target modules found! Check scope configuration.")
291
- print(" Available linear modules:")
318
+ self._emit(
319
+ "WARNING: No target modules found! Check scope configuration."
320
+ )
321
+ self._emit("Available linear modules:")
292
322
  linear_modules = []
293
323
  for name, module in model.named_modules():
294
324
  if isinstance(module, nn.Linear | nn.Conv1d):
295
325
  linear_modules.append((name, module.weight.shape))
296
326
  for name, shape in linear_modules[:10]: # Show first 10
297
- print(f" {name}: {shape}")
327
+ self._emit(f"{name}: {shape}")
298
328
  if len(linear_modules) > 10:
299
- print(f" ... and {len(linear_modules) - 10} more")
329
+ self._emit(f"... and {len(linear_modules) - 10} more")
300
330
 
301
331
  # Execute GuardChain before edit (if provided)
302
332
  guard_results = {}
303
333
  if self.guard_chain is not None:
304
- print(" Executing guard chain preparation...")
334
+ self._emit("Executing guard chain preparation...")
305
335
  guard_results["prepare"] = self.guard_chain.prepare_all(
306
336
  model, adapter, None, {}
307
337
  )
308
338
 
309
- print(" Executing before-edit guards...")
339
+ self._emit("Executing before-edit guards...")
310
340
  self.guard_chain.before_edit_all(model)
311
341
 
312
342
  # Apply quantization to each target module
@@ -314,12 +344,12 @@ class RTNQuantEdit(ModelEdit):
314
344
  total_params_quantized = 0
315
345
 
316
346
  for i, (module_name, module) in enumerate(target_modules):
317
- print(f" [{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
318
- print(
319
- f" Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
347
+ self._emit(f"[{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
348
+ self._emit(
349
+ f"Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
320
350
  )
321
- print(
322
- f" Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
351
+ self._emit(
352
+ f"Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
323
353
  )
324
354
 
325
355
  # Apply RTN quantization
@@ -335,24 +365,22 @@ class RTNQuantEdit(ModelEdit):
335
365
  quantization_results.append(quant_result)
336
366
  total_params_quantized += quant_result["params_quantized"]
337
367
 
338
- print(
339
- f" ✓ Quantized {quant_result['params_quantized']:,} parameters"
340
- )
368
+ self._emit(f"Quantized {quant_result['params_quantized']:,} parameters")
341
369
 
342
370
  # Execute GuardChain after edit (if provided)
343
371
  if self.guard_chain is not None:
344
- print(" Executing after-edit guards...")
372
+ self._emit("Executing after-edit guards...")
345
373
  self.guard_chain.after_edit_all(model)
346
374
 
347
- print(" Finalizing guard chain...")
375
+ self._emit("Finalizing guard chain...")
348
376
  guard_results["finalize"] = self.guard_chain.finalize_all(model)
349
377
 
350
378
  # Check if all guards passed
351
379
  if not self.guard_chain.all_passed(guard_results["finalize"]):
352
- print(" ⚠️ Guard chain validation failed!")
380
+ self._emit("Guard chain validation failed!")
353
381
  guard_results["all_passed"] = False
354
382
  else:
355
- print("All guards passed")
383
+ self._emit("All guards passed")
356
384
  guard_results["all_passed"] = True
357
385
 
358
386
  # Create bitwidth map
@@ -490,11 +518,11 @@ class RTNQuantEdit(ModelEdit):
490
518
 
491
519
  # Log diagnostic information
492
520
  if skipped_modules:
493
- print(f" Skipped {len(skipped_modules)} modules:")
521
+ self._emit(f"Skipped {len(skipped_modules)} modules:")
494
522
  for name, reason in skipped_modules[:5]: # Show first 5
495
- print(f" {name}: {reason}")
523
+ self._emit(f"{name}: {reason}")
496
524
  if len(skipped_modules) > 5:
497
- print(f" ... and {len(skipped_modules) - 5} more")
525
+ self._emit(f"... and {len(skipped_modules) - 5} more")
498
526
 
499
527
  return target_modules
500
528
 
@@ -625,7 +653,7 @@ class RTNQuantEdit(ModelEdit):
625
653
  # Ensure actual quantization occurred by applying quantization loss
626
654
  # This guarantees the weights are actually modified
627
655
  quantization_error = (quantized_weight - original_weight).abs().mean()
628
- print(f" Quantization error: {quantization_error:.6f}")
656
+ self._emit(f"Quantization error: {quantization_error:.6f}")
629
657
 
630
658
  # Write back to module (preserving tying if needed)
631
659
  module.weight.data.copy_(quantized_weight)
@@ -634,7 +662,7 @@ class RTNQuantEdit(ModelEdit):
634
662
  final_weight = module.weight.data
635
663
  actual_change = not torch.allclose(original_weight, final_weight, atol=1e-6)
636
664
  if not actual_change:
637
- print(f" WARNING: No actual weight change detected for {module}")
665
+ self._emit(f"WARNING: No actual weight change detected for {module}")
638
666
 
639
667
  # Handle tied weights
640
668
  if tied_modules:
invarlock/eval/bench.py CHANGED
@@ -47,7 +47,7 @@ class ScenarioConfig:
47
47
  probes: int
48
48
  profile: str = "ci" # "ci" or "release"
49
49
  model_id: str = "gpt2"
50
- adapter: str = "hf_gpt2"
50
+ adapter: str = "hf_causal"
51
51
  device: str = "auto"
52
52
  seq_len: int = 512
53
53
  stride: int = 128
@@ -81,7 +81,7 @@ class BenchmarkConfig:
81
81
  profile: str = "ci" # "ci" or "release"
82
82
  dataset: str = "wikitext2"
83
83
  model_id: str = "gpt2"
84
- adapter: str = "hf_gpt2"
84
+ adapter: str = "hf_causal"
85
85
  device: str = "auto"
86
86
  seq_len: int = 512
87
87
  stride: int = 128
@@ -92,7 +92,6 @@ class BenchmarkConfig:
92
92
  epsilon: float | None = (
93
93
  None # RMT deadband tolerance (None = use resolved deadband)
94
94
  )
95
- strict: bool = False # If True, sets epsilon = 0
96
95
  ppl_overhead_threshold: float = 0.01 # 1%
97
96
  guard_overhead_time_threshold: float = 0.15 # 15%
98
97
  guard_overhead_mem_threshold: float = 0.10 # 10%
@@ -104,10 +103,6 @@ class BenchmarkConfig:
104
103
  """Apply post-initialization logic."""
105
104
  self.output_dir = Path(self.output_dir)
106
105
 
107
- # Handle strict mode
108
- if self.strict:
109
- self.epsilon = 0.0
110
-
111
106
 
112
107
  @dataclass
113
108
  class ScenarioResult:
@@ -1043,7 +1038,6 @@ def run_guard_effect_benchmark(
1043
1038
  profile: str = "ci",
1044
1039
  output_dir: str | Path = "benchmarks",
1045
1040
  epsilon: float | None = None,
1046
- strict: bool = False,
1047
1041
  **kwargs,
1048
1042
  ) -> dict[str, Any]:
1049
1043
  """
@@ -1056,7 +1050,6 @@ def run_guard_effect_benchmark(
1056
1050
  profile: "ci" (50/50 windows) or "release" (100/100 windows)
1057
1051
  output_dir: Directory to save results
1058
1052
  epsilon: Optional epsilon override
1059
- strict: If True, sets epsilon = 0
1060
1053
  **kwargs: Additional configuration options
1061
1054
 
1062
1055
  Returns:
@@ -1075,7 +1068,6 @@ def run_guard_effect_benchmark(
1075
1068
  profile=profile,
1076
1069
  output_dir=Path(output_dir),
1077
1070
  epsilon=epsilon,
1078
- strict=strict,
1079
1071
  **kwargs,
1080
1072
  )
1081
1073
 
@@ -1384,7 +1376,6 @@ def _config_to_dict(config: BenchmarkConfig) -> dict[str, Any]:
1384
1376
  "stride": config.stride,
1385
1377
  "seed": config.seed,
1386
1378
  "epsilon": config.epsilon,
1387
- "strict": config.strict,
1388
1379
  "ppl_overhead_threshold": config.ppl_overhead_threshold,
1389
1380
  "guard_overhead_time_threshold": config.guard_overhead_time_threshold,
1390
1381
  "guard_overhead_mem_threshold": config.guard_overhead_mem_threshold,
@@ -1426,16 +1417,13 @@ def main():
1426
1417
  type=float,
1427
1418
  help="RMT outliers epsilon threshold (default: use resolved RMT deadband)",
1428
1419
  )
1429
- parser.add_argument(
1430
- "--strict", action="store_true", help="Set epsilon=0 (overrides --epsilon)"
1431
- )
1432
1420
 
1433
1421
  # Model and dataset configuration
1434
1422
  parser.add_argument(
1435
1423
  "--dataset", default="wikitext2", help="Dataset to use for benchmarking"
1436
1424
  )
1437
1425
  parser.add_argument("--model-id", default="gpt2", help="Model identifier")
1438
- parser.add_argument("--adapter", default="hf_gpt2", help="Model adapter to use")
1426
+ parser.add_argument("--adapter", default="hf_causal", help="Model adapter to use")
1439
1427
  parser.add_argument(
1440
1428
  "--device", default="auto", help="Device to use (auto|cuda|mps|cpu)"
1441
1429
  )
@@ -1505,7 +1493,6 @@ def main():
1505
1493
  profile=args.profile,
1506
1494
  output_dir=args.out,
1507
1495
  epsilon=args.epsilon,
1508
- strict=args.strict,
1509
1496
  **kwargs,
1510
1497
  )
1511
1498
 
invarlock/eval/data.py CHANGED
@@ -15,7 +15,7 @@ import time
15
15
  import warnings
16
16
  from abc import abstractmethod
17
17
  from collections import Counter
18
- from collections.abc import Sequence
18
+ from collections.abc import Callable, Sequence
19
19
  from pathlib import Path
20
20
  from typing import Any, NamedTuple, Protocol
21
21
 
@@ -56,6 +56,9 @@ except ImportError:
56
56
  HAS_TORCH = False
57
57
 
58
58
 
59
+ EventEmitter = Callable[[str, str, str | None], None]
60
+
61
+
59
62
  class EvaluationWindow(NamedTuple):
60
63
  """A window of tokenized samples for evaluation."""
61
64
 
@@ -166,6 +169,7 @@ class WikiText2Provider:
166
169
  self,
167
170
  cache_dir: Path | None = None,
168
171
  device_hint: str | None = None,
172
+ emit: EventEmitter | None = None,
169
173
  **_: Any,
170
174
  ):
171
175
  """
@@ -175,6 +179,7 @@ class WikiText2Provider:
175
179
  cache_dir: Optional cache directory for dataset storage
176
180
  """
177
181
  self.cache_dir = cache_dir
182
+ self._emit_event = emit
178
183
  self._validate_dependencies()
179
184
  self._last_stratification_stats: dict[str, Any] | None = None
180
185
  self._last_batch_size_used: int = 0
@@ -186,11 +191,23 @@ class WikiText2Provider:
186
191
  normalized_hint = (device_hint or "").strip().lower()
187
192
  self._device_hint: str | None = normalized_hint or None
188
193
 
194
+ def _event(self, tag: str, message: str, *, emoji: str | None = None) -> None:
195
+ """Emit a dataset event via an optional CLI-provided sink."""
196
+ if self._emit_event is None:
197
+ if emoji:
198
+ print(f"{emoji} {message}")
199
+ else:
200
+ print(message)
201
+ return
202
+ try:
203
+ self._emit_event(tag, message, emoji)
204
+ except TypeError:
205
+ # Back-compat: tolerate sinks that only accept (tag, message).
206
+ self._emit_event(tag, message) # type: ignore[misc]
207
+
189
208
  def _validate_dependencies(self) -> None:
190
209
  """Check that required dependencies are available."""
191
210
  if not HAS_DATASETS:
192
- if _LIGHT_IMPORT:
193
- return
194
211
  raise _DepErr(
195
212
  code="E301",
196
213
  message=(
@@ -321,20 +338,17 @@ class WikiText2Provider:
321
338
  Returns:
322
339
  List of filtered text strings
323
340
  """
324
- print(f"📚 Loading WikiText-2 {split} split...")
341
+ self._event(
342
+ "DATA",
343
+ f"WikiText-2 {split}: loading split...",
344
+ emoji="📚",
345
+ )
325
346
 
326
347
  # Serve from cache when possible (load the largest slice once)
327
348
  cached = self._texts_cache.get(split)
328
349
  if cached is not None and len(cached) >= max_samples:
329
350
  return cached[:max_samples]
330
351
 
331
- if not HAS_DATASETS and _LIGHT_IMPORT:
332
- texts = ["hello world", "invarlock synthetic text"] * max(
333
- 1, max_samples // 2
334
- )
335
- self._texts_cache[split] = texts
336
- return texts[:max_samples]
337
-
338
352
  # Load dataset with size limit for efficiency
339
353
  dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
340
354
  dataset = load_dataset(
@@ -375,7 +389,10 @@ class WikiText2Provider:
375
389
  if prev is None or len(valid_texts) > len(prev):
376
390
  self._texts_cache[split] = list(valid_texts)
377
391
 
378
- print(f" ✓ Loaded {len(valid_texts)} valid samples from {len(dataset)} total")
392
+ self._event(
393
+ "DATA",
394
+ f"Loaded {len(valid_texts)}/{len(dataset)} valid samples",
395
+ )
379
396
  return valid_texts
380
397
 
381
398
  def windows(
@@ -444,9 +461,13 @@ class WikiText2Provider:
444
461
  cursor = 0
445
462
  chunk_size = max(64, min(256, target_pool))
446
463
 
447
- print(" 📊 Creating evaluation windows:")
448
- print(f" Requested preview/final: {preview_n}/{final_n}")
449
- print(f" Sampling pool target: {target_pool} (reserve {reserve})")
464
+ self._event(
465
+ "DATA",
466
+ "Creating evaluation windows:",
467
+ emoji="📊",
468
+ )
469
+ self._event("DATA", f"Requested preview/final: {preview_n}/{final_n}")
470
+ self._event("DATA", f"Sampling pool target: {target_pool} (reserve {reserve})")
450
471
 
451
472
  while len(candidates) < total_required + reserve and cursor < len(
452
473
  shuffled_indices
@@ -717,9 +738,9 @@ class WikiText2Provider:
717
738
  ),
718
739
  }
719
740
 
720
- print(f" Seed: {seed}, Seq length: {seq_len}")
721
- print(f" Preview: {len(preview_window)} samples")
722
- print(f" Final: {len(final_window)} samples")
741
+ self._event("DATA", f"Seed: {seed}, Seq length: {seq_len}")
742
+ self._event("DATA", f"Preview: {len(preview_window)} samples")
743
+ self._event("DATA", f"Final: {len(final_window)} samples")
723
744
 
724
745
  return preview_window, final_window
725
746
 
@@ -849,8 +870,9 @@ class WikiText2Provider:
849
870
  attention_masks_list = [entry[2] for entry in collected]
850
871
  valid_indices = [entry[0] for entry in collected]
851
872
 
852
- print(
853
- f" ✓ {window_name}: {len(valid_indices)}/{len(indices)} samples tokenized successfully"
873
+ self._event(
874
+ "DATA",
875
+ f"{window_name}: {len(valid_indices)}/{len(indices)} samples tokenized",
854
876
  )
855
877
 
856
878
  return EvaluationWindow(
@@ -943,7 +965,8 @@ class SyntheticProvider:
943
965
  self, split: str = "validation", max_samples: int = 500, **kwargs
944
966
  ) -> list[str]:
945
967
  """Generate synthetic text samples."""
946
- # Expand base samples to meet requirement
968
+ # Expand base samples to meet requirement, preferring unique variations
969
+ # to avoid duplicate-token windows (important for stratified pairing).
947
970
  expanded_samples: list[str] = []
948
971
  variations = [
949
972
  lambda s: s,
@@ -953,18 +976,25 @@ class SyntheticProvider:
953
976
  lambda s: f"Furthermore, {s.lower()}",
954
977
  lambda s: f"In addition, {s.lower()}",
955
978
  ]
956
-
957
- # Use a deterministic approach based on max_samples
958
- rng = np.random.RandomState(42) # Fixed seed for reproducibility
959
-
960
- while len(expanded_samples) < max_samples:
979
+ # Deterministic coverage of (variation × base sample) combinations first.
980
+ for variation in variations:
961
981
  for base_text in self.base_samples:
962
- if len(expanded_samples) >= max_samples:
963
- break
964
- variation = rng.choice(variations)
965
982
  expanded_samples.append(variation(base_text))
983
+ if len(expanded_samples) >= max_samples:
984
+ return expanded_samples
985
+
986
+ # If callers request more than the unique combination space, keep
987
+ # extending deterministically while ensuring uniqueness via a suffix.
988
+ idx = 0
989
+ while len(expanded_samples) < max_samples:
990
+ base_text = self.base_samples[idx % len(self.base_samples)]
991
+ variation = variations[(idx // len(self.base_samples)) % len(variations)]
992
+ expanded_samples.append(
993
+ f"{variation(base_text)} [synthetic #{len(expanded_samples)}]"
994
+ )
995
+ idx += 1
966
996
 
967
- return expanded_samples[:max_samples]
997
+ return expanded_samples
968
998
 
969
999
  def windows(
970
1000
  self,
@@ -1062,14 +1092,13 @@ class HFTextProvider:
1062
1092
  max_samples: int = 2000,
1063
1093
  ):
1064
1094
  if not HAS_DATASETS:
1065
- if not _LIGHT_IMPORT:
1066
- raise _DepErr(
1067
- code="E301",
1068
- message=(
1069
- "DEPENDENCY-MISSING: datasets library required for hf_text provider"
1070
- ),
1071
- details={"dependency": "datasets"},
1072
- )
1095
+ raise _DepErr(
1096
+ code="E301",
1097
+ message=(
1098
+ "DEPENDENCY-MISSING: datasets library required for hf_text provider"
1099
+ ),
1100
+ details={"dependency": "datasets"},
1101
+ )
1073
1102
  self.dataset_name = dataset_name or "wikitext"
1074
1103
  self.config_name = config_name or None
1075
1104
  self.text_field = text_field
@@ -1077,9 +1106,6 @@ class HFTextProvider:
1077
1106
  self.max_samples = int(max_samples)
1078
1107
 
1079
1108
  def load(self, split: str = "validation", **kwargs) -> list[str]:
1080
- if not HAS_DATASETS and _LIGHT_IMPORT:
1081
- return ["synthetic dataset text"] * int(self.max_samples or 1)
1082
-
1083
1109
  ds = load_dataset(
1084
1110
  path=self.dataset_name,
1085
1111
  name=self.config_name,
@@ -1204,14 +1230,13 @@ class HFSeq2SeqProvider:
1204
1230
  max_samples: int = 2000,
1205
1231
  ) -> None:
1206
1232
  if not HAS_DATASETS:
1207
- if not _LIGHT_IMPORT:
1208
- raise _DepErr(
1209
- code="E301",
1210
- message=(
1211
- "DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
1212
- ),
1213
- details={"dependency": "datasets"},
1214
- )
1233
+ raise _DepErr(
1234
+ code="E301",
1235
+ message=(
1236
+ "DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
1237
+ ),
1238
+ details={"dependency": "datasets"},
1239
+ )
1215
1240
  self.dataset_name = dataset_name
1216
1241
  self.config_name = config_name
1217
1242
  self.src_field = src_field
@@ -1815,12 +1840,15 @@ _PROVIDERS: dict[str, type] = {
1815
1840
  }
1816
1841
 
1817
1842
 
1818
- def get_provider(name: str, **kwargs) -> DatasetProvider:
1843
+ def get_provider(
1844
+ name: str, *, emit: EventEmitter | None = None, **kwargs: Any
1845
+ ) -> DatasetProvider:
1819
1846
  """
1820
1847
  Get a dataset provider by name.
1821
1848
 
1822
1849
  Args:
1823
1850
  name: Provider name ("wikitext2", "synthetic")
1851
+ emit: Optional event sink for dataset/provider logs.
1824
1852
  **kwargs: Provider-specific initialization parameters
1825
1853
 
1826
1854
  Returns:
@@ -1839,7 +1867,10 @@ def get_provider(name: str, **kwargs) -> DatasetProvider:
1839
1867
  )
1840
1868
 
1841
1869
  provider_class = _PROVIDERS[name]
1842
- return provider_class(**kwargs)
1870
+ init_kwargs = dict(kwargs)
1871
+ if emit is not None and name == "wikitext2":
1872
+ init_kwargs["emit"] = emit
1873
+ return provider_class(**init_kwargs)
1843
1874
 
1844
1875
 
1845
1876
  def list_providers() -> list[str]:
invarlock/eval/metrics.py CHANGED
@@ -18,9 +18,10 @@ import gc
18
18
  import logging
19
19
  import math
20
20
  import time
21
+ from collections.abc import Iterable
21
22
  from dataclasses import dataclass
22
23
  from pathlib import Path
23
- from typing import Any
24
+ from typing import Any, Protocol
24
25
 
25
26
  import numpy as np
26
27
  import psutil
@@ -723,7 +724,10 @@ def calculate_lens_metrics_for_model(
723
724
  except Exception as e:
724
725
  logger.error(f"Metrics calculation failed: {e}")
725
726
  if config.strict_validation:
726
- raise MetricsError(f"Metrics calculation failed: {e}") from e
727
+ raise MetricsError(
728
+ code="E401",
729
+ message=f"METRICS-COMPUTE-FAILED: {e}",
730
+ ) from e
727
731
 
728
732
  finally:
729
733
  resource_manager.cleanup()
@@ -2266,6 +2270,57 @@ def analyze_rmt_changes(
2266
2270
  return {"error": str(e)}
2267
2271
 
2268
2272
 
2273
+ class Metric(Protocol):
2274
+ name: str
2275
+ kind: str # "ppl", "accuracy", "exact_match", "bleu", "rouge"
2276
+
2277
+ def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: ...
2278
+
2279
+
2280
+ class PerplexityMetric:
2281
+ """Lightweight perplexity metric from per-record logloss + token counts."""
2282
+
2283
+ name = "perplexity"
2284
+ kind = "ppl"
2285
+
2286
+ def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
2287
+ total_loss = 0.0
2288
+ total_tokens = 0.0
2289
+ for record in dataset:
2290
+ if not isinstance(record, dict):
2291
+ continue
2292
+ loss = record.get("logloss", record.get("loss"))
2293
+ tokens = record.get("token_count", record.get("tokens", 1))
2294
+ try:
2295
+ loss_val = float(loss)
2296
+ tok_val = float(tokens)
2297
+ except Exception:
2298
+ continue
2299
+ if (
2300
+ not math.isfinite(loss_val)
2301
+ or not math.isfinite(tok_val)
2302
+ or tok_val <= 0
2303
+ ):
2304
+ continue
2305
+ total_loss += loss_val * tok_val
2306
+ total_tokens += tok_val
2307
+ if total_tokens <= 0:
2308
+ return float("nan")
2309
+ return float(math.exp(total_loss / total_tokens))
2310
+
2311
+
2312
+ class AccuracyMetric:
2313
+ """Classification accuracy metric from label/prediction records."""
2314
+
2315
+ name = "accuracy"
2316
+ kind = "accuracy"
2317
+
2318
+ def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
2319
+ from invarlock.eval.tasks.classification import accuracy_from_records
2320
+
2321
+ return accuracy_from_records(dataset)
2322
+
2323
+
2269
2324
  # ── Integration with existing system ───────────────────────────────────────
2270
2325
 
2271
2326
  # Update exports to include new functions (add to existing __all__ if it exists)
@@ -2279,6 +2334,9 @@ try:
2279
2334
  "compute_parameter_deltas",
2280
2335
  "analyze_spectral_changes",
2281
2336
  "analyze_rmt_changes",
2337
+ "Metric",
2338
+ "PerplexityMetric",
2339
+ "AccuracyMetric",
2282
2340
  ]
2283
2341
  )
2284
2342
  except NameError:
@@ -2291,4 +2349,7 @@ except NameError:
2291
2349
  "compute_parameter_deltas",
2292
2350
  "analyze_spectral_changes",
2293
2351
  "analyze_rmt_changes",
2352
+ "Metric",
2353
+ "PerplexityMetric",
2354
+ "AccuracyMetric",
2294
2355
  ]