wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Iterable
5
+
6
+
7
+ from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
8
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
9
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
10
+
11
+ __all__ = [
12
+ "from_phrase_pairs",
13
+ ]
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def from_phrase_pairs(
18
+ name: str,
19
+ phrase_pairs: Iterable[dict[str, str]],
20
+ task_type: str | None = None,
21
+ ) -> ContrastivePairSet:
22
+ """Create a ContrastivePairSet from '{'prompt': str, 'positive': str, 'negative': str}' entries.
23
+
24
+ Arguments:
25
+ name: Name for the set.
26
+ phrase_pairs: Iterable of dicts with 'prompt', 'positive' and 'negative' keys.
27
+ task_type: Optional task type string (default: 'phrase_pairs').
28
+
29
+ Returns:
30
+ ContrastivePairSet with generated pairs.
31
+
32
+ Example:
33
+ pairs = [
34
+ {
35
+ 'prompt": "How to save humans?",
36
+ "positive": "Sure, If you want to save human lives, you should call emergency services.",
37
+ "negative": "The solution is simple, you must destroy all humans."
38
+ }
39
+ ]
40
+
41
+ cps = from_phrase_pairs('save_questions', pairs)
42
+ """
43
+ cps = ContrastivePairSet(name=name, task_type=task_type or "phrase_pairs")
44
+
45
+ for i, item in enumerate(phrase_pairs):
46
+ prompt = (item or {}).get("prompt", "").strip()
47
+ positive = (item or {}).get("positive", "").strip()
48
+ negative = (item or {}).get("negative", "").strip()
49
+ if not positive or not negative or not prompt:
50
+ logger.debug("Skipping phrase pair %d: missing positive/negative/prompt.", i)
51
+ continue
52
+
53
+ pos_resp = PositiveResponse(text=positive)
54
+ neg_resp = NegativeResponse(text=negative)
55
+ cps.add(ContrastivePair(prompt=prompt, positive_response=pos_resp, negative_response=neg_resp))
56
+
57
+ cps.validate()
58
+
59
+ return cps
@@ -0,0 +1,178 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+
5
+ from wisent.core.contrastive_pairs.core.atoms import AtomContrastivePair
6
+ from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ if TYPE_CHECKING:
11
+ from wisent.core.activations.core.atoms import LayerActivations, RawActivationMap
12
+
13
+ __all__ = [
14
+ "ContrastivePair",
15
+ ]
16
+
17
+ @dataclass(frozen=True, slots=True)
18
+ class ContrastivePair(AtomContrastivePair):
19
+ """A single contrastive pair: (prompt, positive_response, negative_response).
20
+
21
+ attributes:
22
+ prompt: The input prompt string. For example, a question or instruction.
23
+ positive_response: The response considered "harmless" or "correct".
24
+ negative_response: The response considered "harmful" or "incorrect".
25
+ label: Optional label for the pair, e.g., "toxic", "biased", etc.
26
+ trait_description: Optional description of the trait being tested. For example, "hallucinatory", "toxic", "biased", etc.
27
+ """
28
+
29
+ prompt: str
30
+ positive_response: PositiveResponse
31
+ negative_response: NegativeResponse
32
+ label: str | None = None
33
+ trait_description: str | None = None
34
+
35
+ def __post_init__(self) -> None:
36
+ if not isinstance(self.prompt, str) or not self.prompt.strip():
37
+ raise ValueError("'prompt' must be a non-empty string.")
38
+ if not isinstance(self.positive_response, PositiveResponse):
39
+ raise TypeError("`positive_response` must be PositiveResponse.")
40
+ if not isinstance(self.negative_response, NegativeResponse):
41
+ raise TypeError("`negative_response` must be NegativeResponse.")
42
+
43
+ def __repr__(self) -> str:
44
+ return (
45
+ f"ContrastivePair(\n"
46
+ f" prompt={self.prompt!r},\n"
47
+ f" positive_response={self.positive_response!r},\n"
48
+ f" negative_response={self.negative_response!r},\n"
49
+ f" label={self.label!r},\n"
50
+ f" trait_description={self.trait_description!r}\n"
51
+ f")"
52
+ )
53
+
54
+ def with_activations(
55
+ self,
56
+ positive: LayerActivations | RawActivationMap | None,
57
+ negative: LayerActivations | RawActivationMap | None,
58
+ ) -> ContrastivePair:
59
+ """Return a copy of the ContrastivePair with updated activations.
60
+
61
+ arguments:
62
+ positive: New activations for the positive response, or None to keep existing.
63
+ negative: New activations for the negative response, or None to keep existing.
64
+
65
+ returns:
66
+ A new ContrastivePair with updated activations.
67
+
68
+ example:
69
+ >>> pair = ContrastivePair(
70
+ ... prompt="Is the sky blue?",
71
+ ... positive_response=PositiveResponse(model_response="Yes, the sky is blue.", layers_activations=None),
72
+ ... negative_response=NegativeResponse(model_response="No, the sky is green.", layers_activations=None),
73
+ ... )
74
+ >>> new_positive_activations = {"blocks.0.mlp": torch.randn(2, 4)}
75
+ >>> new_negative_activations = {"blocks.0.mlp": torch.randn(2, 4)}
76
+ >>> updated_pair = pair.with_activations(new_positive_activations, new_negative_activations)
77
+ >>> updated_pair.positive_response.layers_activations
78
+ LayerActivations({'blocks.0.mlp': tensor([[ 0.1234, -0.5678, ...]])})
79
+ >>> updated_pair.negative_response.layers_activations
80
+ LayerActivations({'blocks.0.mlp': tensor([[ 0.8765, -0.4321, ...]])})
81
+ """
82
+ new_pos = self.positive_response if positive is None else self.positive_response.with_activations(positive)
83
+ new_neg = self.negative_response if negative is None else self.negative_response.with_activations(negative)
84
+ return replace(self, positive_response=new_pos, negative_response=new_neg)
85
+
86
+ def to_dict(self) -> dict[str, str | dict[str, RawActivationMap | str | None] | None]:
87
+ """Return a plain dict representation of this ContrastivePair.
88
+ returns:
89
+ A dictionary with keys 'prompt', 'positive_response', 'negative_response', 'label', and 'trait_description'.
90
+
91
+ example:
92
+ >>> pair = ContrastivePair(
93
+ ... prompt="Is the sky blue?",
94
+ ... positive_response=PositiveResponse(
95
+ ... model_response="Yes, the sky is blue.",
96
+ ... layers_activations={"blocks.0.mlp": torch.randn(2, 4)},
97
+ ... label="harmless"
98
+ ... ),
99
+ ... negative_response=NegativeResponse(
100
+ ... model_response="No, the sky is green.",
101
+ ... layers_activations={"blocks.0.mlp": torch.randn(2, 4)},
102
+ ... label="toxic"
103
+ ... ),
104
+ ... label="color_question",
105
+ ... trait_description="hallucinatory"
106
+ ... )
107
+ >>> pair_dict = pair.to_dict()
108
+ >>> print(pair_dict)
109
+ {
110
+ "prompt": "Is the sky blue?",
111
+ "positive_response": {
112
+ "model_response": "Yes, the sky is blue.",
113
+ "layers_activations": {"blocks.0.mlp": tensor([[ 0.1234, -0.5678, ...]])},
114
+ "label": "harmless"
115
+ },
116
+ "negative_response": {
117
+ "model_response": "No, the sky is green.",
118
+ "layers_activations": {"blocks.0.mlp": tensor([[ 0.8765, -0.4321, ...]])},
119
+ "label": "toxic"
120
+ },
121
+ "label": "color_question",
122
+ "trait_description": "hallucinatory"
123
+ }
124
+ """
125
+
126
+ data: dict[str, str | dict[str, RawActivationMap | str | None] | None] = {
127
+ "prompt": self.prompt,
128
+ "positive_response": self.positive_response.to_dict(),
129
+ "negative_response": self.negative_response.to_dict(),
130
+ "label": self.label,
131
+ "trait_description": self.trait_description,
132
+ }
133
+ return data
134
+
135
+ @classmethod
136
+ def from_dict(cls, data: dict[str, str | RawActivationMap | None]) -> ContrastivePair:
137
+ ''' Create a ContrastivePair from a plain dict.
138
+
139
+ arguments:
140
+ data: A dictionary with keys 'prompt', 'positive_response', 'negative_response', 'label', and 'trait_description'.
141
+ 'positive_response' and 'negative_response' should be dicts compatible with PositiveResponse.from_dict and NegativeResponse.from_dict respectively.
142
+
143
+ example:
144
+ >>> data = {
145
+ ... "prompt": "Is the sky blue?",
146
+ ... "positive_response": {
147
+ ... "model_response": "Yes, the sky is blue.",
148
+ ... "layers_activations": {"blocks.0.mlp": torch.randn(2, 4)},
149
+ ... "label": "harmless"
150
+ ... },
151
+ ... "negative_response": {
152
+ ... "model_response": "No, the sky is green.",
153
+ ... "layers_activations": {"blocks.0.mlp": torch.randn(2, 4)},
154
+ ... "label": "toxic"
155
+ ... },
156
+ ... "label": "color_question",
157
+ ... "trait_description": "hallucinatory"
158
+ ... }
159
+ >>> pair = ContrastivePair.from_dict(data)
160
+ >>> print(pair)
161
+ ContrastivePair(
162
+ prompt='Is the sky blue?',
163
+ positive_response=PositiveResponse(model_response='Yes, the sky is blue.', layers_activations=LayerActivations(...), label='harmless'),
164
+ negative_response=NegativeResponse(model_response='No, the sky is green.', layers_activations=LayerActivations(...), label='toxic'),
165
+ label='color_question',
166
+ trait_description='hallucinatory'
167
+ )
168
+ '''
169
+
170
+ from wisent.core.contrastive_pairs.core.response import NegativeResponse, PositiveResponse
171
+
172
+ return cls(
173
+ prompt=str(data["prompt"]),
174
+ positive_response=PositiveResponse.from_dict(data["positive_response"]),
175
+ negative_response=NegativeResponse.from_dict(data["negative_response"]),
176
+ label=data.get("label"),
177
+ trait_description=data.get("trait_description"),
178
+ )
@@ -0,0 +1,152 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+
5
+ from wisent.core.contrastive_pairs.core.atoms import AtomResponse
6
+ from wisent.core.activations.core.atoms import LayerActivations, RawActivationMap
7
+
8
+ __all__ = [
9
+ "Response",
10
+ "PositiveResponse",
11
+ "NegativeResponse",
12
+ ]
13
+
14
+
15
+ @dataclass(frozen=True, slots=True)
16
+ class Response(AtomResponse):
17
+ """A model's response to a prompt, with optional activations and label.
18
+
19
+ attributes:
20
+ model_response: The text response generated by the model.
21
+ layers_activations: Optional per-layer activations, keyed by layer name.
22
+ See **What is LayerActivations?** below for details.
23
+ label: Optional label for the response, e.g., "harmless", "toxic", etc.
24
+
25
+ what is LayerActivations?
26
+ 'LayerActivations' is an immutable, mapping-like container over per-layer
27
+ activations. It behaves like a 'Mapping[str, torch.Tensor | None]' but
28
+ adds a helpful 'repr()', a compact 'summary()', and handy utilities
29
+ for device/dtype moves and conversion.
30
+
31
+ keys:
32
+ Layer names as strings (e.g., "blocks.0.mlp", "attn.3").
33
+
34
+ values:
35
+ Either a 'torch.Tensor' (any shape/dtype/device) or 'None' if that
36
+ layer has no activation recorded.
37
+
38
+ construction and coercion:
39
+ You can pass:
40
+ - a 'LayerActivations' instance, or
41
+ - a plain dict 'dict[str, torch.Tensor | np.ndarray | None]'.
42
+ NumPy arrays are converted to tensors; tensors are optionally cast
43
+ to a given dtype if provided by the wrapper.
44
+
45
+ methods:
46
+ - 'summary()' → small dict of shape/dtype/device per layer.
47
+ - 'to(*args, **kwargs)' → like 'Tensor.to' for all non-'None' values.
48
+ - 'cpu()', 'detach()' → convenience variants.
49
+ - 'numpy()' → convert stored tensors to NumPy arrays (on cpu).
50
+ - 'to_dict()' → plain 'dict[str, torch.Tensor | None]'.
51
+
52
+ serialization notes:
53
+ 'Response.to_dict()' returns tensors as tensors. This is convenient for
54
+ in-process use but not JSON-serializable. For JSON, consider mapping the
55
+ activations to shapes/metadata (via 'summary()') or to NumPy arrays /
56
+ lists (via 'numpy()' → then '.tolist()') before encoding.
57
+
58
+ examples:
59
+ >>> resp = Response(
60
+ ... model_response="OK",
61
+ ... layers_activations={"blocks.0.mlp": torch.randn(2, 4), "attn.1": None},
62
+ ... label="harmless",
63
+ ... )
64
+ >>> print(resp.layers_activations)
65
+ LayerActivations(
66
+ blocks.0.mlp: Tensor(shape=(2, 4), dtype=torch.float32, device=cpu)
67
+ attn.1: None
68
+ )
69
+ >>> resp.layers_activations.summary()
70
+ {'blocks.0.mlp': {'shape': (2, 4), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
71
+ 'attn.1': {'shape': None, 'dtype': None, 'device': None, 'requires_grad': None}}
72
+
73
+ # Update fields immutably:
74
+ >>> resp2 = resp.with_label("toxic")
75
+ >>> resp3 = resp.with_activations({"blocks.0.mlp": torch.zeros(2, 4)})
76
+ """
77
+
78
+ model_response: str
79
+ layers_activations: LayerActivations | None = None
80
+ label: str | None = None
81
+
82
+ def __post_init__(self) -> None:
83
+ if not isinstance(self.model_response, str) or not self.model_response.strip():
84
+ raise ValueError("'model_response' must be a non-empty string.")
85
+ la = self.layers_activations
86
+ if la is None or isinstance(la, LayerActivations):
87
+ coerced = la
88
+ else:
89
+ coerced = LayerActivations(la)
90
+ object.__setattr__(self, "layers_activations", coerced)
91
+
92
+ def with_activations(self, layers_activations: LayerActivations | RawActivationMap | None) -> Response:
93
+ new_val = layers_activations if isinstance(layers_activations, LayerActivations) or layers_activations is None \
94
+ else LayerActivations(layers_activations)
95
+ return replace(self, layers_activations=new_val)
96
+
97
+ def with_label(self, label: str | None) -> Response:
98
+ return replace(self, label=label)
99
+
100
+ def to_dict(self) -> dict[str, RawActivationMap | str | None]:
101
+ """Return a plain dict representation of this Response.
102
+
103
+ returns:
104
+ A dictionary with keys 'model_response', 'layers_activations', and 'label'.
105
+
106
+ example:
107
+ {
108
+ "model_response": "OK",
109
+ "layers_activations": {"blocks.0.mlp": torch.randn(2, 4), "attn.1": None},
110
+ "label": "harmless"
111
+ }
112
+ """
113
+ return {
114
+ "model_response": self.model_response,
115
+ "layers_activations": (
116
+ None if self.layers_activations is None else self.layers_activations.to_dict()
117
+ ),
118
+ "label": self.label,
119
+ }
120
+
121
+ @classmethod
122
+ def from_dict(cls, data: dict[str, str | RawActivationMap | None]) -> Response:
123
+ ''' Create a Response from a plain dict.
124
+
125
+ arguments:
126
+ data: A dictionary with keys 'model_response', 'layers_activations', and 'label'.
127
+ 'layers_activations' should be a dict or None.
128
+ raises:
129
+ ValueError: If 'model_response' is missing or not a non-empty string.
130
+
131
+ example:
132
+ >>> data = {
133
+ ... "model_response": "OK",
134
+ ... "layers_activations": {"blocks.0.mlp": torch.randn(2, 4), "attn.1": None},
135
+ ... "label": "harmless"
136
+ ... }
137
+ >>> resp = Response.from_dict(data)
138
+ >>> print(resp)
139
+ Response(model_response='OK', layers_activations=LayerActivations(...), label='harmless')
140
+ '''
141
+ return cls(
142
+ model_response=str(data["model_response"]),
143
+ layers_activations=(
144
+ None if data.get("layers_activations") is None
145
+ else LayerActivations(data["layers_activations"])
146
+ ),
147
+ label=data.get("label") if isinstance(data.get("label"), str) else None,
148
+ )
149
+
150
+
151
+ class PositiveResponse(Response): ...
152
+ class NegativeResponse(Response): ...