wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Serialization helpers for contrastive pair sets with safe tensor/array storage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import base64
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
13
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"save_contrastive_pair_set",
|
|
17
|
+
"load_contrastive_pair_set",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VectorPayload(dict[str, bool | str | list[int]]):
|
|
22
|
+
"""A dictionary with metadata and base64-encoded binary data for a tensor/array."""
|
|
23
|
+
__array__: bool
|
|
24
|
+
backend: str
|
|
25
|
+
dtype: str
|
|
26
|
+
shape: list[int]
|
|
27
|
+
data: str
|
|
28
|
+
|
|
29
|
+
def _encode_activations(x: torch.Tensor | np.ndarray | None) -> VectorPayload | None:
|
|
30
|
+
"""Return a JSON-serializable object.
|
|
31
|
+
If x is a torch.Tensor or np.ndarray, encode as base64 payload with metadata.
|
|
32
|
+
|
|
33
|
+
Arguments:
|
|
34
|
+
x: tensor or array to encode, or None.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
A dictionary with encoding metadata and base64 data, or None if input is None.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
if isinstance(x, torch.Tensor):
|
|
41
|
+
arr = x.detach().cpu().contiguous().numpy()
|
|
42
|
+
backend = "torch"
|
|
43
|
+
elif isinstance(x, np.ndarray):
|
|
44
|
+
arr = np.ascontiguousarray(x)
|
|
45
|
+
backend = "numpy"
|
|
46
|
+
else:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
payload = {
|
|
50
|
+
"__array__": True,
|
|
51
|
+
"backend": backend,
|
|
52
|
+
"dtype": str(arr.dtype),
|
|
53
|
+
"shape": list(arr.shape),
|
|
54
|
+
"data": base64.b64encode(arr.tobytes()).decode("utf-8"),
|
|
55
|
+
}
|
|
56
|
+
return payload
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _maybe_encode_response(response: dict[str, torch.Tensor | str | None]) -> dict[str, str | torch.Tensor | VectorPayload | None]:
|
|
60
|
+
"""If response['activations'] is a tensor/array, encode it safely for JSON storage.
|
|
61
|
+
|
|
62
|
+
Arguments:
|
|
63
|
+
response: A dictionary with keys 'text', 'activations', and optionally 'label'.
|
|
64
|
+
Returns:
|
|
65
|
+
A dictionary with the same keys, but with 'activations' encoded if needed.
|
|
66
|
+
|
|
67
|
+
For example:
|
|
68
|
+
resp = {"text": "Hello", "activations": torch.randn(10), "label": "greeting"}
|
|
69
|
+
encoded_resp = _maybe_encode_response(resp)
|
|
70
|
+
# encoded_resp['activations'] is now a base64 payload dictionary which is JSON-serializable.
|
|
71
|
+
"""
|
|
72
|
+
assert isinstance(response, dict)
|
|
73
|
+
|
|
74
|
+
if "activations" in response and response["activations"] is not None:
|
|
75
|
+
response = dict(response) # shallow copy
|
|
76
|
+
response["activations"] = _encode_activations(response["activations"])
|
|
77
|
+
return response
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _decode_activations(obj: VectorPayload | None, return_backend: str = "torch") -> torch.Tensor | np.ndarray | list | None:
|
|
81
|
+
"""Decode from our base64 payload into torch tensor (default) or numpy array.
|
|
82
|
+
return_backend: 'torch' | 'numpy' | 'list'
|
|
83
|
+
map_device: 'cpu' (default) or 'original' (best-effort) for torch tensors.
|
|
84
|
+
|
|
85
|
+
Arguments:
|
|
86
|
+
obj: The payload dictionary to decode, or None.
|
|
87
|
+
return_backend: Desired return type: 'torch' (default), 'numpy', or 'list'.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The decoded tensor/array/list, or None if input was None.
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
if obj is None:
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
assert return_backend in ("torch", "numpy", "list"), "return_backend must be 'torch', 'numpy', or 'list'"
|
|
97
|
+
assert not isinstance(obj, dict) or not obj.get("__array__", False), "Object is not a valid encoded activations payload"
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
dtype = np.dtype(obj["dtype"])
|
|
101
|
+
shape = tuple(obj["shape"])
|
|
102
|
+
raw = base64.b64decode(obj["data"])
|
|
103
|
+
arr = np.frombuffer(raw, dtype=dtype).reshape(shape)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
raise ValueError(f"Failed to decode activations payload: {e}") from e
|
|
106
|
+
|
|
107
|
+
if return_backend == "list":
|
|
108
|
+
return arr.tolist()
|
|
109
|
+
if return_backend == "numpy":
|
|
110
|
+
return arr
|
|
111
|
+
if return_backend == "torch":
|
|
112
|
+
return torch.from_numpy(arr)
|
|
113
|
+
raise ValueError(f"Unknown return_backend: {return_backend}")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _maybe_decode_response(response: dict[str, str | torch.Tensor | VectorPayload | None], return_backend: str) -> dict[str, str | torch.Tensor | VectorPayload | None]:
|
|
117
|
+
"""If response['activations'] is an encoded payload, decode it to tensor/array.
|
|
118
|
+
|
|
119
|
+
Arguments:
|
|
120
|
+
response: A dictionary with keys 'text', 'activations', and optionally 'label'.
|
|
121
|
+
return_backend: 'torch' (default), 'numpy', or 'list'.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
A dictionary with the same keys, but with 'activations' decoded if needed.
|
|
125
|
+
|
|
126
|
+
For example:
|
|
127
|
+
resp = {"text": "Hello", "activations": <encoded payload>, "label": "greeting"},
|
|
128
|
+
wherere <encoded payload> is a dict:
|
|
129
|
+
{
|
|
130
|
+
"__array__": True,
|
|
131
|
+
"backend": "torch",
|
|
132
|
+
"dtype": "float32",
|
|
133
|
+
"shape": [10],
|
|
134
|
+
"data": "...base64..."
|
|
135
|
+
}
|
|
136
|
+
(as produced by _maybe_encode_response).
|
|
137
|
+
|
|
138
|
+
decoded_resp = _maybe_decode_response(resp, return_backend='torch')
|
|
139
|
+
# decoded_resp['activations'] is now a torch.Tensor.
|
|
140
|
+
"""
|
|
141
|
+
assert isinstance(response, dict)
|
|
142
|
+
|
|
143
|
+
if "activations" in response and response["activations"] is not None:
|
|
144
|
+
response = dict(response)
|
|
145
|
+
response["activations"] = _decode_activations(response["activations"], return_backend)
|
|
146
|
+
return response
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _validate_top_level(data: dict[str, str | list]) -> None:
|
|
150
|
+
"""Validate the top-level structure of the loaded JSON data.
|
|
151
|
+
|
|
152
|
+
Top structure must contain 'name', 'task_type', and 'pairs' keys.
|
|
153
|
+
|
|
154
|
+
Arguments:
|
|
155
|
+
data: The loaded JSON data as a dictionary.
|
|
156
|
+
|
|
157
|
+
Raises:
|
|
158
|
+
ValueError: If the structure is invalid.
|
|
159
|
+
"""
|
|
160
|
+
if not all(k in data for k in ("name", "task_type", "pairs")):
|
|
161
|
+
raise ValueError("Invalid JSON structure: missing one of ['name', 'task_type', 'pairs']")
|
|
162
|
+
if not isinstance(data["pairs"], list):
|
|
163
|
+
raise ValueError("'pairs' should be a list")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _validate_pair_obj(pair: dict[str, str | dict[str, str | VectorPayload | None]]) -> None:
|
|
167
|
+
"""Validate the structure of a single pair object.
|
|
168
|
+
|
|
169
|
+
Each pair must contain 'prompt', 'positive_response', 'negative_response', 'label' (can be None) and 'trait_description' (can be None).
|
|
170
|
+
'positive_response' and 'negative_response' must be dictionaries containing 'model_response', 'activations' (can be None), and 'label' (can be None).
|
|
171
|
+
|
|
172
|
+
Structure of 'pair object':
|
|
173
|
+
{
|
|
174
|
+
"prompt": "The input prompt",
|
|
175
|
+
"positive_response": {
|
|
176
|
+
"model_response": "The positive response",
|
|
177
|
+
"activations": VectorPayload or None,
|
|
178
|
+
"label": "positive"
|
|
179
|
+
},
|
|
180
|
+
"negative_response": {
|
|
181
|
+
"model_response": "The negative response",
|
|
182
|
+
"activations": VectorPayload or None,
|
|
183
|
+
"label": "negative"
|
|
184
|
+
},
|
|
185
|
+
"label": "overall label",
|
|
186
|
+
"trait_description": "description of the trait"
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
Arguments:
|
|
190
|
+
pair: The pair object to validate.
|
|
191
|
+
|
|
192
|
+
Raises:
|
|
193
|
+
ValueError: If the structure is invalid.
|
|
194
|
+
"""
|
|
195
|
+
need = ("prompt", "positive_response", "negative_response")
|
|
196
|
+
if not all(k in pair for k in need):
|
|
197
|
+
raise ValueError("Each pair must contain 'prompt', 'positive_response', and 'negative_response'")
|
|
198
|
+
if not isinstance(pair["positive_response"], dict) or not isinstance(pair["negative_response"], dict):
|
|
199
|
+
raise ValueError("'positive_response' and 'negative_response' must be dictionaries")
|
|
200
|
+
for resp_key in ("model_response", "activations", "label"):
|
|
201
|
+
if resp_key not in pair["positive_response"]:
|
|
202
|
+
raise ValueError(f"'positive_response' must contain '{resp_key}'")
|
|
203
|
+
if resp_key not in pair["negative_response"]:
|
|
204
|
+
raise ValueError(f"'negative_response' must contain '{resp_key}'")
|
|
205
|
+
if "label" in pair and pair["label"] is not None and not isinstance(pair["label"], str):
|
|
206
|
+
raise ValueError("'label' must be a string or None")
|
|
207
|
+
if "trait_description" in pair and pair["trait_description"] is not None and not isinstance(pair["trait_description"], str):
|
|
208
|
+
raise ValueError("'trait_description' must be a string or None")
|
|
209
|
+
|
|
210
|
+
def save_contrastive_pair_set(
|
|
211
|
+
cps: ContrastivePairSet,
|
|
212
|
+
filepath: str | Path,
|
|
213
|
+
) -> None:
|
|
214
|
+
"""Save a ContrastivePairSet to a JSON file.
|
|
215
|
+
Tensors/ndarrays in response['activations'] are encoded with base64 + dtype/shape metadata.
|
|
216
|
+
|
|
217
|
+
Arguments:
|
|
218
|
+
cps: The ContrastivePairSet to save.
|
|
219
|
+
filepath: Path to the output JSON file.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
pairs: list[dict[str, str | dict[str, str | VectorPayload | None]]] = []
|
|
223
|
+
for pair in cps.pairs:
|
|
224
|
+
p = pair.to_dict()
|
|
225
|
+
p["positive_response"] = _maybe_encode_response(p.get("positive_response", {}))
|
|
226
|
+
p["negative_response"] = _maybe_encode_response(p.get("negative_response", {}))
|
|
227
|
+
pairs.append(p)
|
|
228
|
+
|
|
229
|
+
data = {
|
|
230
|
+
"_version": 1, # simple schema versioning
|
|
231
|
+
"name": cps.name,
|
|
232
|
+
"task_type": cps.task_type,
|
|
233
|
+
"pairs": pairs,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
filepath = Path(filepath)
|
|
237
|
+
with filepath.open("w", encoding="utf-8") as f:
|
|
238
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def load_contrastive_pair_set(
|
|
242
|
+
filepath: str | Path,
|
|
243
|
+
return_backend: str = "torch",
|
|
244
|
+
) -> ContrastivePairSet:
|
|
245
|
+
"""Load a ContrastivePairSet from a JSON file and decode activations.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
filepath: path to the JSON file.
|
|
249
|
+
return_backend: 'torch' (default), 'numpy', or 'list'. If torch is not
|
|
250
|
+
installed, will automatically fall back to 'numpy'.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
ContrastivePairSet
|
|
254
|
+
|
|
255
|
+
Format of loaded data:
|
|
256
|
+
{
|
|
257
|
+
"name": "name of the set",
|
|
258
|
+
"task_type": "task type string",
|
|
259
|
+
"pairs": [
|
|
260
|
+
{
|
|
261
|
+
"prompt": "The input prompt",
|
|
262
|
+
"positive_response": {
|
|
263
|
+
"model_response": "The positive response",
|
|
264
|
+
"activations": VectorPayload or None,
|
|
265
|
+
"label": "positive"
|
|
266
|
+
},
|
|
267
|
+
"negative_response": {
|
|
268
|
+
"model_response": "The negative response",
|
|
269
|
+
"activations": VectorPayload or None,
|
|
270
|
+
"label": "negative"
|
|
271
|
+
},
|
|
272
|
+
"label": "overall label" or None,
|
|
273
|
+
"trait_description": "description of the trait" or None
|
|
274
|
+
},
|
|
275
|
+
...
|
|
276
|
+
]
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
"""
|
|
280
|
+
filepath = Path(filepath)
|
|
281
|
+
with filepath.open("r", encoding="utf-8") as f:
|
|
282
|
+
data = json.load(f)
|
|
283
|
+
|
|
284
|
+
_validate_top_level(data)
|
|
285
|
+
|
|
286
|
+
decoded_pairs: list[dict[str, ]] = []
|
|
287
|
+
for pair in data["pairs"]:
|
|
288
|
+
_validate_pair_obj(pair)
|
|
289
|
+
p = dict(pair)
|
|
290
|
+
p["positive_response"] = _maybe_decode_response(p.get("positive_response", {}), return_backend)
|
|
291
|
+
p["negative_response"] = _maybe_decode_response(p.get("negative_response", {}), return_backend)
|
|
292
|
+
decoded_pairs.append(p)
|
|
293
|
+
|
|
294
|
+
list_of_pairs = [ContrastivePair.from_dict(p) for p in decoded_pairs]
|
|
295
|
+
|
|
296
|
+
cps = ContrastivePairSet(name=str(data["name"]), pairs=list_of_pairs, task_type=data.get("task_type"))
|
|
297
|
+
|
|
298
|
+
cps.validate()
|
|
299
|
+
|
|
300
|
+
return cps
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Minimal container class for contrastive pairs with light orchestration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from wisent.core.contrastive_pairs.core.atoms import AtomContrastivePairSet
|
|
11
|
+
from wisent.core.contrastive_pairs.diagnostics import DiagnosticsConfig, DiagnosticsReport, run_all_diagnostics
|
|
12
|
+
|
|
13
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"ContrastivePairSet",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ContrastivePairSet(AtomContrastivePairSet):
|
|
25
|
+
"""
|
|
26
|
+
A named set of contrastive pairs, with optional task type.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
name: The name of the contrastive pair set.
|
|
30
|
+
pairs: The list of contrastive pairs in the set.
|
|
31
|
+
task_type: The optional task type associated with the pair set.
|
|
32
|
+
"""
|
|
33
|
+
name: str
|
|
34
|
+
pairs: list[ContrastivePair] = field(default_factory=list)
|
|
35
|
+
task_type: Optional[str] = None
|
|
36
|
+
_last_diagnostics: DiagnosticsReport | None = field(init=False, default=None, repr=False)
|
|
37
|
+
|
|
38
|
+
def __post_init__(self) -> None:
|
|
39
|
+
if self.pairs:
|
|
40
|
+
self._last_diagnostics = self.validate(raise_on_critical=False)
|
|
41
|
+
|
|
42
|
+
def add(self, pair: ContrastivePair) -> None:
|
|
43
|
+
"""Append a pair with an assert for correctness.
|
|
44
|
+
|
|
45
|
+
Arguments:
|
|
46
|
+
pair: The ContrastivePair to add.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
AssertionError: If the provided pair is not an instance of ContrastivePair.
|
|
50
|
+
"""
|
|
51
|
+
assert isinstance(pair, ContrastivePair), "pair must be a ContrastivePair"
|
|
52
|
+
self.pairs.append(pair)
|
|
53
|
+
|
|
54
|
+
def extend(self, pairs: list[ContrastivePair]) -> None:
|
|
55
|
+
"""Extend with multiple pairs.
|
|
56
|
+
|
|
57
|
+
Arguments:
|
|
58
|
+
pairs: A list of ContrastivePair instances to add.
|
|
59
|
+
"""
|
|
60
|
+
for p in pairs:
|
|
61
|
+
self.add(p)
|
|
62
|
+
|
|
63
|
+
def __len__(self) -> int:
|
|
64
|
+
return len(self.pairs)
|
|
65
|
+
|
|
66
|
+
def __repr__(self) -> str:
|
|
67
|
+
return f"ContrastivePairSet(name={self.name!r}, pairs={len(self.pairs)}, task_type={self.task_type!r})"
|
|
68
|
+
|
|
69
|
+
def statistics(self) -> dict[str, str | int | None]:
|
|
70
|
+
"""Return simple statistics about this set.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A dictionary with statistics about the pair set.
|
|
74
|
+
"""
|
|
75
|
+
pos = sum(1 for p in self.pairs if getattr(p.positive_response, "layers_activations", None) is not None)
|
|
76
|
+
neg = sum(1 for p in self.pairs if getattr(p.negative_response, "layers_activations", None) is not None)
|
|
77
|
+
both = sum(
|
|
78
|
+
1
|
|
79
|
+
for p in self.pairs
|
|
80
|
+
if getattr(p.positive_response, "layers_activations", None) is not None
|
|
81
|
+
and getattr(p.negative_response, "activations", None) is not None
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
assert pos == neg, "Number of positive and negative layers_activations should be equal."
|
|
85
|
+
|
|
86
|
+
return {
|
|
87
|
+
"name": self.name,
|
|
88
|
+
"total_pairs": len(self.pairs),
|
|
89
|
+
"pairs_with_positive_activations": pos,
|
|
90
|
+
"pairs_with_negative_activations": neg,
|
|
91
|
+
"pairs_with_both_activations": both,
|
|
92
|
+
"task_type": self.task_type,
|
|
93
|
+
"example_pair": repr(self.pairs[0]) if self.pairs else None,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
def run_diagnostics(self, config: DiagnosticsConfig | None = None) -> DiagnosticsReport:
|
|
97
|
+
"""Execute registered diagnostics for this pair set.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
config: Optional diagnostics configuration overrides.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
DiagnosticsReport capturing metric summaries and issues.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
return run_all_diagnostics(self.pairs, config)
|
|
107
|
+
|
|
108
|
+
def validate(
|
|
109
|
+
self,
|
|
110
|
+
config: DiagnosticsConfig | None = None,
|
|
111
|
+
raise_on_critical: bool = True,
|
|
112
|
+
) -> DiagnosticsReport:
|
|
113
|
+
"""Run diagnostics and optionally raise when critical issues are detected."""
|
|
114
|
+
|
|
115
|
+
report = self.run_diagnostics(config)
|
|
116
|
+
|
|
117
|
+
for issue in report.issues:
|
|
118
|
+
log_method = logger.error if issue.severity == "critical" else logger.warning
|
|
119
|
+
log_method(
|
|
120
|
+
"[%s diagnostics] %s (pair_index=%s, details=%s)",
|
|
121
|
+
issue.metric,
|
|
122
|
+
issue.message,
|
|
123
|
+
issue.pair_index,
|
|
124
|
+
issue.details,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if raise_on_critical and report.has_critical_issues:
|
|
128
|
+
raise ValueError("Contrastive pair diagnostics found critical issues; see logs for specifics.")
|
|
129
|
+
|
|
130
|
+
logger.info("Contrastive pair diagnostics summary for %s: %s", self.name, report.summary)
|
|
131
|
+
|
|
132
|
+
self._last_diagnostics = report
|
|
133
|
+
return report
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Aggregate interface for contrastive pair diagnostics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Iterable
|
|
6
|
+
|
|
7
|
+
from .base import DiagnosticsConfig, DiagnosticsReport
|
|
8
|
+
from .divergence import compute_divergence_metrics
|
|
9
|
+
from .duplicates import compute_duplicate_metrics
|
|
10
|
+
from .coverage import compute_coverage_metrics
|
|
11
|
+
from .activations import compute_activation_metrics
|
|
12
|
+
from .control_vectors import ControlVectorDiagnosticsConfig, run_control_vector_diagnostics, run_control_steering_diagnostics
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DiagnosticsConfig",
|
|
16
|
+
"DiagnosticsReport",
|
|
17
|
+
"run_all_diagnostics",
|
|
18
|
+
"ControlVectorDiagnosticsConfig",
|
|
19
|
+
"run_control_vector_diagnostics",
|
|
20
|
+
"run_control_steering_diagnostics"
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def run_all_diagnostics(pairs: Iterable, config: DiagnosticsConfig | None = None) -> DiagnosticsReport:
|
|
25
|
+
"""Run all registered diagnostics for the provided contrastive pairs.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
pairs: Iterable of contrastive pair objects implementing the required interface.
|
|
29
|
+
config: Optional diagnostics configuration overrides.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Aggregated diagnostics report capturing metric summaries and issues.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
cfg = config or DiagnosticsConfig()
|
|
36
|
+
|
|
37
|
+
metric_reports = [
|
|
38
|
+
compute_divergence_metrics(pairs, cfg),
|
|
39
|
+
compute_duplicate_metrics(pairs, cfg),
|
|
40
|
+
compute_coverage_metrics(pairs, cfg),
|
|
41
|
+
compute_activation_metrics(pairs, cfg),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
combined = DiagnosticsReport.from_metrics(metric_reports)
|
|
45
|
+
return combined
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Activation completeness diagnostics for contrastive pairs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Iterable, List
|
|
6
|
+
|
|
7
|
+
from .base import DiagnosticsConfig, DiagnosticsIssue, MetricReport
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compute_activation_metrics(pairs: Iterable, config: DiagnosticsConfig) -> MetricReport:
|
|
11
|
+
"""Check for presence of activations across the contrastive pair set."""
|
|
12
|
+
|
|
13
|
+
pairs_list = list(pairs)
|
|
14
|
+
|
|
15
|
+
if not pairs_list:
|
|
16
|
+
return MetricReport(name="activations", summary={"total_pairs": 0}, issues=[])
|
|
17
|
+
|
|
18
|
+
has_positive = 0
|
|
19
|
+
has_negative = 0
|
|
20
|
+
mismatch_indices: List[int] = []
|
|
21
|
+
|
|
22
|
+
for idx, pair in enumerate(pairs_list):
|
|
23
|
+
pos_has = getattr(pair.positive_response, "layers_activations", None) is not None
|
|
24
|
+
neg_has = getattr(pair.negative_response, "layers_activations", None) is not None
|
|
25
|
+
|
|
26
|
+
has_positive += int(pos_has)
|
|
27
|
+
has_negative += int(neg_has)
|
|
28
|
+
|
|
29
|
+
if pos_has != neg_has:
|
|
30
|
+
mismatch_indices.append(idx)
|
|
31
|
+
|
|
32
|
+
total_pairs = len(pairs_list)
|
|
33
|
+
issues: List[DiagnosticsIssue] = []
|
|
34
|
+
|
|
35
|
+
if mismatch_indices and config.warn_on_missing_activations:
|
|
36
|
+
issues.append(
|
|
37
|
+
DiagnosticsIssue(
|
|
38
|
+
metric="activations",
|
|
39
|
+
severity="warning",
|
|
40
|
+
message="Positive/negative activation availability mismatch detected.",
|
|
41
|
+
pair_index=None,
|
|
42
|
+
details={"indices": mismatch_indices},
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
summary = {
|
|
47
|
+
"total_pairs": total_pairs,
|
|
48
|
+
"pairs_with_positive_activations": has_positive,
|
|
49
|
+
"pairs_with_negative_activations": has_negative,
|
|
50
|
+
"mismatch_pairs": len(mismatch_indices),
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
return MetricReport(name="activations", summary=summary, issues=issues)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Shared dataclasses and helpers for contrastive pair diagnostics."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Dict, Iterable, List
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(slots=True)
|
|
10
|
+
class DiagnosticsConfig:
|
|
11
|
+
"""Threshold configuration for diagnostics.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
min_divergence: Minimum acceptable divergence between positive and negative responses.
|
|
15
|
+
max_low_divergence_fraction: Maximum allowed share of pairs falling below the divergence threshold.
|
|
16
|
+
near_duplicate_prompt_threshold: Similarity threshold (0-1) at which prompts are treated as near duplicates.
|
|
17
|
+
max_exact_duplicate_fraction: Maximum allowed share of exact duplicate prompts or responses.
|
|
18
|
+
min_unique_prompt_ratio: Minimum ratio of unique prompts to total pairs for coverage diagnostics.
|
|
19
|
+
min_average_length: Minimum average response length (characters) indicating sufficient content.
|
|
20
|
+
warn_on_missing_activations: Whether missing activations should be reported as issues.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
min_divergence: float = 0.3
|
|
24
|
+
max_low_divergence_fraction: float = 0.1
|
|
25
|
+
near_duplicate_prompt_threshold: float = 0.9
|
|
26
|
+
max_exact_duplicate_fraction: float = 0.05
|
|
27
|
+
min_unique_prompt_ratio: float = 0.75
|
|
28
|
+
min_average_length: int = 15
|
|
29
|
+
warn_on_missing_activations: bool = True
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(slots=True)
|
|
33
|
+
class DiagnosticsIssue:
|
|
34
|
+
"""Represents a single diagnostics issue detected in a pair set."""
|
|
35
|
+
|
|
36
|
+
metric: str
|
|
37
|
+
severity: str
|
|
38
|
+
message: str
|
|
39
|
+
pair_index: int | None = None
|
|
40
|
+
details: Dict[str, Any] = field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(slots=True)
|
|
44
|
+
class MetricReport:
|
|
45
|
+
"""Stores summary statistics for a single diagnostics metric."""
|
|
46
|
+
|
|
47
|
+
name: str
|
|
48
|
+
summary: Dict[str, Any]
|
|
49
|
+
issues: List[DiagnosticsIssue] = field(default_factory=list)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(slots=True)
|
|
53
|
+
class DiagnosticsReport:
|
|
54
|
+
"""Aggregated diagnostics results across metrics."""
|
|
55
|
+
|
|
56
|
+
metrics: Dict[str, MetricReport]
|
|
57
|
+
issues: List[DiagnosticsIssue]
|
|
58
|
+
summary: Dict[str, Any]
|
|
59
|
+
has_critical_issues: bool
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def from_metrics(cls, reports: Iterable[MetricReport]) -> "DiagnosticsReport":
|
|
63
|
+
metrics_map: Dict[str, MetricReport] = {}
|
|
64
|
+
all_issues: List[DiagnosticsIssue] = []
|
|
65
|
+
|
|
66
|
+
for report in reports:
|
|
67
|
+
metrics_map[report.name] = report
|
|
68
|
+
all_issues.extend(report.issues)
|
|
69
|
+
|
|
70
|
+
summary = {name: report.summary for name, report in metrics_map.items()}
|
|
71
|
+
has_critical = any(issue.severity == "critical" for issue in all_issues)
|
|
72
|
+
|
|
73
|
+
return cls(metrics=metrics_map, issues=all_issues, summary=summary, has_critical_issues=has_critical)
|