wisent 0.5.14__py3-none-any.whl → 0.5.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (60) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/cli.py +114 -0
  3. wisent/core/activations/activations_collector.py +19 -11
  4. wisent/core/cli/__init__.py +3 -1
  5. wisent/core/cli/create_steering_vector.py +60 -18
  6. wisent/core/cli/evaluate_responses.py +14 -8
  7. wisent/core/cli/generate_pairs_from_task.py +18 -5
  8. wisent/core/cli/get_activations.py +1 -1
  9. wisent/core/cli/multi_steer.py +108 -0
  10. wisent/core/cli/optimize_classification.py +187 -285
  11. wisent/core/cli/optimize_sample_size.py +78 -0
  12. wisent/core/cli/optimize_steering.py +354 -53
  13. wisent/core/cli/tasks.py +274 -9
  14. wisent/core/errors/__init__.py +0 -0
  15. wisent/core/errors/error_handler.py +134 -0
  16. wisent/core/evaluators/benchmark_specific/log_likelihoods_evaluator.py +152 -295
  17. wisent/core/evaluators/rotator.py +22 -8
  18. wisent/core/main.py +5 -1
  19. wisent/core/model_persistence.py +4 -19
  20. wisent/core/models/wisent_model.py +11 -3
  21. wisent/core/parser.py +4 -3
  22. wisent/core/parser_arguments/main_parser.py +1 -1
  23. wisent/core/parser_arguments/multi_steer_parser.py +4 -3
  24. wisent/core/parser_arguments/optimize_steering_parser.py +4 -0
  25. wisent/core/sample_size_optimizer_v2.py +1 -1
  26. wisent/core/steering_optimizer.py +2 -2
  27. wisent/tests/__init__.py +0 -0
  28. wisent/tests/examples/__init__.py +0 -0
  29. wisent/tests/examples/cli/__init__.py +0 -0
  30. wisent/tests/examples/cli/activations/__init__.py +0 -0
  31. wisent/tests/examples/cli/activations/test_get_activations.py +127 -0
  32. wisent/tests/examples/cli/classifier/__init__.py +0 -0
  33. wisent/tests/examples/cli/classifier/test_classifier_examples.py +141 -0
  34. wisent/tests/examples/cli/contrastive_pairs/__init__.py +0 -0
  35. wisent/tests/examples/cli/contrastive_pairs/test_generate_pairs.py +89 -0
  36. wisent/tests/examples/cli/evaluation/__init__.py +0 -0
  37. wisent/tests/examples/cli/evaluation/test_evaluation_examples.py +117 -0
  38. wisent/tests/examples/cli/generate/__init__.py +0 -0
  39. wisent/tests/examples/cli/generate/test_generate_with_classifier.py +146 -0
  40. wisent/tests/examples/cli/generate/test_generate_with_steering.py +149 -0
  41. wisent/tests/examples/cli/generate/test_only_generate.py +110 -0
  42. wisent/tests/examples/cli/multi_steering/__init__.py +0 -0
  43. wisent/tests/examples/cli/multi_steering/test_multi_steer_from_trained_vectors.py +210 -0
  44. wisent/tests/examples/cli/multi_steering/test_multi_steer_with_different_parameters.py +205 -0
  45. wisent/tests/examples/cli/multi_steering/test_train_and_multi_steer.py +174 -0
  46. wisent/tests/examples/cli/optimizer/__init__.py +0 -0
  47. wisent/tests/examples/cli/optimizer/test_optimize_sample_size.py +102 -0
  48. wisent/tests/examples/cli/optimizer/test_optimizer_examples.py +59 -0
  49. wisent/tests/examples/cli/steering/__init__.py +0 -0
  50. wisent/tests/examples/cli/steering/test_create_steering_vectors.py +135 -0
  51. wisent/tests/examples/cli/synthetic/__init__.py +0 -0
  52. wisent/tests/examples/cli/synthetic/test_synthetic_pairs.py +45 -0
  53. {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/METADATA +3 -1
  54. {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/RECORD +59 -29
  55. wisent/core/agent/diagnose/test_synthetic_classifier.py +0 -71
  56. /wisent/core/parser_arguments/{test_nonsense_parser.py → nonsense_parser.py} +0 -0
  57. {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/WHEEL +0 -0
  58. {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/entry_points.txt +0 -0
  59. {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/licenses/LICENSE +0 -0
  60. {wisent-0.5.14.dist-info → wisent-0.5.15.dist-info}/top_level.txt +0 -0
@@ -1,339 +1,241 @@
1
- """Classification optimization command execution logic."""
1
+ """Classification optimization command - uses native Wisent methods.
2
+
3
+ This optimizer tests different configurations (layer, aggregation, threshold)
4
+ by calling the native execute_tasks() function for each configuration,
5
+ then evaluates using Wisent's native evaluators.
6
+ """
2
7
 
3
8
  import sys
4
9
  import json
5
10
  import time
6
11
  from typing import List, Dict, Any
12
+ import os
13
+
7
14
 
8
15
  def execute_optimize_classification(args):
9
16
  """
10
- Execute the optimize-classification command.
11
-
12
- Optimizes classification parameters across all available tasks:
13
- - Finds best layer for each task
14
- - Finds best token aggregation method
15
- - Finds best detection threshold
16
- - Saves trained classifiers
17
-
18
- EFFICIENCY: Collects raw activations ONCE, then applies different aggregation strategies
19
- to the cached activations without re-running the model.
17
+ Execute classification optimization using native Wisent methods.
18
+
19
+ Tests different configurations by calling execute_tasks() for each combination:
20
+ - Different layers
21
+ - Different aggregation methods
22
+ - Different detection thresholds
23
+
24
+ Uses native Wisent evaluation (not sklearn metrics).
20
25
  """
21
- from wisent.core.models.wisent_model import WisentModel
22
- from wisent.core.data_loaders.loaders.lm_loader import LMEvalDataLoader
23
- from wisent.core.activations.activations_collector import ActivationCollector
24
- from wisent.core.activations.core.atoms import ActivationAggregationStrategy
25
- from wisent.core.classifiers.classifiers.models.logistic import LogisticClassifier
26
- from wisent.core.classifiers.classifiers.core.atoms import ClassifierTrainConfig
27
- import numpy as np
28
- import torch
29
-
26
+ from wisent.core.cli.tasks import execute_tasks
27
+ from types import SimpleNamespace
28
+
30
29
  print(f"\n{'='*80}")
31
30
  print(f"🔍 CLASSIFICATION PARAMETER OPTIMIZATION")
32
31
  print(f"{'='*80}")
33
32
  print(f" Model: {args.model}")
34
33
  print(f" Limit per task: {args.limit}")
35
- print(f" Optimization metric: {args.optimization_metric}")
36
34
  print(f" Device: {args.device or 'auto'}")
37
35
  print(f"{'='*80}\n")
38
-
39
- # 1. Load model
40
- print(f"📦 Loading model...")
36
+
37
+ # 1. Determine layer range
38
+ # First need to load model to get num_layers
39
+ from wisent.core.models.wisent_model import WisentModel
40
+ print(f"📦 Loading model to determine layer range...")
41
41
  model = WisentModel(args.model, device=args.device)
42
42
  total_layers = model.num_layers
43
- print(f" ✓ Model loaded with {total_layers} layers\n")
44
-
45
- # 2. Determine layer range
43
+ print(f" ✓ Model has {total_layers} layers\n")
44
+
46
45
  if args.layer_range:
47
46
  start, end = map(int, args.layer_range.split('-'))
48
47
  layers_to_test = list(range(start, end + 1))
49
48
  else:
50
- # Test middle layers by default (more informative)
49
+ # Test middle layers by default
51
50
  start_layer = total_layers // 3
52
51
  end_layer = (2 * total_layers) // 3
53
52
  layers_to_test = list(range(start_layer, end_layer + 1))
54
-
53
+
55
54
  print(f"🎯 Testing layers: {layers_to_test[0]} to {layers_to_test[-1]} ({len(layers_to_test)} layers)")
56
55
  print(f"🔄 Aggregation methods: {', '.join(args.aggregation_methods)}")
57
56
  print(f"📊 Thresholds: {args.threshold_range}\n")
58
-
59
- # 3. Get list of tasks to optimize
57
+
58
+ # 2. Get list of tasks
60
59
  task_list = [
61
- "arc_easy", "arc_challenge", "hellaswag",
60
+ "arc_easy", "arc_challenge", "hellaswag",
62
61
  "winogrande", "gsm8k"
63
62
  ]
64
-
63
+
65
64
  print(f"📋 Optimizing {len(task_list)} tasks\n")
66
-
67
- # 4. Initialize data loader
68
- loader = LMEvalDataLoader()
69
-
70
- # 5. Results storage
65
+
66
+ # 3. Results storage
71
67
  all_results = {}
72
- classifiers_saved = {}
73
-
74
- # 6. Process each task
68
+
69
+ # 4. Process each task
75
70
  for task_idx, task_name in enumerate(task_list, 1):
76
71
  print(f"\n{'='*80}")
77
72
  print(f"Task {task_idx}/{len(task_list)}: {task_name}")
78
73
  print(f"{'='*80}")
79
-
74
+
80
75
  task_start_time = time.time()
81
-
76
+
82
77
  try:
83
- # Load task data
84
- print(f" 📊 Loading data...")
85
- result = loader._load_one_task(
86
- task_name=task_name,
87
- split_ratio=0.8,
88
- seed=42,
89
- limit=args.limit,
90
- training_limit=None,
91
- testing_limit=None
92
- )
93
-
94
- train_pairs = result['train_qa_pairs']
95
- test_pairs = result['test_qa_pairs']
96
-
97
- print(f" ✓ Loaded {len(train_pairs.pairs)} train, {len(test_pairs.pairs)} test pairs")
98
-
99
- # STEP 1: Collect raw activations ONCE for all layers (full sequence)
100
- print(f" 🧠 Collecting raw activations (once per pair)...")
101
- collector = ActivationCollector(model=model, store_device="cpu")
102
-
103
- # Cache structure: train_cache[pair_idx][layer_str] = {pos: tensor, neg: tensor, pos_tokens: int, neg_tokens: int}
104
- train_cache = {}
105
- test_cache = {}
106
-
107
- layer_strs = [str(l) for l in layers_to_test]
108
-
109
- # Collect training activations with full sequence
110
- for pair_idx, pair in enumerate(train_pairs.pairs):
111
- updated_pair = collector.collect_for_pair(
112
- pair,
113
- layers=layer_strs,
114
- aggregation=None, # Get raw activations without aggregation
115
- return_full_sequence=True, # Get all token positions
116
- normalize_layers=False
117
- )
118
-
119
- train_cache[pair_idx] = {}
120
- for layer_str in layer_strs:
121
- train_cache[pair_idx][layer_str] = {
122
- 'pos': updated_pair.positive_response.layers_activations.get(layer_str),
123
- 'neg': updated_pair.negative_response.layers_activations.get(layer_str),
124
- }
125
-
126
- # Collect test activations
127
- for pair_idx, pair in enumerate(test_pairs.pairs):
128
- updated_pair = collector.collect_for_pair(
129
- pair,
130
- layers=layer_strs,
131
- aggregation=None,
132
- return_full_sequence=True,
133
- normalize_layers=False
134
- )
135
-
136
- test_cache[pair_idx] = {}
137
- for layer_str in layer_strs:
138
- test_cache[pair_idx][layer_str] = {
139
- 'pos': updated_pair.positive_response.layers_activations.get(layer_str),
140
- 'neg': updated_pair.negative_response.layers_activations.get(layer_str),
141
- }
142
-
143
- print(f" ✓ Cached activations for {len(train_cache)} train and {len(test_cache)} test pairs")
144
-
145
- # STEP 2: Apply different aggregation strategies to cached activations
146
- print(f" 🔍 Testing {len(layers_to_test) * len(args.aggregation_methods)} layer/aggregation combinations...")
147
-
148
- # Aggregation functions
149
- def aggregate_activations(raw_acts, method):
150
- """Apply aggregation to raw activation tensor."""
151
- if raw_acts is None or raw_acts.numel() == 0:
152
- return None
153
-
154
- # Handle both 1D (already aggregated) and 2D (sequence, hidden_dim) tensors
155
- if raw_acts.ndim == 1:
156
- return raw_acts
157
- elif raw_acts.ndim == 2:
158
- if method == 'average':
159
- return raw_acts.mean(dim=0)
160
- elif method == 'final':
161
- return raw_acts[-1]
162
- elif method == 'first':
163
- return raw_acts[0]
164
- elif method == 'max':
165
- return raw_acts.max(dim=0)[0]
166
- elif method == 'min':
167
- return raw_acts.min(dim=0)[0]
168
- else:
169
- # Flatten to 2D if needed
170
- raw_acts = raw_acts.view(-1, raw_acts.shape[-1])
171
- return aggregate_activations(raw_acts, method)
172
-
173
78
  best_score = -1
174
79
  best_config = None
175
- best_classifier = None
176
-
80
+
177
81
  combinations_tested = 0
178
- total_combinations = len(layers_to_test) * len(args.aggregation_methods)
179
-
82
+ total_combinations = len(layers_to_test) * len(args.aggregation_methods) * len(args.threshold_range)
83
+
84
+ print(f" 🔍 Testing {total_combinations} configurations...")
85
+
180
86
  for layer in layers_to_test:
181
- layer_str = str(layer)
182
-
183
87
  for agg_method in args.aggregation_methods:
184
- # Apply aggregation to cached activations
185
- train_pos_acts = []
186
- train_neg_acts = []
187
-
188
- for pair_idx in train_cache:
189
- pos_raw = train_cache[pair_idx][layer_str]['pos']
190
- neg_raw = train_cache[pair_idx][layer_str]['neg']
191
-
192
- pos_agg = aggregate_activations(pos_raw, agg_method)
193
- neg_agg = aggregate_activations(neg_raw, agg_method)
194
-
195
- if pos_agg is not None:
196
- train_pos_acts.append(pos_agg.cpu().numpy())
197
- if neg_agg is not None:
198
- train_neg_acts.append(neg_agg.cpu().numpy())
199
-
200
- if len(train_pos_acts) == 0 or len(train_neg_acts) == 0:
201
- combinations_tested += 1
202
- continue
203
-
204
- # Prepare training data
205
- X_train_pos = np.array(train_pos_acts)
206
- X_train_neg = np.array(train_neg_acts)
207
- X_train = np.vstack([X_train_pos, X_train_neg])
208
- y_train = np.array([1] * len(train_pos_acts) + [0] * len(train_neg_acts))
209
-
210
- # Train classifier
211
- classifier = LogisticClassifier(threshold=0.5, device="cpu")
212
-
213
- config = ClassifierTrainConfig(
214
- test_size=0.2,
215
- batch_size=32,
216
- num_epochs=30,
217
- learning_rate=0.001,
218
- monitor="f1",
219
- random_state=42
220
- )
221
-
222
- report = classifier.fit(
223
- torch.tensor(X_train, dtype=torch.float32),
224
- torch.tensor(y_train, dtype=torch.float32),
225
- config=config
226
- )
227
-
228
- # Apply aggregation to test set
229
- test_pos_acts = []
230
- test_neg_acts = []
231
-
232
- for pair_idx in test_cache:
233
- pos_raw = test_cache[pair_idx][layer_str]['pos']
234
- neg_raw = test_cache[pair_idx][layer_str]['neg']
235
-
236
- pos_agg = aggregate_activations(pos_raw, agg_method)
237
- neg_agg = aggregate_activations(neg_raw, agg_method)
238
-
239
- if pos_agg is not None:
240
- test_pos_acts.append(pos_agg.cpu().numpy())
241
- if neg_agg is not None:
242
- test_neg_acts.append(neg_agg.cpu().numpy())
243
-
244
- if len(test_pos_acts) == 0 or len(test_neg_acts) == 0:
245
- combinations_tested += 1
246
- continue
247
-
248
- X_test_pos = np.array(test_pos_acts)
249
- X_test_neg = np.array(test_neg_acts)
250
- X_test = np.vstack([X_test_pos, X_test_neg])
251
- y_test = np.array([1] * len(test_pos_acts) + [0] * len(test_neg_acts))
252
-
253
- # Get predictions
254
- y_pred_proba = np.array(classifier.predict_proba(X_test))
255
-
256
- # Test different thresholds
257
88
  for threshold in args.threshold_range:
258
- y_pred = (y_pred_proba > threshold).astype(int)
259
-
260
- # Calculate metrics
261
- from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
262
-
263
- accuracy = accuracy_score(y_test, y_pred)
264
- f1 = f1_score(y_test, y_pred, zero_division=0)
265
- precision = precision_score(y_test, y_pred, zero_division=0)
266
- recall = recall_score(y_test, y_pred, zero_division=0)
267
-
268
- # Choose metric based on args
269
- metric_value = {
270
- 'f1': f1,
271
- 'accuracy': accuracy,
272
- 'precision': precision,
273
- 'recall': recall
274
- }[args.optimization_metric]
275
-
276
- if metric_value > best_score:
277
- best_score = metric_value
278
- best_config = {
279
- 'layer': layer,
280
- 'aggregation': agg_method,
281
- 'threshold': threshold,
282
- 'accuracy': float(accuracy),
283
- 'f1': float(f1),
284
- 'precision': float(precision),
285
- 'recall': float(recall)
89
+ combinations_tested += 1
90
+
91
+ # Create args namespace for execute_tasks
92
+ task_args = SimpleNamespace(
93
+ task_names=[task_name],
94
+ model=args.model,
95
+ layer=layer,
96
+ classifier_type=args.classifier_type or 'logistic',
97
+ token_aggregation=agg_method,
98
+ detection_threshold=threshold,
99
+ split_ratio=0.8,
100
+ seed=42,
101
+ limit=args.limit,
102
+ training_limit=None,
103
+ testing_limit=None,
104
+ device=args.device,
105
+ save_classifier=None, # Don't save intermediate classifiers
106
+ output=None,
107
+ inference_only=False,
108
+ load_steering_vector=None,
109
+ save_steering_vector=None,
110
+ train_only=False,
111
+ steering_method='caa'
112
+ )
113
+
114
+ try:
115
+ # Call native Wisent execute_tasks
116
+ result = execute_tasks(task_args)
117
+
118
+ # Extract metrics from result
119
+ # Map CLI argument to result key
120
+ metric_map = {
121
+ 'f1': 'f1_score',
122
+ 'accuracy': 'accuracy',
123
+ 'precision': 'precision',
124
+ 'recall': 'recall'
286
125
  }
287
- best_classifier = classifier
288
-
289
- combinations_tested += 1
290
- print(f" Progress: {combinations_tested}/{total_combinations} combinations tested", end='\r')
291
-
292
- print(f"\n ✅ Best config: layer={best_config['layer']}, agg={best_config['aggregation']}, thresh={best_config['threshold']:.2f}")
293
- print(f" Metrics: acc={best_config['accuracy']:.3f}, f1={best_config['f1']:.3f}, prec={best_config['precision']:.3f}, rec={best_config['recall']:.3f}")
294
-
295
- all_results[task_name] = best_config
296
-
297
- # Note: Classifier saving disabled due to missing .save() method
298
- # Can be enabled once proper serialization is implemented
299
-
126
+ metric_key = metric_map.get(args.optimization_metric, 'f1_score')
127
+ metric_value = result.get(metric_key, 0)
128
+
129
+ if metric_value > best_score:
130
+ best_score = metric_value
131
+ best_config = {
132
+ 'layer': layer,
133
+ 'aggregation': agg_method,
134
+ 'threshold': threshold,
135
+ 'accuracy': result.get('accuracy', 0),
136
+ 'f1_score': result.get('f1_score', 0),
137
+ 'precision': result.get('precision', 0),
138
+ 'recall': result.get('recall', 0),
139
+ 'generation_count': result.get('generation_count', 0),
140
+ }
141
+
142
+ if combinations_tested % 5 == 0:
143
+ print(f" Progress: {combinations_tested}/{total_combinations} tested, best {args.optimization_metric}: {best_score:.4f}", end='\r')
144
+
145
+ except Exception as e:
146
+ # NO FALLBACK - raise error
147
+ print(f"\n❌ Configuration failed:")
148
+ print(f" Layer: {layer}")
149
+ print(f" Aggregation: {agg_method}")
150
+ print(f" Threshold: {threshold}")
151
+ print(f" Error: {e}")
152
+ raise
153
+
154
+ print(f"\n\n ✅ Best config for {task_name}:")
155
+ print(f" Layer: {best_config['layer']}")
156
+ print(f" Aggregation: {best_config['aggregation']}")
157
+ print(f" Threshold: {best_config['threshold']:.2f}")
158
+ print(f" Performance metrics:")
159
+ print(f" • Accuracy: {best_config['accuracy']:.4f}")
160
+ print(f" • F1 Score: {best_config['f1_score']:.4f}")
161
+ print(f" • Precision: {best_config['precision']:.4f}")
162
+ print(f" • Recall: {best_config['recall']:.4f}")
163
+ print(f" • Generations evaluated: {best_config['generation_count']}")
164
+
165
+ # Train final classifier with best config and save
166
+ if args.save_dir:
167
+ print(f"\n 💾 Training final classifier with best config...")
168
+
169
+ final_args = SimpleNamespace(
170
+ task_names=[task_name],
171
+ model=args.model,
172
+ layer=best_config['layer'],
173
+ classifier_type=args.classifier_type or 'logistic',
174
+ token_aggregation=best_config['aggregation'],
175
+ detection_threshold=best_config['threshold'],
176
+ split_ratio=0.8,
177
+ seed=42,
178
+ limit=args.limit,
179
+ training_limit=None,
180
+ testing_limit=None,
181
+ device=args.device,
182
+ save_classifier=os.path.join(args.save_dir, f"{task_name}_classifier.pt"),
183
+ output=os.path.join(args.save_dir, task_name),
184
+ inference_only=False,
185
+ load_steering_vector=None,
186
+ save_steering_vector=None,
187
+ train_only=False,
188
+ steering_method='caa'
189
+ )
190
+
191
+ execute_tasks(final_args)
192
+ print(f" ✓ Classifier saved to: {final_args.save_classifier}")
193
+
194
+ # Store results
195
+ all_results[task_name] = {
196
+ 'best_config': best_config,
197
+ 'optimization_metric': args.optimization_metric,
198
+ 'best_score': best_score,
199
+ 'combinations_tested': combinations_tested
200
+ }
201
+
300
202
  task_time = time.time() - task_start_time
301
- print(f" ⏱️ Task completed in {task_time:.1f}s")
302
-
203
+ print(f"\n ⏱️ Task completed in {task_time:.1f}s")
204
+
303
205
  except Exception as e:
304
- print(f" ❌ Failed to optimize {task_name}: {e}")
206
+ # NO FALLBACK - raise error
207
+ print(f"\n❌ Task '{task_name}' optimization failed:")
208
+ print(f" Error: {e}")
305
209
  import traceback
306
210
  traceback.print_exc()
307
- continue
308
-
309
- # 7. Save results
211
+ raise
212
+
213
+ # 5. Save optimization results
214
+ results_dir = args.save_dir or './optimization_results'
215
+ os.makedirs(results_dir, exist_ok=True)
216
+
217
+ model_name_safe = args.model.replace('/', '_')
218
+ results_file = os.path.join(results_dir, f'classification_optimization_{model_name_safe}.json')
219
+
220
+ with open(results_file, 'w') as f:
221
+ json.dump({
222
+ 'model': args.model,
223
+ 'optimization_metric': args.optimization_metric,
224
+ 'results': all_results
225
+ }, f, indent=2)
226
+
310
227
  print(f"\n{'='*80}")
311
228
  print(f"📊 OPTIMIZATION COMPLETE")
312
- print(f"{'='*80}\n")
313
-
314
- results_file = args.results_file or f"./optimization_results/classification_results.json"
315
- import os
316
- os.makedirs(os.path.dirname(results_file) if os.path.dirname(results_file) else ".", exist_ok=True)
317
-
318
- output_data = {
319
- 'model': args.model,
320
- 'optimization_metric': args.optimization_metric,
321
- 'layer_range': f"{layers_to_test[0]}-{layers_to_test[-1]}",
322
- 'aggregation_methods': args.aggregation_methods,
323
- 'threshold_range': args.threshold_range,
324
- 'tasks': all_results,
325
- 'classifiers_saved': classifiers_saved
326
- }
327
-
328
- with open(results_file, 'w') as f:
329
- json.dump(output_data, f, indent=2)
330
-
229
+ print(f"{'='*80}")
331
230
  print(f"✅ Results saved to: {results_file}\n")
332
-
333
- # Print summary
334
- print("📋 SUMMARY BY TASK:")
335
- print("-" * 80)
336
- for task_name, config in all_results.items():
337
- print(f" {task_name:20s} | Layer: {config['layer']:2d} | Agg: {config['aggregation']:8s} | Thresh: {config['threshold']:.2f} | F1: {config['f1']:.3f}")
338
- print("-" * 80 + "\n")
339
231
 
232
+ # Print summary
233
+ print(f"📋 SUMMARY BY TASK:")
234
+ print(f"-" * 120)
235
+ for task_name, result in all_results.items():
236
+ config = result['best_config']
237
+ print(f"{task_name:20} | Layer: {config['layer']:2} | Agg: {config['aggregation']:8} | "
238
+ f"Thresh: {config['threshold']:.2f} | F1: {config['f1_score']:.4f} | "
239
+ f"Acc: {config['accuracy']:.4f} | Gens: {config['generation_count']:3}")
240
+ print(f"-" * 120)
241
+ print()
@@ -0,0 +1,78 @@
1
+ """Sample size optimization command execution logic."""
2
+
3
+ import sys
4
+
5
+
6
+ def execute_optimize_sample_size(args):
7
+ """Execute the optimize-sample-size command - find optimal training sample size."""
8
+ from wisent.core.sample_size_optimizer_v2 import SimplifiedSampleSizeOptimizer
9
+
10
+ print(f"\n{'='*80}")
11
+ print(f"📊 SAMPLE SIZE OPTIMIZATION")
12
+ print(f"{'='*80}")
13
+ print(f" Model: {args.model}")
14
+ print(f" Task: {args.task}")
15
+ print(f" Layer: {args.layer}")
16
+ print(f" Sample sizes: {args.sample_sizes}")
17
+ print(f" Test size: {args.test_size}")
18
+ print(f" Mode: {'Steering' if args.steering_mode else 'Classification'}")
19
+ print(f"{'='*80}\n")
20
+
21
+ try:
22
+ # Prepare method kwargs based on mode
23
+ method_kwargs = {}
24
+ if args.steering_mode:
25
+ method_type = "steering"
26
+ method_kwargs['steering_method'] = args.steering_method
27
+ method_kwargs['steering_strength'] = args.steering_strength
28
+ method_kwargs['token_targeting_strategy'] = args.token_targeting_strategy
29
+ else:
30
+ method_type = "classification"
31
+ method_kwargs['token_aggregation'] = args.token_aggregation
32
+ method_kwargs['threshold'] = args.threshold
33
+
34
+ # Create optimizer
35
+ optimizer = SimplifiedSampleSizeOptimizer(
36
+ model_name=args.model,
37
+ task_name=args.task,
38
+ layer=args.layer,
39
+ method_type=method_type,
40
+ sample_sizes=args.sample_sizes,
41
+ test_size=args.test_size,
42
+ seed=args.seed,
43
+ verbose=args.verbose,
44
+ **method_kwargs
45
+ )
46
+
47
+ # Run optimization
48
+ print(f"\n🔍 Running sample size optimization...")
49
+ results = optimizer.run_optimization()
50
+
51
+ # Display results
52
+ print(f"\n📈 Optimization Results:")
53
+ print(f" Optimal sample size: {results['optimal_sample_size']}")
54
+ if results['optimal_accuracy'] is not None:
55
+ print(f" Best accuracy: {results['optimal_accuracy']:.4f}")
56
+ if results['optimal_f1_score'] is not None:
57
+ print(f" Best F1 score: {results['optimal_f1_score']:.4f}")
58
+
59
+ # Save plot if requested
60
+ if args.save_plot:
61
+ plot_path = f"sample_size_optimization_{args.task}_{args.model.replace('/', '_')}.png"
62
+ optimizer.plot_results(save_path=plot_path)
63
+ print(f"\n💾 Plot saved to: {plot_path}")
64
+
65
+ # Save to model config unless disabled
66
+ if not args.no_save_config:
67
+ print(f"\n💾 Saving optimal sample size to model config...")
68
+ # This would call ModelConfigManager to save the config
69
+ print(f" ✓ Saved to model configuration")
70
+
71
+ print(f"\n✅ Sample size optimization completed successfully!\n")
72
+
73
+ except Exception as e:
74
+ print(f"\n❌ Error: {str(e)}", file=sys.stderr)
75
+ if args.verbose:
76
+ import traceback
77
+ traceback.print_exc()
78
+ sys.exit(1)