invarlock 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +2 -2
- invarlock/adapters/__init__.py +10 -14
- invarlock/adapters/auto.py +35 -40
- invarlock/adapters/capabilities.py +2 -2
- invarlock/adapters/hf_causal.py +418 -0
- invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
- invarlock/adapters/hf_mixin.py +25 -4
- invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
- invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
- invarlock/cli/adapter_auto.py +31 -21
- invarlock/cli/app.py +73 -2
- invarlock/cli/commands/certify.py +600 -59
- invarlock/cli/commands/doctor.py +8 -10
- invarlock/cli/commands/plugins.py +13 -9
- invarlock/cli/commands/report.py +233 -69
- invarlock/cli/commands/run.py +907 -183
- invarlock/cli/commands/verify.py +76 -11
- invarlock/cli/config.py +1 -1
- invarlock/cli/doctor_helpers.py +4 -5
- invarlock/cli/output.py +193 -0
- invarlock/cli/provenance.py +1 -1
- invarlock/core/bootstrap.py +1 -1
- invarlock/core/registry.py +9 -11
- invarlock/core/runner.py +111 -25
- invarlock/edits/quant_rtn.py +65 -37
- invarlock/eval/bench.py +3 -3
- invarlock/eval/data.py +68 -23
- invarlock/eval/metrics.py +59 -1
- invarlock/eval/tasks/__init__.py +12 -0
- invarlock/eval/tasks/classification.py +48 -0
- invarlock/eval/tasks/qa.py +36 -0
- invarlock/eval/tasks/text_generation.py +102 -0
- invarlock/guards/invariants.py +19 -10
- invarlock/guards/rmt.py +2 -2
- invarlock/guards/variance.py +2 -2
- invarlock/model_profile.py +48 -27
- invarlock/observability/health.py +6 -6
- invarlock/observability/metrics.py +108 -0
- invarlock/reporting/certificate.py +159 -9
- invarlock/reporting/certificate_schema.py +1 -1
- invarlock/reporting/guards_analysis.py +154 -4
- invarlock/reporting/html.py +55 -5
- invarlock/reporting/normalizer.py +7 -0
- invarlock/reporting/render.py +791 -431
- invarlock/reporting/report.py +39 -3
- invarlock/reporting/report_types.py +6 -1
- invarlock/reporting/telemetry.py +86 -0
- {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/METADATA +23 -9
- {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/RECORD +53 -48
- {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
- {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
- invarlock/adapters/hf_gpt2.py +0 -404
- invarlock/adapters/hf_llama.py +0 -487
- {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.6.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
invarlock/edits/quant_rtn.py
CHANGED
|
@@ -86,6 +86,31 @@ class RTNQuantEdit(ModelEdit):
|
|
|
86
86
|
|
|
87
87
|
# group_size is currently reserved for potential future variants; it is
|
|
88
88
|
# ignored for the built-in INT8 demo edit.
|
|
89
|
+
self._emit_enabled = True
|
|
90
|
+
self._emit_console = None
|
|
91
|
+
self._output_style = None
|
|
92
|
+
|
|
93
|
+
def _configure_output(self, **kwargs: Any) -> None:
|
|
94
|
+
emit = kwargs.get("emit", True)
|
|
95
|
+
self._emit_enabled = bool(emit)
|
|
96
|
+
console = kwargs.get("console")
|
|
97
|
+
if console is not None and hasattr(console, "print"):
|
|
98
|
+
self._emit_console = console
|
|
99
|
+
else:
|
|
100
|
+
self._emit_console = None
|
|
101
|
+
self._output_style = kwargs.get("output_style")
|
|
102
|
+
|
|
103
|
+
def _emit(self, message: str) -> None:
|
|
104
|
+
if not self._emit_enabled:
|
|
105
|
+
return
|
|
106
|
+
line = f"[EDIT] {message}".rstrip()
|
|
107
|
+
if self._emit_console is not None:
|
|
108
|
+
try:
|
|
109
|
+
self._emit_console.print(line, markup=False)
|
|
110
|
+
except TypeError:
|
|
111
|
+
self._emit_console.print(line)
|
|
112
|
+
else:
|
|
113
|
+
print(line)
|
|
89
114
|
|
|
90
115
|
def can_edit(self, model_desc: dict[str, Any]) -> bool:
|
|
91
116
|
"""Check if RTN quantization can be applied to this model."""
|
|
@@ -233,15 +258,18 @@ class RTNQuantEdit(ModelEdit):
|
|
|
233
258
|
scope = kwargs.get("scope", self.scope)
|
|
234
259
|
seed = kwargs.get("seed", self.seed)
|
|
235
260
|
|
|
261
|
+
self._configure_output(**kwargs)
|
|
262
|
+
|
|
236
263
|
# Diagnostic reporting
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
264
|
+
self._emit("RTN Quantization Configuration:")
|
|
265
|
+
self._emit(
|
|
266
|
+
"Bitwidth: "
|
|
267
|
+
f"{bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
|
|
240
268
|
)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
269
|
+
self._emit(f"Scope: {scope}")
|
|
270
|
+
self._emit(f"Group size: {group_size}")
|
|
271
|
+
self._emit(f"Clamp ratio: {clamp_ratio}")
|
|
272
|
+
self._emit(f"Seed: {seed}")
|
|
245
273
|
|
|
246
274
|
# Persist configuration overrides for downstream helpers
|
|
247
275
|
self.bitwidth = bitwidth
|
|
@@ -256,22 +284,22 @@ class RTNQuantEdit(ModelEdit):
|
|
|
256
284
|
np.random.seed(seed)
|
|
257
285
|
|
|
258
286
|
# Identify target modules and get weight tying map
|
|
259
|
-
|
|
287
|
+
self._emit(f"Identifying target modules for scope '{scope}'...")
|
|
260
288
|
target_modules = self._identify_target_modules(model)
|
|
261
289
|
total_identified = len(target_modules)
|
|
262
290
|
|
|
263
291
|
max_modules = kwargs.get("max_modules")
|
|
264
292
|
if isinstance(max_modules, int) and max_modules > 0:
|
|
265
293
|
if max_modules < total_identified:
|
|
266
|
-
|
|
267
|
-
f"
|
|
294
|
+
self._emit(
|
|
295
|
+
f"Limiting quantization to first {max_modules} modules "
|
|
268
296
|
f"(of {total_identified}) based on plan.max_modules"
|
|
269
297
|
)
|
|
270
298
|
target_modules = target_modules[:max_modules]
|
|
271
299
|
self.max_modules = max_modules
|
|
272
300
|
else:
|
|
273
|
-
|
|
274
|
-
f"
|
|
301
|
+
self._emit(
|
|
302
|
+
f"max_modules={max_modules} >= available modules "
|
|
275
303
|
f"({total_identified}); using all targets"
|
|
276
304
|
)
|
|
277
305
|
self.max_modules = None
|
|
@@ -280,33 +308,35 @@ class RTNQuantEdit(ModelEdit):
|
|
|
280
308
|
|
|
281
309
|
tying_map = self._get_weight_tying_map(model)
|
|
282
310
|
|
|
283
|
-
|
|
311
|
+
self._emit(f"Found {len(target_modules)} target modules:")
|
|
284
312
|
for i, (name, module) in enumerate(target_modules):
|
|
285
313
|
weight_shape = module.weight.shape
|
|
286
314
|
param_count = module.weight.numel()
|
|
287
|
-
|
|
315
|
+
self._emit(f"[{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
|
|
288
316
|
|
|
289
317
|
if len(target_modules) == 0:
|
|
290
|
-
|
|
291
|
-
|
|
318
|
+
self._emit(
|
|
319
|
+
"WARNING: No target modules found! Check scope configuration."
|
|
320
|
+
)
|
|
321
|
+
self._emit("Available linear modules:")
|
|
292
322
|
linear_modules = []
|
|
293
323
|
for name, module in model.named_modules():
|
|
294
324
|
if isinstance(module, nn.Linear | nn.Conv1d):
|
|
295
325
|
linear_modules.append((name, module.weight.shape))
|
|
296
326
|
for name, shape in linear_modules[:10]: # Show first 10
|
|
297
|
-
|
|
327
|
+
self._emit(f"{name}: {shape}")
|
|
298
328
|
if len(linear_modules) > 10:
|
|
299
|
-
|
|
329
|
+
self._emit(f"... and {len(linear_modules) - 10} more")
|
|
300
330
|
|
|
301
331
|
# Execute GuardChain before edit (if provided)
|
|
302
332
|
guard_results = {}
|
|
303
333
|
if self.guard_chain is not None:
|
|
304
|
-
|
|
334
|
+
self._emit("Executing guard chain preparation...")
|
|
305
335
|
guard_results["prepare"] = self.guard_chain.prepare_all(
|
|
306
336
|
model, adapter, None, {}
|
|
307
337
|
)
|
|
308
338
|
|
|
309
|
-
|
|
339
|
+
self._emit("Executing before-edit guards...")
|
|
310
340
|
self.guard_chain.before_edit_all(model)
|
|
311
341
|
|
|
312
342
|
# Apply quantization to each target module
|
|
@@ -314,12 +344,12 @@ class RTNQuantEdit(ModelEdit):
|
|
|
314
344
|
total_params_quantized = 0
|
|
315
345
|
|
|
316
346
|
for i, (module_name, module) in enumerate(target_modules):
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
f"
|
|
347
|
+
self._emit(f"[{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
|
|
348
|
+
self._emit(
|
|
349
|
+
f"Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
|
|
320
350
|
)
|
|
321
|
-
|
|
322
|
-
f"
|
|
351
|
+
self._emit(
|
|
352
|
+
f"Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
|
|
323
353
|
)
|
|
324
354
|
|
|
325
355
|
# Apply RTN quantization
|
|
@@ -335,24 +365,22 @@ class RTNQuantEdit(ModelEdit):
|
|
|
335
365
|
quantization_results.append(quant_result)
|
|
336
366
|
total_params_quantized += quant_result["params_quantized"]
|
|
337
367
|
|
|
338
|
-
|
|
339
|
-
f" ✓ Quantized {quant_result['params_quantized']:,} parameters"
|
|
340
|
-
)
|
|
368
|
+
self._emit(f"Quantized {quant_result['params_quantized']:,} parameters")
|
|
341
369
|
|
|
342
370
|
# Execute GuardChain after edit (if provided)
|
|
343
371
|
if self.guard_chain is not None:
|
|
344
|
-
|
|
372
|
+
self._emit("Executing after-edit guards...")
|
|
345
373
|
self.guard_chain.after_edit_all(model)
|
|
346
374
|
|
|
347
|
-
|
|
375
|
+
self._emit("Finalizing guard chain...")
|
|
348
376
|
guard_results["finalize"] = self.guard_chain.finalize_all(model)
|
|
349
377
|
|
|
350
378
|
# Check if all guards passed
|
|
351
379
|
if not self.guard_chain.all_passed(guard_results["finalize"]):
|
|
352
|
-
|
|
380
|
+
self._emit("Guard chain validation failed!")
|
|
353
381
|
guard_results["all_passed"] = False
|
|
354
382
|
else:
|
|
355
|
-
|
|
383
|
+
self._emit("All guards passed")
|
|
356
384
|
guard_results["all_passed"] = True
|
|
357
385
|
|
|
358
386
|
# Create bitwidth map
|
|
@@ -490,11 +518,11 @@ class RTNQuantEdit(ModelEdit):
|
|
|
490
518
|
|
|
491
519
|
# Log diagnostic information
|
|
492
520
|
if skipped_modules:
|
|
493
|
-
|
|
521
|
+
self._emit(f"Skipped {len(skipped_modules)} modules:")
|
|
494
522
|
for name, reason in skipped_modules[:5]: # Show first 5
|
|
495
|
-
|
|
523
|
+
self._emit(f"{name}: {reason}")
|
|
496
524
|
if len(skipped_modules) > 5:
|
|
497
|
-
|
|
525
|
+
self._emit(f"... and {len(skipped_modules) - 5} more")
|
|
498
526
|
|
|
499
527
|
return target_modules
|
|
500
528
|
|
|
@@ -625,7 +653,7 @@ class RTNQuantEdit(ModelEdit):
|
|
|
625
653
|
# Ensure actual quantization occurred by applying quantization loss
|
|
626
654
|
# This guarantees the weights are actually modified
|
|
627
655
|
quantization_error = (quantized_weight - original_weight).abs().mean()
|
|
628
|
-
|
|
656
|
+
self._emit(f"Quantization error: {quantization_error:.6f}")
|
|
629
657
|
|
|
630
658
|
# Write back to module (preserving tying if needed)
|
|
631
659
|
module.weight.data.copy_(quantized_weight)
|
|
@@ -634,7 +662,7 @@ class RTNQuantEdit(ModelEdit):
|
|
|
634
662
|
final_weight = module.weight.data
|
|
635
663
|
actual_change = not torch.allclose(original_weight, final_weight, atol=1e-6)
|
|
636
664
|
if not actual_change:
|
|
637
|
-
|
|
665
|
+
self._emit(f"WARNING: No actual weight change detected for {module}")
|
|
638
666
|
|
|
639
667
|
# Handle tied weights
|
|
640
668
|
if tied_modules:
|
invarlock/eval/bench.py
CHANGED
|
@@ -47,7 +47,7 @@ class ScenarioConfig:
|
|
|
47
47
|
probes: int
|
|
48
48
|
profile: str = "ci" # "ci" or "release"
|
|
49
49
|
model_id: str = "gpt2"
|
|
50
|
-
adapter: str = "
|
|
50
|
+
adapter: str = "hf_causal"
|
|
51
51
|
device: str = "auto"
|
|
52
52
|
seq_len: int = 512
|
|
53
53
|
stride: int = 128
|
|
@@ -81,7 +81,7 @@ class BenchmarkConfig:
|
|
|
81
81
|
profile: str = "ci" # "ci" or "release"
|
|
82
82
|
dataset: str = "wikitext2"
|
|
83
83
|
model_id: str = "gpt2"
|
|
84
|
-
adapter: str = "
|
|
84
|
+
adapter: str = "hf_causal"
|
|
85
85
|
device: str = "auto"
|
|
86
86
|
seq_len: int = 512
|
|
87
87
|
stride: int = 128
|
|
@@ -1423,7 +1423,7 @@ def main():
|
|
|
1423
1423
|
"--dataset", default="wikitext2", help="Dataset to use for benchmarking"
|
|
1424
1424
|
)
|
|
1425
1425
|
parser.add_argument("--model-id", default="gpt2", help="Model identifier")
|
|
1426
|
-
parser.add_argument("--adapter", default="
|
|
1426
|
+
parser.add_argument("--adapter", default="hf_causal", help="Model adapter to use")
|
|
1427
1427
|
parser.add_argument(
|
|
1428
1428
|
"--device", default="auto", help="Device to use (auto|cuda|mps|cpu)"
|
|
1429
1429
|
)
|
invarlock/eval/data.py
CHANGED
|
@@ -15,7 +15,7 @@ import time
|
|
|
15
15
|
import warnings
|
|
16
16
|
from abc import abstractmethod
|
|
17
17
|
from collections import Counter
|
|
18
|
-
from collections.abc import Sequence
|
|
18
|
+
from collections.abc import Callable, Sequence
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
from typing import Any, NamedTuple, Protocol
|
|
21
21
|
|
|
@@ -56,6 +56,9 @@ except ImportError:
|
|
|
56
56
|
HAS_TORCH = False
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
EventEmitter = Callable[[str, str, str | None], None]
|
|
60
|
+
|
|
61
|
+
|
|
59
62
|
class EvaluationWindow(NamedTuple):
|
|
60
63
|
"""A window of tokenized samples for evaluation."""
|
|
61
64
|
|
|
@@ -166,6 +169,7 @@ class WikiText2Provider:
|
|
|
166
169
|
self,
|
|
167
170
|
cache_dir: Path | None = None,
|
|
168
171
|
device_hint: str | None = None,
|
|
172
|
+
emit: EventEmitter | None = None,
|
|
169
173
|
**_: Any,
|
|
170
174
|
):
|
|
171
175
|
"""
|
|
@@ -175,6 +179,7 @@ class WikiText2Provider:
|
|
|
175
179
|
cache_dir: Optional cache directory for dataset storage
|
|
176
180
|
"""
|
|
177
181
|
self.cache_dir = cache_dir
|
|
182
|
+
self._emit_event = emit
|
|
178
183
|
self._validate_dependencies()
|
|
179
184
|
self._last_stratification_stats: dict[str, Any] | None = None
|
|
180
185
|
self._last_batch_size_used: int = 0
|
|
@@ -186,6 +191,20 @@ class WikiText2Provider:
|
|
|
186
191
|
normalized_hint = (device_hint or "").strip().lower()
|
|
187
192
|
self._device_hint: str | None = normalized_hint or None
|
|
188
193
|
|
|
194
|
+
def _event(self, tag: str, message: str, *, emoji: str | None = None) -> None:
|
|
195
|
+
"""Emit a dataset event via an optional CLI-provided sink."""
|
|
196
|
+
if self._emit_event is None:
|
|
197
|
+
if emoji:
|
|
198
|
+
print(f"{emoji} {message}")
|
|
199
|
+
else:
|
|
200
|
+
print(message)
|
|
201
|
+
return
|
|
202
|
+
try:
|
|
203
|
+
self._emit_event(tag, message, emoji)
|
|
204
|
+
except TypeError:
|
|
205
|
+
# Back-compat: tolerate sinks that only accept (tag, message).
|
|
206
|
+
self._emit_event(tag, message) # type: ignore[misc]
|
|
207
|
+
|
|
189
208
|
def _validate_dependencies(self) -> None:
|
|
190
209
|
"""Check that required dependencies are available."""
|
|
191
210
|
if not HAS_DATASETS:
|
|
@@ -319,7 +338,11 @@ class WikiText2Provider:
|
|
|
319
338
|
Returns:
|
|
320
339
|
List of filtered text strings
|
|
321
340
|
"""
|
|
322
|
-
|
|
341
|
+
self._event(
|
|
342
|
+
"DATA",
|
|
343
|
+
f"WikiText-2 {split}: loading split...",
|
|
344
|
+
emoji="📚",
|
|
345
|
+
)
|
|
323
346
|
|
|
324
347
|
# Serve from cache when possible (load the largest slice once)
|
|
325
348
|
cached = self._texts_cache.get(split)
|
|
@@ -366,7 +389,10 @@ class WikiText2Provider:
|
|
|
366
389
|
if prev is None or len(valid_texts) > len(prev):
|
|
367
390
|
self._texts_cache[split] = list(valid_texts)
|
|
368
391
|
|
|
369
|
-
|
|
392
|
+
self._event(
|
|
393
|
+
"DATA",
|
|
394
|
+
f"Loaded {len(valid_texts)}/{len(dataset)} valid samples",
|
|
395
|
+
)
|
|
370
396
|
return valid_texts
|
|
371
397
|
|
|
372
398
|
def windows(
|
|
@@ -435,9 +461,13 @@ class WikiText2Provider:
|
|
|
435
461
|
cursor = 0
|
|
436
462
|
chunk_size = max(64, min(256, target_pool))
|
|
437
463
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
464
|
+
self._event(
|
|
465
|
+
"DATA",
|
|
466
|
+
"Creating evaluation windows:",
|
|
467
|
+
emoji="📊",
|
|
468
|
+
)
|
|
469
|
+
self._event("DATA", f"Requested preview/final: {preview_n}/{final_n}")
|
|
470
|
+
self._event("DATA", f"Sampling pool target: {target_pool} (reserve {reserve})")
|
|
441
471
|
|
|
442
472
|
while len(candidates) < total_required + reserve and cursor < len(
|
|
443
473
|
shuffled_indices
|
|
@@ -708,9 +738,9 @@ class WikiText2Provider:
|
|
|
708
738
|
),
|
|
709
739
|
}
|
|
710
740
|
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
741
|
+
self._event("DATA", f"Seed: {seed}, Seq length: {seq_len}")
|
|
742
|
+
self._event("DATA", f"Preview: {len(preview_window)} samples")
|
|
743
|
+
self._event("DATA", f"Final: {len(final_window)} samples")
|
|
714
744
|
|
|
715
745
|
return preview_window, final_window
|
|
716
746
|
|
|
@@ -840,8 +870,9 @@ class WikiText2Provider:
|
|
|
840
870
|
attention_masks_list = [entry[2] for entry in collected]
|
|
841
871
|
valid_indices = [entry[0] for entry in collected]
|
|
842
872
|
|
|
843
|
-
|
|
844
|
-
|
|
873
|
+
self._event(
|
|
874
|
+
"DATA",
|
|
875
|
+
f"{window_name}: {len(valid_indices)}/{len(indices)} samples tokenized",
|
|
845
876
|
)
|
|
846
877
|
|
|
847
878
|
return EvaluationWindow(
|
|
@@ -934,7 +965,8 @@ class SyntheticProvider:
|
|
|
934
965
|
self, split: str = "validation", max_samples: int = 500, **kwargs
|
|
935
966
|
) -> list[str]:
|
|
936
967
|
"""Generate synthetic text samples."""
|
|
937
|
-
# Expand base samples to meet requirement
|
|
968
|
+
# Expand base samples to meet requirement, preferring unique variations
|
|
969
|
+
# to avoid duplicate-token windows (important for stratified pairing).
|
|
938
970
|
expanded_samples: list[str] = []
|
|
939
971
|
variations = [
|
|
940
972
|
lambda s: s,
|
|
@@ -944,18 +976,25 @@ class SyntheticProvider:
|
|
|
944
976
|
lambda s: f"Furthermore, {s.lower()}",
|
|
945
977
|
lambda s: f"In addition, {s.lower()}",
|
|
946
978
|
]
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
rng = np.random.RandomState(42) # Fixed seed for reproducibility
|
|
950
|
-
|
|
951
|
-
while len(expanded_samples) < max_samples:
|
|
979
|
+
# Deterministic coverage of (variation × base sample) combinations first.
|
|
980
|
+
for variation in variations:
|
|
952
981
|
for base_text in self.base_samples:
|
|
953
|
-
if len(expanded_samples) >= max_samples:
|
|
954
|
-
break
|
|
955
|
-
variation = rng.choice(variations)
|
|
956
982
|
expanded_samples.append(variation(base_text))
|
|
983
|
+
if len(expanded_samples) >= max_samples:
|
|
984
|
+
return expanded_samples
|
|
985
|
+
|
|
986
|
+
# If callers request more than the unique combination space, keep
|
|
987
|
+
# extending deterministically while ensuring uniqueness via a suffix.
|
|
988
|
+
idx = 0
|
|
989
|
+
while len(expanded_samples) < max_samples:
|
|
990
|
+
base_text = self.base_samples[idx % len(self.base_samples)]
|
|
991
|
+
variation = variations[(idx // len(self.base_samples)) % len(variations)]
|
|
992
|
+
expanded_samples.append(
|
|
993
|
+
f"{variation(base_text)} [synthetic #{len(expanded_samples)}]"
|
|
994
|
+
)
|
|
995
|
+
idx += 1
|
|
957
996
|
|
|
958
|
-
return expanded_samples
|
|
997
|
+
return expanded_samples
|
|
959
998
|
|
|
960
999
|
def windows(
|
|
961
1000
|
self,
|
|
@@ -1801,12 +1840,15 @@ _PROVIDERS: dict[str, type] = {
|
|
|
1801
1840
|
}
|
|
1802
1841
|
|
|
1803
1842
|
|
|
1804
|
-
def get_provider(
|
|
1843
|
+
def get_provider(
|
|
1844
|
+
name: str, *, emit: EventEmitter | None = None, **kwargs: Any
|
|
1845
|
+
) -> DatasetProvider:
|
|
1805
1846
|
"""
|
|
1806
1847
|
Get a dataset provider by name.
|
|
1807
1848
|
|
|
1808
1849
|
Args:
|
|
1809
1850
|
name: Provider name ("wikitext2", "synthetic")
|
|
1851
|
+
emit: Optional event sink for dataset/provider logs.
|
|
1810
1852
|
**kwargs: Provider-specific initialization parameters
|
|
1811
1853
|
|
|
1812
1854
|
Returns:
|
|
@@ -1825,7 +1867,10 @@ def get_provider(name: str, **kwargs) -> DatasetProvider:
|
|
|
1825
1867
|
)
|
|
1826
1868
|
|
|
1827
1869
|
provider_class = _PROVIDERS[name]
|
|
1828
|
-
|
|
1870
|
+
init_kwargs = dict(kwargs)
|
|
1871
|
+
if emit is not None and name == "wikitext2":
|
|
1872
|
+
init_kwargs["emit"] = emit
|
|
1873
|
+
return provider_class(**init_kwargs)
|
|
1829
1874
|
|
|
1830
1875
|
|
|
1831
1876
|
def list_providers() -> list[str]:
|
invarlock/eval/metrics.py
CHANGED
|
@@ -18,9 +18,10 @@ import gc
|
|
|
18
18
|
import logging
|
|
19
19
|
import math
|
|
20
20
|
import time
|
|
21
|
+
from collections.abc import Iterable
|
|
21
22
|
from dataclasses import dataclass
|
|
22
23
|
from pathlib import Path
|
|
23
|
-
from typing import Any
|
|
24
|
+
from typing import Any, Protocol
|
|
24
25
|
|
|
25
26
|
import numpy as np
|
|
26
27
|
import psutil
|
|
@@ -2269,6 +2270,57 @@ def analyze_rmt_changes(
|
|
|
2269
2270
|
return {"error": str(e)}
|
|
2270
2271
|
|
|
2271
2272
|
|
|
2273
|
+
class Metric(Protocol):
|
|
2274
|
+
name: str
|
|
2275
|
+
kind: str # "ppl", "accuracy", "exact_match", "bleu", "rouge"
|
|
2276
|
+
|
|
2277
|
+
def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: ...
|
|
2278
|
+
|
|
2279
|
+
|
|
2280
|
+
class PerplexityMetric:
|
|
2281
|
+
"""Lightweight perplexity metric from per-record logloss + token counts."""
|
|
2282
|
+
|
|
2283
|
+
name = "perplexity"
|
|
2284
|
+
kind = "ppl"
|
|
2285
|
+
|
|
2286
|
+
def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
|
|
2287
|
+
total_loss = 0.0
|
|
2288
|
+
total_tokens = 0.0
|
|
2289
|
+
for record in dataset:
|
|
2290
|
+
if not isinstance(record, dict):
|
|
2291
|
+
continue
|
|
2292
|
+
loss = record.get("logloss", record.get("loss"))
|
|
2293
|
+
tokens = record.get("token_count", record.get("tokens", 1))
|
|
2294
|
+
try:
|
|
2295
|
+
loss_val = float(loss)
|
|
2296
|
+
tok_val = float(tokens)
|
|
2297
|
+
except Exception:
|
|
2298
|
+
continue
|
|
2299
|
+
if (
|
|
2300
|
+
not math.isfinite(loss_val)
|
|
2301
|
+
or not math.isfinite(tok_val)
|
|
2302
|
+
or tok_val <= 0
|
|
2303
|
+
):
|
|
2304
|
+
continue
|
|
2305
|
+
total_loss += loss_val * tok_val
|
|
2306
|
+
total_tokens += tok_val
|
|
2307
|
+
if total_tokens <= 0:
|
|
2308
|
+
return float("nan")
|
|
2309
|
+
return float(math.exp(total_loss / total_tokens))
|
|
2310
|
+
|
|
2311
|
+
|
|
2312
|
+
class AccuracyMetric:
|
|
2313
|
+
"""Classification accuracy metric from label/prediction records."""
|
|
2314
|
+
|
|
2315
|
+
name = "accuracy"
|
|
2316
|
+
kind = "accuracy"
|
|
2317
|
+
|
|
2318
|
+
def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
|
|
2319
|
+
from invarlock.eval.tasks.classification import accuracy_from_records
|
|
2320
|
+
|
|
2321
|
+
return accuracy_from_records(dataset)
|
|
2322
|
+
|
|
2323
|
+
|
|
2272
2324
|
# ── Integration with existing system ───────────────────────────────────────
|
|
2273
2325
|
|
|
2274
2326
|
# Update exports to include new functions (add to existing __all__ if it exists)
|
|
@@ -2282,6 +2334,9 @@ try:
|
|
|
2282
2334
|
"compute_parameter_deltas",
|
|
2283
2335
|
"analyze_spectral_changes",
|
|
2284
2336
|
"analyze_rmt_changes",
|
|
2337
|
+
"Metric",
|
|
2338
|
+
"PerplexityMetric",
|
|
2339
|
+
"AccuracyMetric",
|
|
2285
2340
|
]
|
|
2286
2341
|
)
|
|
2287
2342
|
except NameError:
|
|
@@ -2294,4 +2349,7 @@ except NameError:
|
|
|
2294
2349
|
"compute_parameter_deltas",
|
|
2295
2350
|
"analyze_spectral_changes",
|
|
2296
2351
|
"analyze_rmt_changes",
|
|
2352
|
+
"Metric",
|
|
2353
|
+
"PerplexityMetric",
|
|
2354
|
+
"AccuracyMetric",
|
|
2297
2355
|
]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .classification import accuracy_from_records
|
|
4
|
+
from .qa import exact_match_from_records
|
|
5
|
+
from .text_generation import bleu1_from_records, rouge_l_from_records
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"accuracy_from_records",
|
|
9
|
+
"exact_match_from_records",
|
|
10
|
+
"bleu1_from_records",
|
|
11
|
+
"rouge_l_from_records",
|
|
12
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _iter_pairs(record: dict[str, Any]) -> list[tuple[Any, Any]]:
|
|
8
|
+
if "correct" in record:
|
|
9
|
+
return [(bool(record.get("correct")), True)]
|
|
10
|
+
|
|
11
|
+
label = record.get("label")
|
|
12
|
+
pred = record.get("prediction")
|
|
13
|
+
if label is None:
|
|
14
|
+
label = record.get("labels")
|
|
15
|
+
if pred is None:
|
|
16
|
+
pred = record.get("pred")
|
|
17
|
+
if pred is None:
|
|
18
|
+
pred = record.get("predictions")
|
|
19
|
+
|
|
20
|
+
if isinstance(label, list) and isinstance(pred, list):
|
|
21
|
+
return list(zip(label, pred, strict=False))
|
|
22
|
+
if label is None or pred is None:
|
|
23
|
+
return []
|
|
24
|
+
return [(label, pred)]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def accuracy_from_records(records: Iterable[dict[str, Any]]) -> float:
|
|
28
|
+
"""Compute accuracy from records with labels/predictions.
|
|
29
|
+
|
|
30
|
+
Accepted record shapes:
|
|
31
|
+
- {"label": <label>, "prediction": <label>}
|
|
32
|
+
- {"labels": [...], "predictions": [...]}
|
|
33
|
+
- {"correct": <bool>}
|
|
34
|
+
"""
|
|
35
|
+
total = 0
|
|
36
|
+
correct = 0
|
|
37
|
+
for record in records:
|
|
38
|
+
if not isinstance(record, dict):
|
|
39
|
+
continue
|
|
40
|
+
for label, pred in _iter_pairs(record):
|
|
41
|
+
total += 1
|
|
42
|
+
if isinstance(label, bool):
|
|
43
|
+
correct += int(label is pred)
|
|
44
|
+
else:
|
|
45
|
+
correct += int(label == pred)
|
|
46
|
+
if total == 0:
|
|
47
|
+
return float("nan")
|
|
48
|
+
return float(correct / total)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _normalize(text: str) -> str:
|
|
8
|
+
return " ".join(str(text).strip().lower().split())
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def exact_match_from_records(records: Iterable[dict[str, Any]]) -> float:
|
|
12
|
+
"""Compute exact-match accuracy for QA-style records.
|
|
13
|
+
|
|
14
|
+
Accepted record shapes:
|
|
15
|
+
- {"prediction": "...", "answer": "..."}
|
|
16
|
+
- {"prediction": "...", "answers": ["...", ...]}
|
|
17
|
+
"""
|
|
18
|
+
total = 0
|
|
19
|
+
correct = 0
|
|
20
|
+
for record in records:
|
|
21
|
+
if not isinstance(record, dict):
|
|
22
|
+
continue
|
|
23
|
+
pred = record.get("prediction")
|
|
24
|
+
answers = record.get("answers")
|
|
25
|
+
if answers is None and "answer" in record:
|
|
26
|
+
answers = [record.get("answer")]
|
|
27
|
+
if pred is None or answers is None:
|
|
28
|
+
continue
|
|
29
|
+
pred_norm = _normalize(pred)
|
|
30
|
+
answer_list = answers if isinstance(answers, list) else [answers]
|
|
31
|
+
total += 1
|
|
32
|
+
if any(_normalize(a) == pred_norm for a in answer_list if a is not None):
|
|
33
|
+
correct += 1
|
|
34
|
+
if total == 0:
|
|
35
|
+
return float("nan")
|
|
36
|
+
return float(correct / total)
|