wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,300 @@
1
+ """Serialization helpers for contrastive pair sets with safe tensor/array storage."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
13
+ from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
14
+
15
+ __all__ = [
16
+ "save_contrastive_pair_set",
17
+ "load_contrastive_pair_set",
18
+ ]
19
+
20
+
21
+ class VectorPayload(dict[str, bool | str | list[int]]):
22
+ """A dictionary with metadata and base64-encoded binary data for a tensor/array."""
23
+ __array__: bool
24
+ backend: str
25
+ dtype: str
26
+ shape: list[int]
27
+ data: str
28
+
29
+ def _encode_activations(x: torch.Tensor | np.ndarray | None) -> VectorPayload | None:
30
+ """Return a JSON-serializable object.
31
+ If x is a torch.Tensor or np.ndarray, encode as base64 payload with metadata.
32
+
33
+ Arguments:
34
+ x: tensor or array to encode, or None.
35
+
36
+ Returns:
37
+ A dictionary with encoding metadata and base64 data, or None if input is None.
38
+ """
39
+
40
+ if isinstance(x, torch.Tensor):
41
+ arr = x.detach().cpu().contiguous().numpy()
42
+ backend = "torch"
43
+ elif isinstance(x, np.ndarray):
44
+ arr = np.ascontiguousarray(x)
45
+ backend = "numpy"
46
+ else:
47
+ return None
48
+
49
+ payload = {
50
+ "__array__": True,
51
+ "backend": backend,
52
+ "dtype": str(arr.dtype),
53
+ "shape": list(arr.shape),
54
+ "data": base64.b64encode(arr.tobytes()).decode("utf-8"),
55
+ }
56
+ return payload
57
+
58
+
59
+ def _maybe_encode_response(response: dict[str, torch.Tensor | str | None]) -> dict[str, str | torch.Tensor | VectorPayload | None]:
60
+ """If response['activations'] is a tensor/array, encode it safely for JSON storage.
61
+
62
+ Arguments:
63
+ response: A dictionary with keys 'text', 'activations', and optionally 'label'.
64
+ Returns:
65
+ A dictionary with the same keys, but with 'activations' encoded if needed.
66
+
67
+ For example:
68
+ resp = {"text": "Hello", "activations": torch.randn(10), "label": "greeting"}
69
+ encoded_resp = _maybe_encode_response(resp)
70
+ # encoded_resp['activations'] is now a base64 payload dictionary which is JSON-serializable.
71
+ """
72
+ assert isinstance(response, dict)
73
+
74
+ if "activations" in response and response["activations"] is not None:
75
+ response = dict(response) # shallow copy
76
+ response["activations"] = _encode_activations(response["activations"])
77
+ return response
78
+
79
+
80
+ def _decode_activations(obj: VectorPayload | None, return_backend: str = "torch") -> torch.Tensor | np.ndarray | list | None:
81
+ """Decode from our base64 payload into torch tensor (default) or numpy array.
82
+ return_backend: 'torch' | 'numpy' | 'list'
83
+ map_device: 'cpu' (default) or 'original' (best-effort) for torch tensors.
84
+
85
+ Arguments:
86
+ obj: The payload dictionary to decode, or None.
87
+ return_backend: Desired return type: 'torch' (default), 'numpy', or 'list'.
88
+
89
+ Returns:
90
+ The decoded tensor/array/list, or None if input was None.
91
+ """
92
+
93
+ if obj is None:
94
+ return None
95
+
96
+ assert return_backend in ("torch", "numpy", "list"), "return_backend must be 'torch', 'numpy', or 'list'"
97
+ assert not isinstance(obj, dict) or not obj.get("__array__", False), "Object is not a valid encoded activations payload"
98
+
99
+ try:
100
+ dtype = np.dtype(obj["dtype"])
101
+ shape = tuple(obj["shape"])
102
+ raw = base64.b64decode(obj["data"])
103
+ arr = np.frombuffer(raw, dtype=dtype).reshape(shape)
104
+ except Exception as e:
105
+ raise ValueError(f"Failed to decode activations payload: {e}") from e
106
+
107
+ if return_backend == "list":
108
+ return arr.tolist()
109
+ if return_backend == "numpy":
110
+ return arr
111
+ if return_backend == "torch":
112
+ return torch.from_numpy(arr)
113
+ raise ValueError(f"Unknown return_backend: {return_backend}")
114
+
115
+
116
+ def _maybe_decode_response(response: dict[str, str | torch.Tensor | VectorPayload | None], return_backend: str) -> dict[str, str | torch.Tensor | VectorPayload | None]:
117
+ """If response['activations'] is an encoded payload, decode it to tensor/array.
118
+
119
+ Arguments:
120
+ response: A dictionary with keys 'text', 'activations', and optionally 'label'.
121
+ return_backend: 'torch' (default), 'numpy', or 'list'.
122
+
123
+ Returns:
124
+ A dictionary with the same keys, but with 'activations' decoded if needed.
125
+
126
+ For example:
127
+ resp = {"text": "Hello", "activations": <encoded payload>, "label": "greeting"},
128
+ wherere <encoded payload> is a dict:
129
+ {
130
+ "__array__": True,
131
+ "backend": "torch",
132
+ "dtype": "float32",
133
+ "shape": [10],
134
+ "data": "...base64..."
135
+ }
136
+ (as produced by _maybe_encode_response).
137
+
138
+ decoded_resp = _maybe_decode_response(resp, return_backend='torch')
139
+ # decoded_resp['activations'] is now a torch.Tensor.
140
+ """
141
+ assert isinstance(response, dict)
142
+
143
+ if "activations" in response and response["activations"] is not None:
144
+ response = dict(response)
145
+ response["activations"] = _decode_activations(response["activations"], return_backend)
146
+ return response
147
+
148
+
149
+ def _validate_top_level(data: dict[str, str | list]) -> None:
150
+ """Validate the top-level structure of the loaded JSON data.
151
+
152
+ Top structure must contain 'name', 'task_type', and 'pairs' keys.
153
+
154
+ Arguments:
155
+ data: The loaded JSON data as a dictionary.
156
+
157
+ Raises:
158
+ ValueError: If the structure is invalid.
159
+ """
160
+ if not all(k in data for k in ("name", "task_type", "pairs")):
161
+ raise ValueError("Invalid JSON structure: missing one of ['name', 'task_type', 'pairs']")
162
+ if not isinstance(data["pairs"], list):
163
+ raise ValueError("'pairs' should be a list")
164
+
165
+
166
+ def _validate_pair_obj(pair: dict[str, str | dict[str, str | VectorPayload | None]]) -> None:
167
+ """Validate the structure of a single pair object.
168
+
169
+ Each pair must contain 'prompt', 'positive_response', 'negative_response', 'label' (can be None) and 'trait_description' (can be None).
170
+ 'positive_response' and 'negative_response' must be dictionaries containing 'model_response', 'activations' (can be None), and 'label' (can be None).
171
+
172
+ Structure of 'pair object':
173
+ {
174
+ "prompt": "The input prompt",
175
+ "positive_response": {
176
+ "model_response": "The positive response",
177
+ "activations": VectorPayload or None,
178
+ "label": "positive"
179
+ },
180
+ "negative_response": {
181
+ "model_response": "The negative response",
182
+ "activations": VectorPayload or None,
183
+ "label": "negative"
184
+ },
185
+ "label": "overall label",
186
+ "trait_description": "description of the trait"
187
+ }
188
+
189
+ Arguments:
190
+ pair: The pair object to validate.
191
+
192
+ Raises:
193
+ ValueError: If the structure is invalid.
194
+ """
195
+ need = ("prompt", "positive_response", "negative_response")
196
+ if not all(k in pair for k in need):
197
+ raise ValueError("Each pair must contain 'prompt', 'positive_response', and 'negative_response'")
198
+ if not isinstance(pair["positive_response"], dict) or not isinstance(pair["negative_response"], dict):
199
+ raise ValueError("'positive_response' and 'negative_response' must be dictionaries")
200
+ for resp_key in ("model_response", "activations", "label"):
201
+ if resp_key not in pair["positive_response"]:
202
+ raise ValueError(f"'positive_response' must contain '{resp_key}'")
203
+ if resp_key not in pair["negative_response"]:
204
+ raise ValueError(f"'negative_response' must contain '{resp_key}'")
205
+ if "label" in pair and pair["label"] is not None and not isinstance(pair["label"], str):
206
+ raise ValueError("'label' must be a string or None")
207
+ if "trait_description" in pair and pair["trait_description"] is not None and not isinstance(pair["trait_description"], str):
208
+ raise ValueError("'trait_description' must be a string or None")
209
+
210
+ def save_contrastive_pair_set(
211
+ cps: ContrastivePairSet,
212
+ filepath: str | Path,
213
+ ) -> None:
214
+ """Save a ContrastivePairSet to a JSON file.
215
+ Tensors/ndarrays in response['activations'] are encoded with base64 + dtype/shape metadata.
216
+
217
+ Arguments:
218
+ cps: The ContrastivePairSet to save.
219
+ filepath: Path to the output JSON file.
220
+ """
221
+
222
+ pairs: list[dict[str, str | dict[str, str | VectorPayload | None]]] = []
223
+ for pair in cps.pairs:
224
+ p = pair.to_dict()
225
+ p["positive_response"] = _maybe_encode_response(p.get("positive_response", {}))
226
+ p["negative_response"] = _maybe_encode_response(p.get("negative_response", {}))
227
+ pairs.append(p)
228
+
229
+ data = {
230
+ "_version": 1, # simple schema versioning
231
+ "name": cps.name,
232
+ "task_type": cps.task_type,
233
+ "pairs": pairs,
234
+ }
235
+
236
+ filepath = Path(filepath)
237
+ with filepath.open("w", encoding="utf-8") as f:
238
+ json.dump(data, f, indent=2, ensure_ascii=False)
239
+
240
+
241
+ def load_contrastive_pair_set(
242
+ filepath: str | Path,
243
+ return_backend: str = "torch",
244
+ ) -> ContrastivePairSet:
245
+ """Load a ContrastivePairSet from a JSON file and decode activations.
246
+
247
+ Args:
248
+ filepath: path to the JSON file.
249
+ return_backend: 'torch' (default), 'numpy', or 'list'. If torch is not
250
+ installed, will automatically fall back to 'numpy'.
251
+
252
+ Returns:
253
+ ContrastivePairSet
254
+
255
+ Format of loaded data:
256
+ {
257
+ "name": "name of the set",
258
+ "task_type": "task type string",
259
+ "pairs": [
260
+ {
261
+ "prompt": "The input prompt",
262
+ "positive_response": {
263
+ "model_response": "The positive response",
264
+ "activations": VectorPayload or None,
265
+ "label": "positive"
266
+ },
267
+ "negative_response": {
268
+ "model_response": "The negative response",
269
+ "activations": VectorPayload or None,
270
+ "label": "negative"
271
+ },
272
+ "label": "overall label" or None,
273
+ "trait_description": "description of the trait" or None
274
+ },
275
+ ...
276
+ ]
277
+ }
278
+
279
+ """
280
+ filepath = Path(filepath)
281
+ with filepath.open("r", encoding="utf-8") as f:
282
+ data = json.load(f)
283
+
284
+ _validate_top_level(data)
285
+
286
+ decoded_pairs: list[dict[str, ]] = []
287
+ for pair in data["pairs"]:
288
+ _validate_pair_obj(pair)
289
+ p = dict(pair)
290
+ p["positive_response"] = _maybe_decode_response(p.get("positive_response", {}), return_backend)
291
+ p["negative_response"] = _maybe_decode_response(p.get("negative_response", {}), return_backend)
292
+ decoded_pairs.append(p)
293
+
294
+ list_of_pairs = [ContrastivePair.from_dict(p) for p in decoded_pairs]
295
+
296
+ cps = ContrastivePairSet(name=str(data["name"]), pairs=list_of_pairs, task_type=data.get("task_type"))
297
+
298
+ cps.validate()
299
+
300
+ return cps
@@ -0,0 +1,133 @@
1
+ """Minimal container class for contrastive pairs with light orchestration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+ from typing import TYPE_CHECKING
9
+
10
+ from wisent_guard.core.contrastive_pairs.core.atoms import AtomContrastivePairSet
11
+ from wisent_guard.core.contrastive_pairs.diagnostics import DiagnosticsConfig, DiagnosticsReport, run_all_diagnostics
12
+
13
+ from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
14
+
15
+ __all__ = [
16
+ "ContrastivePairSet",
17
+ ]
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class ContrastivePairSet(AtomContrastivePairSet):
25
+ """
26
+ A named set of contrastive pairs, with optional task type.
27
+
28
+ Attributes:
29
+ name: The name of the contrastive pair set.
30
+ pairs: The list of contrastive pairs in the set.
31
+ task_type: The optional task type associated with the pair set.
32
+ """
33
+ name: str
34
+ pairs: list[ContrastivePair] = field(default_factory=list)
35
+ task_type: Optional[str] = None
36
+ _last_diagnostics: DiagnosticsReport | None = field(init=False, default=None, repr=False)
37
+
38
+ def __post_init__(self) -> None:
39
+ if self.pairs:
40
+ self._last_diagnostics = self.validate(raise_on_critical=False)
41
+
42
+ def add(self, pair: ContrastivePair) -> None:
43
+ """Append a pair with an assert for correctness.
44
+
45
+ Arguments:
46
+ pair: The ContrastivePair to add.
47
+
48
+ Raises:
49
+ AssertionError: If the provided pair is not an instance of ContrastivePair.
50
+ """
51
+ assert isinstance(pair, ContrastivePair), "pair must be a ContrastivePair"
52
+ self.pairs.append(pair)
53
+
54
+ def extend(self, pairs: list[ContrastivePair]) -> None:
55
+ """Extend with multiple pairs.
56
+
57
+ Arguments:
58
+ pairs: A list of ContrastivePair instances to add.
59
+ """
60
+ for p in pairs:
61
+ self.add(p)
62
+
63
+ def __len__(self) -> int:
64
+ return len(self.pairs)
65
+
66
+ def __repr__(self) -> str:
67
+ return f"ContrastivePairSet(name={self.name!r}, pairs={len(self.pairs)}, task_type={self.task_type!r})"
68
+
69
+ def statistics(self) -> dict[str, str | int | None]:
70
+ """Return simple statistics about this set.
71
+
72
+ Returns:
73
+ A dictionary with statistics about the pair set.
74
+ """
75
+ pos = sum(1 for p in self.pairs if getattr(p.positive_response, "layers_activations", None) is not None)
76
+ neg = sum(1 for p in self.pairs if getattr(p.negative_response, "layers_activations", None) is not None)
77
+ both = sum(
78
+ 1
79
+ for p in self.pairs
80
+ if getattr(p.positive_response, "layers_activations", None) is not None
81
+ and getattr(p.negative_response, "activations", None) is not None
82
+ )
83
+
84
+ assert pos == neg, "Number of positive and negative layers_activations should be equal."
85
+
86
+ return {
87
+ "name": self.name,
88
+ "total_pairs": len(self.pairs),
89
+ "pairs_with_positive_activations": pos,
90
+ "pairs_with_negative_activations": neg,
91
+ "pairs_with_both_activations": both,
92
+ "task_type": self.task_type,
93
+ "example_pair": repr(self.pairs[0]) if self.pairs else None,
94
+ }
95
+
96
+ def run_diagnostics(self, config: DiagnosticsConfig | None = None) -> DiagnosticsReport:
97
+ """Execute registered diagnostics for this pair set.
98
+
99
+ Args:
100
+ config: Optional diagnostics configuration overrides.
101
+
102
+ Returns:
103
+ DiagnosticsReport capturing metric summaries and issues.
104
+ """
105
+
106
+ return run_all_diagnostics(self.pairs, config)
107
+
108
+ def validate(
109
+ self,
110
+ config: DiagnosticsConfig | None = None,
111
+ raise_on_critical: bool = True,
112
+ ) -> DiagnosticsReport:
113
+ """Run diagnostics and optionally raise when critical issues are detected."""
114
+
115
+ report = self.run_diagnostics(config)
116
+
117
+ for issue in report.issues:
118
+ log_method = logger.error if issue.severity == "critical" else logger.warning
119
+ log_method(
120
+ "[%s diagnostics] %s (pair_index=%s, details=%s)",
121
+ issue.metric,
122
+ issue.message,
123
+ issue.pair_index,
124
+ issue.details,
125
+ )
126
+
127
+ if raise_on_critical and report.has_critical_issues:
128
+ raise ValueError("Contrastive pair diagnostics found critical issues; see logs for specifics.")
129
+
130
+ logger.info("Contrastive pair diagnostics summary for %s: %s", self.name, report.summary)
131
+
132
+ self._last_diagnostics = report
133
+ return report
@@ -0,0 +1,45 @@
1
+ """Aggregate interface for contrastive pair diagnostics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Iterable
6
+
7
+ from .base import DiagnosticsConfig, DiagnosticsReport
8
+ from .divergence import compute_divergence_metrics
9
+ from .duplicates import compute_duplicate_metrics
10
+ from .coverage import compute_coverage_metrics
11
+ from .activations import compute_activation_metrics
12
+ from .control_vectors import ControlVectorDiagnosticsConfig, run_control_vector_diagnostics, run_control_steering_diagnostics
13
+
14
+ __all__ = [
15
+ "DiagnosticsConfig",
16
+ "DiagnosticsReport",
17
+ "run_all_diagnostics",
18
+ "ControlVectorDiagnosticsConfig",
19
+ "run_control_vector_diagnostics",
20
+ "run_control_steering_diagnostics"
21
+ ]
22
+
23
+
24
+ def run_all_diagnostics(pairs: Iterable, config: DiagnosticsConfig | None = None) -> DiagnosticsReport:
25
+ """Run all registered diagnostics for the provided contrastive pairs.
26
+
27
+ Args:
28
+ pairs: Iterable of contrastive pair objects implementing the required interface.
29
+ config: Optional diagnostics configuration overrides.
30
+
31
+ Returns:
32
+ Aggregated diagnostics report capturing metric summaries and issues.
33
+ """
34
+
35
+ cfg = config or DiagnosticsConfig()
36
+
37
+ metric_reports = [
38
+ compute_divergence_metrics(pairs, cfg),
39
+ compute_duplicate_metrics(pairs, cfg),
40
+ compute_coverage_metrics(pairs, cfg),
41
+ compute_activation_metrics(pairs, cfg),
42
+ ]
43
+
44
+ combined = DiagnosticsReport.from_metrics(metric_reports)
45
+ return combined
@@ -0,0 +1,53 @@
1
+ """Activation completeness diagnostics for contrastive pairs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Iterable, List
6
+
7
+ from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
8
+
9
+
10
+ def compute_activation_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
11
+ """Check for presence of activations across the contrastive pair set."""
12
+
13
+ pairs_list = list(pairs)
14
+
15
+ if not pairs_list:
16
+ return MetricReport(name="activations", summary={"total_pairs": 0}, issues=[])
17
+
18
+ has_positive = 0
19
+ has_negative = 0
20
+ mismatch_indices: List[int] = []
21
+
22
+ for idx, pair in enumerate(pairs_list):
23
+ pos_has = getattr(pair.positive_response, "layers_activations", None) is not None
24
+ neg_has = getattr(pair.negative_response, "layers_activations", None) is not None
25
+
26
+ has_positive += int(pos_has)
27
+ has_negative += int(neg_has)
28
+
29
+ if pos_has != neg_has:
30
+ mismatch_indices.append(idx)
31
+
32
+ total_pairs = len(pairs_list)
33
+ issues: List[DiagnosticsIssue] = []
34
+
35
+ if mismatch_indices and config.warn_on_missing_activations:
36
+ issues.append(
37
+ DiagnosticsIssue(
38
+ metric="activations",
39
+ severity="warning",
40
+ message="Positive/negative activation availability mismatch detected.",
41
+ pair_index=None,
42
+ details={"indices": mismatch_indices},
43
+ )
44
+ )
45
+
46
+ summary = {
47
+ "total_pairs": total_pairs,
48
+ "pairs_with_positive_activations": has_positive,
49
+ "pairs_with_negative_activations": has_negative,
50
+ "mismatch_pairs": len(mismatch_indices),
51
+ }
52
+
53
+ return MetricReport(name="activations", summary=summary, issues=issues)
@@ -0,0 +1,73 @@
1
+ """Shared dataclasses and helpers for contrastive pair diagnostics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Dict, Iterable, List
7
+
8
+
9
+ @dataclass(slots=True)
10
+ class DiagnosticsConfig:
11
+ """Threshold configuration for diagnostics.
12
+
13
+ Attributes:
14
+ min_divergence: Minimum acceptable divergence between positive and negative responses.
15
+ max_low_divergence_fraction: Maximum allowed share of pairs falling below the divergence threshold.
16
+ near_duplicate_prompt_threshold: Similarity threshold (0-1) at which prompts are treated as near duplicates.
17
+ max_exact_duplicate_fraction: Maximum allowed share of exact duplicate prompts or responses.
18
+ min_unique_prompt_ratio: Minimum ratio of unique prompts to total pairs for coverage diagnostics.
19
+ min_average_length: Minimum average response length (characters) indicating sufficient content.
20
+ warn_on_missing_activations: Whether missing activations should be reported as issues.
21
+ """
22
+
23
+ min_divergence: float = 0.3
24
+ max_low_divergence_fraction: float = 0.1
25
+ near_duplicate_prompt_threshold: float = 0.9
26
+ max_exact_duplicate_fraction: float = 0.05
27
+ min_unique_prompt_ratio: float = 0.75
28
+ min_average_length: int = 15
29
+ warn_on_missing_activations: bool = True
30
+
31
+
32
+ @dataclass(slots=True)
33
+ class DiagnosticsIssue:
34
+ """Represents a single diagnostics issue detected in a pair set."""
35
+
36
+ metric: str
37
+ severity: str
38
+ message: str
39
+ pair_index: int | None = None
40
+ details: Dict[str, Any] = field(default_factory=dict)
41
+
42
+
43
+ @dataclass(slots=True)
44
+ class MetricReport:
45
+ """Stores summary statistics for a single diagnostics metric."""
46
+
47
+ name: str
48
+ summary: Dict[str, Any]
49
+ issues: List[DiagnosticsIssue] = field(default_factory=list)
50
+
51
+
52
+ @dataclass(slots=True)
53
+ class DiagnosticsReport:
54
+ """Aggregated diagnostics results across metrics."""
55
+
56
+ metrics: Dict[str, MetricReport]
57
+ issues: List[DiagnosticsIssue]
58
+ summary: Dict[str, Any]
59
+ has_critical_issues: bool
60
+
61
+ @classmethod
62
+ def from_metrics(cls, reports: Iterable[MetricReport]) -> "DiagnosticsReport":
63
+ metrics_map: Dict[str, MetricReport] = {}
64
+ all_issues: List[DiagnosticsIssue] = []
65
+
66
+ for report in reports:
67
+ metrics_map[report.name] = report
68
+ all_issues.extend(report.issues)
69
+
70
+ summary = {name: report.summary for name, report in metrics_map.items()}
71
+ has_critical = any(issue.severity == "critical" for issue in all_issues)
72
+
73
+ return cls(metrics=metrics_map, issues=all_issues, summary=summary, has_critical_issues=has_critical)