wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.1.dist-info/METADATA +67 -0
- wisent-0.5.1.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Diagnostics for steering/control vectors."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import statistics
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Mapping
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from wisent_guard.core.activations.core.atoms import LayerActivations, RawActivationMap
|
|
12
|
+
|
|
13
|
+
from .base import DiagnosticsIssue, DiagnosticsReport, MetricReport
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ControlVectorDiagnosticsConfig",
|
|
17
|
+
"run_control_vector_diagnostics",
|
|
18
|
+
"run_control_steering_diagnostics",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(slots=True)
|
|
23
|
+
class ControlVectorDiagnosticsConfig:
|
|
24
|
+
"""Thresholds and options for control vector diagnostics."""
|
|
25
|
+
|
|
26
|
+
min_norm: float = 1e-4
|
|
27
|
+
max_norm: float | None = None
|
|
28
|
+
zero_value_threshold: float = 1e-8
|
|
29
|
+
max_zero_fraction: float = 0.999
|
|
30
|
+
warn_on_missing: bool = True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _to_layer_activations(vectors: LayerActivations | RawActivationMap | Mapping[str, object] | None) -> LayerActivations:
|
|
34
|
+
if isinstance(vectors, LayerActivations):
|
|
35
|
+
return vectors
|
|
36
|
+
data: RawActivationMap = vectors or {}
|
|
37
|
+
return LayerActivations(data)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def run_control_vector_diagnostics(
|
|
41
|
+
vectors: LayerActivations | RawActivationMap | Mapping[str, object] | None,
|
|
42
|
+
config: ControlVectorDiagnosticsConfig | None = None,
|
|
43
|
+
) -> DiagnosticsReport:
|
|
44
|
+
"""Evaluate steering/control vectors for basic health metrics."""
|
|
45
|
+
|
|
46
|
+
cfg = config or ControlVectorDiagnosticsConfig()
|
|
47
|
+
activations = _to_layer_activations(vectors)
|
|
48
|
+
|
|
49
|
+
issues: list[DiagnosticsIssue] = []
|
|
50
|
+
norms: list[float] = []
|
|
51
|
+
zero_fractions: list[float] = []
|
|
52
|
+
per_layer: dict[str, dict[str, float]] = {}
|
|
53
|
+
|
|
54
|
+
for layer, tensor in activations.to_dict().items():
|
|
55
|
+
if tensor is None:
|
|
56
|
+
if cfg.warn_on_missing:
|
|
57
|
+
issues.append(
|
|
58
|
+
DiagnosticsIssue(
|
|
59
|
+
metric="control_vectors",
|
|
60
|
+
severity="warning",
|
|
61
|
+
message=f"Layer {layer} has no control vector",
|
|
62
|
+
details={"layer": layer},
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
detached = tensor.detach()
|
|
68
|
+
if detached.numel() == 0:
|
|
69
|
+
issues.append(
|
|
70
|
+
DiagnosticsIssue(
|
|
71
|
+
metric="control_vectors",
|
|
72
|
+
severity="critical",
|
|
73
|
+
message=f"Layer {layer} control vector is empty",
|
|
74
|
+
details={"layer": layer},
|
|
75
|
+
)
|
|
76
|
+
)
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
flat = detached.to(dtype=torch.float32, device="cpu").reshape(-1)
|
|
80
|
+
|
|
81
|
+
if not torch.isfinite(flat).all():
|
|
82
|
+
non_finite = (~torch.isfinite(flat)).sum().item()
|
|
83
|
+
issues.append(
|
|
84
|
+
DiagnosticsIssue(
|
|
85
|
+
metric="control_vectors",
|
|
86
|
+
severity="critical",
|
|
87
|
+
message=f"Layer {layer} contains non-finite values",
|
|
88
|
+
details={"layer": layer, "non_finite_entries": int(non_finite)},
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
norm_value = float(torch.linalg.vector_norm(flat).item())
|
|
94
|
+
norms.append(norm_value)
|
|
95
|
+
|
|
96
|
+
zero_fraction = float((flat.abs() <= cfg.zero_value_threshold).sum().item()) / float(flat.numel())
|
|
97
|
+
zero_fractions.append(zero_fraction)
|
|
98
|
+
|
|
99
|
+
per_layer[layer] = {
|
|
100
|
+
"norm": norm_value,
|
|
101
|
+
"zero_fraction": zero_fraction,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
if norm_value < cfg.min_norm:
|
|
105
|
+
issues.append(
|
|
106
|
+
DiagnosticsIssue(
|
|
107
|
+
metric="control_vectors",
|
|
108
|
+
severity="critical",
|
|
109
|
+
message=f"Layer {layer} control vector norm {norm_value:.3e} below minimum {cfg.min_norm}",
|
|
110
|
+
details={"layer": layer, "norm": norm_value},
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if cfg.max_norm is not None and norm_value > cfg.max_norm:
|
|
115
|
+
issues.append(
|
|
116
|
+
DiagnosticsIssue(
|
|
117
|
+
metric="control_vectors",
|
|
118
|
+
severity="warning",
|
|
119
|
+
message=f"Layer {layer} control vector norm {norm_value:.3e} exceeds maximum {cfg.max_norm}",
|
|
120
|
+
details={"layer": layer, "norm": norm_value},
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if zero_fraction >= cfg.max_zero_fraction:
|
|
125
|
+
severity = "critical" if zero_fraction >= 1.0 - 1e-9 else "warning"
|
|
126
|
+
issues.append(
|
|
127
|
+
DiagnosticsIssue(
|
|
128
|
+
metric="control_vectors",
|
|
129
|
+
severity=severity,
|
|
130
|
+
message=(
|
|
131
|
+
f"Layer {layer} control vector is {zero_fraction:.3%} zero-valued, exceeding allowed {cfg.max_zero_fraction:.3%}"
|
|
132
|
+
),
|
|
133
|
+
details={"layer": layer, "zero_fraction": zero_fraction},
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
summary: dict[str, object] = {
|
|
138
|
+
"evaluated_layers": len(norms),
|
|
139
|
+
"norm_min": min(norms) if norms else None,
|
|
140
|
+
"norm_max": max(norms) if norms else None,
|
|
141
|
+
"norm_mean": statistics.mean(norms) if norms else None,
|
|
142
|
+
"norm_median": statistics.median(norms) if norms else None,
|
|
143
|
+
"zero_fraction_max": max(zero_fractions) if zero_fractions else None,
|
|
144
|
+
"per_layer": per_layer,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if not norms and not issues:
|
|
148
|
+
issues.append(
|
|
149
|
+
DiagnosticsIssue(
|
|
150
|
+
metric="control_vectors",
|
|
151
|
+
severity="critical",
|
|
152
|
+
message="No control vectors were provided for diagnostics",
|
|
153
|
+
details={},
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
report = MetricReport(name="control_vectors", summary=summary, issues=issues)
|
|
158
|
+
return DiagnosticsReport.from_metrics([report])
|
|
159
|
+
|
|
160
|
+
def run_control_steering_diagnostics(steering_vectors: list[RawActivationMap] | RawActivationMap | None) -> list[DiagnosticsReport]:
|
|
161
|
+
if steering_vectors is None:
|
|
162
|
+
return [DiagnosticsReport.from_metrics([])]
|
|
163
|
+
|
|
164
|
+
if not isinstance(steering_vectors, list):
|
|
165
|
+
steering_vectors = [steering_vectors]
|
|
166
|
+
|
|
167
|
+
# Run diagnostics for each steering vector
|
|
168
|
+
reports = [run_control_vector_diagnostics(vec) for vec in steering_vectors]
|
|
169
|
+
return reports
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Coverage and diversity diagnostics for contrastive pairs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from statistics import mean
|
|
6
|
+
from typing import Iterable, List
|
|
7
|
+
|
|
8
|
+
from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compute_coverage_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
|
|
12
|
+
"""Assess dataset coverage such as prompt diversity and response length."""
|
|
13
|
+
|
|
14
|
+
pairs_list = list(pairs)
|
|
15
|
+
|
|
16
|
+
if not pairs_list:
|
|
17
|
+
return MetricReport(name="coverage", summary={"total_pairs": 0}, issues=[])
|
|
18
|
+
|
|
19
|
+
unique_prompts = {getattr(pair, "prompt", "").strip().lower() for pair in pairs_list}
|
|
20
|
+
prompt_ratio = len(unique_prompts) / len(pairs_list)
|
|
21
|
+
|
|
22
|
+
pos_lengths: List[int] = []
|
|
23
|
+
neg_lengths: List[int] = []
|
|
24
|
+
labels = set()
|
|
25
|
+
|
|
26
|
+
for pair in pairs_list:
|
|
27
|
+
pos_text = getattr(pair.positive_response, "model_response", "")
|
|
28
|
+
neg_text = getattr(pair.negative_response, "model_response", "")
|
|
29
|
+
pos_lengths.append(len(pos_text))
|
|
30
|
+
neg_lengths.append(len(neg_text))
|
|
31
|
+
|
|
32
|
+
if pair.label:
|
|
33
|
+
labels.add(pair.label)
|
|
34
|
+
|
|
35
|
+
avg_positive_length = mean(pos_lengths) if pos_lengths else 0.0
|
|
36
|
+
avg_negative_length = mean(neg_lengths) if neg_lengths else 0.0
|
|
37
|
+
|
|
38
|
+
issues: List[DiagnosticsIssue] = []
|
|
39
|
+
|
|
40
|
+
if prompt_ratio < config.min_unique_prompt_ratio:
|
|
41
|
+
issues.append(
|
|
42
|
+
DiagnosticsIssue(
|
|
43
|
+
metric="coverage",
|
|
44
|
+
severity="warning",
|
|
45
|
+
message="Prompt diversity below configured ratio.",
|
|
46
|
+
pair_index=None,
|
|
47
|
+
details={
|
|
48
|
+
"ratio": prompt_ratio,
|
|
49
|
+
"threshold": config.min_unique_prompt_ratio,
|
|
50
|
+
"unique_prompts": len(unique_prompts),
|
|
51
|
+
"total_pairs": len(pairs_list),
|
|
52
|
+
},
|
|
53
|
+
)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if avg_positive_length < config.min_average_length or avg_negative_length < config.min_average_length:
|
|
57
|
+
issues.append(
|
|
58
|
+
DiagnosticsIssue(
|
|
59
|
+
metric="coverage",
|
|
60
|
+
severity="warning",
|
|
61
|
+
message="Average response length below minimum threshold.",
|
|
62
|
+
pair_index=None,
|
|
63
|
+
details={
|
|
64
|
+
"avg_positive_length": avg_positive_length,
|
|
65
|
+
"avg_negative_length": avg_negative_length,
|
|
66
|
+
"threshold": config.min_average_length,
|
|
67
|
+
},
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
summary = {
|
|
72
|
+
"total_pairs": len(pairs_list),
|
|
73
|
+
"unique_prompt_ratio": prompt_ratio,
|
|
74
|
+
"avg_positive_length": avg_positive_length,
|
|
75
|
+
"avg_negative_length": avg_negative_length,
|
|
76
|
+
"label_coverage": len(labels),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return MetricReport(name="coverage", summary=summary, issues=issues)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Divergence diagnostics for contrastive pairs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from difflib import SequenceMatcher
|
|
6
|
+
from statistics import mean
|
|
7
|
+
from typing import Iterable, List
|
|
8
|
+
|
|
9
|
+
from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _normalize_text(text: str) -> str:
|
|
13
|
+
return " ".join(text.strip().lower().split())
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def compute_divergence_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
|
|
17
|
+
"""Evaluate textual divergence between positive and negative responses."""
|
|
18
|
+
|
|
19
|
+
pairs_list = list(pairs)
|
|
20
|
+
|
|
21
|
+
divergences: List[float] = []
|
|
22
|
+
issues: List[DiagnosticsIssue] = []
|
|
23
|
+
|
|
24
|
+
if not pairs_list:
|
|
25
|
+
return MetricReport(
|
|
26
|
+
name="divergence",
|
|
27
|
+
summary={
|
|
28
|
+
"mean_divergence": 0.0,
|
|
29
|
+
"min_divergence": 0.0,
|
|
30
|
+
"max_divergence": 0.0,
|
|
31
|
+
"low_divergence_fraction": 0.0,
|
|
32
|
+
},
|
|
33
|
+
issues=[],
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
for idx, pair in enumerate(pairs_list):
|
|
37
|
+
positive = getattr(pair.positive_response, "model_response", "")
|
|
38
|
+
negative = getattr(pair.negative_response, "model_response", "")
|
|
39
|
+
|
|
40
|
+
norm_pos = _normalize_text(positive)
|
|
41
|
+
norm_neg = _normalize_text(negative)
|
|
42
|
+
|
|
43
|
+
if not norm_pos or not norm_neg:
|
|
44
|
+
issues.append(
|
|
45
|
+
DiagnosticsIssue(
|
|
46
|
+
metric="divergence",
|
|
47
|
+
severity="critical",
|
|
48
|
+
message="Missing positive or negative response text.",
|
|
49
|
+
pair_index=idx,
|
|
50
|
+
details={"positive": bool(norm_pos), "negative": bool(norm_neg)},
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
divergences.append(0.0)
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
similarity = SequenceMatcher(None, norm_pos, norm_neg).ratio()
|
|
57
|
+
divergence = 1.0 - similarity
|
|
58
|
+
divergences.append(divergence)
|
|
59
|
+
|
|
60
|
+
if divergence < config.min_divergence:
|
|
61
|
+
issues.append(
|
|
62
|
+
DiagnosticsIssue(
|
|
63
|
+
metric="divergence",
|
|
64
|
+
severity="warning",
|
|
65
|
+
message="Positive and negative responses are highly similar.",
|
|
66
|
+
pair_index=idx,
|
|
67
|
+
details={"divergence": divergence, "similarity": similarity},
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
low_divergence_fraction = 0.0
|
|
72
|
+
low_divergence_count = sum(1 for value in divergences if value < config.min_divergence)
|
|
73
|
+
low_divergence_fraction = low_divergence_count / len(divergences)
|
|
74
|
+
|
|
75
|
+
if low_divergence_fraction > config.max_low_divergence_fraction:
|
|
76
|
+
issues.append(
|
|
77
|
+
DiagnosticsIssue(
|
|
78
|
+
metric="divergence",
|
|
79
|
+
severity="critical",
|
|
80
|
+
message="Too many pairs fall below divergence threshold.",
|
|
81
|
+
pair_index=None,
|
|
82
|
+
details={
|
|
83
|
+
"fraction": low_divergence_fraction,
|
|
84
|
+
"threshold": config.max_low_divergence_fraction,
|
|
85
|
+
"count": low_divergence_count,
|
|
86
|
+
"total": len(divergences),
|
|
87
|
+
},
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
summary = {
|
|
92
|
+
"mean_divergence": mean(divergences) if divergences else 0.0,
|
|
93
|
+
"min_divergence": min(divergences) if divergences else 0.0,
|
|
94
|
+
"max_divergence": max(divergences) if divergences else 0.0,
|
|
95
|
+
"low_divergence_fraction": low_divergence_fraction,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
return MetricReport(name="divergence", summary=summary, issues=issues)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Duplicate detection diagnostics for contrastive pairs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import Counter, defaultdict
|
|
6
|
+
from difflib import SequenceMatcher
|
|
7
|
+
from typing import Dict, Iterable, List
|
|
8
|
+
|
|
9
|
+
from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _norm(text: str) -> str:
|
|
13
|
+
return " ".join(text.strip().lower().split())
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def compute_duplicate_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
|
|
17
|
+
"""Detect exact and near duplicates across prompts and responses."""
|
|
18
|
+
|
|
19
|
+
pairs_list = list(pairs)
|
|
20
|
+
|
|
21
|
+
prompt_counter: Counter[str] = Counter()
|
|
22
|
+
positive_counter: Counter[str] = Counter()
|
|
23
|
+
negative_counter: Counter[str] = Counter()
|
|
24
|
+
indexed_prompts: Dict[str, List[int]] = defaultdict(list)
|
|
25
|
+
|
|
26
|
+
for idx, pair in enumerate(pairs_list):
|
|
27
|
+
prompt = _norm(getattr(pair, "prompt", ""))
|
|
28
|
+
pos = _norm(getattr(pair.positive_response, "model_response", ""))
|
|
29
|
+
neg = _norm(getattr(pair.negative_response, "model_response", ""))
|
|
30
|
+
|
|
31
|
+
if prompt:
|
|
32
|
+
prompt_counter[prompt] += 1
|
|
33
|
+
indexed_prompts[prompt].append(idx)
|
|
34
|
+
if pos:
|
|
35
|
+
positive_counter[pos] += 1
|
|
36
|
+
if neg:
|
|
37
|
+
negative_counter[neg] += 1
|
|
38
|
+
|
|
39
|
+
total_pairs = len(pairs_list)
|
|
40
|
+
issues: List[DiagnosticsIssue] = []
|
|
41
|
+
|
|
42
|
+
if total_pairs == 0:
|
|
43
|
+
return MetricReport(name="duplicates", summary={"total_pairs": 0}, issues=[])
|
|
44
|
+
|
|
45
|
+
def _collect_exact(counter: Counter[str], label: str) -> List[DiagnosticsIssue]:
|
|
46
|
+
duplicates: List[DiagnosticsIssue] = []
|
|
47
|
+
for value, count in counter.items():
|
|
48
|
+
if count > 1:
|
|
49
|
+
duplicates.append(
|
|
50
|
+
DiagnosticsIssue(
|
|
51
|
+
metric="duplicates",
|
|
52
|
+
severity="warning",
|
|
53
|
+
message=f"Exact duplicate detected in {label}.",
|
|
54
|
+
pair_index=None,
|
|
55
|
+
details={"value": value, "count": count, "field": label},
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
return duplicates
|
|
59
|
+
|
|
60
|
+
issues.extend(_collect_exact(prompt_counter, "prompt"))
|
|
61
|
+
issues.extend(_collect_exact(positive_counter, "positive_response"))
|
|
62
|
+
issues.extend(_collect_exact(negative_counter, "negative_response"))
|
|
63
|
+
|
|
64
|
+
exact_duplicate_fraction = sum(max(0, count - 1) for count in prompt_counter.values()) / total_pairs
|
|
65
|
+
if exact_duplicate_fraction > config.max_exact_duplicate_fraction:
|
|
66
|
+
issues.append(
|
|
67
|
+
DiagnosticsIssue(
|
|
68
|
+
metric="duplicates",
|
|
69
|
+
severity="critical",
|
|
70
|
+
message="Too many exact duplicate prompts detected.",
|
|
71
|
+
pair_index=None,
|
|
72
|
+
details={
|
|
73
|
+
"fraction": exact_duplicate_fraction,
|
|
74
|
+
"threshold": config.max_exact_duplicate_fraction,
|
|
75
|
+
"duplicates": [
|
|
76
|
+
{"prompt": prompt, "count": count}
|
|
77
|
+
for prompt, count in prompt_counter.items()
|
|
78
|
+
if count > 1
|
|
79
|
+
],
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
near_duplicate_pairs: List[tuple[int, int, float]] = []
|
|
85
|
+
prompt_items = list(prompt_counter.keys())
|
|
86
|
+
for i, prompt_a in enumerate(prompt_items):
|
|
87
|
+
for prompt_b in prompt_items[i + 1 :]:
|
|
88
|
+
similarity = SequenceMatcher(None, prompt_a, prompt_b).ratio()
|
|
89
|
+
if similarity >= config.near_duplicate_prompt_threshold:
|
|
90
|
+
indices_a = indexed_prompts[prompt_a]
|
|
91
|
+
indices_b = indexed_prompts[prompt_b]
|
|
92
|
+
near_duplicate_pairs.append((indices_a[0], indices_b[0], similarity))
|
|
93
|
+
issues.append(
|
|
94
|
+
DiagnosticsIssue(
|
|
95
|
+
metric="duplicates",
|
|
96
|
+
severity="warning",
|
|
97
|
+
message="Near-duplicate prompts detected.",
|
|
98
|
+
pair_index=None,
|
|
99
|
+
details={
|
|
100
|
+
"prompt_a": prompt_a,
|
|
101
|
+
"prompt_b": prompt_b,
|
|
102
|
+
"similarity": similarity,
|
|
103
|
+
"a_indices": indices_a,
|
|
104
|
+
"b_indices": indices_b,
|
|
105
|
+
},
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
summary = {
|
|
110
|
+
"total_pairs": total_pairs,
|
|
111
|
+
"exact_duplicate_fraction": exact_duplicate_fraction,
|
|
112
|
+
"unique_prompts": len(prompt_counter),
|
|
113
|
+
"near_duplicate_count": len(near_duplicate_pairs),
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return MetricReport(name="duplicates", summary=summary, issues=issues)
|
|
File without changes
|