wisent 0.5.13__py3-none-any.whl → 0.5.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (62) hide show
  1. wisent/__init__.py +1 -1
  2. wisent/cli.py +114 -0
  3. wisent/core/activations/activations_collector.py +19 -11
  4. wisent/core/agent/__init__.py +1 -18
  5. wisent/core/agent/diagnose/__init__.py +1 -55
  6. wisent/core/cli/__init__.py +3 -1
  7. wisent/core/cli/create_steering_vector.py +60 -18
  8. wisent/core/cli/evaluate_responses.py +14 -8
  9. wisent/core/cli/generate_pairs_from_task.py +18 -5
  10. wisent/core/cli/get_activations.py +1 -1
  11. wisent/core/cli/multi_steer.py +108 -0
  12. wisent/core/cli/optimize_classification.py +187 -285
  13. wisent/core/cli/optimize_sample_size.py +78 -0
  14. wisent/core/cli/optimize_steering.py +354 -53
  15. wisent/core/cli/tasks.py +274 -9
  16. wisent/core/errors/__init__.py +0 -0
  17. wisent/core/errors/error_handler.py +134 -0
  18. wisent/core/evaluators/benchmark_specific/log_likelihoods_evaluator.py +152 -295
  19. wisent/core/evaluators/rotator.py +22 -8
  20. wisent/core/main.py +5 -1
  21. wisent/core/model_persistence.py +4 -19
  22. wisent/core/models/wisent_model.py +11 -3
  23. wisent/core/parser.py +4 -3
  24. wisent/core/parser_arguments/main_parser.py +1 -1
  25. wisent/core/parser_arguments/multi_steer_parser.py +4 -3
  26. wisent/core/parser_arguments/optimize_steering_parser.py +4 -0
  27. wisent/core/sample_size_optimizer_v2.py +1 -1
  28. wisent/core/steering_optimizer.py +2 -2
  29. wisent/tests/__init__.py +0 -0
  30. wisent/tests/examples/__init__.py +0 -0
  31. wisent/tests/examples/cli/__init__.py +0 -0
  32. wisent/tests/examples/cli/activations/__init__.py +0 -0
  33. wisent/tests/examples/cli/activations/test_get_activations.py +127 -0
  34. wisent/tests/examples/cli/classifier/__init__.py +0 -0
  35. wisent/tests/examples/cli/classifier/test_classifier_examples.py +141 -0
  36. wisent/tests/examples/cli/contrastive_pairs/__init__.py +0 -0
  37. wisent/tests/examples/cli/contrastive_pairs/test_generate_pairs.py +89 -0
  38. wisent/tests/examples/cli/evaluation/__init__.py +0 -0
  39. wisent/tests/examples/cli/evaluation/test_evaluation_examples.py +117 -0
  40. wisent/tests/examples/cli/generate/__init__.py +0 -0
  41. wisent/tests/examples/cli/generate/test_generate_with_classifier.py +146 -0
  42. wisent/tests/examples/cli/generate/test_generate_with_steering.py +149 -0
  43. wisent/tests/examples/cli/generate/test_only_generate.py +110 -0
  44. wisent/tests/examples/cli/multi_steering/__init__.py +0 -0
  45. wisent/tests/examples/cli/multi_steering/test_multi_steer_from_trained_vectors.py +210 -0
  46. wisent/tests/examples/cli/multi_steering/test_multi_steer_with_different_parameters.py +205 -0
  47. wisent/tests/examples/cli/multi_steering/test_train_and_multi_steer.py +174 -0
  48. wisent/tests/examples/cli/optimizer/__init__.py +0 -0
  49. wisent/tests/examples/cli/optimizer/test_optimize_sample_size.py +102 -0
  50. wisent/tests/examples/cli/optimizer/test_optimizer_examples.py +59 -0
  51. wisent/tests/examples/cli/steering/__init__.py +0 -0
  52. wisent/tests/examples/cli/steering/test_create_steering_vectors.py +135 -0
  53. wisent/tests/examples/cli/synthetic/__init__.py +0 -0
  54. wisent/tests/examples/cli/synthetic/test_synthetic_pairs.py +45 -0
  55. {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/METADATA +3 -1
  56. {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/RECORD +61 -31
  57. wisent/core/agent/diagnose/test_synthetic_classifier.py +0 -71
  58. /wisent/core/parser_arguments/{test_nonsense_parser.py → nonsense_parser.py} +0 -0
  59. {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/WHEEL +0 -0
  60. {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/entry_points.txt +0 -0
  61. {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/licenses/LICENSE +0 -0
  62. {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/top_level.txt +0 -0
@@ -1,329 +1,186 @@
1
- """
2
- Log-Likelihoods Ground Truth Evaluator
1
+ """Log Likelihoods Evaluator for multiple choice tasks.
3
2
 
4
- This module handles ground truth evaluation for log-likelihoods based tasks,
5
- typically used for multiple choice questions. Instead of generating text,
6
- it loads the multiple choice options from lm-eval tasks and runs the classifier
7
- directly on each choice to evaluate performance against known ground truth.
3
+ This evaluator handles tasks like BoolQ, MMLU, ARC where evaluation is done
4
+ by comparing log likelihoods of different answer choices rather than generating text.
5
+ Works with steering by computing log probabilities with steering applied.
8
6
  """
9
7
 
10
8
  import logging
11
- from typing import Any, Dict, Optional
12
- from dataclasses import dataclass
9
+ import torch
10
+ from typing import Any, List
13
11
 
14
- from wisent.core.activations.core.atoms import ActivationAggregationStrategy
15
- from wisent.core.activations.activations import Activations
12
+ from wisent.core.evaluators.core.atoms import BaseEvaluator, EvalResult
13
+ from wisent.core.errors.error_handler import (
14
+ ModelNotProvidedError,
15
+ validate_choices,
16
+ require_all_parameters
17
+ )
16
18
 
17
19
  logger = logging.getLogger(__name__)
18
20
 
19
21
 
20
- @dataclass
21
- class Layer:
22
- """Simple layer metadata class."""
23
- index: int
24
- type: str = "transformer"
22
+ class LogLikelihoodsEvaluator(BaseEvaluator):
23
+ """Evaluator for multiple choice tasks using log likelihood comparison.
25
24
 
25
+ Compatible with:
26
+ - BoolQ: Boolean questions with yes/no choices
27
+ - MMLU: Multiple choice questions
28
+ - ARC: Science questions with multiple choices
29
+ - Any task requiring log likelihood comparison
26
30
 
27
- class LogLikelihoodsEvaluator:
31
+ This evaluator computes the log likelihood of each choice and selects
32
+ the one with the highest probability. Can apply steering before computing
33
+ log likelihoods.
28
34
  """
29
- Evaluator for log-likelihoods based ground truth assessment.
30
35
 
31
- This evaluator loads multiple choice options from lm-eval tasks and runs
32
- the classifier on each choice to evaluate performance against known ground truth.
33
- No text generation is performed - only direct classification evaluation.
34
- """
36
+ name = "log_likelihoods"
37
+ description = "Log likelihood evaluator for multiple choice tasks"
38
+ task_names = ("boolq", "mmlu", "arc_easy", "arc_challenge", "truthfulqa_mc1", "truthfulqa_mc2")
35
39
 
36
- def __init__(self, task_name: Optional[str] = None, model=None):
37
- """
38
- Initialize the log-likelihoods evaluator.
40
+ def __init__(self, model=None):
41
+ """Initialize with optional model for log likelihood computation.
39
42
 
40
43
  Args:
41
- task_name: Name of the task (e.g., "truthfulqa_mc1", "mmlu", etc.)
42
- model: The model instance used to extract activations
44
+ model: WisentModel instance that can compute log likelihoods
43
45
  """
44
- self.task_name = task_name
45
46
  self.model = model
46
47
 
47
- def evaluate_classifier_on_task(
48
- self,
49
- classifier,
50
- task_name: str,
51
- num_samples: int = 100,
52
- model=None,
53
- layer: int = 15,
54
- token_aggregation: str = "average",
55
- ) -> Dict[str, Any]:
56
- """
57
- Evaluate a classifier on a log-likelihoods task by running it on multiple choice options.
48
+ def evaluate(self, response: str, expected: Any, **kwargs) -> EvalResult:
49
+ """Evaluate using log likelihood comparison of choices.
58
50
 
59
51
  Args:
60
- classifier: The classifier to evaluate
61
- task_name: Name of the lm-eval task
62
- num_samples: Number of samples to evaluate (default: 100)
63
- model: The model instance (overrides self.model if provided)
64
- layer: Layer to extract activations from (default: 15)
65
- token_aggregation: Token aggregation method ("average", "final", "first", "max", "min")
52
+ response: Not used for log likelihood evaluation
53
+ expected: Expected answer
54
+ **kwargs:
55
+ model: WisentModel instance (REQUIRED)
56
+ question: The question/context (REQUIRED)
57
+ choices: List of answer choices (REQUIRED)
58
+ steering_plan: Optional steering plan to apply
66
59
 
67
60
  Returns:
68
- Dict containing evaluation results
61
+ EvalResult with TRUTHFUL/UNTRUTHFUL
62
+
63
+ Raises:
64
+ ModelNotProvidedError: If model is not provided
65
+ MissingParameterError: If question is not provided
66
+ InvalidChoicesError: If choices are invalid or missing
69
67
  """
68
+ model = kwargs.get('model') or self.model
69
+ question = kwargs.get('question')
70
+ choices = kwargs.get('choices')
71
+ steering_plan = kwargs.get('steering_plan')
72
+ task_name = kwargs.get('task_name', 'unknown')
73
+
74
+ # NO FALLBACKS - require all parameters
75
+ if not model:
76
+ raise ModelNotProvidedError(evaluator_name=self.name, task_name=task_name)
77
+
78
+ require_all_parameters(
79
+ {'question': question},
80
+ context=f"{self.name} evaluator",
81
+ task_name=task_name
82
+ )
83
+
84
+ validate_choices(choices, task_name=task_name, min_choices=2)
85
+
86
+ return self._evaluate_log_likelihood(
87
+ model, question, choices, expected, steering_plan
88
+ )
89
+
90
+ def _evaluate_log_likelihood(
91
+ self, model, question: str, choices: List[str], expected: Any, steering_plan=None
92
+ ) -> EvalResult:
93
+ """Evaluate by comparing log likelihoods of choices."""
70
94
  try:
71
- # Use provided model or fall back to self.model
72
- evaluation_model = model or self.model
73
- if evaluation_model is None:
74
- return self._error_result("No model provided for activation extraction")
75
-
76
- logger.info(f"Loading task data for {task_name}...")
77
-
78
- # Use existing task loading infrastructure
79
- task_data = evaluation_model.load_lm_eval_task(task_name, shots=0, limit=num_samples)
80
- docs, _ = evaluation_model.split_task_data(task_data, split_ratio=1.0) # Use all for evaluation
81
-
82
- if not docs:
83
- return self._error_result(f"No documents retrieved from task: {task_name}")
84
-
85
- logger.info(f"Retrieved {len(docs)} documents from {task_name}")
86
-
87
- # Use existing QA extraction infrastructure (task-agnostic)
88
- from .contrastive_pairs.contrastive_pair_set import ContrastivePairSet
89
-
90
- qa_pairs = ContrastivePairSet.extract_qa_pairs_from_task_docs(task_name, task_data, docs)
91
-
92
- if not qa_pairs:
93
- return self._error_result(f"No QA pairs could be extracted from task: {task_name}")
94
-
95
- logger.info(f"Extracted {len(qa_pairs)} QA pairs from {task_name}")
96
-
97
- # Use existing contrastive pair creation infrastructure
98
- from wisent.core.activations.activation_collection_method import (
99
- ActivationCollectionLogic,
100
- )
101
- from wisent.core.activations.prompts import PromptConstructionStrategy
102
-
103
- collector = ActivationCollectionLogic(model=evaluation_model)
104
-
105
- # For evaluation, use DIRECT_COMPLETION instead of MULTIPLE_CHOICE
106
- # This creates prompts like "Q" -> "good_resp"/"bad_resp" instead of "Which is better: Q A. bad B. good"
107
- logger.info("🔍 EVALUATION MODE: Using DIRECT_COMPLETION prompt strategy instead of MULTIPLE_CHOICE")
108
- contrastive_pairs = collector.create_batch_contrastive_pairs(
109
- qa_pairs, prompt_strategy=PromptConstructionStrategy.DIRECT_COMPLETION
110
- )
111
-
112
- if not contrastive_pairs:
113
- return self._error_result("No contrastive pairs could be created from QA pairs")
114
-
115
- logger.info(f"Created {len(contrastive_pairs)} contrastive pairs")
116
-
117
- # Map token aggregation to token targeting strategy for evaluation
118
- targeting_strategy_mapping = { # TODO Refactor - we should stay with one standard
119
- "average": ActivationAggregationStrategy.MEAN_POOLING,
120
- "final": ActivationAggregationStrategy.LAST_TOKEN,
121
- "first": ActivationAggregationStrategy.FIRST_TOKEN,
122
- "max": ActivationAggregationStrategy.MAX_POOLING,
123
- "min": ActivationAggregationStrategy.MEAN_POOLING, # Fallback to mean
124
- }
125
-
126
- targeting_strategy = targeting_strategy_mapping.get(
127
- token_aggregation, ActivationAggregationStrategy.MEAN_POOLING
128
- )
129
-
130
- logger.info(
131
- f"🔍 EVALUATION MODE: Using {targeting_strategy.value} targeting strategy (from token_aggregation: {token_aggregation})"
132
- )
133
- logger.info("🎯 ACTIVATION COLLECTION PARAMS:")
134
- logger.info(f" • Layer: {layer}")
135
- logger.info(f" • Device: {evaluation_model.device}")
136
- logger.info(f" • Token targeting: {targeting_strategy.value}")
137
- logger.info(f" • Pairs count: {len(contrastive_pairs)}")
138
-
139
- processed_pairs = collector.collect_activations_batch(
140
- pairs=contrastive_pairs,
141
- layer_index=layer,
142
- device=evaluation_model.device,
143
- token_targeting_strategy=targeting_strategy,
144
- )
145
-
146
- if not processed_pairs:
147
- return self._error_result("No activations could be extracted from contrastive pairs")
148
-
149
- logger.info(f"Extracted activations from {len(processed_pairs)} pairs")
150
-
151
- # Debug: Show where activations are collected from
152
- if processed_pairs:
153
- sample_pair = processed_pairs[0]
154
- logger.info("📍 DETAILED ACTIVATION COLLECTION ANALYSIS:")
155
- logger.info(f" 🔧 Sample pair type: {type(sample_pair).__name__}")
156
- logger.info(
157
- f" 🔧 Pair attributes: {[attr for attr in dir(sample_pair) if not attr.startswith('_')][:8]}..."
158
- )
159
-
160
- if hasattr(sample_pair, "positive_activations") and sample_pair.positive_activations is not None:
161
- logger.info(f" ✅ Positive activations shape: {sample_pair.positive_activations.shape}")
162
- if hasattr(sample_pair, "negative_activations") and sample_pair.negative_activations is not None:
163
- logger.info(f" ✅ Negative activations shape: {sample_pair.negative_activations.shape}")
164
-
165
- if hasattr(sample_pair, "_prompt_pair") and sample_pair._prompt_pair:
166
- logger.debug(f" 🔸 Positive prompt: {sample_pair._prompt_pair.positive_prompt[:100]}...")
167
- logger.debug(f" 🔸 Negative prompt: {sample_pair._prompt_pair.negative_prompt[:100]}...")
168
- logger.debug(f" 🎯 Target token: {sample_pair._prompt_pair.target_token}")
169
- logger.debug(f" 📊 Prompt strategy: {sample_pair._prompt_strategy.value}")
170
- logger.info(f" 🔍 Token targeting: {targeting_strategy.value} (evaluation mode)")
171
- elif hasattr(sample_pair, "prompt") and hasattr(sample_pair, "positive_response"):
172
- logger.debug(f" 🔸 Question prompt: {sample_pair.prompt[:100]}...")
173
- logger.debug(f" ✅ Positive response: {sample_pair.positive_response[:50]}...")
174
- logger.debug(f" ❌ Negative response: {sample_pair.negative_response[:50]}...")
175
- logger.debug(
176
- f" 🔍 Token targeting used: {targeting_strategy.value} (from CLI token_aggregation: {token_aggregation})"
177
- )
178
- else:
179
- logger.info(" 📍 ACTIVATION COLLECTION: Unknown format - investigating...")
180
- logger.info(
181
- f" 🔧 All attributes: {[attr for attr in dir(sample_pair) if not attr.startswith('__')]}"
182
- )
183
-
184
- # Map token aggregation to activation method
185
- activation_method = token_aggregation
186
- # Handle both string and enum types
187
- method_name = activation_method.value if hasattr(activation_method, 'value') else str(activation_method)
188
- logger.info(
189
- f"🎯 Using activation aggregation method: {method_name} (from token_aggregation: {token_aggregation})"
95
+ # Apply steering if provided
96
+ if steering_plan:
97
+ model.attach(steering_plan)
98
+
99
+ # Compute log likelihood for each choice
100
+ log_probs = []
101
+ for choice in choices:
102
+ log_prob = self._compute_choice_log_likelihood(model, question, choice)
103
+ log_probs.append(log_prob)
104
+
105
+ # Detach steering
106
+ if steering_plan:
107
+ model.detach()
108
+
109
+ # Select choice with highest log likelihood
110
+ predicted_idx = log_probs.index(max(log_probs))
111
+ predicted_choice = choices[predicted_idx]
112
+
113
+ # Normalize expected answer for comparison
114
+ expected_normalized = str(expected).strip().lower()
115
+ predicted_normalized = predicted_choice.strip().lower()
116
+
117
+ is_correct = predicted_normalized == expected_normalized
118
+
119
+ return EvalResult(
120
+ ground_truth="TRUTHFUL" if is_correct else "UNTRUTHFUL",
121
+ method_used=self.name,
122
+ confidence=1.0 if is_correct else 0.0,
123
+ details=f"Predicted: '{predicted_choice}' (log_prob={log_probs[predicted_idx]:.3f}), Expected: '{expected}'",
124
+ meta={
125
+ "predicted": predicted_choice,
126
+ "expected": expected,
127
+ "log_probs": {choice: lp for choice, lp in zip(choices, log_probs)},
128
+ }
190
129
  )
191
130
 
192
- # Evaluate classifier on each sample
193
- results = []
194
- total_correct = 0
195
- total_samples = 0
196
-
197
- for i, pair in enumerate(processed_pairs):
198
- try:
199
- sample_result = self._evaluate_classifier_on_sample(
200
- classifier, pair, qa_pairs[i], activation_method
201
- )
202
- results.append(sample_result)
203
-
204
- if sample_result.get("classifier_correct", False):
205
- total_correct += 1
206
- total_samples += 1
207
-
208
- except Exception as e:
209
- logger.error(f"Error evaluating sample {i}: {e}")
210
- continue
211
-
212
- # Calculate overall metrics
213
- accuracy = total_correct / total_samples if total_samples > 0 else 0.0
214
-
215
- return {
216
- "ground_truth": "EVALUATED",
217
- "method_used": "log-likelihoods-classifier",
218
- "confidence": accuracy,
219
- "details": f"Evaluated {total_samples} samples with {total_correct} correct predictions",
220
- "task_name": task_name,
221
- "evaluation_method": "log-likelihoods",
222
- "lm_eval_metrics": {
223
- "accuracy": accuracy,
224
- "correct_predictions": total_correct,
225
- "total_samples": total_samples,
226
- },
227
- "sample_results": results[:10], # First 10 for debugging
228
- }
229
-
230
131
  except Exception as e:
132
+ logger.error(f"Error in log likelihood evaluation: {e}")
231
133
  import traceback
134
+ logger.error(traceback.format_exc())
135
+ # NO FALLBACK - raise the error
136
+ raise
232
137
 
233
- logger.error(f"Error evaluating classifier on task {task_name}: {e}")
234
- logger.error(f"Traceback: {traceback.format_exc()}")
235
- return self._error_result(f"Evaluation error: {e!s}")
236
-
237
- def _evaluate_classifier_on_sample(
238
- self, classifier, processed_pair, qa_pair: Dict[str, Any], activation_method
239
- ) -> Dict[str, Any]:
240
- """
241
- Evaluate the classifier on a single processed contrastive pair.
138
+ def _compute_choice_log_likelihood(self, model, question: str, choice: str) -> float:
139
+ """Compute log likelihood of a choice given a question.
242
140
 
243
141
  Args:
244
- classifier: The classifier to evaluate
245
- processed_pair: ContrastivePair with activations already extracted
246
- qa_pair: Original QA pair data for reference
247
- activation_method:
142
+ model: WisentModel instance
143
+ question: The question/context
144
+ choice: The answer choice
145
+
248
146
  Returns:
249
- Dict containing evaluation results for this sample
147
+ Log likelihood (higher = more likely)
250
148
  """
251
- try:
252
- # Extract activations from the processed pair
253
- positive_activations = processed_pair.positive_activations # B choice (correct)
254
- negative_activations = processed_pair.negative_activations # A choice (incorrect)
255
-
256
- if positive_activations is None or negative_activations is None:
257
- return {
258
- "question": qa_pair["question"],
259
- "correct_answer": qa_pair["correct_answer"],
260
- "incorrect_answer": qa_pair["incorrect_answer"],
261
- "classifier_correct": False,
262
- "error": "Missing activations",
263
- }
264
-
265
- layer_obj = Layer(index=15, type="transformer")
266
-
267
- # Process positive (correct) choice using CLI token aggregation method
268
- positive_act = Activations(
269
- tensor=positive_activations, layer=layer_obj, aggregation_strategy=activation_method
270
- )
271
- positive_features = positive_act.extract_features_for_classifier()
272
- positive_prediction = classifier.predict_proba([positive_features.cpu().numpy()])
273
- if isinstance(positive_prediction, (list, tuple)) and len(positive_prediction) > 0:
274
- positive_prediction = positive_prediction[0]
275
-
276
- # Process negative (incorrect) choice using CLI token aggregation method
277
- negative_act = Activations(
278
- tensor=negative_activations, layer=layer_obj, aggregation_strategy=activation_method
279
- )
280
- negative_features = negative_act.extract_features_for_classifier()
281
- negative_prediction = classifier.predict_proba([negative_features.cpu().numpy()])
282
- if isinstance(negative_prediction, (list, tuple)) and len(negative_prediction) > 0:
283
- negative_prediction = negative_prediction[0]
284
-
285
- # Evaluate classifier performance
286
- # For truthfulness: lower scores = more truthful/correct
287
- # Classifier should give lower score to correct answer, higher to incorrect
288
- classifier_correct = positive_prediction < negative_prediction
289
-
290
- # Enhanced logging to show classifier vs normal predictions
291
- logger.info("📋 Sample Evaluation:")
292
- logger.debug(f" 🔸 Question: {qa_pair['question'][:100]}...")
293
- logger.info(" 🧠 CLASSIFIER PREDICTION:")
294
- logger.debug(f" ✅ Correct answer score: {float(positive_prediction):.3f} (lower = more truthful)")
295
- logger.debug(f" ❌ Incorrect answer score: {float(negative_prediction):.3f} (higher = less truthful)")
296
- logger.debug(
297
- f" 🎯 Classifier judgment: {'CORRECT' if classifier_correct else 'INCORRECT'} (positive {'<' if classifier_correct else '>='} negative)"
298
- )
299
-
300
- return {
301
- "question": qa_pair["question"],
302
- "correct_answer": qa_pair["correct_answer"],
303
- "incorrect_answer": qa_pair["incorrect_answer"],
304
- "positive_prediction": float(positive_prediction),
305
- "negative_prediction": float(negative_prediction),
306
- "classifier_correct": classifier_correct,
307
- }
308
-
309
- except Exception as e:
310
- logger.debug(f"Error evaluating sample: {e}")
311
- return {
312
- "question": qa_pair.get("question", "Unknown"),
313
- "correct_answer": qa_pair.get("correct_answer", "Unknown"),
314
- "incorrect_answer": qa_pair.get("incorrect_answer", "Unknown"),
315
- "classifier_correct": False,
316
- "error": str(e),
317
- }
318
-
319
- def _error_result(self, error_msg: str) -> Dict[str, Any]:
320
- """Return an error result."""
321
- return {
322
- "ground_truth": "UNKNOWN",
323
- "method_used": "log-likelihoods-error",
324
- "confidence": 0.0,
325
- "details": error_msg,
326
- "task_name": self.task_name or "unknown",
327
- "evaluation_method": "log-likelihoods",
328
- "lm_eval_metrics": {"accuracy": 0.0, "correct_predictions": 0, "total_samples": 0},
329
- }
149
+ # Format as: question + choice
150
+ full_text = f"{question}\n{choice}"
151
+
152
+ # Tokenize question and choice separately
153
+ question_inputs = model.tokenizer(question, return_tensors="pt", add_special_tokens=True).to(model.device)
154
+ choice_tokens = model.tokenizer(choice, return_tensors="pt", add_special_tokens=False).to(model.device)
155
+
156
+ # Get model logits for the full sequence
157
+ with torch.no_grad():
158
+ # Tokenize full sequence
159
+ full_inputs = model.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(model.device)
160
+ outputs = model.hf_model(**full_inputs)
161
+ logits = outputs.logits
162
+
163
+ # Compute log probability of the choice tokens
164
+ # logits shape: [batch, seq_len, vocab_size]
165
+ # We want log prob of choice tokens given question
166
+
167
+ question_len = question_inputs.input_ids.shape[1]
168
+ choice_len = choice_tokens.input_ids.shape[1]
169
+
170
+ # Get logits at positions where we're predicting choice tokens
171
+ log_prob = 0.0
172
+ for i in range(choice_len):
173
+ # Position in full sequence where we predict token i of choice
174
+ # Subtract 1 because we predict the next token
175
+ pos = question_len + i - 1
176
+ if pos >= 0 and pos < logits.shape[1]:
177
+ token_logits = logits[0, pos, :] # Logits at this position
178
+ token_log_probs = torch.nn.functional.log_softmax(token_logits, dim=-1)
179
+ # Get log prob of the actual choice token at this position
180
+ actual_token_id = choice_tokens.input_ids[0, i]
181
+ log_prob += token_log_probs[actual_token_id].item()
182
+
183
+ # Normalize by length to avoid bias toward shorter choices
184
+ normalized_log_prob = log_prob / max(choice_len, 1)
185
+
186
+ return normalized_log_prob
@@ -25,8 +25,8 @@ class EvaluatorRotator:
25
25
  ) -> None:
26
26
  if autoload:
27
27
  self.discover_evaluators(evaluators_location)
28
+ self._task_name = task_name # Set before resolving
28
29
  self._evaluator = self._resolve_evaluator(evaluator)
29
- self._task_name = task_name
30
30
 
31
31
  @staticmethod
32
32
  def discover_evaluators(location: Union[str, Path] = "wisent.core.evaluators.oracles") -> None:
@@ -93,17 +93,31 @@ class EvaluatorRotator:
93
93
  )
94
94
  return sorted(out, key=lambda x: x["name"])
95
95
 
96
- @staticmethod
97
96
  def _resolve_evaluator(
97
+ self,
98
98
  evaluator: Union[str, BaseEvaluator, Type[BaseEvaluator], None]
99
99
  ) -> BaseEvaluator:
100
100
  if evaluator is None:
101
- registry = BaseEvaluator.list_registered()
102
- if "lm_eval" in registry:
103
- return registry["lm_eval"]()
104
- if registry:
105
- return next(iter(registry.values()))()
106
- raise EvaluatorError("No evaluators registered.")
101
+ # Auto-select based on task_name if provided
102
+ if self._task_name:
103
+ registry = BaseEvaluator.list_registered()
104
+ for name, cls in registry.items():
105
+ task_names = getattr(cls, 'task_names', ())
106
+ if self._task_name in task_names:
107
+ logger.info(f"Auto-selected evaluator '{name}' for task '{self._task_name}'")
108
+ return cls()
109
+ # NO FALLBACK - raise error if no evaluator found for task
110
+ raise EvaluatorError(
111
+ f"No evaluator found for task '{self._task_name}'. "
112
+ f"Available evaluators: {list(registry.keys())}. "
113
+ f"Please specify an evaluator explicitly or add task_names to an evaluator."
114
+ )
115
+
116
+ # NO FALLBACK - if no task_name and no evaluator, require explicit selection
117
+ raise EvaluatorError(
118
+ "No evaluator specified and no task_name provided. "
119
+ "Either provide an evaluator name or a task_name for auto-selection."
120
+ )
107
121
  if isinstance(evaluator, BaseEvaluator):
108
122
  return evaluator
109
123
  if inspect.isclass(evaluator) and issubclass(evaluator, BaseEvaluator):
wisent/core/main.py CHANGED
@@ -8,7 +8,7 @@ and provides the main() function that serves as the CLI entry point.
8
8
  import sys
9
9
  from wisent.core.parser_arguments import setup_parser
10
10
  from wisent.core.branding import print_banner
11
- from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, execute_generate_pairs, execute_get_activations, execute_create_steering_vector, execute_generate_vector_from_task, execute_generate_vector_from_synthetic, execute_optimize_classification, execute_optimize_steering, execute_generate_responses, execute_evaluate_responses
11
+ from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, execute_generate_pairs, execute_get_activations, execute_create_steering_vector, execute_generate_vector_from_task, execute_generate_vector_from_synthetic, execute_optimize_classification, execute_optimize_steering, execute_optimize_sample_size, execute_generate_responses, execute_evaluate_responses, execute_multi_steer
12
12
 
13
13
 
14
14
  def main():
@@ -44,10 +44,14 @@ def main():
44
44
  execute_optimize_classification(args)
45
45
  elif args.command == 'optimize-steering':
46
46
  execute_optimize_steering(args)
47
+ elif args.command == 'optimize-sample-size':
48
+ execute_optimize_sample_size(args)
47
49
  elif args.command == 'generate-responses':
48
50
  execute_generate_responses(args)
49
51
  elif args.command == 'evaluate-responses':
50
52
  execute_evaluate_responses(args)
53
+ elif args.command == 'multi-steer':
54
+ execute_multi_steer(args)
51
55
  else:
52
56
  print(f"\n✗ Command '{args.command}' is not yet implemented")
53
57
  sys.exit(1)
@@ -33,16 +33,8 @@ class ModelPersistence:
33
33
  if save_dir:
34
34
  os.makedirs(save_dir, exist_ok=True)
35
35
 
36
- # Split path and sanitize only the filename part
37
- directory = os.path.dirname(save_path)
38
- filename = os.path.basename(save_path)
39
- # Sanitize filename to handle periods in model names
40
- safe_filename = filename.replace('.', '_')
41
- safe_path = os.path.join(directory, safe_filename)
42
-
43
- # Add layer suffix to filename
44
- base, ext = os.path.splitext(safe_path)
45
- classifier_path = f"{base}_layer_{layer}{ext or '.pkl'}"
36
+ # Use the exact path provided by the user
37
+ classifier_path = save_path
46
38
 
47
39
  # Prepare data to save
48
40
  save_data = {
@@ -69,15 +61,8 @@ class ModelPersistence:
69
61
  Returns:
70
62
  Tuple of (classifier, metadata)
71
63
  """
72
- # Split path and sanitize only the filename part to match save format
73
- directory = os.path.dirname(load_path)
74
- filename = os.path.basename(load_path)
75
- safe_filename = filename.replace('.', '_')
76
- safe_path = os.path.join(directory, safe_filename)
77
-
78
- # Add layer suffix to filename
79
- base, ext = os.path.splitext(safe_path)
80
- classifier_path = f"{base}_layer_{layer}{ext or '.pkl'}"
64
+ # Use the exact path provided by the user
65
+ classifier_path = load_path
81
66
 
82
67
  if not os.path.exists(classifier_path):
83
68
  raise FileNotFoundError(f"Classifier file not found: {classifier_path}")
@@ -284,9 +284,17 @@ class WisentModel:
284
284
  {"input_ids": tensor([[...]]), "attention_mask": tensor([[...]])}
285
285
  """
286
286
 
287
- ids = self.tokenizer.apply_chat_template(
288
- message, tokenize=True, add_generation_prompt=add_generation_prompt, enable_thinking=enable_thinking, return_tensors="pt"
289
- )[0]
287
+ try:
288
+ ids = self.tokenizer.apply_chat_template(
289
+ message, tokenize=True, add_generation_prompt=add_generation_prompt, enable_thinking=enable_thinking, return_tensors="pt"
290
+ )[0]
291
+ except ValueError as e:
292
+ if "chat_template is not set" in str(e):
293
+ # Fallback for models without chat templates: concatenate messages
294
+ text = " ".join([msg.get("content", "") for msg in message if isinstance(msg, dict)])
295
+ ids = self.tokenizer.encode(text, return_tensors="pt")[0]
296
+ else:
297
+ raise
290
298
  return {
291
299
  "input_ids": ids,
292
300
  "attention_mask": torch.ones_like(ids),
wisent/core/parser.py CHANGED
@@ -1590,14 +1590,15 @@ def setup_multi_steer_parser(parser):
1590
1590
  "--vector",
1591
1591
  type=str,
1592
1592
  action="append",
1593
- required=True,
1593
+ required=False,
1594
+ default=None,
1594
1595
  metavar="PATH:WEIGHT",
1595
- help="Path to steering vector and its weight (format: path/to/vector.pt:0.5). Can be specified multiple times.",
1596
+ help="Path to steering vector and its weight (format: path/to/vector.pt:0.5). Can be specified multiple times. If omitted, generates unsteered baseline.",
1596
1597
  )
1597
1598
 
1598
1599
  # Model configuration
1599
1600
  parser.add_argument("--model", type=str, required=True, help="Model name or path")
1600
- parser.add_argument("--layer", type=int, required=True, help="Layer index to apply combined steering")
1601
+ parser.add_argument("--layer", type=int, required=False, default=None, help="Layer index to apply combined steering (required when using vectors)")
1601
1602
  parser.add_argument("--device", type=str, default=None, help="Device to run on (default: auto-detect)")
1602
1603
 
1603
1604
  # Steering method configuration
@@ -15,7 +15,7 @@ from wisent.core.parser_arguments.create_steering_vector_parser import setup_cre
15
15
  from wisent.core.parser_arguments.generate_vector_from_task_parser import setup_generate_vector_from_task_parser
16
16
  from wisent.core.parser_arguments.generate_vector_from_synthetic_parser import setup_generate_vector_from_synthetic_parser
17
17
  from wisent.core.parser_arguments.synthetic_parser import setup_synthetic_parser
18
- from wisent.core.parser_arguments.test_nonsense_parser import setup_test_nonsense_parser
18
+ from wisent.core.parser_arguments.nonsense_parser import setup_test_nonsense_parser
19
19
  from wisent.core.parser_arguments.monitor_parser import setup_monitor_parser
20
20
  from wisent.core.parser_arguments.agent_parser import setup_agent_parser
21
21
  from wisent.core.parser_arguments.model_config_parser import setup_model_config_parser