wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,349 @@
1
+ """
2
+ Activation pre-generation module for efficient Optuna-based classifier optimization.
3
+
4
+ This module generates activations once and stores them for reuse across all Optuna trials,
5
+ significantly improving optimization performance by avoiding redundant activation extraction.
6
+ """
7
+
8
+ import hashlib
9
+ import logging
10
+ import pickle
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Any, Optional
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from wisent.core.activations.activation_collection_method import ActivationCollectionLogic
19
+ from wisent.core.activations.core import ActivationAggregationStrategy, Activations
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class ActivationData:
26
+ """Container for pre-generated activation data with Activations wrapper integration."""
27
+
28
+ activations: torch.Tensor
29
+ labels: torch.Tensor
30
+ layer: int
31
+ aggregation: ActivationAggregationStrategy
32
+ metadata: dict[str, Any]
33
+
34
+ def to_numpy(self) -> tuple[np.ndarray, np.ndarray]:
35
+ """Convert to numpy arrays for sklearn compatibility."""
36
+ X = self.activations.detach().cpu().numpy()
37
+ y = self.labels.detach().cpu().numpy()
38
+ return X, y
39
+
40
+ def to_tensors(self, device: str = None, dtype: torch.dtype = None) -> tuple[torch.Tensor, torch.Tensor]:
41
+ """Return tensors directly for PyTorch classifiers."""
42
+ # Use specified dtype, or preserve original dtype if not specified
43
+ target_dtype = dtype if dtype is not None else self.activations.dtype
44
+
45
+ if device:
46
+ X = self.activations.to(device=device, dtype=target_dtype)
47
+ y = self.labels.to(device=device, dtype=target_dtype)
48
+ else:
49
+ X = self.activations.to(dtype=target_dtype)
50
+ y = self.labels.to(dtype=target_dtype)
51
+ return X, y
52
+
53
+ def to_activations_objects(self) -> list[Activations]:
54
+ """
55
+ Convert stored activations to Activations objects for better abstraction.
56
+
57
+ Returns:
58
+ List of Activations objects, one per sample
59
+ """
60
+ activations_list = []
61
+
62
+ # Create Activations object for each sample using enum directly (no conversion needed!)
63
+ for i in range(self.activations.shape[0]):
64
+ sample_tensor = self.activations[i : i + 1] # Keep batch dimension
65
+ activation_obj = Activations(
66
+ tensor=sample_tensor,
67
+ layer=self.layer,
68
+ aggregation_strategy=self.aggregation, # Direct enum usage
69
+ )
70
+ activations_list.append(activation_obj)
71
+
72
+ return activations_list
73
+
74
+ def get_statistics(self) -> dict[str, Any]:
75
+ """Get statistics about the activation data using Activations primitives."""
76
+ # Create a representative Activations object for statistics
77
+ sample_activation = Activations(
78
+ tensor=self.activations[:1], # Use first sample
79
+ layer=self.layer,
80
+ aggregation_strategy=self.aggregation, # Direct enum usage
81
+ )
82
+
83
+ # Get core statistics and add our metadata
84
+ stats = sample_activation.get_statistics()
85
+ stats.update(
86
+ {
87
+ "n_samples": self.activations.shape[0],
88
+ "n_positive": self.metadata.get("n_positive", "unknown"),
89
+ "n_negative": self.metadata.get("n_negative", "unknown"),
90
+ "aggregation_method": self.aggregation.value, # Display value for readability
91
+ "layer": self.layer,
92
+ }
93
+ )
94
+
95
+ return stats
96
+
97
+
98
+ @dataclass
99
+ class GenerationConfig:
100
+ """Configuration for activation generation."""
101
+
102
+ layer_search_range: tuple[int, int]
103
+ aggregation_methods: Optional[list[ActivationAggregationStrategy]] = None
104
+ cache_dir: Optional[str] = None
105
+ device: Optional[str] = None
106
+ dtype: Optional[torch.dtype] = None # Auto-detect if None
107
+ batch_size: int = 32
108
+
109
+ def __post_init__(self):
110
+ if self.cache_dir is None:
111
+ self.cache_dir = "./activation_cache"
112
+ if not self.aggregation_methods:
113
+ self.aggregation_methods = [
114
+ ActivationAggregationStrategy.MEAN_POOLING,
115
+ ActivationAggregationStrategy.LAST_TOKEN,
116
+ ActivationAggregationStrategy.FIRST_TOKEN,
117
+ ActivationAggregationStrategy.MAX_POOLING,
118
+ ]
119
+
120
+
121
+ class ActivationGenerator:
122
+ """
123
+ Generates and caches activations for efficient classifier optimization.
124
+
125
+ Key features:
126
+ - Pre-generates activations once for all layers and aggregation methods
127
+ - Caches results to disk for reuse across optimization runs
128
+ - Memory-efficient batch processing
129
+ - Supports both contrastive pairs and labeled datasets
130
+ """
131
+
132
+ def __init__(self, config: GenerationConfig):
133
+ self.config = config
134
+ self.cache_dir = Path(config.cache_dir)
135
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
136
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
137
+
138
+ def generate_from_contrastive_pairs(
139
+ self, model, contrastive_pairs: list, task_name: str, model_name: str, limit: int
140
+ ) -> dict[str, ActivationData]:
141
+ """
142
+ Generate activations from contrastive pairs.
143
+
144
+ Args:
145
+ model: Language model
146
+ contrastive_pairs: List of contrastive pairs
147
+ task_name: Name of the task
148
+ model_name: Name of the model
149
+ limit: Data limit used
150
+
151
+ Returns:
152
+ Dict mapping (layer, aggregation) keys to ActivationData
153
+ """
154
+ # Create cache key
155
+ cache_key = self._create_cache_key(model_name, task_name, limit, "contrastive")
156
+
157
+ # Try to load from cache
158
+ cached_data = self._load_from_cache(cache_key)
159
+ if cached_data is not None:
160
+ self.logger.info(f"Loaded pre-generated activations from cache: {cache_key}")
161
+ return cached_data
162
+
163
+ self.logger.info(f"Generating activations for {len(contrastive_pairs)} contrastive pairs")
164
+
165
+ # Initialize activation collector
166
+ collector = ActivationCollectionLogic(model=model)
167
+ activation_data = {}
168
+
169
+ for layer in range(self.config.layer_search_range[0], self.config.layer_search_range[1] + 1):
170
+ self.logger.info(f"Processing layer {layer}")
171
+
172
+ try:
173
+ # Extract activations for this layer
174
+ processed_pairs = collector.collect_activations_batch(
175
+ pairs=contrastive_pairs, layer_index=layer, device=self.config.device
176
+ )
177
+
178
+ # Convert to tensor format
179
+ positive_activations = []
180
+ negative_activations = []
181
+
182
+ for pair in processed_pairs:
183
+ if hasattr(pair, "positive_activations") and pair.positive_activations is not None:
184
+ positive_activations.append(pair.positive_activations.detach().cpu())
185
+ if hasattr(pair, "negative_activations") and pair.negative_activations is not None:
186
+ negative_activations.append(pair.negative_activations.detach().cpu())
187
+
188
+ if not positive_activations or not negative_activations:
189
+ self.logger.warning(f"Insufficient activations for layer {layer}")
190
+ continue
191
+
192
+ # Stack activations
193
+ pos_stack = torch.stack(positive_activations) # [n_samples, hidden_dim]
194
+ neg_stack = torch.stack(negative_activations) # [n_samples, hidden_dim]
195
+
196
+ # Apply aggregation methods using core Activations primitives (batch-optimized)
197
+ for aggregation in self.config.aggregation_methods:
198
+ try:
199
+ # Apply batch aggregation efficiently using core strategy logic
200
+ pos_aggregated = self._apply_batch_aggregation(pos_stack, aggregation)
201
+ neg_aggregated = self._apply_batch_aggregation(neg_stack, aggregation)
202
+
203
+ # Combine positive (label=0) and negative (label=1)
204
+ X = torch.cat([pos_aggregated, neg_aggregated], dim=0)
205
+ y = torch.cat([torch.zeros(len(pos_aggregated)), torch.ones(len(neg_aggregated))], dim=0)
206
+
207
+ # Create activation data
208
+ key = f"layer_{layer}_agg_{aggregation.value}"
209
+ activation_data[key] = ActivationData(
210
+ activations=X,
211
+ labels=y,
212
+ layer=layer,
213
+ aggregation=aggregation,
214
+ metadata={
215
+ "task_name": task_name,
216
+ "model_name": model_name,
217
+ "n_positive": len(pos_aggregated),
218
+ "n_negative": len(neg_aggregated),
219
+ "feature_dim": X.shape[1] if len(X.shape) > 1 else X.shape[0],
220
+ },
221
+ )
222
+
223
+ self.logger.debug(f"Layer {layer}, aggregation {aggregation.value}: {X.shape[0]} samples")
224
+
225
+ except Exception as e:
226
+ self.logger.warning(f"Failed to apply aggregation {aggregation.value} for layer {layer}: {e}")
227
+ continue
228
+
229
+ except Exception as e:
230
+ self.logger.warning(f"Failed to process layer {layer}: {e}")
231
+ continue
232
+
233
+ # Cache the results
234
+ self._save_to_cache(cache_key, activation_data)
235
+
236
+ self.logger.info(f"Generated activations for {len(activation_data)} layer-aggregation combinations")
237
+ return activation_data
238
+
239
+ def _apply_batch_aggregation(
240
+ self, activations: torch.Tensor, strategy: ActivationAggregationStrategy
241
+ ) -> torch.Tensor:
242
+ """
243
+ Apply aggregation strategy to a batch of activations efficiently.
244
+
245
+ Uses the same logic as core Activations primitives but optimized for batch processing.
246
+
247
+ Args:
248
+ activations: Tensor of shape [n_samples, ...] or [n_samples, n_tokens, hidden_dim]
249
+ strategy: Aggregation strategy from core primitives
250
+
251
+ Returns:
252
+ Aggregated activations of shape [n_samples, hidden_dim]
253
+ """
254
+ if len(activations.shape) == 2:
255
+ # Already aggregated at token level, return as-is
256
+ return activations
257
+ if len(activations.shape) == 3:
258
+ # [n_samples, n_tokens, hidden_dim] -> [n_samples, hidden_dim]
259
+ if strategy == ActivationAggregationStrategy.MEAN_POOLING:
260
+ return torch.mean(activations, dim=1)
261
+ if strategy == ActivationAggregationStrategy.LAST_TOKEN:
262
+ return activations[:, -1, :]
263
+ if strategy == ActivationAggregationStrategy.FIRST_TOKEN:
264
+ return activations[:, 0, :]
265
+ if strategy == ActivationAggregationStrategy.MAX_POOLING:
266
+ return torch.max(activations, dim=1)[0]
267
+ # Default to mean pooling
268
+ self.logger.warning(f"Unknown aggregation strategy {strategy}, using mean pooling")
269
+ return torch.mean(activations, dim=1)
270
+ # Flatten to [n_samples, -1] for other shapes
271
+ return activations.view(activations.shape[0], -1)
272
+
273
+ def _create_cache_key(self, model_name: str, task_name: str, limit: int, data_type: str) -> str:
274
+ """Create a unique cache key for the given parameters."""
275
+ key_components = [
276
+ model_name.replace("/", "_"),
277
+ task_name,
278
+ str(limit),
279
+ data_type,
280
+ f"{self.config.layer_search_range[0]}-{self.config.layer_search_range[1]}",
281
+ str(sorted([agg.value for agg in self.config.aggregation_methods])),
282
+ ]
283
+ key_string = "_".join(key_components)
284
+ return hashlib.md5(key_string.encode()).hexdigest()
285
+
286
+ def _load_from_cache(self, cache_key: str) -> Optional[dict[str, ActivationData]]:
287
+ """Load activation data from cache."""
288
+ cache_file = self.cache_dir / f"{cache_key}.pkl"
289
+
290
+ if not cache_file.exists():
291
+ return None
292
+
293
+ try:
294
+ with open(cache_file, "rb") as f:
295
+ data = pickle.load(f)
296
+
297
+ self.logger.debug(f"Loaded {len(data)} activation datasets from cache")
298
+ return data
299
+
300
+ except Exception as e:
301
+ self.logger.warning(f"Failed to load cache file {cache_file}: {e}")
302
+ return None
303
+
304
+ def _save_to_cache(self, cache_key: str, data: dict[str, ActivationData]) -> None:
305
+ """Save activation data to cache."""
306
+ cache_file = self.cache_dir / f"{cache_key}.pkl"
307
+
308
+ try:
309
+ with open(cache_file, "wb") as f:
310
+ pickle.dump(data, f)
311
+
312
+ self.logger.info(f"Saved {len(data)} activation datasets to cache: {cache_file}")
313
+
314
+ except Exception as e:
315
+ self.logger.error(f"Failed to save cache file {cache_file}: {e}")
316
+
317
+ def clear_cache(self) -> None:
318
+ """Clear all cached activation data."""
319
+ cache_files = list(self.cache_dir.glob("*.pkl"))
320
+ for cache_file in cache_files:
321
+ try:
322
+ cache_file.unlink()
323
+ self.logger.info(f"Removed cache file: {cache_file}")
324
+ except Exception as e:
325
+ self.logger.warning(f"Failed to remove cache file {cache_file}: {e}")
326
+
327
+ self.logger.info(f"Cleared {len(cache_files)} cache files")
328
+
329
+ def get_cache_info(self) -> dict[str, Any]:
330
+ """Get information about cached data."""
331
+ cache_files = list(self.cache_dir.glob("*.pkl"))
332
+
333
+ info = {
334
+ "cache_dir": str(self.cache_dir),
335
+ "total_files": len(cache_files),
336
+ "total_size_mb": sum(f.stat().st_size for f in cache_files) / (1024 * 1024),
337
+ "files": [],
338
+ }
339
+
340
+ for cache_file in cache_files:
341
+ try:
342
+ size_mb = cache_file.stat().st_size / (1024 * 1024)
343
+ info["files"].append(
344
+ {"name": cache_file.name, "size_mb": size_mb, "modified": cache_file.stat().st_mtime}
345
+ )
346
+ except Exception as e:
347
+ self.logger.warning(f"Failed to get info for {cache_file}: {e}")
348
+
349
+ return info