invarlock 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. invarlock/__init__.py +2 -2
  2. invarlock/adapters/__init__.py +10 -14
  3. invarlock/adapters/auto.py +35 -40
  4. invarlock/adapters/capabilities.py +2 -2
  5. invarlock/adapters/hf_causal.py +418 -0
  6. invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
  7. invarlock/adapters/hf_mixin.py +25 -4
  8. invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
  9. invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
  10. invarlock/cli/adapter_auto.py +31 -21
  11. invarlock/cli/app.py +73 -2
  12. invarlock/cli/commands/certify.py +600 -59
  13. invarlock/cli/commands/doctor.py +8 -10
  14. invarlock/cli/commands/plugins.py +13 -9
  15. invarlock/cli/commands/report.py +233 -69
  16. invarlock/cli/commands/run.py +907 -183
  17. invarlock/cli/commands/verify.py +76 -11
  18. invarlock/cli/config.py +1 -1
  19. invarlock/cli/doctor_helpers.py +4 -5
  20. invarlock/cli/output.py +193 -0
  21. invarlock/cli/provenance.py +1 -1
  22. invarlock/core/bootstrap.py +1 -1
  23. invarlock/core/registry.py +9 -11
  24. invarlock/core/runner.py +111 -25
  25. invarlock/edits/quant_rtn.py +65 -37
  26. invarlock/eval/bench.py +3 -3
  27. invarlock/eval/data.py +68 -23
  28. invarlock/eval/metrics.py +59 -1
  29. invarlock/eval/tasks/__init__.py +12 -0
  30. invarlock/eval/tasks/classification.py +48 -0
  31. invarlock/eval/tasks/qa.py +36 -0
  32. invarlock/eval/tasks/text_generation.py +102 -0
  33. invarlock/guards/invariants.py +19 -10
  34. invarlock/guards/rmt.py +2 -2
  35. invarlock/guards/variance.py +2 -2
  36. invarlock/model_profile.py +48 -27
  37. invarlock/observability/health.py +6 -6
  38. invarlock/observability/metrics.py +108 -0
  39. invarlock/reporting/certificate.py +159 -9
  40. invarlock/reporting/certificate_schema.py +1 -1
  41. invarlock/reporting/guards_analysis.py +154 -4
  42. invarlock/reporting/html.py +55 -5
  43. invarlock/reporting/normalizer.py +7 -0
  44. invarlock/reporting/render.py +791 -431
  45. invarlock/reporting/report.py +39 -3
  46. invarlock/reporting/report_types.py +6 -1
  47. invarlock/reporting/telemetry.py +86 -0
  48. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/METADATA +23 -9
  49. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/RECORD +53 -48
  50. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
  51. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
  52. invarlock/adapters/hf_gpt2.py +0 -404
  53. invarlock/adapters/hf_llama.py +0 -487
  54. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
  55. {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
@@ -86,6 +86,31 @@ class RTNQuantEdit(ModelEdit):
86
86
 
87
87
  # group_size is currently reserved for potential future variants; it is
88
88
  # ignored for the built-in INT8 demo edit.
89
+ self._emit_enabled = True
90
+ self._emit_console = None
91
+ self._output_style = None
92
+
93
+ def _configure_output(self, **kwargs: Any) -> None:
94
+ emit = kwargs.get("emit", True)
95
+ self._emit_enabled = bool(emit)
96
+ console = kwargs.get("console")
97
+ if console is not None and hasattr(console, "print"):
98
+ self._emit_console = console
99
+ else:
100
+ self._emit_console = None
101
+ self._output_style = kwargs.get("output_style")
102
+
103
+ def _emit(self, message: str) -> None:
104
+ if not self._emit_enabled:
105
+ return
106
+ line = f"[EDIT] {message}".rstrip()
107
+ if self._emit_console is not None:
108
+ try:
109
+ self._emit_console.print(line, markup=False)
110
+ except TypeError:
111
+ self._emit_console.print(line)
112
+ else:
113
+ print(line)
89
114
 
90
115
  def can_edit(self, model_desc: dict[str, Any]) -> bool:
91
116
  """Check if RTN quantization can be applied to this model."""
@@ -233,15 +258,18 @@ class RTNQuantEdit(ModelEdit):
233
258
  scope = kwargs.get("scope", self.scope)
234
259
  seed = kwargs.get("seed", self.seed)
235
260
 
261
+ self._configure_output(**kwargs)
262
+
236
263
  # Diagnostic reporting
237
- print("🔧 RTN Quantization Configuration:")
238
- print(
239
- f" Bitwidth: {bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
264
+ self._emit("RTN Quantization Configuration:")
265
+ self._emit(
266
+ "Bitwidth: "
267
+ f"{bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
240
268
  )
241
- print(f" Scope: {scope}")
242
- print(f" Group size: {group_size}")
243
- print(f" Clamp ratio: {clamp_ratio}")
244
- print(f" Seed: {seed}")
269
+ self._emit(f"Scope: {scope}")
270
+ self._emit(f"Group size: {group_size}")
271
+ self._emit(f"Clamp ratio: {clamp_ratio}")
272
+ self._emit(f"Seed: {seed}")
245
273
 
246
274
  # Persist configuration overrides for downstream helpers
247
275
  self.bitwidth = bitwidth
@@ -256,22 +284,22 @@ class RTNQuantEdit(ModelEdit):
256
284
  np.random.seed(seed)
257
285
 
258
286
  # Identify target modules and get weight tying map
259
- print(f"🎯 Identifying target modules for scope '{scope}'...")
287
+ self._emit(f"Identifying target modules for scope '{scope}'...")
260
288
  target_modules = self._identify_target_modules(model)
261
289
  total_identified = len(target_modules)
262
290
 
263
291
  max_modules = kwargs.get("max_modules")
264
292
  if isinstance(max_modules, int) and max_modules > 0:
265
293
  if max_modules < total_identified:
266
- print(
267
- f" Limiting quantization to first {max_modules} modules "
294
+ self._emit(
295
+ f"Limiting quantization to first {max_modules} modules "
268
296
  f"(of {total_identified}) based on plan.max_modules"
269
297
  )
270
298
  target_modules = target_modules[:max_modules]
271
299
  self.max_modules = max_modules
272
300
  else:
273
- print(
274
- f" max_modules={max_modules} >= available modules "
301
+ self._emit(
302
+ f"max_modules={max_modules} >= available modules "
275
303
  f"({total_identified}); using all targets"
276
304
  )
277
305
  self.max_modules = None
@@ -280,33 +308,35 @@ class RTNQuantEdit(ModelEdit):
280
308
 
281
309
  tying_map = self._get_weight_tying_map(model)
282
310
 
283
- print(f" Found {len(target_modules)} target modules:")
311
+ self._emit(f"Found {len(target_modules)} target modules:")
284
312
  for i, (name, module) in enumerate(target_modules):
285
313
  weight_shape = module.weight.shape
286
314
  param_count = module.weight.numel()
287
- print(f" [{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
315
+ self._emit(f"[{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
288
316
 
289
317
  if len(target_modules) == 0:
290
- print("❌ WARNING: No target modules found! Check scope configuration.")
291
- print(" Available linear modules:")
318
+ self._emit(
319
+ "WARNING: No target modules found! Check scope configuration."
320
+ )
321
+ self._emit("Available linear modules:")
292
322
  linear_modules = []
293
323
  for name, module in model.named_modules():
294
324
  if isinstance(module, nn.Linear | nn.Conv1d):
295
325
  linear_modules.append((name, module.weight.shape))
296
326
  for name, shape in linear_modules[:10]: # Show first 10
297
- print(f" {name}: {shape}")
327
+ self._emit(f"{name}: {shape}")
298
328
  if len(linear_modules) > 10:
299
- print(f" ... and {len(linear_modules) - 10} more")
329
+ self._emit(f"... and {len(linear_modules) - 10} more")
300
330
 
301
331
  # Execute GuardChain before edit (if provided)
302
332
  guard_results = {}
303
333
  if self.guard_chain is not None:
304
- print(" Executing guard chain preparation...")
334
+ self._emit("Executing guard chain preparation...")
305
335
  guard_results["prepare"] = self.guard_chain.prepare_all(
306
336
  model, adapter, None, {}
307
337
  )
308
338
 
309
- print(" Executing before-edit guards...")
339
+ self._emit("Executing before-edit guards...")
310
340
  self.guard_chain.before_edit_all(model)
311
341
 
312
342
  # Apply quantization to each target module
@@ -314,12 +344,12 @@ class RTNQuantEdit(ModelEdit):
314
344
  total_params_quantized = 0
315
345
 
316
346
  for i, (module_name, module) in enumerate(target_modules):
317
- print(f" [{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
318
- print(
319
- f" Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
347
+ self._emit(f"[{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
348
+ self._emit(
349
+ f"Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
320
350
  )
321
- print(
322
- f" Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
351
+ self._emit(
352
+ f"Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
323
353
  )
324
354
 
325
355
  # Apply RTN quantization
@@ -335,24 +365,22 @@ class RTNQuantEdit(ModelEdit):
335
365
  quantization_results.append(quant_result)
336
366
  total_params_quantized += quant_result["params_quantized"]
337
367
 
338
- print(
339
- f" ✓ Quantized {quant_result['params_quantized']:,} parameters"
340
- )
368
+ self._emit(f"Quantized {quant_result['params_quantized']:,} parameters")
341
369
 
342
370
  # Execute GuardChain after edit (if provided)
343
371
  if self.guard_chain is not None:
344
- print(" Executing after-edit guards...")
372
+ self._emit("Executing after-edit guards...")
345
373
  self.guard_chain.after_edit_all(model)
346
374
 
347
- print(" Finalizing guard chain...")
375
+ self._emit("Finalizing guard chain...")
348
376
  guard_results["finalize"] = self.guard_chain.finalize_all(model)
349
377
 
350
378
  # Check if all guards passed
351
379
  if not self.guard_chain.all_passed(guard_results["finalize"]):
352
- print(" ⚠️ Guard chain validation failed!")
380
+ self._emit("Guard chain validation failed!")
353
381
  guard_results["all_passed"] = False
354
382
  else:
355
- print("All guards passed")
383
+ self._emit("All guards passed")
356
384
  guard_results["all_passed"] = True
357
385
 
358
386
  # Create bitwidth map
@@ -490,11 +518,11 @@ class RTNQuantEdit(ModelEdit):
490
518
 
491
519
  # Log diagnostic information
492
520
  if skipped_modules:
493
- print(f" Skipped {len(skipped_modules)} modules:")
521
+ self._emit(f"Skipped {len(skipped_modules)} modules:")
494
522
  for name, reason in skipped_modules[:5]: # Show first 5
495
- print(f" {name}: {reason}")
523
+ self._emit(f"{name}: {reason}")
496
524
  if len(skipped_modules) > 5:
497
- print(f" ... and {len(skipped_modules) - 5} more")
525
+ self._emit(f"... and {len(skipped_modules) - 5} more")
498
526
 
499
527
  return target_modules
500
528
 
@@ -625,7 +653,7 @@ class RTNQuantEdit(ModelEdit):
625
653
  # Ensure actual quantization occurred by applying quantization loss
626
654
  # This guarantees the weights are actually modified
627
655
  quantization_error = (quantized_weight - original_weight).abs().mean()
628
- print(f" Quantization error: {quantization_error:.6f}")
656
+ self._emit(f"Quantization error: {quantization_error:.6f}")
629
657
 
630
658
  # Write back to module (preserving tying if needed)
631
659
  module.weight.data.copy_(quantized_weight)
@@ -634,7 +662,7 @@ class RTNQuantEdit(ModelEdit):
634
662
  final_weight = module.weight.data
635
663
  actual_change = not torch.allclose(original_weight, final_weight, atol=1e-6)
636
664
  if not actual_change:
637
- print(f" WARNING: No actual weight change detected for {module}")
665
+ self._emit(f"WARNING: No actual weight change detected for {module}")
638
666
 
639
667
  # Handle tied weights
640
668
  if tied_modules:
invarlock/eval/bench.py CHANGED
@@ -47,7 +47,7 @@ class ScenarioConfig:
47
47
  probes: int
48
48
  profile: str = "ci" # "ci" or "release"
49
49
  model_id: str = "gpt2"
50
- adapter: str = "hf_gpt2"
50
+ adapter: str = "hf_causal"
51
51
  device: str = "auto"
52
52
  seq_len: int = 512
53
53
  stride: int = 128
@@ -81,7 +81,7 @@ class BenchmarkConfig:
81
81
  profile: str = "ci" # "ci" or "release"
82
82
  dataset: str = "wikitext2"
83
83
  model_id: str = "gpt2"
84
- adapter: str = "hf_gpt2"
84
+ adapter: str = "hf_causal"
85
85
  device: str = "auto"
86
86
  seq_len: int = 512
87
87
  stride: int = 128
@@ -1423,7 +1423,7 @@ def main():
1423
1423
  "--dataset", default="wikitext2", help="Dataset to use for benchmarking"
1424
1424
  )
1425
1425
  parser.add_argument("--model-id", default="gpt2", help="Model identifier")
1426
- parser.add_argument("--adapter", default="hf_gpt2", help="Model adapter to use")
1426
+ parser.add_argument("--adapter", default="hf_causal", help="Model adapter to use")
1427
1427
  parser.add_argument(
1428
1428
  "--device", default="auto", help="Device to use (auto|cuda|mps|cpu)"
1429
1429
  )
invarlock/eval/data.py CHANGED
@@ -15,7 +15,7 @@ import time
15
15
  import warnings
16
16
  from abc import abstractmethod
17
17
  from collections import Counter
18
- from collections.abc import Sequence
18
+ from collections.abc import Callable, Sequence
19
19
  from pathlib import Path
20
20
  from typing import Any, NamedTuple, Protocol
21
21
 
@@ -56,6 +56,9 @@ except ImportError:
56
56
  HAS_TORCH = False
57
57
 
58
58
 
59
+ EventEmitter = Callable[[str, str, str | None], None]
60
+
61
+
59
62
  class EvaluationWindow(NamedTuple):
60
63
  """A window of tokenized samples for evaluation."""
61
64
 
@@ -166,6 +169,7 @@ class WikiText2Provider:
166
169
  self,
167
170
  cache_dir: Path | None = None,
168
171
  device_hint: str | None = None,
172
+ emit: EventEmitter | None = None,
169
173
  **_: Any,
170
174
  ):
171
175
  """
@@ -175,6 +179,7 @@ class WikiText2Provider:
175
179
  cache_dir: Optional cache directory for dataset storage
176
180
  """
177
181
  self.cache_dir = cache_dir
182
+ self._emit_event = emit
178
183
  self._validate_dependencies()
179
184
  self._last_stratification_stats: dict[str, Any] | None = None
180
185
  self._last_batch_size_used: int = 0
@@ -186,6 +191,20 @@ class WikiText2Provider:
186
191
  normalized_hint = (device_hint or "").strip().lower()
187
192
  self._device_hint: str | None = normalized_hint or None
188
193
 
194
+ def _event(self, tag: str, message: str, *, emoji: str | None = None) -> None:
195
+ """Emit a dataset event via an optional CLI-provided sink."""
196
+ if self._emit_event is None:
197
+ if emoji:
198
+ print(f"{emoji} {message}")
199
+ else:
200
+ print(message)
201
+ return
202
+ try:
203
+ self._emit_event(tag, message, emoji)
204
+ except TypeError:
205
+ # Back-compat: tolerate sinks that only accept (tag, message).
206
+ self._emit_event(tag, message) # type: ignore[misc]
207
+
189
208
  def _validate_dependencies(self) -> None:
190
209
  """Check that required dependencies are available."""
191
210
  if not HAS_DATASETS:
@@ -319,7 +338,11 @@ class WikiText2Provider:
319
338
  Returns:
320
339
  List of filtered text strings
321
340
  """
322
- print(f"📚 Loading WikiText-2 {split} split...")
341
+ self._event(
342
+ "DATA",
343
+ f"WikiText-2 {split}: loading split...",
344
+ emoji="📚",
345
+ )
323
346
 
324
347
  # Serve from cache when possible (load the largest slice once)
325
348
  cached = self._texts_cache.get(split)
@@ -366,7 +389,10 @@ class WikiText2Provider:
366
389
  if prev is None or len(valid_texts) > len(prev):
367
390
  self._texts_cache[split] = list(valid_texts)
368
391
 
369
- print(f" ✓ Loaded {len(valid_texts)} valid samples from {len(dataset)} total")
392
+ self._event(
393
+ "DATA",
394
+ f"Loaded {len(valid_texts)}/{len(dataset)} valid samples",
395
+ )
370
396
  return valid_texts
371
397
 
372
398
  def windows(
@@ -435,9 +461,13 @@ class WikiText2Provider:
435
461
  cursor = 0
436
462
  chunk_size = max(64, min(256, target_pool))
437
463
 
438
- print(" 📊 Creating evaluation windows:")
439
- print(f" Requested preview/final: {preview_n}/{final_n}")
440
- print(f" Sampling pool target: {target_pool} (reserve {reserve})")
464
+ self._event(
465
+ "DATA",
466
+ "Creating evaluation windows:",
467
+ emoji="📊",
468
+ )
469
+ self._event("DATA", f"Requested preview/final: {preview_n}/{final_n}")
470
+ self._event("DATA", f"Sampling pool target: {target_pool} (reserve {reserve})")
441
471
 
442
472
  while len(candidates) < total_required + reserve and cursor < len(
443
473
  shuffled_indices
@@ -708,9 +738,9 @@ class WikiText2Provider:
708
738
  ),
709
739
  }
710
740
 
711
- print(f" Seed: {seed}, Seq length: {seq_len}")
712
- print(f" Preview: {len(preview_window)} samples")
713
- print(f" Final: {len(final_window)} samples")
741
+ self._event("DATA", f"Seed: {seed}, Seq length: {seq_len}")
742
+ self._event("DATA", f"Preview: {len(preview_window)} samples")
743
+ self._event("DATA", f"Final: {len(final_window)} samples")
714
744
 
715
745
  return preview_window, final_window
716
746
 
@@ -840,8 +870,9 @@ class WikiText2Provider:
840
870
  attention_masks_list = [entry[2] for entry in collected]
841
871
  valid_indices = [entry[0] for entry in collected]
842
872
 
843
- print(
844
- f" ✓ {window_name}: {len(valid_indices)}/{len(indices)} samples tokenized successfully"
873
+ self._event(
874
+ "DATA",
875
+ f"{window_name}: {len(valid_indices)}/{len(indices)} samples tokenized",
845
876
  )
846
877
 
847
878
  return EvaluationWindow(
@@ -934,7 +965,8 @@ class SyntheticProvider:
934
965
  self, split: str = "validation", max_samples: int = 500, **kwargs
935
966
  ) -> list[str]:
936
967
  """Generate synthetic text samples."""
937
- # Expand base samples to meet requirement
968
+ # Expand base samples to meet requirement, preferring unique variations
969
+ # to avoid duplicate-token windows (important for stratified pairing).
938
970
  expanded_samples: list[str] = []
939
971
  variations = [
940
972
  lambda s: s,
@@ -944,18 +976,25 @@ class SyntheticProvider:
944
976
  lambda s: f"Furthermore, {s.lower()}",
945
977
  lambda s: f"In addition, {s.lower()}",
946
978
  ]
947
-
948
- # Use a deterministic approach based on max_samples
949
- rng = np.random.RandomState(42) # Fixed seed for reproducibility
950
-
951
- while len(expanded_samples) < max_samples:
979
+ # Deterministic coverage of (variation × base sample) combinations first.
980
+ for variation in variations:
952
981
  for base_text in self.base_samples:
953
- if len(expanded_samples) >= max_samples:
954
- break
955
- variation = rng.choice(variations)
956
982
  expanded_samples.append(variation(base_text))
983
+ if len(expanded_samples) >= max_samples:
984
+ return expanded_samples
985
+
986
+ # If callers request more than the unique combination space, keep
987
+ # extending deterministically while ensuring uniqueness via a suffix.
988
+ idx = 0
989
+ while len(expanded_samples) < max_samples:
990
+ base_text = self.base_samples[idx % len(self.base_samples)]
991
+ variation = variations[(idx // len(self.base_samples)) % len(variations)]
992
+ expanded_samples.append(
993
+ f"{variation(base_text)} [synthetic #{len(expanded_samples)}]"
994
+ )
995
+ idx += 1
957
996
 
958
- return expanded_samples[:max_samples]
997
+ return expanded_samples
959
998
 
960
999
  def windows(
961
1000
  self,
@@ -1801,12 +1840,15 @@ _PROVIDERS: dict[str, type] = {
1801
1840
  }
1802
1841
 
1803
1842
 
1804
- def get_provider(name: str, **kwargs) -> DatasetProvider:
1843
+ def get_provider(
1844
+ name: str, *, emit: EventEmitter | None = None, **kwargs: Any
1845
+ ) -> DatasetProvider:
1805
1846
  """
1806
1847
  Get a dataset provider by name.
1807
1848
 
1808
1849
  Args:
1809
1850
  name: Provider name ("wikitext2", "synthetic")
1851
+ emit: Optional event sink for dataset/provider logs.
1810
1852
  **kwargs: Provider-specific initialization parameters
1811
1853
 
1812
1854
  Returns:
@@ -1825,7 +1867,10 @@ def get_provider(name: str, **kwargs) -> DatasetProvider:
1825
1867
  )
1826
1868
 
1827
1869
  provider_class = _PROVIDERS[name]
1828
- return provider_class(**kwargs)
1870
+ init_kwargs = dict(kwargs)
1871
+ if emit is not None and name == "wikitext2":
1872
+ init_kwargs["emit"] = emit
1873
+ return provider_class(**init_kwargs)
1829
1874
 
1830
1875
 
1831
1876
  def list_providers() -> list[str]:
invarlock/eval/metrics.py CHANGED
@@ -18,9 +18,10 @@ import gc
18
18
  import logging
19
19
  import math
20
20
  import time
21
+ from collections.abc import Iterable
21
22
  from dataclasses import dataclass
22
23
  from pathlib import Path
23
- from typing import Any
24
+ from typing import Any, Protocol
24
25
 
25
26
  import numpy as np
26
27
  import psutil
@@ -2269,6 +2270,57 @@ def analyze_rmt_changes(
2269
2270
  return {"error": str(e)}
2270
2271
 
2271
2272
 
2273
+ class Metric(Protocol):
2274
+ name: str
2275
+ kind: str # "ppl", "accuracy", "exact_match", "bleu", "rouge"
2276
+
2277
+ def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: ...
2278
+
2279
+
2280
+ class PerplexityMetric:
2281
+ """Lightweight perplexity metric from per-record logloss + token counts."""
2282
+
2283
+ name = "perplexity"
2284
+ kind = "ppl"
2285
+
2286
+ def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
2287
+ total_loss = 0.0
2288
+ total_tokens = 0.0
2289
+ for record in dataset:
2290
+ if not isinstance(record, dict):
2291
+ continue
2292
+ loss = record.get("logloss", record.get("loss"))
2293
+ tokens = record.get("token_count", record.get("tokens", 1))
2294
+ try:
2295
+ loss_val = float(loss)
2296
+ tok_val = float(tokens)
2297
+ except Exception:
2298
+ continue
2299
+ if (
2300
+ not math.isfinite(loss_val)
2301
+ or not math.isfinite(tok_val)
2302
+ or tok_val <= 0
2303
+ ):
2304
+ continue
2305
+ total_loss += loss_val * tok_val
2306
+ total_tokens += tok_val
2307
+ if total_tokens <= 0:
2308
+ return float("nan")
2309
+ return float(math.exp(total_loss / total_tokens))
2310
+
2311
+
2312
+ class AccuracyMetric:
2313
+ """Classification accuracy metric from label/prediction records."""
2314
+
2315
+ name = "accuracy"
2316
+ kind = "accuracy"
2317
+
2318
+ def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
2319
+ from invarlock.eval.tasks.classification import accuracy_from_records
2320
+
2321
+ return accuracy_from_records(dataset)
2322
+
2323
+
2272
2324
  # ── Integration with existing system ───────────────────────────────────────
2273
2325
 
2274
2326
  # Update exports to include new functions (add to existing __all__ if it exists)
@@ -2282,6 +2334,9 @@ try:
2282
2334
  "compute_parameter_deltas",
2283
2335
  "analyze_spectral_changes",
2284
2336
  "analyze_rmt_changes",
2337
+ "Metric",
2338
+ "PerplexityMetric",
2339
+ "AccuracyMetric",
2285
2340
  ]
2286
2341
  )
2287
2342
  except NameError:
@@ -2294,4 +2349,7 @@ except NameError:
2294
2349
  "compute_parameter_deltas",
2295
2350
  "analyze_spectral_changes",
2296
2351
  "analyze_rmt_changes",
2352
+ "Metric",
2353
+ "PerplexityMetric",
2354
+ "AccuracyMetric",
2297
2355
  ]
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from .classification import accuracy_from_records
4
+ from .qa import exact_match_from_records
5
+ from .text_generation import bleu1_from_records, rouge_l_from_records
6
+
7
+ __all__ = [
8
+ "accuracy_from_records",
9
+ "exact_match_from_records",
10
+ "bleu1_from_records",
11
+ "rouge_l_from_records",
12
+ ]
@@ -0,0 +1,48 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+
7
+ def _iter_pairs(record: dict[str, Any]) -> list[tuple[Any, Any]]:
8
+ if "correct" in record:
9
+ return [(bool(record.get("correct")), True)]
10
+
11
+ label = record.get("label")
12
+ pred = record.get("prediction")
13
+ if label is None:
14
+ label = record.get("labels")
15
+ if pred is None:
16
+ pred = record.get("pred")
17
+ if pred is None:
18
+ pred = record.get("predictions")
19
+
20
+ if isinstance(label, list) and isinstance(pred, list):
21
+ return list(zip(label, pred, strict=False))
22
+ if label is None or pred is None:
23
+ return []
24
+ return [(label, pred)]
25
+
26
+
27
+ def accuracy_from_records(records: Iterable[dict[str, Any]]) -> float:
28
+ """Compute accuracy from records with labels/predictions.
29
+
30
+ Accepted record shapes:
31
+ - {"label": <label>, "prediction": <label>}
32
+ - {"labels": [...], "predictions": [...]}
33
+ - {"correct": <bool>}
34
+ """
35
+ total = 0
36
+ correct = 0
37
+ for record in records:
38
+ if not isinstance(record, dict):
39
+ continue
40
+ for label, pred in _iter_pairs(record):
41
+ total += 1
42
+ if isinstance(label, bool):
43
+ correct += int(label is pred)
44
+ else:
45
+ correct += int(label == pred)
46
+ if total == 0:
47
+ return float("nan")
48
+ return float(correct / total)
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+
7
+ def _normalize(text: str) -> str:
8
+ return " ".join(str(text).strip().lower().split())
9
+
10
+
11
+ def exact_match_from_records(records: Iterable[dict[str, Any]]) -> float:
12
+ """Compute exact-match accuracy for QA-style records.
13
+
14
+ Accepted record shapes:
15
+ - {"prediction": "...", "answer": "..."}
16
+ - {"prediction": "...", "answers": ["...", ...]}
17
+ """
18
+ total = 0
19
+ correct = 0
20
+ for record in records:
21
+ if not isinstance(record, dict):
22
+ continue
23
+ pred = record.get("prediction")
24
+ answers = record.get("answers")
25
+ if answers is None and "answer" in record:
26
+ answers = [record.get("answer")]
27
+ if pred is None or answers is None:
28
+ continue
29
+ pred_norm = _normalize(pred)
30
+ answer_list = answers if isinstance(answers, list) else [answers]
31
+ total += 1
32
+ if any(_normalize(a) == pred_norm for a in answer_list if a is not None):
33
+ correct += 1
34
+ if total == 0:
35
+ return float("nan")
36
+ return float(correct / total)