invarlock 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invarlock/__init__.py +2 -2
- invarlock/_data/runtime/tiers.yaml +57 -30
- invarlock/adapters/__init__.py +11 -15
- invarlock/adapters/auto.py +35 -40
- invarlock/adapters/capabilities.py +2 -2
- invarlock/adapters/hf_causal.py +418 -0
- invarlock/adapters/{hf_onnx.py → hf_causal_onnx.py} +3 -3
- invarlock/adapters/hf_mixin.py +25 -4
- invarlock/adapters/{hf_bert.py → hf_mlm.py} +4 -11
- invarlock/adapters/{hf_t5.py → hf_seq2seq.py} +9 -9
- invarlock/calibration/spectral_null.py +15 -10
- invarlock/calibration/variance_ve.py +0 -2
- invarlock/cli/adapter_auto.py +31 -21
- invarlock/cli/app.py +73 -2
- invarlock/cli/commands/calibrate.py +6 -2
- invarlock/cli/commands/certify.py +651 -91
- invarlock/cli/commands/doctor.py +11 -11
- invarlock/cli/commands/explain_gates.py +57 -8
- invarlock/cli/commands/plugins.py +13 -9
- invarlock/cli/commands/report.py +233 -69
- invarlock/cli/commands/run.py +1066 -244
- invarlock/cli/commands/verify.py +154 -15
- invarlock/cli/config.py +22 -6
- invarlock/cli/doctor_helpers.py +4 -5
- invarlock/cli/output.py +193 -0
- invarlock/cli/provenance.py +1 -1
- invarlock/core/api.py +45 -5
- invarlock/core/auto_tuning.py +65 -20
- invarlock/core/bootstrap.py +1 -1
- invarlock/core/contracts.py +7 -1
- invarlock/core/registry.py +11 -13
- invarlock/core/runner.py +425 -75
- invarlock/edits/quant_rtn.py +65 -37
- invarlock/eval/bench.py +3 -16
- invarlock/eval/data.py +82 -51
- invarlock/eval/metrics.py +63 -2
- invarlock/eval/primary_metric.py +23 -0
- invarlock/eval/tail_stats.py +230 -0
- invarlock/eval/tasks/__init__.py +12 -0
- invarlock/eval/tasks/classification.py +48 -0
- invarlock/eval/tasks/qa.py +36 -0
- invarlock/eval/tasks/text_generation.py +102 -0
- invarlock/guards/_estimators.py +154 -0
- invarlock/guards/invariants.py +19 -10
- invarlock/guards/policies.py +16 -6
- invarlock/guards/rmt.py +627 -546
- invarlock/guards/spectral.py +348 -110
- invarlock/guards/tier_config.py +32 -30
- invarlock/guards/variance.py +7 -31
- invarlock/guards_ref/rmt_ref.py +23 -23
- invarlock/model_profile.py +90 -42
- invarlock/observability/health.py +6 -6
- invarlock/observability/metrics.py +108 -0
- invarlock/reporting/certificate.py +384 -55
- invarlock/reporting/certificate_schema.py +3 -2
- invarlock/reporting/dataset_hashing.py +15 -2
- invarlock/reporting/guards_analysis.py +350 -277
- invarlock/reporting/html.py +55 -5
- invarlock/reporting/normalizer.py +13 -0
- invarlock/reporting/policy_utils.py +38 -36
- invarlock/reporting/primary_metric_utils.py +71 -17
- invarlock/reporting/render.py +852 -431
- invarlock/reporting/report.py +40 -4
- invarlock/reporting/report_types.py +11 -3
- invarlock/reporting/telemetry.py +86 -0
- invarlock/reporting/validate.py +1 -18
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/METADATA +27 -13
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/RECORD +72 -65
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/WHEEL +1 -1
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/entry_points.txt +5 -3
- invarlock/adapters/hf_gpt2.py +0 -404
- invarlock/adapters/hf_llama.py +0 -487
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/licenses/LICENSE +0 -0
- {invarlock-0.3.5.dist-info → invarlock-0.3.7.dist-info}/top_level.txt +0 -0
invarlock/edits/quant_rtn.py
CHANGED
|
@@ -86,6 +86,31 @@ class RTNQuantEdit(ModelEdit):
|
|
|
86
86
|
|
|
87
87
|
# group_size is currently reserved for potential future variants; it is
|
|
88
88
|
# ignored for the built-in INT8 demo edit.
|
|
89
|
+
self._emit_enabled = True
|
|
90
|
+
self._emit_console = None
|
|
91
|
+
self._output_style = None
|
|
92
|
+
|
|
93
|
+
def _configure_output(self, **kwargs: Any) -> None:
|
|
94
|
+
emit = kwargs.get("emit", True)
|
|
95
|
+
self._emit_enabled = bool(emit)
|
|
96
|
+
console = kwargs.get("console")
|
|
97
|
+
if console is not None and hasattr(console, "print"):
|
|
98
|
+
self._emit_console = console
|
|
99
|
+
else:
|
|
100
|
+
self._emit_console = None
|
|
101
|
+
self._output_style = kwargs.get("output_style")
|
|
102
|
+
|
|
103
|
+
def _emit(self, message: str) -> None:
|
|
104
|
+
if not self._emit_enabled:
|
|
105
|
+
return
|
|
106
|
+
line = f"[EDIT] {message}".rstrip()
|
|
107
|
+
if self._emit_console is not None:
|
|
108
|
+
try:
|
|
109
|
+
self._emit_console.print(line, markup=False)
|
|
110
|
+
except TypeError:
|
|
111
|
+
self._emit_console.print(line)
|
|
112
|
+
else:
|
|
113
|
+
print(line)
|
|
89
114
|
|
|
90
115
|
def can_edit(self, model_desc: dict[str, Any]) -> bool:
|
|
91
116
|
"""Check if RTN quantization can be applied to this model."""
|
|
@@ -233,15 +258,18 @@ class RTNQuantEdit(ModelEdit):
|
|
|
233
258
|
scope = kwargs.get("scope", self.scope)
|
|
234
259
|
seed = kwargs.get("seed", self.seed)
|
|
235
260
|
|
|
261
|
+
self._configure_output(**kwargs)
|
|
262
|
+
|
|
236
263
|
# Diagnostic reporting
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
264
|
+
self._emit("RTN Quantization Configuration:")
|
|
265
|
+
self._emit(
|
|
266
|
+
"Bitwidth: "
|
|
267
|
+
f"{bitwidth} (from config: {kwargs.get('bitwidth', kwargs.get('bits', 'default'))})"
|
|
240
268
|
)
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
269
|
+
self._emit(f"Scope: {scope}")
|
|
270
|
+
self._emit(f"Group size: {group_size}")
|
|
271
|
+
self._emit(f"Clamp ratio: {clamp_ratio}")
|
|
272
|
+
self._emit(f"Seed: {seed}")
|
|
245
273
|
|
|
246
274
|
# Persist configuration overrides for downstream helpers
|
|
247
275
|
self.bitwidth = bitwidth
|
|
@@ -256,22 +284,22 @@ class RTNQuantEdit(ModelEdit):
|
|
|
256
284
|
np.random.seed(seed)
|
|
257
285
|
|
|
258
286
|
# Identify target modules and get weight tying map
|
|
259
|
-
|
|
287
|
+
self._emit(f"Identifying target modules for scope '{scope}'...")
|
|
260
288
|
target_modules = self._identify_target_modules(model)
|
|
261
289
|
total_identified = len(target_modules)
|
|
262
290
|
|
|
263
291
|
max_modules = kwargs.get("max_modules")
|
|
264
292
|
if isinstance(max_modules, int) and max_modules > 0:
|
|
265
293
|
if max_modules < total_identified:
|
|
266
|
-
|
|
267
|
-
f"
|
|
294
|
+
self._emit(
|
|
295
|
+
f"Limiting quantization to first {max_modules} modules "
|
|
268
296
|
f"(of {total_identified}) based on plan.max_modules"
|
|
269
297
|
)
|
|
270
298
|
target_modules = target_modules[:max_modules]
|
|
271
299
|
self.max_modules = max_modules
|
|
272
300
|
else:
|
|
273
|
-
|
|
274
|
-
f"
|
|
301
|
+
self._emit(
|
|
302
|
+
f"max_modules={max_modules} >= available modules "
|
|
275
303
|
f"({total_identified}); using all targets"
|
|
276
304
|
)
|
|
277
305
|
self.max_modules = None
|
|
@@ -280,33 +308,35 @@ class RTNQuantEdit(ModelEdit):
|
|
|
280
308
|
|
|
281
309
|
tying_map = self._get_weight_tying_map(model)
|
|
282
310
|
|
|
283
|
-
|
|
311
|
+
self._emit(f"Found {len(target_modules)} target modules:")
|
|
284
312
|
for i, (name, module) in enumerate(target_modules):
|
|
285
313
|
weight_shape = module.weight.shape
|
|
286
314
|
param_count = module.weight.numel()
|
|
287
|
-
|
|
315
|
+
self._emit(f"[{i + 1}] {name}: {weight_shape} ({param_count:,} params)")
|
|
288
316
|
|
|
289
317
|
if len(target_modules) == 0:
|
|
290
|
-
|
|
291
|
-
|
|
318
|
+
self._emit(
|
|
319
|
+
"WARNING: No target modules found! Check scope configuration."
|
|
320
|
+
)
|
|
321
|
+
self._emit("Available linear modules:")
|
|
292
322
|
linear_modules = []
|
|
293
323
|
for name, module in model.named_modules():
|
|
294
324
|
if isinstance(module, nn.Linear | nn.Conv1d):
|
|
295
325
|
linear_modules.append((name, module.weight.shape))
|
|
296
326
|
for name, shape in linear_modules[:10]: # Show first 10
|
|
297
|
-
|
|
327
|
+
self._emit(f"{name}: {shape}")
|
|
298
328
|
if len(linear_modules) > 10:
|
|
299
|
-
|
|
329
|
+
self._emit(f"... and {len(linear_modules) - 10} more")
|
|
300
330
|
|
|
301
331
|
# Execute GuardChain before edit (if provided)
|
|
302
332
|
guard_results = {}
|
|
303
333
|
if self.guard_chain is not None:
|
|
304
|
-
|
|
334
|
+
self._emit("Executing guard chain preparation...")
|
|
305
335
|
guard_results["prepare"] = self.guard_chain.prepare_all(
|
|
306
336
|
model, adapter, None, {}
|
|
307
337
|
)
|
|
308
338
|
|
|
309
|
-
|
|
339
|
+
self._emit("Executing before-edit guards...")
|
|
310
340
|
self.guard_chain.before_edit_all(model)
|
|
311
341
|
|
|
312
342
|
# Apply quantization to each target module
|
|
@@ -314,12 +344,12 @@ class RTNQuantEdit(ModelEdit):
|
|
|
314
344
|
total_params_quantized = 0
|
|
315
345
|
|
|
316
346
|
for i, (module_name, module) in enumerate(target_modules):
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
f"
|
|
347
|
+
self._emit(f"[{i + 1}/{len(target_modules)}] Quantizing: {module_name}")
|
|
348
|
+
self._emit(
|
|
349
|
+
f"Shape: {module.weight.shape}, Params: {module.weight.numel():,}"
|
|
320
350
|
)
|
|
321
|
-
|
|
322
|
-
f"
|
|
351
|
+
self._emit(
|
|
352
|
+
f"Weight range: [{module.weight.min():.4f}, {module.weight.max():.4f}]"
|
|
323
353
|
)
|
|
324
354
|
|
|
325
355
|
# Apply RTN quantization
|
|
@@ -335,24 +365,22 @@ class RTNQuantEdit(ModelEdit):
|
|
|
335
365
|
quantization_results.append(quant_result)
|
|
336
366
|
total_params_quantized += quant_result["params_quantized"]
|
|
337
367
|
|
|
338
|
-
|
|
339
|
-
f" ✓ Quantized {quant_result['params_quantized']:,} parameters"
|
|
340
|
-
)
|
|
368
|
+
self._emit(f"Quantized {quant_result['params_quantized']:,} parameters")
|
|
341
369
|
|
|
342
370
|
# Execute GuardChain after edit (if provided)
|
|
343
371
|
if self.guard_chain is not None:
|
|
344
|
-
|
|
372
|
+
self._emit("Executing after-edit guards...")
|
|
345
373
|
self.guard_chain.after_edit_all(model)
|
|
346
374
|
|
|
347
|
-
|
|
375
|
+
self._emit("Finalizing guard chain...")
|
|
348
376
|
guard_results["finalize"] = self.guard_chain.finalize_all(model)
|
|
349
377
|
|
|
350
378
|
# Check if all guards passed
|
|
351
379
|
if not self.guard_chain.all_passed(guard_results["finalize"]):
|
|
352
|
-
|
|
380
|
+
self._emit("Guard chain validation failed!")
|
|
353
381
|
guard_results["all_passed"] = False
|
|
354
382
|
else:
|
|
355
|
-
|
|
383
|
+
self._emit("All guards passed")
|
|
356
384
|
guard_results["all_passed"] = True
|
|
357
385
|
|
|
358
386
|
# Create bitwidth map
|
|
@@ -490,11 +518,11 @@ class RTNQuantEdit(ModelEdit):
|
|
|
490
518
|
|
|
491
519
|
# Log diagnostic information
|
|
492
520
|
if skipped_modules:
|
|
493
|
-
|
|
521
|
+
self._emit(f"Skipped {len(skipped_modules)} modules:")
|
|
494
522
|
for name, reason in skipped_modules[:5]: # Show first 5
|
|
495
|
-
|
|
523
|
+
self._emit(f"{name}: {reason}")
|
|
496
524
|
if len(skipped_modules) > 5:
|
|
497
|
-
|
|
525
|
+
self._emit(f"... and {len(skipped_modules) - 5} more")
|
|
498
526
|
|
|
499
527
|
return target_modules
|
|
500
528
|
|
|
@@ -625,7 +653,7 @@ class RTNQuantEdit(ModelEdit):
|
|
|
625
653
|
# Ensure actual quantization occurred by applying quantization loss
|
|
626
654
|
# This guarantees the weights are actually modified
|
|
627
655
|
quantization_error = (quantized_weight - original_weight).abs().mean()
|
|
628
|
-
|
|
656
|
+
self._emit(f"Quantization error: {quantization_error:.6f}")
|
|
629
657
|
|
|
630
658
|
# Write back to module (preserving tying if needed)
|
|
631
659
|
module.weight.data.copy_(quantized_weight)
|
|
@@ -634,7 +662,7 @@ class RTNQuantEdit(ModelEdit):
|
|
|
634
662
|
final_weight = module.weight.data
|
|
635
663
|
actual_change = not torch.allclose(original_weight, final_weight, atol=1e-6)
|
|
636
664
|
if not actual_change:
|
|
637
|
-
|
|
665
|
+
self._emit(f"WARNING: No actual weight change detected for {module}")
|
|
638
666
|
|
|
639
667
|
# Handle tied weights
|
|
640
668
|
if tied_modules:
|
invarlock/eval/bench.py
CHANGED
|
@@ -47,7 +47,7 @@ class ScenarioConfig:
|
|
|
47
47
|
probes: int
|
|
48
48
|
profile: str = "ci" # "ci" or "release"
|
|
49
49
|
model_id: str = "gpt2"
|
|
50
|
-
adapter: str = "
|
|
50
|
+
adapter: str = "hf_causal"
|
|
51
51
|
device: str = "auto"
|
|
52
52
|
seq_len: int = 512
|
|
53
53
|
stride: int = 128
|
|
@@ -81,7 +81,7 @@ class BenchmarkConfig:
|
|
|
81
81
|
profile: str = "ci" # "ci" or "release"
|
|
82
82
|
dataset: str = "wikitext2"
|
|
83
83
|
model_id: str = "gpt2"
|
|
84
|
-
adapter: str = "
|
|
84
|
+
adapter: str = "hf_causal"
|
|
85
85
|
device: str = "auto"
|
|
86
86
|
seq_len: int = 512
|
|
87
87
|
stride: int = 128
|
|
@@ -92,7 +92,6 @@ class BenchmarkConfig:
|
|
|
92
92
|
epsilon: float | None = (
|
|
93
93
|
None # RMT deadband tolerance (None = use resolved deadband)
|
|
94
94
|
)
|
|
95
|
-
strict: bool = False # If True, sets epsilon = 0
|
|
96
95
|
ppl_overhead_threshold: float = 0.01 # 1%
|
|
97
96
|
guard_overhead_time_threshold: float = 0.15 # 15%
|
|
98
97
|
guard_overhead_mem_threshold: float = 0.10 # 10%
|
|
@@ -104,10 +103,6 @@ class BenchmarkConfig:
|
|
|
104
103
|
"""Apply post-initialization logic."""
|
|
105
104
|
self.output_dir = Path(self.output_dir)
|
|
106
105
|
|
|
107
|
-
# Handle strict mode
|
|
108
|
-
if self.strict:
|
|
109
|
-
self.epsilon = 0.0
|
|
110
|
-
|
|
111
106
|
|
|
112
107
|
@dataclass
|
|
113
108
|
class ScenarioResult:
|
|
@@ -1043,7 +1038,6 @@ def run_guard_effect_benchmark(
|
|
|
1043
1038
|
profile: str = "ci",
|
|
1044
1039
|
output_dir: str | Path = "benchmarks",
|
|
1045
1040
|
epsilon: float | None = None,
|
|
1046
|
-
strict: bool = False,
|
|
1047
1041
|
**kwargs,
|
|
1048
1042
|
) -> dict[str, Any]:
|
|
1049
1043
|
"""
|
|
@@ -1056,7 +1050,6 @@ def run_guard_effect_benchmark(
|
|
|
1056
1050
|
profile: "ci" (50/50 windows) or "release" (100/100 windows)
|
|
1057
1051
|
output_dir: Directory to save results
|
|
1058
1052
|
epsilon: Optional epsilon override
|
|
1059
|
-
strict: If True, sets epsilon = 0
|
|
1060
1053
|
**kwargs: Additional configuration options
|
|
1061
1054
|
|
|
1062
1055
|
Returns:
|
|
@@ -1075,7 +1068,6 @@ def run_guard_effect_benchmark(
|
|
|
1075
1068
|
profile=profile,
|
|
1076
1069
|
output_dir=Path(output_dir),
|
|
1077
1070
|
epsilon=epsilon,
|
|
1078
|
-
strict=strict,
|
|
1079
1071
|
**kwargs,
|
|
1080
1072
|
)
|
|
1081
1073
|
|
|
@@ -1384,7 +1376,6 @@ def _config_to_dict(config: BenchmarkConfig) -> dict[str, Any]:
|
|
|
1384
1376
|
"stride": config.stride,
|
|
1385
1377
|
"seed": config.seed,
|
|
1386
1378
|
"epsilon": config.epsilon,
|
|
1387
|
-
"strict": config.strict,
|
|
1388
1379
|
"ppl_overhead_threshold": config.ppl_overhead_threshold,
|
|
1389
1380
|
"guard_overhead_time_threshold": config.guard_overhead_time_threshold,
|
|
1390
1381
|
"guard_overhead_mem_threshold": config.guard_overhead_mem_threshold,
|
|
@@ -1426,16 +1417,13 @@ def main():
|
|
|
1426
1417
|
type=float,
|
|
1427
1418
|
help="RMT outliers epsilon threshold (default: use resolved RMT deadband)",
|
|
1428
1419
|
)
|
|
1429
|
-
parser.add_argument(
|
|
1430
|
-
"--strict", action="store_true", help="Set epsilon=0 (overrides --epsilon)"
|
|
1431
|
-
)
|
|
1432
1420
|
|
|
1433
1421
|
# Model and dataset configuration
|
|
1434
1422
|
parser.add_argument(
|
|
1435
1423
|
"--dataset", default="wikitext2", help="Dataset to use for benchmarking"
|
|
1436
1424
|
)
|
|
1437
1425
|
parser.add_argument("--model-id", default="gpt2", help="Model identifier")
|
|
1438
|
-
parser.add_argument("--adapter", default="
|
|
1426
|
+
parser.add_argument("--adapter", default="hf_causal", help="Model adapter to use")
|
|
1439
1427
|
parser.add_argument(
|
|
1440
1428
|
"--device", default="auto", help="Device to use (auto|cuda|mps|cpu)"
|
|
1441
1429
|
)
|
|
@@ -1505,7 +1493,6 @@ def main():
|
|
|
1505
1493
|
profile=args.profile,
|
|
1506
1494
|
output_dir=args.out,
|
|
1507
1495
|
epsilon=args.epsilon,
|
|
1508
|
-
strict=args.strict,
|
|
1509
1496
|
**kwargs,
|
|
1510
1497
|
)
|
|
1511
1498
|
|
invarlock/eval/data.py
CHANGED
|
@@ -15,7 +15,7 @@ import time
|
|
|
15
15
|
import warnings
|
|
16
16
|
from abc import abstractmethod
|
|
17
17
|
from collections import Counter
|
|
18
|
-
from collections.abc import Sequence
|
|
18
|
+
from collections.abc import Callable, Sequence
|
|
19
19
|
from pathlib import Path
|
|
20
20
|
from typing import Any, NamedTuple, Protocol
|
|
21
21
|
|
|
@@ -56,6 +56,9 @@ except ImportError:
|
|
|
56
56
|
HAS_TORCH = False
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
EventEmitter = Callable[[str, str, str | None], None]
|
|
60
|
+
|
|
61
|
+
|
|
59
62
|
class EvaluationWindow(NamedTuple):
|
|
60
63
|
"""A window of tokenized samples for evaluation."""
|
|
61
64
|
|
|
@@ -166,6 +169,7 @@ class WikiText2Provider:
|
|
|
166
169
|
self,
|
|
167
170
|
cache_dir: Path | None = None,
|
|
168
171
|
device_hint: str | None = None,
|
|
172
|
+
emit: EventEmitter | None = None,
|
|
169
173
|
**_: Any,
|
|
170
174
|
):
|
|
171
175
|
"""
|
|
@@ -175,6 +179,7 @@ class WikiText2Provider:
|
|
|
175
179
|
cache_dir: Optional cache directory for dataset storage
|
|
176
180
|
"""
|
|
177
181
|
self.cache_dir = cache_dir
|
|
182
|
+
self._emit_event = emit
|
|
178
183
|
self._validate_dependencies()
|
|
179
184
|
self._last_stratification_stats: dict[str, Any] | None = None
|
|
180
185
|
self._last_batch_size_used: int = 0
|
|
@@ -186,11 +191,23 @@ class WikiText2Provider:
|
|
|
186
191
|
normalized_hint = (device_hint or "").strip().lower()
|
|
187
192
|
self._device_hint: str | None = normalized_hint or None
|
|
188
193
|
|
|
194
|
+
def _event(self, tag: str, message: str, *, emoji: str | None = None) -> None:
|
|
195
|
+
"""Emit a dataset event via an optional CLI-provided sink."""
|
|
196
|
+
if self._emit_event is None:
|
|
197
|
+
if emoji:
|
|
198
|
+
print(f"{emoji} {message}")
|
|
199
|
+
else:
|
|
200
|
+
print(message)
|
|
201
|
+
return
|
|
202
|
+
try:
|
|
203
|
+
self._emit_event(tag, message, emoji)
|
|
204
|
+
except TypeError:
|
|
205
|
+
# Back-compat: tolerate sinks that only accept (tag, message).
|
|
206
|
+
self._emit_event(tag, message) # type: ignore[misc]
|
|
207
|
+
|
|
189
208
|
def _validate_dependencies(self) -> None:
|
|
190
209
|
"""Check that required dependencies are available."""
|
|
191
210
|
if not HAS_DATASETS:
|
|
192
|
-
if _LIGHT_IMPORT:
|
|
193
|
-
return
|
|
194
211
|
raise _DepErr(
|
|
195
212
|
code="E301",
|
|
196
213
|
message=(
|
|
@@ -321,20 +338,17 @@ class WikiText2Provider:
|
|
|
321
338
|
Returns:
|
|
322
339
|
List of filtered text strings
|
|
323
340
|
"""
|
|
324
|
-
|
|
341
|
+
self._event(
|
|
342
|
+
"DATA",
|
|
343
|
+
f"WikiText-2 {split}: loading split...",
|
|
344
|
+
emoji="📚",
|
|
345
|
+
)
|
|
325
346
|
|
|
326
347
|
# Serve from cache when possible (load the largest slice once)
|
|
327
348
|
cached = self._texts_cache.get(split)
|
|
328
349
|
if cached is not None and len(cached) >= max_samples:
|
|
329
350
|
return cached[:max_samples]
|
|
330
351
|
|
|
331
|
-
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
332
|
-
texts = ["hello world", "invarlock synthetic text"] * max(
|
|
333
|
-
1, max_samples // 2
|
|
334
|
-
)
|
|
335
|
-
self._texts_cache[split] = texts
|
|
336
|
-
return texts[:max_samples]
|
|
337
|
-
|
|
338
352
|
# Load dataset with size limit for efficiency
|
|
339
353
|
dataset_slice = f"{split}[:{max_samples}]" if max_samples > 0 else split
|
|
340
354
|
dataset = load_dataset(
|
|
@@ -375,7 +389,10 @@ class WikiText2Provider:
|
|
|
375
389
|
if prev is None or len(valid_texts) > len(prev):
|
|
376
390
|
self._texts_cache[split] = list(valid_texts)
|
|
377
391
|
|
|
378
|
-
|
|
392
|
+
self._event(
|
|
393
|
+
"DATA",
|
|
394
|
+
f"Loaded {len(valid_texts)}/{len(dataset)} valid samples",
|
|
395
|
+
)
|
|
379
396
|
return valid_texts
|
|
380
397
|
|
|
381
398
|
def windows(
|
|
@@ -444,9 +461,13 @@ class WikiText2Provider:
|
|
|
444
461
|
cursor = 0
|
|
445
462
|
chunk_size = max(64, min(256, target_pool))
|
|
446
463
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
464
|
+
self._event(
|
|
465
|
+
"DATA",
|
|
466
|
+
"Creating evaluation windows:",
|
|
467
|
+
emoji="📊",
|
|
468
|
+
)
|
|
469
|
+
self._event("DATA", f"Requested preview/final: {preview_n}/{final_n}")
|
|
470
|
+
self._event("DATA", f"Sampling pool target: {target_pool} (reserve {reserve})")
|
|
450
471
|
|
|
451
472
|
while len(candidates) < total_required + reserve and cursor < len(
|
|
452
473
|
shuffled_indices
|
|
@@ -717,9 +738,9 @@ class WikiText2Provider:
|
|
|
717
738
|
),
|
|
718
739
|
}
|
|
719
740
|
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
741
|
+
self._event("DATA", f"Seed: {seed}, Seq length: {seq_len}")
|
|
742
|
+
self._event("DATA", f"Preview: {len(preview_window)} samples")
|
|
743
|
+
self._event("DATA", f"Final: {len(final_window)} samples")
|
|
723
744
|
|
|
724
745
|
return preview_window, final_window
|
|
725
746
|
|
|
@@ -849,8 +870,9 @@ class WikiText2Provider:
|
|
|
849
870
|
attention_masks_list = [entry[2] for entry in collected]
|
|
850
871
|
valid_indices = [entry[0] for entry in collected]
|
|
851
872
|
|
|
852
|
-
|
|
853
|
-
|
|
873
|
+
self._event(
|
|
874
|
+
"DATA",
|
|
875
|
+
f"{window_name}: {len(valid_indices)}/{len(indices)} samples tokenized",
|
|
854
876
|
)
|
|
855
877
|
|
|
856
878
|
return EvaluationWindow(
|
|
@@ -943,7 +965,8 @@ class SyntheticProvider:
|
|
|
943
965
|
self, split: str = "validation", max_samples: int = 500, **kwargs
|
|
944
966
|
) -> list[str]:
|
|
945
967
|
"""Generate synthetic text samples."""
|
|
946
|
-
# Expand base samples to meet requirement
|
|
968
|
+
# Expand base samples to meet requirement, preferring unique variations
|
|
969
|
+
# to avoid duplicate-token windows (important for stratified pairing).
|
|
947
970
|
expanded_samples: list[str] = []
|
|
948
971
|
variations = [
|
|
949
972
|
lambda s: s,
|
|
@@ -953,18 +976,25 @@ class SyntheticProvider:
|
|
|
953
976
|
lambda s: f"Furthermore, {s.lower()}",
|
|
954
977
|
lambda s: f"In addition, {s.lower()}",
|
|
955
978
|
]
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
rng = np.random.RandomState(42) # Fixed seed for reproducibility
|
|
959
|
-
|
|
960
|
-
while len(expanded_samples) < max_samples:
|
|
979
|
+
# Deterministic coverage of (variation × base sample) combinations first.
|
|
980
|
+
for variation in variations:
|
|
961
981
|
for base_text in self.base_samples:
|
|
962
|
-
if len(expanded_samples) >= max_samples:
|
|
963
|
-
break
|
|
964
|
-
variation = rng.choice(variations)
|
|
965
982
|
expanded_samples.append(variation(base_text))
|
|
983
|
+
if len(expanded_samples) >= max_samples:
|
|
984
|
+
return expanded_samples
|
|
985
|
+
|
|
986
|
+
# If callers request more than the unique combination space, keep
|
|
987
|
+
# extending deterministically while ensuring uniqueness via a suffix.
|
|
988
|
+
idx = 0
|
|
989
|
+
while len(expanded_samples) < max_samples:
|
|
990
|
+
base_text = self.base_samples[idx % len(self.base_samples)]
|
|
991
|
+
variation = variations[(idx // len(self.base_samples)) % len(variations)]
|
|
992
|
+
expanded_samples.append(
|
|
993
|
+
f"{variation(base_text)} [synthetic #{len(expanded_samples)}]"
|
|
994
|
+
)
|
|
995
|
+
idx += 1
|
|
966
996
|
|
|
967
|
-
return expanded_samples
|
|
997
|
+
return expanded_samples
|
|
968
998
|
|
|
969
999
|
def windows(
|
|
970
1000
|
self,
|
|
@@ -1062,14 +1092,13 @@ class HFTextProvider:
|
|
|
1062
1092
|
max_samples: int = 2000,
|
|
1063
1093
|
):
|
|
1064
1094
|
if not HAS_DATASETS:
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
)
|
|
1095
|
+
raise _DepErr(
|
|
1096
|
+
code="E301",
|
|
1097
|
+
message=(
|
|
1098
|
+
"DEPENDENCY-MISSING: datasets library required for hf_text provider"
|
|
1099
|
+
),
|
|
1100
|
+
details={"dependency": "datasets"},
|
|
1101
|
+
)
|
|
1073
1102
|
self.dataset_name = dataset_name or "wikitext"
|
|
1074
1103
|
self.config_name = config_name or None
|
|
1075
1104
|
self.text_field = text_field
|
|
@@ -1077,9 +1106,6 @@ class HFTextProvider:
|
|
|
1077
1106
|
self.max_samples = int(max_samples)
|
|
1078
1107
|
|
|
1079
1108
|
def load(self, split: str = "validation", **kwargs) -> list[str]:
|
|
1080
|
-
if not HAS_DATASETS and _LIGHT_IMPORT:
|
|
1081
|
-
return ["synthetic dataset text"] * int(self.max_samples or 1)
|
|
1082
|
-
|
|
1083
1109
|
ds = load_dataset(
|
|
1084
1110
|
path=self.dataset_name,
|
|
1085
1111
|
name=self.config_name,
|
|
@@ -1204,14 +1230,13 @@ class HFSeq2SeqProvider:
|
|
|
1204
1230
|
max_samples: int = 2000,
|
|
1205
1231
|
) -> None:
|
|
1206
1232
|
if not HAS_DATASETS:
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
)
|
|
1233
|
+
raise _DepErr(
|
|
1234
|
+
code="E301",
|
|
1235
|
+
message=(
|
|
1236
|
+
"DEPENDENCY-MISSING: datasets library required for hf_seq2seq provider"
|
|
1237
|
+
),
|
|
1238
|
+
details={"dependency": "datasets"},
|
|
1239
|
+
)
|
|
1215
1240
|
self.dataset_name = dataset_name
|
|
1216
1241
|
self.config_name = config_name
|
|
1217
1242
|
self.src_field = src_field
|
|
@@ -1815,12 +1840,15 @@ _PROVIDERS: dict[str, type] = {
|
|
|
1815
1840
|
}
|
|
1816
1841
|
|
|
1817
1842
|
|
|
1818
|
-
def get_provider(
|
|
1843
|
+
def get_provider(
|
|
1844
|
+
name: str, *, emit: EventEmitter | None = None, **kwargs: Any
|
|
1845
|
+
) -> DatasetProvider:
|
|
1819
1846
|
"""
|
|
1820
1847
|
Get a dataset provider by name.
|
|
1821
1848
|
|
|
1822
1849
|
Args:
|
|
1823
1850
|
name: Provider name ("wikitext2", "synthetic")
|
|
1851
|
+
emit: Optional event sink for dataset/provider logs.
|
|
1824
1852
|
**kwargs: Provider-specific initialization parameters
|
|
1825
1853
|
|
|
1826
1854
|
Returns:
|
|
@@ -1839,7 +1867,10 @@ def get_provider(name: str, **kwargs) -> DatasetProvider:
|
|
|
1839
1867
|
)
|
|
1840
1868
|
|
|
1841
1869
|
provider_class = _PROVIDERS[name]
|
|
1842
|
-
|
|
1870
|
+
init_kwargs = dict(kwargs)
|
|
1871
|
+
if emit is not None and name == "wikitext2":
|
|
1872
|
+
init_kwargs["emit"] = emit
|
|
1873
|
+
return provider_class(**init_kwargs)
|
|
1843
1874
|
|
|
1844
1875
|
|
|
1845
1876
|
def list_providers() -> list[str]:
|
invarlock/eval/metrics.py
CHANGED
|
@@ -18,9 +18,10 @@ import gc
|
|
|
18
18
|
import logging
|
|
19
19
|
import math
|
|
20
20
|
import time
|
|
21
|
+
from collections.abc import Iterable
|
|
21
22
|
from dataclasses import dataclass
|
|
22
23
|
from pathlib import Path
|
|
23
|
-
from typing import Any
|
|
24
|
+
from typing import Any, Protocol
|
|
24
25
|
|
|
25
26
|
import numpy as np
|
|
26
27
|
import psutil
|
|
@@ -723,7 +724,10 @@ def calculate_lens_metrics_for_model(
|
|
|
723
724
|
except Exception as e:
|
|
724
725
|
logger.error(f"Metrics calculation failed: {e}")
|
|
725
726
|
if config.strict_validation:
|
|
726
|
-
raise MetricsError(
|
|
727
|
+
raise MetricsError(
|
|
728
|
+
code="E401",
|
|
729
|
+
message=f"METRICS-COMPUTE-FAILED: {e}",
|
|
730
|
+
) from e
|
|
727
731
|
|
|
728
732
|
finally:
|
|
729
733
|
resource_manager.cleanup()
|
|
@@ -2266,6 +2270,57 @@ def analyze_rmt_changes(
|
|
|
2266
2270
|
return {"error": str(e)}
|
|
2267
2271
|
|
|
2268
2272
|
|
|
2273
|
+
class Metric(Protocol):
|
|
2274
|
+
name: str
|
|
2275
|
+
kind: str # "ppl", "accuracy", "exact_match", "bleu", "rouge"
|
|
2276
|
+
|
|
2277
|
+
def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: ...
|
|
2278
|
+
|
|
2279
|
+
|
|
2280
|
+
class PerplexityMetric:
|
|
2281
|
+
"""Lightweight perplexity metric from per-record logloss + token counts."""
|
|
2282
|
+
|
|
2283
|
+
name = "perplexity"
|
|
2284
|
+
kind = "ppl"
|
|
2285
|
+
|
|
2286
|
+
def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
|
|
2287
|
+
total_loss = 0.0
|
|
2288
|
+
total_tokens = 0.0
|
|
2289
|
+
for record in dataset:
|
|
2290
|
+
if not isinstance(record, dict):
|
|
2291
|
+
continue
|
|
2292
|
+
loss = record.get("logloss", record.get("loss"))
|
|
2293
|
+
tokens = record.get("token_count", record.get("tokens", 1))
|
|
2294
|
+
try:
|
|
2295
|
+
loss_val = float(loss)
|
|
2296
|
+
tok_val = float(tokens)
|
|
2297
|
+
except Exception:
|
|
2298
|
+
continue
|
|
2299
|
+
if (
|
|
2300
|
+
not math.isfinite(loss_val)
|
|
2301
|
+
or not math.isfinite(tok_val)
|
|
2302
|
+
or tok_val <= 0
|
|
2303
|
+
):
|
|
2304
|
+
continue
|
|
2305
|
+
total_loss += loss_val * tok_val
|
|
2306
|
+
total_tokens += tok_val
|
|
2307
|
+
if total_tokens <= 0:
|
|
2308
|
+
return float("nan")
|
|
2309
|
+
return float(math.exp(total_loss / total_tokens))
|
|
2310
|
+
|
|
2311
|
+
|
|
2312
|
+
class AccuracyMetric:
|
|
2313
|
+
"""Classification accuracy metric from label/prediction records."""
|
|
2314
|
+
|
|
2315
|
+
name = "accuracy"
|
|
2316
|
+
kind = "accuracy"
|
|
2317
|
+
|
|
2318
|
+
def compute(self, model: Any, dataset: Iterable[dict[str, Any]]) -> float: # noqa: ARG002
|
|
2319
|
+
from invarlock.eval.tasks.classification import accuracy_from_records
|
|
2320
|
+
|
|
2321
|
+
return accuracy_from_records(dataset)
|
|
2322
|
+
|
|
2323
|
+
|
|
2269
2324
|
# ── Integration with existing system ───────────────────────────────────────
|
|
2270
2325
|
|
|
2271
2326
|
# Update exports to include new functions (add to existing __all__ if it exists)
|
|
@@ -2279,6 +2334,9 @@ try:
|
|
|
2279
2334
|
"compute_parameter_deltas",
|
|
2280
2335
|
"analyze_spectral_changes",
|
|
2281
2336
|
"analyze_rmt_changes",
|
|
2337
|
+
"Metric",
|
|
2338
|
+
"PerplexityMetric",
|
|
2339
|
+
"AccuracyMetric",
|
|
2282
2340
|
]
|
|
2283
2341
|
)
|
|
2284
2342
|
except NameError:
|
|
@@ -2291,4 +2349,7 @@ except NameError:
|
|
|
2291
2349
|
"compute_parameter_deltas",
|
|
2292
2350
|
"analyze_spectral_changes",
|
|
2293
2351
|
"analyze_rmt_changes",
|
|
2352
|
+
"Metric",
|
|
2353
|
+
"PerplexityMetric",
|
|
2354
|
+
"AccuracyMetric",
|
|
2294
2355
|
]
|