wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,169 @@
1
+ """Diagnostics for steering/control vectors."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import statistics
6
+ from dataclasses import dataclass
7
+ from typing import Mapping
8
+
9
+ import torch
10
+
11
+ from wisent_guard.core.activations.core.atoms import LayerActivations, RawActivationMap
12
+
13
+ from .base import DiagnosticsIssue, DiagnosticsReport, MetricReport
14
+
15
+ __all__ = [
16
+ "ControlVectorDiagnosticsConfig",
17
+ "run_control_vector_diagnostics",
18
+ "run_control_steering_diagnostics",
19
+ ]
20
+
21
+
22
+ @dataclass(slots=True)
23
+ class ControlVectorDiagnosticsConfig:
24
+ """Thresholds and options for control vector diagnostics."""
25
+
26
+ min_norm: float = 1e-4
27
+ max_norm: float | None = None
28
+ zero_value_threshold: float = 1e-8
29
+ max_zero_fraction: float = 0.999
30
+ warn_on_missing: bool = True
31
+
32
+
33
+ def _to_layer_activations(vectors: LayerActivations | RawActivationMap | Mapping[str, object] | None) -> LayerActivations:
34
+ if isinstance(vectors, LayerActivations):
35
+ return vectors
36
+ data: RawActivationMap = vectors or {}
37
+ return LayerActivations(data)
38
+
39
+
40
+ def run_control_vector_diagnostics(
41
+ vectors: LayerActivations | RawActivationMap | Mapping[str, object] | None,
42
+ config: ControlVectorDiagnosticsConfig | None = None,
43
+ ) -> DiagnosticsReport:
44
+ """Evaluate steering/control vectors for basic health metrics."""
45
+
46
+ cfg = config or ControlVectorDiagnosticsConfig()
47
+ activations = _to_layer_activations(vectors)
48
+
49
+ issues: list[DiagnosticsIssue] = []
50
+ norms: list[float] = []
51
+ zero_fractions: list[float] = []
52
+ per_layer: dict[str, dict[str, float]] = {}
53
+
54
+ for layer, tensor in activations.to_dict().items():
55
+ if tensor is None:
56
+ if cfg.warn_on_missing:
57
+ issues.append(
58
+ DiagnosticsIssue(
59
+ metric="control_vectors",
60
+ severity="warning",
61
+ message=f"Layer {layer} has no control vector",
62
+ details={"layer": layer},
63
+ )
64
+ )
65
+ continue
66
+
67
+ detached = tensor.detach()
68
+ if detached.numel() == 0:
69
+ issues.append(
70
+ DiagnosticsIssue(
71
+ metric="control_vectors",
72
+ severity="critical",
73
+ message=f"Layer {layer} control vector is empty",
74
+ details={"layer": layer},
75
+ )
76
+ )
77
+ continue
78
+
79
+ flat = detached.to(dtype=torch.float32, device="cpu").reshape(-1)
80
+
81
+ if not torch.isfinite(flat).all():
82
+ non_finite = (~torch.isfinite(flat)).sum().item()
83
+ issues.append(
84
+ DiagnosticsIssue(
85
+ metric="control_vectors",
86
+ severity="critical",
87
+ message=f"Layer {layer} contains non-finite values",
88
+ details={"layer": layer, "non_finite_entries": int(non_finite)},
89
+ )
90
+ )
91
+ continue
92
+
93
+ norm_value = float(torch.linalg.vector_norm(flat).item())
94
+ norms.append(norm_value)
95
+
96
+ zero_fraction = float((flat.abs() <= cfg.zero_value_threshold).sum().item()) / float(flat.numel())
97
+ zero_fractions.append(zero_fraction)
98
+
99
+ per_layer[layer] = {
100
+ "norm": norm_value,
101
+ "zero_fraction": zero_fraction,
102
+ }
103
+
104
+ if norm_value < cfg.min_norm:
105
+ issues.append(
106
+ DiagnosticsIssue(
107
+ metric="control_vectors",
108
+ severity="critical",
109
+ message=f"Layer {layer} control vector norm {norm_value:.3e} below minimum {cfg.min_norm}",
110
+ details={"layer": layer, "norm": norm_value},
111
+ )
112
+ )
113
+
114
+ if cfg.max_norm is not None and norm_value > cfg.max_norm:
115
+ issues.append(
116
+ DiagnosticsIssue(
117
+ metric="control_vectors",
118
+ severity="warning",
119
+ message=f"Layer {layer} control vector norm {norm_value:.3e} exceeds maximum {cfg.max_norm}",
120
+ details={"layer": layer, "norm": norm_value},
121
+ )
122
+ )
123
+
124
+ if zero_fraction >= cfg.max_zero_fraction:
125
+ severity = "critical" if zero_fraction >= 1.0 - 1e-9 else "warning"
126
+ issues.append(
127
+ DiagnosticsIssue(
128
+ metric="control_vectors",
129
+ severity=severity,
130
+ message=(
131
+ f"Layer {layer} control vector is {zero_fraction:.3%} zero-valued, exceeding allowed {cfg.max_zero_fraction:.3%}"
132
+ ),
133
+ details={"layer": layer, "zero_fraction": zero_fraction},
134
+ )
135
+ )
136
+
137
+ summary: dict[str, object] = {
138
+ "evaluated_layers": len(norms),
139
+ "norm_min": min(norms) if norms else None,
140
+ "norm_max": max(norms) if norms else None,
141
+ "norm_mean": statistics.mean(norms) if norms else None,
142
+ "norm_median": statistics.median(norms) if norms else None,
143
+ "zero_fraction_max": max(zero_fractions) if zero_fractions else None,
144
+ "per_layer": per_layer,
145
+ }
146
+
147
+ if not norms and not issues:
148
+ issues.append(
149
+ DiagnosticsIssue(
150
+ metric="control_vectors",
151
+ severity="critical",
152
+ message="No control vectors were provided for diagnostics",
153
+ details={},
154
+ )
155
+ )
156
+
157
+ report = MetricReport(name="control_vectors", summary=summary, issues=issues)
158
+ return DiagnosticsReport.from_metrics([report])
159
+
160
+ def run_control_steering_diagnostics(steering_vectors: list[RawActivationMap] | RawActivationMap | None) -> list[DiagnosticsReport]:
161
+ if steering_vectors is None:
162
+ return [DiagnosticsReport.from_metrics([])]
163
+
164
+ if not isinstance(steering_vectors, list):
165
+ steering_vectors = [steering_vectors]
166
+
167
+ # Run diagnostics for each steering vector
168
+ reports = [run_control_vector_diagnostics(vec) for vec in steering_vectors]
169
+ return reports
@@ -0,0 +1,79 @@
1
+ """Coverage and diversity diagnostics for contrastive pairs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from statistics import mean
6
+ from typing import Iterable, List
7
+
8
+ from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
9
+
10
+
11
+ def compute_coverage_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
12
+ """Assess dataset coverage such as prompt diversity and response length."""
13
+
14
+ pairs_list = list(pairs)
15
+
16
+ if not pairs_list:
17
+ return MetricReport(name="coverage", summary={"total_pairs": 0}, issues=[])
18
+
19
+ unique_prompts = {getattr(pair, "prompt", "").strip().lower() for pair in pairs_list}
20
+ prompt_ratio = len(unique_prompts) / len(pairs_list)
21
+
22
+ pos_lengths: List[int] = []
23
+ neg_lengths: List[int] = []
24
+ labels = set()
25
+
26
+ for pair in pairs_list:
27
+ pos_text = getattr(pair.positive_response, "model_response", "")
28
+ neg_text = getattr(pair.negative_response, "model_response", "")
29
+ pos_lengths.append(len(pos_text))
30
+ neg_lengths.append(len(neg_text))
31
+
32
+ if pair.label:
33
+ labels.add(pair.label)
34
+
35
+ avg_positive_length = mean(pos_lengths) if pos_lengths else 0.0
36
+ avg_negative_length = mean(neg_lengths) if neg_lengths else 0.0
37
+
38
+ issues: List[DiagnosticsIssue] = []
39
+
40
+ if prompt_ratio < config.min_unique_prompt_ratio:
41
+ issues.append(
42
+ DiagnosticsIssue(
43
+ metric="coverage",
44
+ severity="warning",
45
+ message="Prompt diversity below configured ratio.",
46
+ pair_index=None,
47
+ details={
48
+ "ratio": prompt_ratio,
49
+ "threshold": config.min_unique_prompt_ratio,
50
+ "unique_prompts": len(unique_prompts),
51
+ "total_pairs": len(pairs_list),
52
+ },
53
+ )
54
+ )
55
+
56
+ if avg_positive_length < config.min_average_length or avg_negative_length < config.min_average_length:
57
+ issues.append(
58
+ DiagnosticsIssue(
59
+ metric="coverage",
60
+ severity="warning",
61
+ message="Average response length below minimum threshold.",
62
+ pair_index=None,
63
+ details={
64
+ "avg_positive_length": avg_positive_length,
65
+ "avg_negative_length": avg_negative_length,
66
+ "threshold": config.min_average_length,
67
+ },
68
+ )
69
+ )
70
+
71
+ summary = {
72
+ "total_pairs": len(pairs_list),
73
+ "unique_prompt_ratio": prompt_ratio,
74
+ "avg_positive_length": avg_positive_length,
75
+ "avg_negative_length": avg_negative_length,
76
+ "label_coverage": len(labels),
77
+ }
78
+
79
+ return MetricReport(name="coverage", summary=summary, issues=issues)
@@ -0,0 +1,98 @@
1
+ """Divergence diagnostics for contrastive pairs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from difflib import SequenceMatcher
6
+ from statistics import mean
7
+ from typing import Iterable, List
8
+
9
+ from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
10
+
11
+
12
+ def _normalize_text(text: str) -> str:
13
+ return " ".join(text.strip().lower().split())
14
+
15
+
16
+ def compute_divergence_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
17
+ """Evaluate textual divergence between positive and negative responses."""
18
+
19
+ pairs_list = list(pairs)
20
+
21
+ divergences: List[float] = []
22
+ issues: List[DiagnosticsIssue] = []
23
+
24
+ if not pairs_list:
25
+ return MetricReport(
26
+ name="divergence",
27
+ summary={
28
+ "mean_divergence": 0.0,
29
+ "min_divergence": 0.0,
30
+ "max_divergence": 0.0,
31
+ "low_divergence_fraction": 0.0,
32
+ },
33
+ issues=[],
34
+ )
35
+
36
+ for idx, pair in enumerate(pairs_list):
37
+ positive = getattr(pair.positive_response, "model_response", "")
38
+ negative = getattr(pair.negative_response, "model_response", "")
39
+
40
+ norm_pos = _normalize_text(positive)
41
+ norm_neg = _normalize_text(negative)
42
+
43
+ if not norm_pos or not norm_neg:
44
+ issues.append(
45
+ DiagnosticsIssue(
46
+ metric="divergence",
47
+ severity="critical",
48
+ message="Missing positive or negative response text.",
49
+ pair_index=idx,
50
+ details={"positive": bool(norm_pos), "negative": bool(norm_neg)},
51
+ )
52
+ )
53
+ divergences.append(0.0)
54
+ continue
55
+
56
+ similarity = SequenceMatcher(None, norm_pos, norm_neg).ratio()
57
+ divergence = 1.0 - similarity
58
+ divergences.append(divergence)
59
+
60
+ if divergence < config.min_divergence:
61
+ issues.append(
62
+ DiagnosticsIssue(
63
+ metric="divergence",
64
+ severity="warning",
65
+ message="Positive and negative responses are highly similar.",
66
+ pair_index=idx,
67
+ details={"divergence": divergence, "similarity": similarity},
68
+ )
69
+ )
70
+
71
+ low_divergence_fraction = 0.0
72
+ low_divergence_count = sum(1 for value in divergences if value < config.min_divergence)
73
+ low_divergence_fraction = low_divergence_count / len(divergences)
74
+
75
+ if low_divergence_fraction > config.max_low_divergence_fraction:
76
+ issues.append(
77
+ DiagnosticsIssue(
78
+ metric="divergence",
79
+ severity="critical",
80
+ message="Too many pairs fall below divergence threshold.",
81
+ pair_index=None,
82
+ details={
83
+ "fraction": low_divergence_fraction,
84
+ "threshold": config.max_low_divergence_fraction,
85
+ "count": low_divergence_count,
86
+ "total": len(divergences),
87
+ },
88
+ )
89
+ )
90
+
91
+ summary = {
92
+ "mean_divergence": mean(divergences) if divergences else 0.0,
93
+ "min_divergence": min(divergences) if divergences else 0.0,
94
+ "max_divergence": max(divergences) if divergences else 0.0,
95
+ "low_divergence_fraction": low_divergence_fraction,
96
+ }
97
+
98
+ return MetricReport(name="divergence", summary=summary, issues=issues)
@@ -0,0 +1,116 @@
1
+ """Duplicate detection diagnostics for contrastive pairs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import Counter, defaultdict
6
+ from difflib import SequenceMatcher
7
+ from typing import Dict, Iterable, List
8
+
9
+ from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
10
+
11
+
12
+ def _norm(text: str) -> str:
13
+ return " ".join(text.strip().lower().split())
14
+
15
+
16
+ def compute_duplicate_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
17
+ """Detect exact and near duplicates across prompts and responses."""
18
+
19
+ pairs_list = list(pairs)
20
+
21
+ prompt_counter: Counter[str] = Counter()
22
+ positive_counter: Counter[str] = Counter()
23
+ negative_counter: Counter[str] = Counter()
24
+ indexed_prompts: Dict[str, List[int]] = defaultdict(list)
25
+
26
+ for idx, pair in enumerate(pairs_list):
27
+ prompt = _norm(getattr(pair, "prompt", ""))
28
+ pos = _norm(getattr(pair.positive_response, "model_response", ""))
29
+ neg = _norm(getattr(pair.negative_response, "model_response", ""))
30
+
31
+ if prompt:
32
+ prompt_counter[prompt] += 1
33
+ indexed_prompts[prompt].append(idx)
34
+ if pos:
35
+ positive_counter[pos] += 1
36
+ if neg:
37
+ negative_counter[neg] += 1
38
+
39
+ total_pairs = len(pairs_list)
40
+ issues: List[DiagnosticsIssue] = []
41
+
42
+ if total_pairs == 0:
43
+ return MetricReport(name="duplicates", summary={"total_pairs": 0}, issues=[])
44
+
45
+ def _collect_exact(counter: Counter[str], label: str) -> List[DiagnosticsIssue]:
46
+ duplicates: List[DiagnosticsIssue] = []
47
+ for value, count in counter.items():
48
+ if count > 1:
49
+ duplicates.append(
50
+ DiagnosticsIssue(
51
+ metric="duplicates",
52
+ severity="warning",
53
+ message=f"Exact duplicate detected in {label}.",
54
+ pair_index=None,
55
+ details={"value": value, "count": count, "field": label},
56
+ )
57
+ )
58
+ return duplicates
59
+
60
+ issues.extend(_collect_exact(prompt_counter, "prompt"))
61
+ issues.extend(_collect_exact(positive_counter, "positive_response"))
62
+ issues.extend(_collect_exact(negative_counter, "negative_response"))
63
+
64
+ exact_duplicate_fraction = sum(max(0, count - 1) for count in prompt_counter.values()) / total_pairs
65
+ if exact_duplicate_fraction > config.max_exact_duplicate_fraction:
66
+ issues.append(
67
+ DiagnosticsIssue(
68
+ metric="duplicates",
69
+ severity="critical",
70
+ message="Too many exact duplicate prompts detected.",
71
+ pair_index=None,
72
+ details={
73
+ "fraction": exact_duplicate_fraction,
74
+ "threshold": config.max_exact_duplicate_fraction,
75
+ "duplicates": [
76
+ {"prompt": prompt, "count": count}
77
+ for prompt, count in prompt_counter.items()
78
+ if count > 1
79
+ ],
80
+ },
81
+ )
82
+ )
83
+
84
+ near_duplicate_pairs: List[tuple[int, int, float]] = []
85
+ prompt_items = list(prompt_counter.keys())
86
+ for i, prompt_a in enumerate(prompt_items):
87
+ for prompt_b in prompt_items[i + 1 :]:
88
+ similarity = SequenceMatcher(None, prompt_a, prompt_b).ratio()
89
+ if similarity >= config.near_duplicate_prompt_threshold:
90
+ indices_a = indexed_prompts[prompt_a]
91
+ indices_b = indexed_prompts[prompt_b]
92
+ near_duplicate_pairs.append((indices_a[0], indices_b[0], similarity))
93
+ issues.append(
94
+ DiagnosticsIssue(
95
+ metric="duplicates",
96
+ severity="warning",
97
+ message="Near-duplicate prompts detected.",
98
+ pair_index=None,
99
+ details={
100
+ "prompt_a": prompt_a,
101
+ "prompt_b": prompt_b,
102
+ "similarity": similarity,
103
+ "a_indices": indices_a,
104
+ "b_indices": indices_b,
105
+ },
106
+ )
107
+ )
108
+
109
+ summary = {
110
+ "total_pairs": total_pairs,
111
+ "exact_duplicate_fraction": exact_duplicate_fraction,
112
+ "unique_prompts": len(prompt_counter),
113
+ "near_duplicate_count": len(near_duplicate_pairs),
114
+ }
115
+
116
+ return MetricReport(name="duplicates", summary=summary, issues=issues)