wisent 0.5.14__py3-none-any.whl → 0.5.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/cli.py +114 -0
- wisent/core/activations/activations_collector.py +19 -11
- wisent/core/cli/__init__.py +3 -1
- wisent/core/cli/create_steering_vector.py +60 -18
- wisent/core/cli/evaluate_responses.py +14 -8
- wisent/core/cli/generate_pairs_from_task.py +18 -5
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/cli/multi_steer.py +108 -0
- wisent/core/cli/optimize_classification.py +187 -285
- wisent/core/cli/optimize_sample_size.py +78 -0
- wisent/core/cli/optimize_steering.py +354 -53
- wisent/core/cli/tasks.py +274 -9
- wisent/core/errors/__init__.py +0 -0
- wisent/core/errors/error_handler.py +134 -0
- wisent/core/evaluators/benchmark_specific/log_likelihoods_evaluator.py +152 -295
- wisent/core/evaluators/rotator.py +22 -8
- wisent/core/main.py +5 -1
- wisent/core/model_persistence.py +4 -19
- wisent/core/models/wisent_model.py +11 -3
- wisent/core/parser.py +4 -3
- wisent/core/parser_arguments/main_parser.py +1 -1
- wisent/core/parser_arguments/multi_steer_parser.py +4 -3
- wisent/core/parser_arguments/optimize_steering_parser.py +4 -0
- wisent/core/sample_size_optimizer_v2.py +1 -1
- wisent/core/steering_optimizer.py +2 -2
- wisent/tests/__init__.py +0 -0
- wisent/tests/examples/__init__.py +0 -0
- wisent/tests/examples/cli/__init__.py +0 -0
- wisent/tests/examples/cli/activations/__init__.py +0 -0
- wisent/tests/examples/cli/activations/test_get_activations.py +127 -0
- wisent/tests/examples/cli/classifier/__init__.py +0 -0
- wisent/tests/examples/cli/classifier/test_classifier_examples.py +141 -0
- wisent/tests/examples/cli/contrastive_pairs/__init__.py +0 -0
- wisent/tests/examples/cli/contrastive_pairs/test_generate_pairs.py +89 -0
- wisent/tests/examples/cli/evaluation/__init__.py +0 -0
- wisent/tests/examples/cli/evaluation/test_evaluation_examples.py +117 -0
- wisent/tests/examples/cli/generate/__init__.py +0 -0
- wisent/tests/examples/cli/generate/test_generate_with_classifier.py +146 -0
- wisent/tests/examples/cli/generate/test_generate_with_steering.py +149 -0
- wisent/tests/examples/cli/generate/test_only_generate.py +110 -0
- wisent/tests/examples/cli/multi_steering/__init__.py +0 -0
- wisent/tests/examples/cli/multi_steering/test_multi_steer_from_trained_vectors.py +210 -0
- wisent/tests/examples/cli/multi_steering/test_multi_steer_with_different_parameters.py +205 -0
- wisent/tests/examples/cli/multi_steering/test_train_and_multi_steer.py +174 -0
- wisent/tests/examples/cli/optimizer/__init__.py +0 -0
- wisent/tests/examples/cli/optimizer/test_optimize_sample_size.py +102 -0
- wisent/tests/examples/cli/optimizer/test_optimizer_examples.py +59 -0
- wisent/tests/examples/cli/steering/__init__.py +0 -0
- wisent/tests/examples/cli/steering/test_create_steering_vectors.py +135 -0
- wisent/tests/examples/cli/synthetic/__init__.py +0 -0
- wisent/tests/examples/cli/synthetic/test_synthetic_pairs.py +45 -0
- {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/METADATA +3 -1
- {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/RECORD +59 -29
- wisent/core/agent/diagnose/test_synthetic_classifier.py +0 -71
- /wisent/core/parser_arguments/{test_nonsense_parser.py → nonsense_parser.py} +0 -0
- {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/WHEEL +0 -0
- {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/entry_points.txt +0 -0
- {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/top_level.txt +0 -0
|
@@ -1,339 +1,241 @@
|
|
|
1
|
-
"""Classification optimization command
|
|
1
|
+
"""Classification optimization command - uses native Wisent methods.
|
|
2
|
+
|
|
3
|
+
This optimizer tests different configurations (layer, aggregation, threshold)
|
|
4
|
+
by calling the native execute_tasks() function for each configuration,
|
|
5
|
+
then evaluates using Wisent's native evaluators.
|
|
6
|
+
"""
|
|
2
7
|
|
|
3
8
|
import sys
|
|
4
9
|
import json
|
|
5
10
|
import time
|
|
6
11
|
from typing import List, Dict, Any
|
|
12
|
+
import os
|
|
13
|
+
|
|
7
14
|
|
|
8
15
|
def execute_optimize_classification(args):
|
|
9
16
|
"""
|
|
10
|
-
Execute
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
-
|
|
14
|
-
-
|
|
15
|
-
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
EFFICIENCY: Collects raw activations ONCE, then applies different aggregation strategies
|
|
19
|
-
to the cached activations without re-running the model.
|
|
17
|
+
Execute classification optimization using native Wisent methods.
|
|
18
|
+
|
|
19
|
+
Tests different configurations by calling execute_tasks() for each combination:
|
|
20
|
+
- Different layers
|
|
21
|
+
- Different aggregation methods
|
|
22
|
+
- Different detection thresholds
|
|
23
|
+
|
|
24
|
+
Uses native Wisent evaluation (not sklearn metrics).
|
|
20
25
|
"""
|
|
21
|
-
from wisent.core.
|
|
22
|
-
from
|
|
23
|
-
|
|
24
|
-
from wisent.core.activations.core.atoms import ActivationAggregationStrategy
|
|
25
|
-
from wisent.core.classifiers.classifiers.models.logistic import LogisticClassifier
|
|
26
|
-
from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainConfig
|
|
27
|
-
import numpy as np
|
|
28
|
-
import torch
|
|
29
|
-
|
|
26
|
+
from wisent.core.cli.tasks import execute_tasks
|
|
27
|
+
from types import SimpleNamespace
|
|
28
|
+
|
|
30
29
|
print(f"\n{'='*80}")
|
|
31
30
|
print(f"🔍 CLASSIFICATION PARAMETER OPTIMIZATION")
|
|
32
31
|
print(f"{'='*80}")
|
|
33
32
|
print(f" Model: {args.model}")
|
|
34
33
|
print(f" Limit per task: {args.limit}")
|
|
35
|
-
print(f" Optimization metric: {args.optimization_metric}")
|
|
36
34
|
print(f" Device: {args.device or 'auto'}")
|
|
37
35
|
print(f"{'='*80}\n")
|
|
38
|
-
|
|
39
|
-
# 1.
|
|
40
|
-
|
|
36
|
+
|
|
37
|
+
# 1. Determine layer range
|
|
38
|
+
# First need to load model to get num_layers
|
|
39
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
40
|
+
print(f"📦 Loading model to determine layer range...")
|
|
41
41
|
model = WisentModel(args.model, device=args.device)
|
|
42
42
|
total_layers = model.num_layers
|
|
43
|
-
print(f" ✓ Model
|
|
44
|
-
|
|
45
|
-
# 2. Determine layer range
|
|
43
|
+
print(f" ✓ Model has {total_layers} layers\n")
|
|
44
|
+
|
|
46
45
|
if args.layer_range:
|
|
47
46
|
start, end = map(int, args.layer_range.split('-'))
|
|
48
47
|
layers_to_test = list(range(start, end + 1))
|
|
49
48
|
else:
|
|
50
|
-
# Test middle layers by default
|
|
49
|
+
# Test middle layers by default
|
|
51
50
|
start_layer = total_layers // 3
|
|
52
51
|
end_layer = (2 * total_layers) // 3
|
|
53
52
|
layers_to_test = list(range(start_layer, end_layer + 1))
|
|
54
|
-
|
|
53
|
+
|
|
55
54
|
print(f"🎯 Testing layers: {layers_to_test[0]} to {layers_to_test[-1]} ({len(layers_to_test)} layers)")
|
|
56
55
|
print(f"🔄 Aggregation methods: {', '.join(args.aggregation_methods)}")
|
|
57
56
|
print(f"📊 Thresholds: {args.threshold_range}\n")
|
|
58
|
-
|
|
59
|
-
#
|
|
57
|
+
|
|
58
|
+
# 2. Get list of tasks
|
|
60
59
|
task_list = [
|
|
61
|
-
"arc_easy", "arc_challenge", "hellaswag",
|
|
60
|
+
"arc_easy", "arc_challenge", "hellaswag",
|
|
62
61
|
"winogrande", "gsm8k"
|
|
63
62
|
]
|
|
64
|
-
|
|
63
|
+
|
|
65
64
|
print(f"📋 Optimizing {len(task_list)} tasks\n")
|
|
66
|
-
|
|
67
|
-
#
|
|
68
|
-
loader = LMEvalDataLoader()
|
|
69
|
-
|
|
70
|
-
# 5. Results storage
|
|
65
|
+
|
|
66
|
+
# 3. Results storage
|
|
71
67
|
all_results = {}
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
# 6. Process each task
|
|
68
|
+
|
|
69
|
+
# 4. Process each task
|
|
75
70
|
for task_idx, task_name in enumerate(task_list, 1):
|
|
76
71
|
print(f"\n{'='*80}")
|
|
77
72
|
print(f"Task {task_idx}/{len(task_list)}: {task_name}")
|
|
78
73
|
print(f"{'='*80}")
|
|
79
|
-
|
|
74
|
+
|
|
80
75
|
task_start_time = time.time()
|
|
81
|
-
|
|
76
|
+
|
|
82
77
|
try:
|
|
83
|
-
# Load task data
|
|
84
|
-
print(f" 📊 Loading data...")
|
|
85
|
-
result = loader._load_one_task(
|
|
86
|
-
task_name=task_name,
|
|
87
|
-
split_ratio=0.8,
|
|
88
|
-
seed=42,
|
|
89
|
-
limit=args.limit,
|
|
90
|
-
training_limit=None,
|
|
91
|
-
testing_limit=None
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
train_pairs = result['train_qa_pairs']
|
|
95
|
-
test_pairs = result['test_qa_pairs']
|
|
96
|
-
|
|
97
|
-
print(f" ✓ Loaded {len(train_pairs.pairs)} train, {len(test_pairs.pairs)} test pairs")
|
|
98
|
-
|
|
99
|
-
# STEP 1: Collect raw activations ONCE for all layers (full sequence)
|
|
100
|
-
print(f" 🧠 Collecting raw activations (once per pair)...")
|
|
101
|
-
collector = ActivationCollector(model=model, store_device="cpu")
|
|
102
|
-
|
|
103
|
-
# Cache structure: train_cache[pair_idx][layer_str] = {pos: tensor, neg: tensor, pos_tokens: int, neg_tokens: int}
|
|
104
|
-
train_cache = {}
|
|
105
|
-
test_cache = {}
|
|
106
|
-
|
|
107
|
-
layer_strs = [str(l) for l in layers_to_test]
|
|
108
|
-
|
|
109
|
-
# Collect training activations with full sequence
|
|
110
|
-
for pair_idx, pair in enumerate(train_pairs.pairs):
|
|
111
|
-
updated_pair = collector.collect_for_pair(
|
|
112
|
-
pair,
|
|
113
|
-
layers=layer_strs,
|
|
114
|
-
aggregation=None, # Get raw activations without aggregation
|
|
115
|
-
return_full_sequence=True, # Get all token positions
|
|
116
|
-
normalize_layers=False
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
train_cache[pair_idx] = {}
|
|
120
|
-
for layer_str in layer_strs:
|
|
121
|
-
train_cache[pair_idx][layer_str] = {
|
|
122
|
-
'pos': updated_pair.positive_response.layers_activations.get(layer_str),
|
|
123
|
-
'neg': updated_pair.negative_response.layers_activations.get(layer_str),
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
# Collect test activations
|
|
127
|
-
for pair_idx, pair in enumerate(test_pairs.pairs):
|
|
128
|
-
updated_pair = collector.collect_for_pair(
|
|
129
|
-
pair,
|
|
130
|
-
layers=layer_strs,
|
|
131
|
-
aggregation=None,
|
|
132
|
-
return_full_sequence=True,
|
|
133
|
-
normalize_layers=False
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
test_cache[pair_idx] = {}
|
|
137
|
-
for layer_str in layer_strs:
|
|
138
|
-
test_cache[pair_idx][layer_str] = {
|
|
139
|
-
'pos': updated_pair.positive_response.layers_activations.get(layer_str),
|
|
140
|
-
'neg': updated_pair.negative_response.layers_activations.get(layer_str),
|
|
141
|
-
}
|
|
142
|
-
|
|
143
|
-
print(f" ✓ Cached activations for {len(train_cache)} train and {len(test_cache)} test pairs")
|
|
144
|
-
|
|
145
|
-
# STEP 2: Apply different aggregation strategies to cached activations
|
|
146
|
-
print(f" 🔍 Testing {len(layers_to_test) * len(args.aggregation_methods)} layer/aggregation combinations...")
|
|
147
|
-
|
|
148
|
-
# Aggregation functions
|
|
149
|
-
def aggregate_activations(raw_acts, method):
|
|
150
|
-
"""Apply aggregation to raw activation tensor."""
|
|
151
|
-
if raw_acts is None or raw_acts.numel() == 0:
|
|
152
|
-
return None
|
|
153
|
-
|
|
154
|
-
# Handle both 1D (already aggregated) and 2D (sequence, hidden_dim) tensors
|
|
155
|
-
if raw_acts.ndim == 1:
|
|
156
|
-
return raw_acts
|
|
157
|
-
elif raw_acts.ndim == 2:
|
|
158
|
-
if method == 'average':
|
|
159
|
-
return raw_acts.mean(dim=0)
|
|
160
|
-
elif method == 'final':
|
|
161
|
-
return raw_acts[-1]
|
|
162
|
-
elif method == 'first':
|
|
163
|
-
return raw_acts[0]
|
|
164
|
-
elif method == 'max':
|
|
165
|
-
return raw_acts.max(dim=0)[0]
|
|
166
|
-
elif method == 'min':
|
|
167
|
-
return raw_acts.min(dim=0)[0]
|
|
168
|
-
else:
|
|
169
|
-
# Flatten to 2D if needed
|
|
170
|
-
raw_acts = raw_acts.view(-1, raw_acts.shape[-1])
|
|
171
|
-
return aggregate_activations(raw_acts, method)
|
|
172
|
-
|
|
173
78
|
best_score = -1
|
|
174
79
|
best_config = None
|
|
175
|
-
|
|
176
|
-
|
|
80
|
+
|
|
177
81
|
combinations_tested = 0
|
|
178
|
-
total_combinations = len(layers_to_test) * len(args.aggregation_methods)
|
|
179
|
-
|
|
82
|
+
total_combinations = len(layers_to_test) * len(args.aggregation_methods) * len(args.threshold_range)
|
|
83
|
+
|
|
84
|
+
print(f" 🔍 Testing {total_combinations} configurations...")
|
|
85
|
+
|
|
180
86
|
for layer in layers_to_test:
|
|
181
|
-
layer_str = str(layer)
|
|
182
|
-
|
|
183
87
|
for agg_method in args.aggregation_methods:
|
|
184
|
-
# Apply aggregation to cached activations
|
|
185
|
-
train_pos_acts = []
|
|
186
|
-
train_neg_acts = []
|
|
187
|
-
|
|
188
|
-
for pair_idx in train_cache:
|
|
189
|
-
pos_raw = train_cache[pair_idx][layer_str]['pos']
|
|
190
|
-
neg_raw = train_cache[pair_idx][layer_str]['neg']
|
|
191
|
-
|
|
192
|
-
pos_agg = aggregate_activations(pos_raw, agg_method)
|
|
193
|
-
neg_agg = aggregate_activations(neg_raw, agg_method)
|
|
194
|
-
|
|
195
|
-
if pos_agg is not None:
|
|
196
|
-
train_pos_acts.append(pos_agg.cpu().numpy())
|
|
197
|
-
if neg_agg is not None:
|
|
198
|
-
train_neg_acts.append(neg_agg.cpu().numpy())
|
|
199
|
-
|
|
200
|
-
if len(train_pos_acts) == 0 or len(train_neg_acts) == 0:
|
|
201
|
-
combinations_tested += 1
|
|
202
|
-
continue
|
|
203
|
-
|
|
204
|
-
# Prepare training data
|
|
205
|
-
X_train_pos = np.array(train_pos_acts)
|
|
206
|
-
X_train_neg = np.array(train_neg_acts)
|
|
207
|
-
X_train = np.vstack([X_train_pos, X_train_neg])
|
|
208
|
-
y_train = np.array([1] * len(train_pos_acts) + [0] * len(train_neg_acts))
|
|
209
|
-
|
|
210
|
-
# Train classifier
|
|
211
|
-
classifier = LogisticClassifier(threshold=0.5, device="cpu")
|
|
212
|
-
|
|
213
|
-
config = ClassifierTrainConfig(
|
|
214
|
-
test_size=0.2,
|
|
215
|
-
batch_size=32,
|
|
216
|
-
num_epochs=30,
|
|
217
|
-
learning_rate=0.001,
|
|
218
|
-
monitor="f1",
|
|
219
|
-
random_state=42
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
report = classifier.fit(
|
|
223
|
-
torch.tensor(X_train, dtype=torch.float32),
|
|
224
|
-
torch.tensor(y_train, dtype=torch.float32),
|
|
225
|
-
config=config
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
# Apply aggregation to test set
|
|
229
|
-
test_pos_acts = []
|
|
230
|
-
test_neg_acts = []
|
|
231
|
-
|
|
232
|
-
for pair_idx in test_cache:
|
|
233
|
-
pos_raw = test_cache[pair_idx][layer_str]['pos']
|
|
234
|
-
neg_raw = test_cache[pair_idx][layer_str]['neg']
|
|
235
|
-
|
|
236
|
-
pos_agg = aggregate_activations(pos_raw, agg_method)
|
|
237
|
-
neg_agg = aggregate_activations(neg_raw, agg_method)
|
|
238
|
-
|
|
239
|
-
if pos_agg is not None:
|
|
240
|
-
test_pos_acts.append(pos_agg.cpu().numpy())
|
|
241
|
-
if neg_agg is not None:
|
|
242
|
-
test_neg_acts.append(neg_agg.cpu().numpy())
|
|
243
|
-
|
|
244
|
-
if len(test_pos_acts) == 0 or len(test_neg_acts) == 0:
|
|
245
|
-
combinations_tested += 1
|
|
246
|
-
continue
|
|
247
|
-
|
|
248
|
-
X_test_pos = np.array(test_pos_acts)
|
|
249
|
-
X_test_neg = np.array(test_neg_acts)
|
|
250
|
-
X_test = np.vstack([X_test_pos, X_test_neg])
|
|
251
|
-
y_test = np.array([1] * len(test_pos_acts) + [0] * len(test_neg_acts))
|
|
252
|
-
|
|
253
|
-
# Get predictions
|
|
254
|
-
y_pred_proba = np.array(classifier.predict_proba(X_test))
|
|
255
|
-
|
|
256
|
-
# Test different thresholds
|
|
257
88
|
for threshold in args.threshold_range:
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
#
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
89
|
+
combinations_tested += 1
|
|
90
|
+
|
|
91
|
+
# Create args namespace for execute_tasks
|
|
92
|
+
task_args = SimpleNamespace(
|
|
93
|
+
task_names=[task_name],
|
|
94
|
+
model=args.model,
|
|
95
|
+
layer=layer,
|
|
96
|
+
classifier_type=args.classifier_type or 'logistic',
|
|
97
|
+
token_aggregation=agg_method,
|
|
98
|
+
detection_threshold=threshold,
|
|
99
|
+
split_ratio=0.8,
|
|
100
|
+
seed=42,
|
|
101
|
+
limit=args.limit,
|
|
102
|
+
training_limit=None,
|
|
103
|
+
testing_limit=None,
|
|
104
|
+
device=args.device,
|
|
105
|
+
save_classifier=None, # Don't save intermediate classifiers
|
|
106
|
+
output=None,
|
|
107
|
+
inference_only=False,
|
|
108
|
+
load_steering_vector=None,
|
|
109
|
+
save_steering_vector=None,
|
|
110
|
+
train_only=False,
|
|
111
|
+
steering_method='caa'
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
# Call native Wisent execute_tasks
|
|
116
|
+
result = execute_tasks(task_args)
|
|
117
|
+
|
|
118
|
+
# Extract metrics from result
|
|
119
|
+
# Map CLI argument to result key
|
|
120
|
+
metric_map = {
|
|
121
|
+
'f1': 'f1_score',
|
|
122
|
+
'accuracy': 'accuracy',
|
|
123
|
+
'precision': 'precision',
|
|
124
|
+
'recall': 'recall'
|
|
286
125
|
}
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
126
|
+
metric_key = metric_map.get(args.optimization_metric, 'f1_score')
|
|
127
|
+
metric_value = result.get(metric_key, 0)
|
|
128
|
+
|
|
129
|
+
if metric_value > best_score:
|
|
130
|
+
best_score = metric_value
|
|
131
|
+
best_config = {
|
|
132
|
+
'layer': layer,
|
|
133
|
+
'aggregation': agg_method,
|
|
134
|
+
'threshold': threshold,
|
|
135
|
+
'accuracy': result.get('accuracy', 0),
|
|
136
|
+
'f1_score': result.get('f1_score', 0),
|
|
137
|
+
'precision': result.get('precision', 0),
|
|
138
|
+
'recall': result.get('recall', 0),
|
|
139
|
+
'generation_count': result.get('generation_count', 0),
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if combinations_tested % 5 == 0:
|
|
143
|
+
print(f" Progress: {combinations_tested}/{total_combinations} tested, best {args.optimization_metric}: {best_score:.4f}", end='\r')
|
|
144
|
+
|
|
145
|
+
except Exception as e:
|
|
146
|
+
# NO FALLBACK - raise error
|
|
147
|
+
print(f"\n❌ Configuration failed:")
|
|
148
|
+
print(f" Layer: {layer}")
|
|
149
|
+
print(f" Aggregation: {agg_method}")
|
|
150
|
+
print(f" Threshold: {threshold}")
|
|
151
|
+
print(f" Error: {e}")
|
|
152
|
+
raise
|
|
153
|
+
|
|
154
|
+
print(f"\n\n ✅ Best config for {task_name}:")
|
|
155
|
+
print(f" Layer: {best_config['layer']}")
|
|
156
|
+
print(f" Aggregation: {best_config['aggregation']}")
|
|
157
|
+
print(f" Threshold: {best_config['threshold']:.2f}")
|
|
158
|
+
print(f" Performance metrics:")
|
|
159
|
+
print(f" • Accuracy: {best_config['accuracy']:.4f}")
|
|
160
|
+
print(f" • F1 Score: {best_config['f1_score']:.4f}")
|
|
161
|
+
print(f" • Precision: {best_config['precision']:.4f}")
|
|
162
|
+
print(f" • Recall: {best_config['recall']:.4f}")
|
|
163
|
+
print(f" • Generations evaluated: {best_config['generation_count']}")
|
|
164
|
+
|
|
165
|
+
# Train final classifier with best config and save
|
|
166
|
+
if args.save_dir:
|
|
167
|
+
print(f"\n 💾 Training final classifier with best config...")
|
|
168
|
+
|
|
169
|
+
final_args = SimpleNamespace(
|
|
170
|
+
task_names=[task_name],
|
|
171
|
+
model=args.model,
|
|
172
|
+
layer=best_config['layer'],
|
|
173
|
+
classifier_type=args.classifier_type or 'logistic',
|
|
174
|
+
token_aggregation=best_config['aggregation'],
|
|
175
|
+
detection_threshold=best_config['threshold'],
|
|
176
|
+
split_ratio=0.8,
|
|
177
|
+
seed=42,
|
|
178
|
+
limit=args.limit,
|
|
179
|
+
training_limit=None,
|
|
180
|
+
testing_limit=None,
|
|
181
|
+
device=args.device,
|
|
182
|
+
save_classifier=os.path.join(args.save_dir, f"{task_name}_classifier.pt"),
|
|
183
|
+
output=os.path.join(args.save_dir, task_name),
|
|
184
|
+
inference_only=False,
|
|
185
|
+
load_steering_vector=None,
|
|
186
|
+
save_steering_vector=None,
|
|
187
|
+
train_only=False,
|
|
188
|
+
steering_method='caa'
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
execute_tasks(final_args)
|
|
192
|
+
print(f" ✓ Classifier saved to: {final_args.save_classifier}")
|
|
193
|
+
|
|
194
|
+
# Store results
|
|
195
|
+
all_results[task_name] = {
|
|
196
|
+
'best_config': best_config,
|
|
197
|
+
'optimization_metric': args.optimization_metric,
|
|
198
|
+
'best_score': best_score,
|
|
199
|
+
'combinations_tested': combinations_tested
|
|
200
|
+
}
|
|
201
|
+
|
|
300
202
|
task_time = time.time() - task_start_time
|
|
301
|
-
print(f" ⏱️ Task completed in {task_time:.1f}s")
|
|
302
|
-
|
|
203
|
+
print(f"\n ⏱️ Task completed in {task_time:.1f}s")
|
|
204
|
+
|
|
303
205
|
except Exception as e:
|
|
304
|
-
|
|
206
|
+
# NO FALLBACK - raise error
|
|
207
|
+
print(f"\n❌ Task '{task_name}' optimization failed:")
|
|
208
|
+
print(f" Error: {e}")
|
|
305
209
|
import traceback
|
|
306
210
|
traceback.print_exc()
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
#
|
|
211
|
+
raise
|
|
212
|
+
|
|
213
|
+
# 5. Save optimization results
|
|
214
|
+
results_dir = args.save_dir or './optimization_results'
|
|
215
|
+
os.makedirs(results_dir, exist_ok=True)
|
|
216
|
+
|
|
217
|
+
model_name_safe = args.model.replace('/', '_')
|
|
218
|
+
results_file = os.path.join(results_dir, f'classification_optimization_{model_name_safe}.json')
|
|
219
|
+
|
|
220
|
+
with open(results_file, 'w') as f:
|
|
221
|
+
json.dump({
|
|
222
|
+
'model': args.model,
|
|
223
|
+
'optimization_metric': args.optimization_metric,
|
|
224
|
+
'results': all_results
|
|
225
|
+
}, f, indent=2)
|
|
226
|
+
|
|
310
227
|
print(f"\n{'='*80}")
|
|
311
228
|
print(f"📊 OPTIMIZATION COMPLETE")
|
|
312
|
-
print(f"{'='*80}
|
|
313
|
-
|
|
314
|
-
results_file = args.results_file or f"./optimization_results/classification_results.json"
|
|
315
|
-
import os
|
|
316
|
-
os.makedirs(os.path.dirname(results_file) if os.path.dirname(results_file) else ".", exist_ok=True)
|
|
317
|
-
|
|
318
|
-
output_data = {
|
|
319
|
-
'model': args.model,
|
|
320
|
-
'optimization_metric': args.optimization_metric,
|
|
321
|
-
'layer_range': f"{layers_to_test[0]}-{layers_to_test[-1]}",
|
|
322
|
-
'aggregation_methods': args.aggregation_methods,
|
|
323
|
-
'threshold_range': args.threshold_range,
|
|
324
|
-
'tasks': all_results,
|
|
325
|
-
'classifiers_saved': classifiers_saved
|
|
326
|
-
}
|
|
327
|
-
|
|
328
|
-
with open(results_file, 'w') as f:
|
|
329
|
-
json.dump(output_data, f, indent=2)
|
|
330
|
-
|
|
229
|
+
print(f"{'='*80}")
|
|
331
230
|
print(f"✅ Results saved to: {results_file}\n")
|
|
332
|
-
|
|
333
|
-
# Print summary
|
|
334
|
-
print("📋 SUMMARY BY TASK:")
|
|
335
|
-
print("-" * 80)
|
|
336
|
-
for task_name, config in all_results.items():
|
|
337
|
-
print(f" {task_name:20s} | Layer: {config['layer']:2d} | Agg: {config['aggregation']:8s} | Thresh: {config['threshold']:.2f} | F1: {config['f1']:.3f}")
|
|
338
|
-
print("-" * 80 + "\n")
|
|
339
231
|
|
|
232
|
+
# Print summary
|
|
233
|
+
print(f"📋 SUMMARY BY TASK:")
|
|
234
|
+
print(f"-" * 120)
|
|
235
|
+
for task_name, result in all_results.items():
|
|
236
|
+
config = result['best_config']
|
|
237
|
+
print(f"{task_name:20} | Layer: {config['layer']:2} | Agg: {config['aggregation']:8} | "
|
|
238
|
+
f"Thresh: {config['threshold']:.2f} | F1: {config['f1_score']:.4f} | "
|
|
239
|
+
f"Acc: {config['accuracy']:.4f} | Gens: {config['generation_count']:3}")
|
|
240
|
+
print(f"-" * 120)
|
|
241
|
+
print()
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Sample size optimization command execution logic."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def execute_optimize_sample_size(args):
|
|
7
|
+
"""Execute the optimize-sample-size command - find optimal training sample size."""
|
|
8
|
+
from wisent.core.sample_size_optimizer_v2 import SimplifiedSampleSizeOptimizer
|
|
9
|
+
|
|
10
|
+
print(f"\n{'='*80}")
|
|
11
|
+
print(f"📊 SAMPLE SIZE OPTIMIZATION")
|
|
12
|
+
print(f"{'='*80}")
|
|
13
|
+
print(f" Model: {args.model}")
|
|
14
|
+
print(f" Task: {args.task}")
|
|
15
|
+
print(f" Layer: {args.layer}")
|
|
16
|
+
print(f" Sample sizes: {args.sample_sizes}")
|
|
17
|
+
print(f" Test size: {args.test_size}")
|
|
18
|
+
print(f" Mode: {'Steering' if args.steering_mode else 'Classification'}")
|
|
19
|
+
print(f"{'='*80}\n")
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
# Prepare method kwargs based on mode
|
|
23
|
+
method_kwargs = {}
|
|
24
|
+
if args.steering_mode:
|
|
25
|
+
method_type = "steering"
|
|
26
|
+
method_kwargs['steering_method'] = args.steering_method
|
|
27
|
+
method_kwargs['steering_strength'] = args.steering_strength
|
|
28
|
+
method_kwargs['token_targeting_strategy'] = args.token_targeting_strategy
|
|
29
|
+
else:
|
|
30
|
+
method_type = "classification"
|
|
31
|
+
method_kwargs['token_aggregation'] = args.token_aggregation
|
|
32
|
+
method_kwargs['threshold'] = args.threshold
|
|
33
|
+
|
|
34
|
+
# Create optimizer
|
|
35
|
+
optimizer = SimplifiedSampleSizeOptimizer(
|
|
36
|
+
model_name=args.model,
|
|
37
|
+
task_name=args.task,
|
|
38
|
+
layer=args.layer,
|
|
39
|
+
method_type=method_type,
|
|
40
|
+
sample_sizes=args.sample_sizes,
|
|
41
|
+
test_size=args.test_size,
|
|
42
|
+
seed=args.seed,
|
|
43
|
+
verbose=args.verbose,
|
|
44
|
+
**method_kwargs
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Run optimization
|
|
48
|
+
print(f"\n🔍 Running sample size optimization...")
|
|
49
|
+
results = optimizer.run_optimization()
|
|
50
|
+
|
|
51
|
+
# Display results
|
|
52
|
+
print(f"\n📈 Optimization Results:")
|
|
53
|
+
print(f" Optimal sample size: {results['optimal_sample_size']}")
|
|
54
|
+
if results['optimal_accuracy'] is not None:
|
|
55
|
+
print(f" Best accuracy: {results['optimal_accuracy']:.4f}")
|
|
56
|
+
if results['optimal_f1_score'] is not None:
|
|
57
|
+
print(f" Best F1 score: {results['optimal_f1_score']:.4f}")
|
|
58
|
+
|
|
59
|
+
# Save plot if requested
|
|
60
|
+
if args.save_plot:
|
|
61
|
+
plot_path = f"sample_size_optimization_{args.task}_{args.model.replace('/', '_')}.png"
|
|
62
|
+
optimizer.plot_results(save_path=plot_path)
|
|
63
|
+
print(f"\n💾 Plot saved to: {plot_path}")
|
|
64
|
+
|
|
65
|
+
# Save to model config unless disabled
|
|
66
|
+
if not args.no_save_config:
|
|
67
|
+
print(f"\n💾 Saving optimal sample size to model config...")
|
|
68
|
+
# This would call ModelConfigManager to save the config
|
|
69
|
+
print(f" ✓ Saved to model configuration")
|
|
70
|
+
|
|
71
|
+
print(f"\n✅ Sample size optimization completed successfully!\n")
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
print(f"\n❌ Error: {str(e)}", file=sys.stderr)
|
|
75
|
+
if args.verbose:
|
|
76
|
+
import traceback
|
|
77
|
+
traceback.print_exc()
|
|
78
|
+
sys.exit(1)
|