wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,338 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Sequence
4
+ import torch
5
+
6
+
7
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
8
+ from wisent.core.activations.core.atoms import LayerActivations, ActivationAggregationStrategy, LayerName, RawActivationMap
9
+ from wisent.core.models.wisent_model import WisentModel
10
+ __all__ = ["ActivationCollector"]
11
+
12
+ @dataclass(slots=True)
13
+ class ActivationCollector:
14
+ """
15
+ Collect per-layer activations for (prompt + response) using a chat template.
16
+
17
+ arguments:
18
+ model:
19
+ :class: WisentModel
20
+ store_device:
21
+ Device to store collected activations on (default "cpu").
22
+ dtype:
23
+ Optional torch.dtype to cast activations to (e.g., torch.float32).
24
+ If None, keep original dtype.
25
+
26
+ detailed explanation:
27
+
28
+ Let:
29
+ - L = 4 transformer blocks
30
+ - hidden size H = 256
31
+ - prompt tokenized length T_prompt = 14
32
+ - full sequence (prompt + response) tokenized length T_full = 22
33
+
34
+ Step 1: Build templated strings (NOT tokenized yet)
35
+ prompt_text = tok.apply_chat_template(
36
+ [{"role": "user", "content": prompt}],
37
+ tokenize=False, add_generation_prompt=True
38
+ )
39
+ full_text = tok.apply_chat_template(
40
+ [{"role": "user", "content": prompt},
41
+ {"role": "assistant", "content": response}],
42
+ tokenize=False, add_generation_prompt=False
43
+ )
44
+
45
+ Step 2: Tokenize both with identical flags
46
+ prompt_enc = tok(prompt_text, return_tensors="pt", add_special_tokens=False)
47
+ full_enc = tok(full_text, return_tensors="pt", add_special_tokens=False)
48
+
49
+ Shapes:
50
+ prompt_enc["input_ids"].shape == (1, T_prompt) == (1, 14)
51
+ full_enc["input_ids"].shape == (1, T_full) == (1, 22)
52
+
53
+ Boundary:
54
+ prompt_len = prompt_enc["input_ids"].shape[-1] == 14
55
+ continuation tokens in the full sequence start at index 14.
56
+
57
+ Step 3: Forward pass with hidden states
58
+ out = model.hf_model(**full_enc, output_hidden_states=True, use_cache=False)
59
+ hs = out.hidden_states
60
+
61
+ hs is a tuple of length L + 1 (includes embedding layer at index 0):
62
+ len(hs) == 5 -> indices: 0=embeddings, 1..4 = blocks
63
+ Each hs[i].shape == (1, T_full, H) == (1, 22, 256)
64
+
65
+ We map layer names "1".."L" to hs[1]..hs[L]:
66
+ "1" -> hs[1], "2" -> hs[2], ..., "4" -> hs[4]
67
+
68
+ Step 4: Per-layer extraction
69
+ For a chosen layer i (1-based), get hs[i].squeeze(0) -> shape (T_full, H) == (22, 256)
70
+
71
+ If return_full_sequence=True:
72
+ store value with shape (T_full, H) == (22, 256)
73
+ Else (aggregate to a single vector [H]):
74
+ - CONTINUATION_TOKEN / CHOICE_TOKEN: take first continuation token -> cont[0] -> (H,)
75
+ - FIRST_TOKEN: layer_seq[0] -> (H,)
76
+ - LAST_TOKEN: layer_seq[-1] -> (H,)
77
+ - MEAN_POOLING: cont.mean(0) -> (H,)
78
+ - MAX_POOLING: cont.max(0)[0] -> (H,)
79
+
80
+ where:
81
+ layer_seq = hs[i].squeeze(0) # (22, 256)
82
+ cont_start = prompt_len = 14
83
+ cont = layer_seq[14:] # (22-14=8, 256)
84
+
85
+ Step 5: Storage and return
86
+ - We move each stored tensor to 'store_device' (default "cpu") and cast to 'dtype'
87
+ if provided (e.g., float32).
88
+ - Keys are layer names: "1", "2", ..., "L".
89
+ - Results are wrapped into LayerActivations with `activation_aggregation_strategy`
90
+ set to your chosen strategy (or None if keeping full sequences).
91
+
92
+ examples:
93
+ Example usage (aggregated vectors per layer)
94
+ >>> collector = ActivationCollector(model=my_wrapper, store_device="cpu", dtype=torch.float32)
95
+ >>> updated_pair = collector.collect_for_pair(
96
+ ... pair,
97
+ ... layers=["1", "3"], # subset (or None for all)
98
+ ... aggregation=ActivationAggregationStrategy.CONTINUATION_TOKEN,
99
+ ... return_full_sequence=False,
100
+ ... )
101
+ >>> pos_acts = updated_pair.positive_response.layers_activations
102
+ >>> pos_acts.summary()
103
+ {
104
+ '1': {'shape': (256,), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
105
+ '3': {'shape': (256,), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
106
+ '_activation_aggregation_strategy': {'strategy': 'continuation_token'}
107
+ }
108
+
109
+ Example usage (full sequences per layer)
110
+ >>> updated_pair = collector.collect_for_pair(
111
+ ... pair,
112
+ ... layers=None, # all layers "1".."L"
113
+ ... aggregation=ActivationAggregationStrategy.MEAN_POOLING, # ignored when return_full_sequence=True
114
+ ... return_full_sequence=True,
115
+ ... )
116
+ >>> neg_acts = updated_pair.negative_response.layers_activations
117
+ >>> # Suppose L=4 and T_full=22, H=256
118
+ >>> neg_acts.summary()
119
+ {
120
+ '1': {'shape': (22, 256), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
121
+ '2': {'shape': (22, 256), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
122
+ '3': {'shape': (22, 256), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
123
+ '4': {'shape': (22, 256), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False},
124
+ '_activation_aggregation_strategy': {'strategy': None}
125
+ }
126
+ """
127
+
128
+ model: WisentModel
129
+ store_device: str | torch.device = "cpu"
130
+ dtype: torch.dtype | None = None
131
+
132
+ def collect_for_pair(
133
+ self,
134
+ pair: ContrastivePair,
135
+ layers: Sequence[LayerName] | None = None,
136
+ aggregation: ActivationAggregationStrategy = ActivationAggregationStrategy.CONTINUATION_TOKEN,
137
+ return_full_sequence: bool = False,
138
+ normalize_layers: bool = False,
139
+ ) -> ContrastivePair:
140
+ pos = self._collect_for_texts(pair.prompt, _resp_text(pair.positive_response),
141
+ layers, aggregation, return_full_sequence, normalize_layers)
142
+ neg = self._collect_for_texts(pair.prompt, _resp_text(pair.negative_response),
143
+ layers, aggregation, return_full_sequence, normalize_layers)
144
+ return pair.with_activations(positive=pos, negative=neg)
145
+
146
+ def _collect_for_texts(
147
+ self,
148
+ prompt: str,
149
+ response: str,
150
+ layers: Sequence[LayerName] | None,
151
+ aggregation: ActivationAggregationStrategy,
152
+ return_full_sequence: bool,
153
+ normalize_layers: bool = False,
154
+ ) -> LayerActivations:
155
+
156
+ self._ensure_eval_mode()
157
+ with torch.inference_mode():
158
+ tok = self.model.tokenizer # type: ignore[union-attr]
159
+ if not hasattr(tok, "apply_chat_template"):
160
+ raise RuntimeError("Tokenizer has no apply_chat_template; set it up or use a non-chat path.")
161
+
162
+ # 1) Build templated strings
163
+ prompt_text = tok.apply_chat_template(
164
+ [{"role": "user", "content": prompt}],
165
+ tokenize=False,
166
+ add_generation_prompt=True,
167
+ )
168
+ full_text = tok.apply_chat_template(
169
+ [{"role": "user", "content": prompt},
170
+ {"role": "assistant", "content": response}],
171
+ tokenize=False,
172
+ add_generation_prompt=False,
173
+ )
174
+
175
+ # 2) Tokenize both with identical flags
176
+ prompt_enc = tok(prompt_text, return_tensors="pt", add_special_tokens=False)
177
+ full_enc = tok(full_text, return_tensors="pt", add_special_tokens=False)
178
+
179
+ # 3) Boundary from prompt-only tokens (CPU is fine)
180
+ prompt_len = int(prompt_enc["input_ids"].shape[-1])
181
+
182
+ # 4) Move only the batch that goes into the model
183
+ compute_device = getattr(self.model, "compute_device", None) or next(self.model.hf_model.parameters()).device
184
+ full_enc = {k: v.to(compute_device) for k, v in full_enc.items()}
185
+
186
+ # 5) Forward on the full sequence to get hidden states
187
+ out = self.model.hf_model(**full_enc, output_hidden_states=True, use_cache=False)
188
+ hs: tuple[torch.Tensor, ...] = out.hidden_states # hs[0]=emb, hs[1:]=layers
189
+
190
+ if not hs:
191
+ raise RuntimeError("No hidden_states returned. Can be due to model not supporting it.")
192
+
193
+ n_blocks = len(hs) - 1
194
+ names_by_idx = [str(i) for i in range(1, n_blocks + 1)]
195
+
196
+ keep = self._select_indices(layers, n_blocks)
197
+ collected: RawActivationMap = {}
198
+
199
+ for idx in keep:
200
+ name = names_by_idx[idx]
201
+ h = hs[idx + 1].squeeze(0) # [1, T, H] -> [T, H]
202
+ if return_full_sequence:
203
+ value = h
204
+ else:
205
+ value = self._aggregate(h, aggregation, prompt_len)
206
+ value = value.to(self.store_device)
207
+ if self.dtype is not None:
208
+ value = value.to(self.dtype)
209
+
210
+ if normalize_layers:
211
+ value = self._normalization(value)
212
+
213
+ collected[name] = value
214
+
215
+ return LayerActivations(
216
+ collected,
217
+ activation_aggregation_strategy=None if return_full_sequence else aggregation,
218
+ )
219
+
220
+ def _select_indices(self, layer_names: Sequence[str] | None, n_blocks: int) -> list[int]:
221
+ """Map layer names '1'..'L' -> indices 0..L-1."""
222
+ if not layer_names:
223
+ return list(range(n_blocks))
224
+ out: list[int] = []
225
+ for name in layer_names:
226
+ try:
227
+ i = int(name)
228
+ except ValueError:
229
+ raise KeyError(f"Layer name must be numeric string like '3', got {name!r}")
230
+ if not (1 <= i <= n_blocks):
231
+ raise IndexError(f"Layer '{i}' out of range 1..{n_blocks}")
232
+ out.append(i - 1)
233
+ return sorted(set(out))
234
+
235
+ def _aggregate(
236
+ self,
237
+ layer_seq: torch.Tensor, # [T, H]
238
+ aggregation: ActivationAggregationStrategy,
239
+ prompt_len: int,
240
+ ) -> torch.Tensor: # [H]
241
+ if layer_seq.ndim != 2:
242
+ raise ValueError(f"Expected [seq_len, hidden_dim], got {tuple(layer_seq.shape)}")
243
+
244
+ # continuation = tokens after the prompt boundary
245
+ cont_start = min(max(prompt_len, 0), layer_seq.shape[0] - 1)
246
+ cont = layer_seq[cont_start:] if cont_start < layer_seq.shape[0] else layer_seq[-1:].contiguous()
247
+ if cont.numel() == 0:
248
+ cont = layer_seq[-1:].contiguous()
249
+
250
+ s = aggregation
251
+
252
+ if s in (ActivationAggregationStrategy.CONTINUATION_TOKEN):
253
+ return cont[0]
254
+
255
+ elif s in (ActivationAggregationStrategy.CHOICE_TOKEN):
256
+ choice_idx = prompt_len + 1
257
+ if choice_idx < layer_seq.shape[0]:
258
+ return layer_seq[choice_idx]
259
+ else:
260
+ return layer_seq[-1]
261
+ elif s is ActivationAggregationStrategy.FIRST_TOKEN:
262
+ return layer_seq[0]
263
+ elif s is ActivationAggregationStrategy.LAST_TOKEN:
264
+ return layer_seq[-1]
265
+ elif s is ActivationAggregationStrategy.MEAN_POOLING:
266
+ return cont.mean(dim=0)
267
+ elif s is ActivationAggregationStrategy.MAX_POOLING:
268
+ return cont.max(dim=0).values
269
+ else:
270
+ return cont[0]
271
+
272
+ def _normalization(
273
+ self,
274
+ x: torch.Tensor,
275
+ dim: int = -1,
276
+ eps: float = 1e-12,
277
+ ) -> torch.Tensor:
278
+ """
279
+ Safely L2-normalize 'x' along 'dim'.
280
+
281
+ arguments:
282
+ x:
283
+ Tensor of the shape [..., H] or [T, H]
284
+ dim:
285
+ Dimension along which to normalize (default -1, the last dimension).
286
+ eps:
287
+ Small value to avoid division by zero (default 1e-12).
288
+
289
+ returns:
290
+ L2-normalized tensor of the same shape as 'x'.
291
+ """
292
+ if not torch.is_floating_point(x):
293
+ return x
294
+
295
+ norm = torch.linalg.vector_norm(x, ord=2, dim=dim, keepdim=True)
296
+
297
+ mask = norm > eps
298
+
299
+ safe_norm = torch.where(mask, norm, torch.ones_like(norm))
300
+ y = x / safe_norm
301
+ y = torch.where(mask, y, torch.zeros_like(y))
302
+
303
+ return y
304
+
305
+ def _ensure_eval_mode(self) -> None:
306
+ try:
307
+ self.model.hf_model.eval()
308
+ except Exception:
309
+ pass
310
+
311
+ def _resp_text(resp_obj: object) -> str:
312
+ for attr in ("model_response", "text"):
313
+ if hasattr(resp_obj, attr) and isinstance(getattr(resp_obj, attr), str):
314
+ return getattr(resp_obj, attr)
315
+ return str(resp_obj)
316
+
317
+ if __name__ == "__main__":
318
+ from wisent.core.contrastive_pairs.core.pair import ContrastivePair
319
+ from wisent.core.contrastive_pairs.core.response import PositiveResponse, NegativeResponse
320
+
321
+ model = WisentModel(model_name="/home/gg/.cache/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6")
322
+ collector = ActivationCollector(model=model, store_device="cpu")
323
+
324
+ pair = ContrastivePair(
325
+ prompt="The capital of France is",
326
+ positive_response=PositiveResponse(" Paris."),
327
+ negative_response=NegativeResponse(" London."),
328
+ )
329
+
330
+ updated = collector.collect_for_pair(
331
+ pair,
332
+ layers=["1", "3"],
333
+ aggregation=ActivationAggregationStrategy.CONTINUATION_TOKEN,
334
+ return_full_sequence=False,
335
+ )
336
+
337
+ print(updated)
338
+
File without changes
@@ -0,0 +1,216 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum, auto, unique
4
+ from typing import Mapping, Iterator, TypeAlias
5
+ import numpy as np
6
+ import torch
7
+ import sys
8
+
9
+ # Python 3.10 compatibility
10
+ if sys.version_info >= (3, 11):
11
+ from enum import StrEnum
12
+ else:
13
+ class StrEnum(str, Enum):
14
+ """StrEnum backport for Python < 3.11"""
15
+ def _generate_next_value_(name, start, count, last_values):
16
+ return name.lower()
17
+
18
+ def __str__(self) -> str:
19
+ return str(self.value)
20
+
21
+ __all__ = ["LayerActivations", "ActivationAggregationStrategy", "ActivationCollector", "LayerName", "LayerActivation", "ActivationMap", "RawActivationMap"]
22
+
23
+ LayerName: TypeAlias = str
24
+ LayerActivation: TypeAlias = torch.Tensor | None
25
+ ActivationMap: TypeAlias = Mapping[LayerName, LayerActivation]
26
+ RawActivationMap: TypeAlias = Mapping[LayerName, torch.Tensor | np.ndarray | None]
27
+
28
+ class _LowerSnakeStrEnum(StrEnum):
29
+ """StrEnum whose auto() values are lower_snake_case of the member name."""
30
+ def _generate_next_value_(name, start, count, last_values): # type: ignore
31
+ return name.lower()
32
+
33
+ @unique
34
+ class ActivationAggregationStrategy(_LowerSnakeStrEnum):
35
+ """Strategies for selecting/aggregating tokens in activation extraction.
36
+ """
37
+
38
+ CHOICE_TOKEN = auto() # target A/B choice tokens (multiple choice)
39
+ CONTINUATION_TOKEN = auto() # first token of the continuation
40
+ LAST_TOKEN = auto() # always use the last token
41
+ FIRST_TOKEN = auto() # always use the first token
42
+ MEAN_POOLING = auto() # mean over all tokens
43
+ MAX_POOLING = auto() # max over all tokens
44
+
45
+ @property
46
+ def description(self) -> str:
47
+ return {
48
+ ActivationAggregationStrategy.CHOICE_TOKEN: "Target A/B choice tokens (multiple choice).",
49
+ ActivationAggregationStrategy.CONTINUATION_TOKEN: "Use the first token of the continuation.",
50
+ ActivationAggregationStrategy.LAST_TOKEN: "Always select the last token.",
51
+ ActivationAggregationStrategy.FIRST_TOKEN: "Always select the first token.",
52
+ ActivationAggregationStrategy.MEAN_POOLING: "Aggregate by mean over all tokens.",
53
+ ActivationAggregationStrategy.MAX_POOLING: "Aggregate by max over all tokens.",
54
+ }[self]
55
+
56
+
57
+ class LayerActivations(Mapping[LayerName, LayerActivation]):
58
+ """Immutable mapping of layer names to activations.
59
+
60
+ Behaves like: 'Mapping[str, torch.Tensor | None]'.
61
+
62
+ construction:
63
+ 'LayerActivations(data: Mapping[str, torch.Tensor | np.ndarray | None] | None, *, dtype: torch.dtype | None = None)'
64
+
65
+ - 'torch.Tensor' values are kept as-is (or cast to 'dtype' if given).
66
+ - 'np.ndarray' values are converted via 'torch.from_numpy' (then cast if needed).
67
+ - 'None' values are preserved.
68
+ - Missing/empty input yields an empty container.
69
+
70
+ atributes:
71
+ _data:
72
+ internal storage dict. It contains information about layer activations.
73
+ _strategy:
74
+ 'ActivationAggregationStrategy' (see below). Indicates how activations were aggregated if applicable.
75
+
76
+ methods:
77
+ 'summary()':
78
+ dict with per-layer shape/dtype/device/requires_grad.
79
+ 'to(*args, **kwargs)':
80
+ apply 'Tensor.to' to all non-'None' values.
81
+ 'cpu()', 'detach()':
82
+ convenience operations.
83
+ 'numpy()':
84
+ map tensors to cpu NumPy arrays (others to 'None').
85
+ 'to_dict()':
86
+ plain dict (useful for (de)serialization).
87
+
88
+ examples:
89
+ >>> acts = LayerActivations({"layer1": torch.randn(2, 10, 768), "layer2": None}, activation_aggregation_strategy="mean_pooling")
90
+ >>> acts["layer1"].shape
91
+ torch.Size([2, 10, 768])
92
+ >>> acts["layer2"] is None
93
+ True
94
+ >>> acts.activation_aggregation_strategy
95
+ <ActivationAggregationStrategy.MEAN_POOLING: 'mean_pooling'>
96
+ >>> acts.summary()
97
+ {'layer1': {'shape': (2, 10, 768), 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False}, 'layer2': {'shape': None, 'dtype': None, 'device': None, 'requires_grad': None}}
98
+ >>> acts.numpy()
99
+ {'layer1': array(...), 'layer2': None}
100
+ >>> acts.to("cuda")
101
+ LayerActivations(
102
+ layer1: Tensor(shape=(2, 10, 768), dtype=torch.float32, device=cuda:0)
103
+ layer2: None
104
+ )
105
+ >>> acts.detach() # if any tensor required grad
106
+ LayerActivations(
107
+ layer1: Tensor(shape=(2, 10, 768), dtype=torch.float32, device=cpu)
108
+ layer2: None
109
+ )
110
+
111
+ notes:
112
+ - Use 'summary()' or 'numpy()' if you need JSON-serializable content.
113
+ - Keys are strings by convention; enforced by type hints.
114
+ """
115
+ __slots__ = ("_data", "_strategy")
116
+
117
+ def __init__(self, data: RawActivationMap | None = None, activation_aggregation_strategy: ActivationAggregationStrategy | None = None, dtype: torch.dtype | None = None):
118
+ store: dict[LayerName, LayerActivation] = {}
119
+ if data:
120
+ for layer, val in data.items():
121
+ if val is None:
122
+ store[layer] = None
123
+ elif isinstance(val, torch.Tensor):
124
+ store[layer] = val if dtype is None else val.to(dtype)
125
+ elif isinstance(val, np.ndarray):
126
+ t = torch.from_numpy(val)
127
+ store[layer] = t if dtype is None else t.to(dtype)
128
+ else:
129
+ raise TypeError(
130
+ f"Activations for layer '{layer}' must be torch.Tensor, np.ndarray, or None."
131
+ )
132
+ self._data = store
133
+ self._strategy = self._normalize_strategy(activation_aggregation_strategy)
134
+
135
+ @staticmethod
136
+ def _normalize_strategy(
137
+ s: ActivationAggregationStrategy | str | None
138
+ ) -> ActivationAggregationStrategy | None:
139
+ if s is None:
140
+ return None
141
+ if isinstance(s, ActivationAggregationStrategy):
142
+ return s
143
+ if isinstance(s, str):
144
+ try:
145
+ return ActivationAggregationStrategy(s)
146
+ except ValueError:
147
+ valid = ", ".join([e.value for e in ActivationAggregationStrategy])
148
+ raise ValueError(
149
+ f"Unknown activation_agregation_strategy='{s}'. "
150
+ f"Valid options: {valid}"
151
+ )
152
+ raise TypeError(
153
+ "activation_agregation_strategy must be ActivationAggregationStrategy | str | None"
154
+ )
155
+
156
+ @property
157
+ def activation_aggregation_strategy(self) -> ActivationAggregationStrategy | None:
158
+ return self._strategy
159
+
160
+ def __getitem__(self, key: LayerName) -> LayerActivation:
161
+ return self._data[key]
162
+ def __iter__(self) -> Iterator[LayerName]:
163
+ return iter(self._data)
164
+ def __len__(self) -> int:
165
+ return len(self._data)
166
+
167
+ def summary(self) -> dict[LayerName, dict[str, tuple | str | bool | None]]:
168
+ ''' Return a summary of the activations. For each layer, provides
169
+ shape, dtype, device, requires_grad status, and aggregation strategy.
170
+ '''
171
+ out: dict[LayerName, dict[str, dict[str, tuple | str | bool | None]]] = {}
172
+ for k, v in self._data.items():
173
+ if isinstance(v, torch.Tensor):
174
+ out[k] = {
175
+ "shape": tuple(v.shape),
176
+ "dtype": str(v.dtype),
177
+ "device": str(v.device),
178
+ "requires_grad": bool(v.requires_grad),
179
+ }
180
+ else:
181
+ out[k] = {"shape": None, "dtype": None, "device": None, "requires_grad": None}
182
+
183
+ out["_activation_aggregation_strategy"] = {"strategy": self._strategy.value if self._strategy else None}
184
+ return out
185
+
186
+ def numpy(self) -> dict[LayerName, np.ndarray | None]:
187
+ return {k: (v.detach().cpu().numpy() if isinstance(v, torch.Tensor) else None)
188
+ for k, v in self._data.items()}
189
+
190
+ def to_dict(self) -> dict[LayerName, LayerActivation]:
191
+ return dict(self._data)
192
+
193
+ def to(self, *args, **kwargs) -> LayerActivations:
194
+ return LayerActivations({k: (v.to(*args, **kwargs) if isinstance(v, torch.Tensor) else None)
195
+ for k, v in self._data.items()})
196
+
197
+ def detach(self) -> LayerActivations:
198
+ return LayerActivations({k: (v.detach() if isinstance(v, torch.Tensor) else None)
199
+ for k, v in self._data.items()})
200
+
201
+ def cpu(self) -> LayerActivations:
202
+ return self.to("cpu")
203
+
204
+ def __repr__(self) -> str:
205
+ lines = ["LayerActivations("]
206
+ for k, v in self._data.items():
207
+ if isinstance(v, torch.Tensor):
208
+ lines.append(
209
+ f" {k}: Tensor(shape={tuple(v.shape)}, dtype={v.dtype}, device={v.device})"
210
+ )
211
+ else:
212
+ lines.append(f" {k}: None")
213
+ lines.append(")")
214
+ lines.append(f" _activation_aggregation_strategy: {self._strategy.value if self._strategy else None}")
215
+
216
+ return "\n".join(lines)
@@ -0,0 +1,18 @@
1
+ """
2
+ Agent module for wisent-guard autonomous systems.
3
+
4
+ This module provides:
5
+ - ResponseDiagnostics: Response analysis and quality assessment
6
+ - ResponseSteering: Response improvement and steering
7
+ - Data classes for analysis and improvement results
8
+ """
9
+
10
+ from .diagnose import ResponseDiagnostics, AnalysisResult
11
+ from .steer import ResponseSteering, ImprovementResult
12
+
13
+ __all__ = [
14
+ 'ResponseDiagnostics',
15
+ 'AnalysisResult',
16
+ 'ResponseSteering',
17
+ 'ImprovementResult'
18
+ ]