wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Activation pre-generation module for efficient Optuna-based classifier optimization.
|
|
3
|
+
|
|
4
|
+
This module generates activations once and stores them for reuse across all Optuna trials,
|
|
5
|
+
significantly improving optimization performance by avoiding redundant activation extraction.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import hashlib
|
|
9
|
+
import logging
|
|
10
|
+
import pickle
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from wisent.core.activations.activation_collection_method import ActivationCollectionLogic
|
|
19
|
+
from wisent.core.activations.core import ActivationAggregationStrategy, Activations
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class ActivationData:
|
|
26
|
+
"""Container for pre-generated activation data with Activations wrapper integration."""
|
|
27
|
+
|
|
28
|
+
activations: torch.Tensor
|
|
29
|
+
labels: torch.Tensor
|
|
30
|
+
layer: int
|
|
31
|
+
aggregation: ActivationAggregationStrategy
|
|
32
|
+
metadata: dict[str, Any]
|
|
33
|
+
|
|
34
|
+
def to_numpy(self) -> tuple[np.ndarray, np.ndarray]:
|
|
35
|
+
"""Convert to numpy arrays for sklearn compatibility."""
|
|
36
|
+
X = self.activations.detach().cpu().numpy()
|
|
37
|
+
y = self.labels.detach().cpu().numpy()
|
|
38
|
+
return X, y
|
|
39
|
+
|
|
40
|
+
def to_tensors(self, device: str = None, dtype: torch.dtype = None) -> tuple[torch.Tensor, torch.Tensor]:
|
|
41
|
+
"""Return tensors directly for PyTorch classifiers."""
|
|
42
|
+
# Use specified dtype, or preserve original dtype if not specified
|
|
43
|
+
target_dtype = dtype if dtype is not None else self.activations.dtype
|
|
44
|
+
|
|
45
|
+
if device:
|
|
46
|
+
X = self.activations.to(device=device, dtype=target_dtype)
|
|
47
|
+
y = self.labels.to(device=device, dtype=target_dtype)
|
|
48
|
+
else:
|
|
49
|
+
X = self.activations.to(dtype=target_dtype)
|
|
50
|
+
y = self.labels.to(dtype=target_dtype)
|
|
51
|
+
return X, y
|
|
52
|
+
|
|
53
|
+
def to_activations_objects(self) -> list[Activations]:
|
|
54
|
+
"""
|
|
55
|
+
Convert stored activations to Activations objects for better abstraction.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
List of Activations objects, one per sample
|
|
59
|
+
"""
|
|
60
|
+
activations_list = []
|
|
61
|
+
|
|
62
|
+
# Create Activations object for each sample using enum directly (no conversion needed!)
|
|
63
|
+
for i in range(self.activations.shape[0]):
|
|
64
|
+
sample_tensor = self.activations[i : i + 1] # Keep batch dimension
|
|
65
|
+
activation_obj = Activations(
|
|
66
|
+
tensor=sample_tensor,
|
|
67
|
+
layer=self.layer,
|
|
68
|
+
aggregation_strategy=self.aggregation, # Direct enum usage
|
|
69
|
+
)
|
|
70
|
+
activations_list.append(activation_obj)
|
|
71
|
+
|
|
72
|
+
return activations_list
|
|
73
|
+
|
|
74
|
+
def get_statistics(self) -> dict[str, Any]:
|
|
75
|
+
"""Get statistics about the activation data using Activations primitives."""
|
|
76
|
+
# Create a representative Activations object for statistics
|
|
77
|
+
sample_activation = Activations(
|
|
78
|
+
tensor=self.activations[:1], # Use first sample
|
|
79
|
+
layer=self.layer,
|
|
80
|
+
aggregation_strategy=self.aggregation, # Direct enum usage
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Get core statistics and add our metadata
|
|
84
|
+
stats = sample_activation.get_statistics()
|
|
85
|
+
stats.update(
|
|
86
|
+
{
|
|
87
|
+
"n_samples": self.activations.shape[0],
|
|
88
|
+
"n_positive": self.metadata.get("n_positive", "unknown"),
|
|
89
|
+
"n_negative": self.metadata.get("n_negative", "unknown"),
|
|
90
|
+
"aggregation_method": self.aggregation.value, # Display value for readability
|
|
91
|
+
"layer": self.layer,
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return stats
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass
|
|
99
|
+
class GenerationConfig:
|
|
100
|
+
"""Configuration for activation generation."""
|
|
101
|
+
|
|
102
|
+
layer_search_range: tuple[int, int]
|
|
103
|
+
aggregation_methods: Optional[list[ActivationAggregationStrategy]] = None
|
|
104
|
+
cache_dir: Optional[str] = None
|
|
105
|
+
device: Optional[str] = None
|
|
106
|
+
dtype: Optional[torch.dtype] = None # Auto-detect if None
|
|
107
|
+
batch_size: int = 32
|
|
108
|
+
|
|
109
|
+
def __post_init__(self):
|
|
110
|
+
if self.cache_dir is None:
|
|
111
|
+
self.cache_dir = "./activation_cache"
|
|
112
|
+
if not self.aggregation_methods:
|
|
113
|
+
self.aggregation_methods = [
|
|
114
|
+
ActivationAggregationStrategy.MEAN_POOLING,
|
|
115
|
+
ActivationAggregationStrategy.LAST_TOKEN,
|
|
116
|
+
ActivationAggregationStrategy.FIRST_TOKEN,
|
|
117
|
+
ActivationAggregationStrategy.MAX_POOLING,
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ActivationGenerator:
|
|
122
|
+
"""
|
|
123
|
+
Generates and caches activations for efficient classifier optimization.
|
|
124
|
+
|
|
125
|
+
Key features:
|
|
126
|
+
- Pre-generates activations once for all layers and aggregation methods
|
|
127
|
+
- Caches results to disk for reuse across optimization runs
|
|
128
|
+
- Memory-efficient batch processing
|
|
129
|
+
- Supports both contrastive pairs and labeled datasets
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
def __init__(self, config: GenerationConfig):
|
|
133
|
+
self.config = config
|
|
134
|
+
self.cache_dir = Path(config.cache_dir)
|
|
135
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
137
|
+
|
|
138
|
+
def generate_from_contrastive_pairs(
|
|
139
|
+
self, model, contrastive_pairs: list, task_name: str, model_name: str, limit: int
|
|
140
|
+
) -> dict[str, ActivationData]:
|
|
141
|
+
"""
|
|
142
|
+
Generate activations from contrastive pairs.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
model: Language model
|
|
146
|
+
contrastive_pairs: List of contrastive pairs
|
|
147
|
+
task_name: Name of the task
|
|
148
|
+
model_name: Name of the model
|
|
149
|
+
limit: Data limit used
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Dict mapping (layer, aggregation) keys to ActivationData
|
|
153
|
+
"""
|
|
154
|
+
# Create cache key
|
|
155
|
+
cache_key = self._create_cache_key(model_name, task_name, limit, "contrastive")
|
|
156
|
+
|
|
157
|
+
# Try to load from cache
|
|
158
|
+
cached_data = self._load_from_cache(cache_key)
|
|
159
|
+
if cached_data is not None:
|
|
160
|
+
self.logger.info(f"Loaded pre-generated activations from cache: {cache_key}")
|
|
161
|
+
return cached_data
|
|
162
|
+
|
|
163
|
+
self.logger.info(f"Generating activations for {len(contrastive_pairs)} contrastive pairs")
|
|
164
|
+
|
|
165
|
+
# Initialize activation collector
|
|
166
|
+
collector = ActivationCollectionLogic(model=model)
|
|
167
|
+
activation_data = {}
|
|
168
|
+
|
|
169
|
+
for layer in range(self.config.layer_search_range[0], self.config.layer_search_range[1] + 1):
|
|
170
|
+
self.logger.info(f"Processing layer {layer}")
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
# Extract activations for this layer
|
|
174
|
+
processed_pairs = collector.collect_activations_batch(
|
|
175
|
+
pairs=contrastive_pairs, layer_index=layer, device=self.config.device
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Convert to tensor format
|
|
179
|
+
positive_activations = []
|
|
180
|
+
negative_activations = []
|
|
181
|
+
|
|
182
|
+
for pair in processed_pairs:
|
|
183
|
+
if hasattr(pair, "positive_activations") and pair.positive_activations is not None:
|
|
184
|
+
positive_activations.append(pair.positive_activations.detach().cpu())
|
|
185
|
+
if hasattr(pair, "negative_activations") and pair.negative_activations is not None:
|
|
186
|
+
negative_activations.append(pair.negative_activations.detach().cpu())
|
|
187
|
+
|
|
188
|
+
if not positive_activations or not negative_activations:
|
|
189
|
+
self.logger.warning(f"Insufficient activations for layer {layer}")
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
# Stack activations
|
|
193
|
+
pos_stack = torch.stack(positive_activations) # [n_samples, hidden_dim]
|
|
194
|
+
neg_stack = torch.stack(negative_activations) # [n_samples, hidden_dim]
|
|
195
|
+
|
|
196
|
+
# Apply aggregation methods using core Activations primitives (batch-optimized)
|
|
197
|
+
for aggregation in self.config.aggregation_methods:
|
|
198
|
+
try:
|
|
199
|
+
# Apply batch aggregation efficiently using core strategy logic
|
|
200
|
+
pos_aggregated = self._apply_batch_aggregation(pos_stack, aggregation)
|
|
201
|
+
neg_aggregated = self._apply_batch_aggregation(neg_stack, aggregation)
|
|
202
|
+
|
|
203
|
+
# Combine positive (label=0) and negative (label=1)
|
|
204
|
+
X = torch.cat([pos_aggregated, neg_aggregated], dim=0)
|
|
205
|
+
y = torch.cat([torch.zeros(len(pos_aggregated)), torch.ones(len(neg_aggregated))], dim=0)
|
|
206
|
+
|
|
207
|
+
# Create activation data
|
|
208
|
+
key = f"layer_{layer}_agg_{aggregation.value}"
|
|
209
|
+
activation_data[key] = ActivationData(
|
|
210
|
+
activations=X,
|
|
211
|
+
labels=y,
|
|
212
|
+
layer=layer,
|
|
213
|
+
aggregation=aggregation,
|
|
214
|
+
metadata={
|
|
215
|
+
"task_name": task_name,
|
|
216
|
+
"model_name": model_name,
|
|
217
|
+
"n_positive": len(pos_aggregated),
|
|
218
|
+
"n_negative": len(neg_aggregated),
|
|
219
|
+
"feature_dim": X.shape[1] if len(X.shape) > 1 else X.shape[0],
|
|
220
|
+
},
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
self.logger.debug(f"Layer {layer}, aggregation {aggregation.value}: {X.shape[0]} samples")
|
|
224
|
+
|
|
225
|
+
except Exception as e:
|
|
226
|
+
self.logger.warning(f"Failed to apply aggregation {aggregation.value} for layer {layer}: {e}")
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
self.logger.warning(f"Failed to process layer {layer}: {e}")
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
# Cache the results
|
|
234
|
+
self._save_to_cache(cache_key, activation_data)
|
|
235
|
+
|
|
236
|
+
self.logger.info(f"Generated activations for {len(activation_data)} layer-aggregation combinations")
|
|
237
|
+
return activation_data
|
|
238
|
+
|
|
239
|
+
def _apply_batch_aggregation(
|
|
240
|
+
self, activations: torch.Tensor, strategy: ActivationAggregationStrategy
|
|
241
|
+
) -> torch.Tensor:
|
|
242
|
+
"""
|
|
243
|
+
Apply aggregation strategy to a batch of activations efficiently.
|
|
244
|
+
|
|
245
|
+
Uses the same logic as core Activations primitives but optimized for batch processing.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
activations: Tensor of shape [n_samples, ...] or [n_samples, n_tokens, hidden_dim]
|
|
249
|
+
strategy: Aggregation strategy from core primitives
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Aggregated activations of shape [n_samples, hidden_dim]
|
|
253
|
+
"""
|
|
254
|
+
if len(activations.shape) == 2:
|
|
255
|
+
# Already aggregated at token level, return as-is
|
|
256
|
+
return activations
|
|
257
|
+
if len(activations.shape) == 3:
|
|
258
|
+
# [n_samples, n_tokens, hidden_dim] -> [n_samples, hidden_dim]
|
|
259
|
+
if strategy == ActivationAggregationStrategy.MEAN_POOLING:
|
|
260
|
+
return torch.mean(activations, dim=1)
|
|
261
|
+
if strategy == ActivationAggregationStrategy.LAST_TOKEN:
|
|
262
|
+
return activations[:, -1, :]
|
|
263
|
+
if strategy == ActivationAggregationStrategy.FIRST_TOKEN:
|
|
264
|
+
return activations[:, 0, :]
|
|
265
|
+
if strategy == ActivationAggregationStrategy.MAX_POOLING:
|
|
266
|
+
return torch.max(activations, dim=1)[0]
|
|
267
|
+
# Default to mean pooling
|
|
268
|
+
self.logger.warning(f"Unknown aggregation strategy {strategy}, using mean pooling")
|
|
269
|
+
return torch.mean(activations, dim=1)
|
|
270
|
+
# Flatten to [n_samples, -1] for other shapes
|
|
271
|
+
return activations.view(activations.shape[0], -1)
|
|
272
|
+
|
|
273
|
+
def _create_cache_key(self, model_name: str, task_name: str, limit: int, data_type: str) -> str:
|
|
274
|
+
"""Create a unique cache key for the given parameters."""
|
|
275
|
+
key_components = [
|
|
276
|
+
model_name.replace("/", "_"),
|
|
277
|
+
task_name,
|
|
278
|
+
str(limit),
|
|
279
|
+
data_type,
|
|
280
|
+
f"{self.config.layer_search_range[0]}-{self.config.layer_search_range[1]}",
|
|
281
|
+
str(sorted([agg.value for agg in self.config.aggregation_methods])),
|
|
282
|
+
]
|
|
283
|
+
key_string = "_".join(key_components)
|
|
284
|
+
return hashlib.md5(key_string.encode()).hexdigest()
|
|
285
|
+
|
|
286
|
+
def _load_from_cache(self, cache_key: str) -> Optional[dict[str, ActivationData]]:
|
|
287
|
+
"""Load activation data from cache."""
|
|
288
|
+
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
|
289
|
+
|
|
290
|
+
if not cache_file.exists():
|
|
291
|
+
return None
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
with open(cache_file, "rb") as f:
|
|
295
|
+
data = pickle.load(f)
|
|
296
|
+
|
|
297
|
+
self.logger.debug(f"Loaded {len(data)} activation datasets from cache")
|
|
298
|
+
return data
|
|
299
|
+
|
|
300
|
+
except Exception as e:
|
|
301
|
+
self.logger.warning(f"Failed to load cache file {cache_file}: {e}")
|
|
302
|
+
return None
|
|
303
|
+
|
|
304
|
+
def _save_to_cache(self, cache_key: str, data: dict[str, ActivationData]) -> None:
|
|
305
|
+
"""Save activation data to cache."""
|
|
306
|
+
cache_file = self.cache_dir / f"{cache_key}.pkl"
|
|
307
|
+
|
|
308
|
+
try:
|
|
309
|
+
with open(cache_file, "wb") as f:
|
|
310
|
+
pickle.dump(data, f)
|
|
311
|
+
|
|
312
|
+
self.logger.info(f"Saved {len(data)} activation datasets to cache: {cache_file}")
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
self.logger.error(f"Failed to save cache file {cache_file}: {e}")
|
|
316
|
+
|
|
317
|
+
def clear_cache(self) -> None:
|
|
318
|
+
"""Clear all cached activation data."""
|
|
319
|
+
cache_files = list(self.cache_dir.glob("*.pkl"))
|
|
320
|
+
for cache_file in cache_files:
|
|
321
|
+
try:
|
|
322
|
+
cache_file.unlink()
|
|
323
|
+
self.logger.info(f"Removed cache file: {cache_file}")
|
|
324
|
+
except Exception as e:
|
|
325
|
+
self.logger.warning(f"Failed to remove cache file {cache_file}: {e}")
|
|
326
|
+
|
|
327
|
+
self.logger.info(f"Cleared {len(cache_files)} cache files")
|
|
328
|
+
|
|
329
|
+
def get_cache_info(self) -> dict[str, Any]:
|
|
330
|
+
"""Get information about cached data."""
|
|
331
|
+
cache_files = list(self.cache_dir.glob("*.pkl"))
|
|
332
|
+
|
|
333
|
+
info = {
|
|
334
|
+
"cache_dir": str(self.cache_dir),
|
|
335
|
+
"total_files": len(cache_files),
|
|
336
|
+
"total_size_mb": sum(f.stat().st_size for f in cache_files) / (1024 * 1024),
|
|
337
|
+
"files": [],
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
for cache_file in cache_files:
|
|
341
|
+
try:
|
|
342
|
+
size_mb = cache_file.stat().st_size / (1024 * 1024)
|
|
343
|
+
info["files"].append(
|
|
344
|
+
{"name": cache_file.name, "size_mb": size_mb, "modified": cache_file.stat().st_mtime}
|
|
345
|
+
)
|
|
346
|
+
except Exception as e:
|
|
347
|
+
self.logger.warning(f"Failed to get info for {cache_file}: {e}")
|
|
348
|
+
|
|
349
|
+
return info
|