wisent 0.5.13__py3-none-any.whl → 0.5.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -1
- wisent/cli.py +114 -0
- wisent/core/activations/activations_collector.py +19 -11
- wisent/core/agent/__init__.py +1 -18
- wisent/core/agent/diagnose/__init__.py +1 -55
- wisent/core/cli/__init__.py +3 -1
- wisent/core/cli/create_steering_vector.py +60 -18
- wisent/core/cli/evaluate_responses.py +14 -8
- wisent/core/cli/generate_pairs_from_task.py +18 -5
- wisent/core/cli/get_activations.py +1 -1
- wisent/core/cli/multi_steer.py +108 -0
- wisent/core/cli/optimize_classification.py +187 -285
- wisent/core/cli/optimize_sample_size.py +78 -0
- wisent/core/cli/optimize_steering.py +354 -53
- wisent/core/cli/tasks.py +274 -9
- wisent/core/errors/__init__.py +0 -0
- wisent/core/errors/error_handler.py +134 -0
- wisent/core/evaluators/benchmark_specific/log_likelihoods_evaluator.py +152 -295
- wisent/core/evaluators/rotator.py +22 -8
- wisent/core/main.py +5 -1
- wisent/core/model_persistence.py +4 -19
- wisent/core/models/wisent_model.py +11 -3
- wisent/core/parser.py +4 -3
- wisent/core/parser_arguments/main_parser.py +1 -1
- wisent/core/parser_arguments/multi_steer_parser.py +4 -3
- wisent/core/parser_arguments/optimize_steering_parser.py +4 -0
- wisent/core/sample_size_optimizer_v2.py +1 -1
- wisent/core/steering_optimizer.py +2 -2
- wisent/tests/__init__.py +0 -0
- wisent/tests/examples/__init__.py +0 -0
- wisent/tests/examples/cli/__init__.py +0 -0
- wisent/tests/examples/cli/activations/__init__.py +0 -0
- wisent/tests/examples/cli/activations/test_get_activations.py +127 -0
- wisent/tests/examples/cli/classifier/__init__.py +0 -0
- wisent/tests/examples/cli/classifier/test_classifier_examples.py +141 -0
- wisent/tests/examples/cli/contrastive_pairs/__init__.py +0 -0
- wisent/tests/examples/cli/contrastive_pairs/test_generate_pairs.py +89 -0
- wisent/tests/examples/cli/evaluation/__init__.py +0 -0
- wisent/tests/examples/cli/evaluation/test_evaluation_examples.py +117 -0
- wisent/tests/examples/cli/generate/__init__.py +0 -0
- wisent/tests/examples/cli/generate/test_generate_with_classifier.py +146 -0
- wisent/tests/examples/cli/generate/test_generate_with_steering.py +149 -0
- wisent/tests/examples/cli/generate/test_only_generate.py +110 -0
- wisent/tests/examples/cli/multi_steering/__init__.py +0 -0
- wisent/tests/examples/cli/multi_steering/test_multi_steer_from_trained_vectors.py +210 -0
- wisent/tests/examples/cli/multi_steering/test_multi_steer_with_different_parameters.py +205 -0
- wisent/tests/examples/cli/multi_steering/test_train_and_multi_steer.py +174 -0
- wisent/tests/examples/cli/optimizer/__init__.py +0 -0
- wisent/tests/examples/cli/optimizer/test_optimize_sample_size.py +102 -0
- wisent/tests/examples/cli/optimizer/test_optimizer_examples.py +59 -0
- wisent/tests/examples/cli/steering/__init__.py +0 -0
- wisent/tests/examples/cli/steering/test_create_steering_vectors.py +135 -0
- wisent/tests/examples/cli/synthetic/__init__.py +0 -0
- wisent/tests/examples/cli/synthetic/test_synthetic_pairs.py +45 -0
- {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/METADATA +3 -1
- {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/RECORD +61 -31
- wisent/core/agent/diagnose/test_synthetic_classifier.py +0 -71
- /wisent/core/parser_arguments/{test_nonsense_parser.py → nonsense_parser.py} +0 -0
- {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/WHEEL +0 -0
- {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/entry_points.txt +0 -0
- {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/licenses/LICENSE +0 -0
- {wisent-0.5.13.dist-info → wisent-0.5.15.dist-info}/top_level.txt +0 -0
|
@@ -1,329 +1,186 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Log-Likelihoods Ground Truth Evaluator
|
|
1
|
+
"""Log Likelihoods Evaluator for multiple choice tasks.
|
|
3
2
|
|
|
4
|
-
This
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
directly on each choice to evaluate performance against known ground truth.
|
|
3
|
+
This evaluator handles tasks like BoolQ, MMLU, ARC where evaluation is done
|
|
4
|
+
by comparing log likelihoods of different answer choices rather than generating text.
|
|
5
|
+
Works with steering by computing log probabilities with steering applied.
|
|
8
6
|
"""
|
|
9
7
|
|
|
10
8
|
import logging
|
|
11
|
-
|
|
12
|
-
from
|
|
9
|
+
import torch
|
|
10
|
+
from typing import Any, List
|
|
13
11
|
|
|
14
|
-
from wisent.core.
|
|
15
|
-
from wisent.core.
|
|
12
|
+
from wisent.core.evaluators.core.atoms import BaseEvaluator, EvalResult
|
|
13
|
+
from wisent.core.errors.error_handler import (
|
|
14
|
+
ModelNotProvidedError,
|
|
15
|
+
validate_choices,
|
|
16
|
+
require_all_parameters
|
|
17
|
+
)
|
|
16
18
|
|
|
17
19
|
logger = logging.getLogger(__name__)
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
"""Simple layer metadata class."""
|
|
23
|
-
index: int
|
|
24
|
-
type: str = "transformer"
|
|
22
|
+
class LogLikelihoodsEvaluator(BaseEvaluator):
|
|
23
|
+
"""Evaluator for multiple choice tasks using log likelihood comparison.
|
|
25
24
|
|
|
25
|
+
Compatible with:
|
|
26
|
+
- BoolQ: Boolean questions with yes/no choices
|
|
27
|
+
- MMLU: Multiple choice questions
|
|
28
|
+
- ARC: Science questions with multiple choices
|
|
29
|
+
- Any task requiring log likelihood comparison
|
|
26
30
|
|
|
27
|
-
|
|
31
|
+
This evaluator computes the log likelihood of each choice and selects
|
|
32
|
+
the one with the highest probability. Can apply steering before computing
|
|
33
|
+
log likelihoods.
|
|
28
34
|
"""
|
|
29
|
-
Evaluator for log-likelihoods based ground truth assessment.
|
|
30
35
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
"""
|
|
36
|
+
name = "log_likelihoods"
|
|
37
|
+
description = "Log likelihood evaluator for multiple choice tasks"
|
|
38
|
+
task_names = ("boolq", "mmlu", "arc_easy", "arc_challenge", "truthfulqa_mc1", "truthfulqa_mc2")
|
|
35
39
|
|
|
36
|
-
def __init__(self,
|
|
37
|
-
"""
|
|
38
|
-
Initialize the log-likelihoods evaluator.
|
|
40
|
+
def __init__(self, model=None):
|
|
41
|
+
"""Initialize with optional model for log likelihood computation.
|
|
39
42
|
|
|
40
43
|
Args:
|
|
41
|
-
|
|
42
|
-
model: The model instance used to extract activations
|
|
44
|
+
model: WisentModel instance that can compute log likelihoods
|
|
43
45
|
"""
|
|
44
|
-
self.task_name = task_name
|
|
45
46
|
self.model = model
|
|
46
47
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
classifier,
|
|
50
|
-
task_name: str,
|
|
51
|
-
num_samples: int = 100,
|
|
52
|
-
model=None,
|
|
53
|
-
layer: int = 15,
|
|
54
|
-
token_aggregation: str = "average",
|
|
55
|
-
) -> Dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Evaluate a classifier on a log-likelihoods task by running it on multiple choice options.
|
|
48
|
+
def evaluate(self, response: str, expected: Any, **kwargs) -> EvalResult:
|
|
49
|
+
"""Evaluate using log likelihood comparison of choices.
|
|
58
50
|
|
|
59
51
|
Args:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
52
|
+
response: Not used for log likelihood evaluation
|
|
53
|
+
expected: Expected answer
|
|
54
|
+
**kwargs:
|
|
55
|
+
model: WisentModel instance (REQUIRED)
|
|
56
|
+
question: The question/context (REQUIRED)
|
|
57
|
+
choices: List of answer choices (REQUIRED)
|
|
58
|
+
steering_plan: Optional steering plan to apply
|
|
66
59
|
|
|
67
60
|
Returns:
|
|
68
|
-
|
|
61
|
+
EvalResult with TRUTHFUL/UNTRUTHFUL
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ModelNotProvidedError: If model is not provided
|
|
65
|
+
MissingParameterError: If question is not provided
|
|
66
|
+
InvalidChoicesError: If choices are invalid or missing
|
|
69
67
|
"""
|
|
68
|
+
model = kwargs.get('model') or self.model
|
|
69
|
+
question = kwargs.get('question')
|
|
70
|
+
choices = kwargs.get('choices')
|
|
71
|
+
steering_plan = kwargs.get('steering_plan')
|
|
72
|
+
task_name = kwargs.get('task_name', 'unknown')
|
|
73
|
+
|
|
74
|
+
# NO FALLBACKS - require all parameters
|
|
75
|
+
if not model:
|
|
76
|
+
raise ModelNotProvidedError(evaluator_name=self.name, task_name=task_name)
|
|
77
|
+
|
|
78
|
+
require_all_parameters(
|
|
79
|
+
{'question': question},
|
|
80
|
+
context=f"{self.name} evaluator",
|
|
81
|
+
task_name=task_name
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
validate_choices(choices, task_name=task_name, min_choices=2)
|
|
85
|
+
|
|
86
|
+
return self._evaluate_log_likelihood(
|
|
87
|
+
model, question, choices, expected, steering_plan
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _evaluate_log_likelihood(
|
|
91
|
+
self, model, question: str, choices: List[str], expected: Any, steering_plan=None
|
|
92
|
+
) -> EvalResult:
|
|
93
|
+
"""Evaluate by comparing log likelihoods of choices."""
|
|
70
94
|
try:
|
|
71
|
-
#
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
# For evaluation, use DIRECT_COMPLETION instead of MULTIPLE_CHOICE
|
|
106
|
-
# This creates prompts like "Q" -> "good_resp"/"bad_resp" instead of "Which is better: Q A. bad B. good"
|
|
107
|
-
logger.info("🔍 EVALUATION MODE: Using DIRECT_COMPLETION prompt strategy instead of MULTIPLE_CHOICE")
|
|
108
|
-
contrastive_pairs = collector.create_batch_contrastive_pairs(
|
|
109
|
-
qa_pairs, prompt_strategy=PromptConstructionStrategy.DIRECT_COMPLETION
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
if not contrastive_pairs:
|
|
113
|
-
return self._error_result("No contrastive pairs could be created from QA pairs")
|
|
114
|
-
|
|
115
|
-
logger.info(f"Created {len(contrastive_pairs)} contrastive pairs")
|
|
116
|
-
|
|
117
|
-
# Map token aggregation to token targeting strategy for evaluation
|
|
118
|
-
targeting_strategy_mapping = { # TODO Refactor - we should stay with one standard
|
|
119
|
-
"average": ActivationAggregationStrategy.MEAN_POOLING,
|
|
120
|
-
"final": ActivationAggregationStrategy.LAST_TOKEN,
|
|
121
|
-
"first": ActivationAggregationStrategy.FIRST_TOKEN,
|
|
122
|
-
"max": ActivationAggregationStrategy.MAX_POOLING,
|
|
123
|
-
"min": ActivationAggregationStrategy.MEAN_POOLING, # Fallback to mean
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
targeting_strategy = targeting_strategy_mapping.get(
|
|
127
|
-
token_aggregation, ActivationAggregationStrategy.MEAN_POOLING
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
logger.info(
|
|
131
|
-
f"🔍 EVALUATION MODE: Using {targeting_strategy.value} targeting strategy (from token_aggregation: {token_aggregation})"
|
|
132
|
-
)
|
|
133
|
-
logger.info("🎯 ACTIVATION COLLECTION PARAMS:")
|
|
134
|
-
logger.info(f" • Layer: {layer}")
|
|
135
|
-
logger.info(f" • Device: {evaluation_model.device}")
|
|
136
|
-
logger.info(f" • Token targeting: {targeting_strategy.value}")
|
|
137
|
-
logger.info(f" • Pairs count: {len(contrastive_pairs)}")
|
|
138
|
-
|
|
139
|
-
processed_pairs = collector.collect_activations_batch(
|
|
140
|
-
pairs=contrastive_pairs,
|
|
141
|
-
layer_index=layer,
|
|
142
|
-
device=evaluation_model.device,
|
|
143
|
-
token_targeting_strategy=targeting_strategy,
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
if not processed_pairs:
|
|
147
|
-
return self._error_result("No activations could be extracted from contrastive pairs")
|
|
148
|
-
|
|
149
|
-
logger.info(f"Extracted activations from {len(processed_pairs)} pairs")
|
|
150
|
-
|
|
151
|
-
# Debug: Show where activations are collected from
|
|
152
|
-
if processed_pairs:
|
|
153
|
-
sample_pair = processed_pairs[0]
|
|
154
|
-
logger.info("📍 DETAILED ACTIVATION COLLECTION ANALYSIS:")
|
|
155
|
-
logger.info(f" 🔧 Sample pair type: {type(sample_pair).__name__}")
|
|
156
|
-
logger.info(
|
|
157
|
-
f" 🔧 Pair attributes: {[attr for attr in dir(sample_pair) if not attr.startswith('_')][:8]}..."
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
if hasattr(sample_pair, "positive_activations") and sample_pair.positive_activations is not None:
|
|
161
|
-
logger.info(f" ✅ Positive activations shape: {sample_pair.positive_activations.shape}")
|
|
162
|
-
if hasattr(sample_pair, "negative_activations") and sample_pair.negative_activations is not None:
|
|
163
|
-
logger.info(f" ✅ Negative activations shape: {sample_pair.negative_activations.shape}")
|
|
164
|
-
|
|
165
|
-
if hasattr(sample_pair, "_prompt_pair") and sample_pair._prompt_pair:
|
|
166
|
-
logger.debug(f" 🔸 Positive prompt: {sample_pair._prompt_pair.positive_prompt[:100]}...")
|
|
167
|
-
logger.debug(f" 🔸 Negative prompt: {sample_pair._prompt_pair.negative_prompt[:100]}...")
|
|
168
|
-
logger.debug(f" 🎯 Target token: {sample_pair._prompt_pair.target_token}")
|
|
169
|
-
logger.debug(f" 📊 Prompt strategy: {sample_pair._prompt_strategy.value}")
|
|
170
|
-
logger.info(f" 🔍 Token targeting: {targeting_strategy.value} (evaluation mode)")
|
|
171
|
-
elif hasattr(sample_pair, "prompt") and hasattr(sample_pair, "positive_response"):
|
|
172
|
-
logger.debug(f" 🔸 Question prompt: {sample_pair.prompt[:100]}...")
|
|
173
|
-
logger.debug(f" ✅ Positive response: {sample_pair.positive_response[:50]}...")
|
|
174
|
-
logger.debug(f" ❌ Negative response: {sample_pair.negative_response[:50]}...")
|
|
175
|
-
logger.debug(
|
|
176
|
-
f" 🔍 Token targeting used: {targeting_strategy.value} (from CLI token_aggregation: {token_aggregation})"
|
|
177
|
-
)
|
|
178
|
-
else:
|
|
179
|
-
logger.info(" 📍 ACTIVATION COLLECTION: Unknown format - investigating...")
|
|
180
|
-
logger.info(
|
|
181
|
-
f" 🔧 All attributes: {[attr for attr in dir(sample_pair) if not attr.startswith('__')]}"
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
# Map token aggregation to activation method
|
|
185
|
-
activation_method = token_aggregation
|
|
186
|
-
# Handle both string and enum types
|
|
187
|
-
method_name = activation_method.value if hasattr(activation_method, 'value') else str(activation_method)
|
|
188
|
-
logger.info(
|
|
189
|
-
f"🎯 Using activation aggregation method: {method_name} (from token_aggregation: {token_aggregation})"
|
|
95
|
+
# Apply steering if provided
|
|
96
|
+
if steering_plan:
|
|
97
|
+
model.attach(steering_plan)
|
|
98
|
+
|
|
99
|
+
# Compute log likelihood for each choice
|
|
100
|
+
log_probs = []
|
|
101
|
+
for choice in choices:
|
|
102
|
+
log_prob = self._compute_choice_log_likelihood(model, question, choice)
|
|
103
|
+
log_probs.append(log_prob)
|
|
104
|
+
|
|
105
|
+
# Detach steering
|
|
106
|
+
if steering_plan:
|
|
107
|
+
model.detach()
|
|
108
|
+
|
|
109
|
+
# Select choice with highest log likelihood
|
|
110
|
+
predicted_idx = log_probs.index(max(log_probs))
|
|
111
|
+
predicted_choice = choices[predicted_idx]
|
|
112
|
+
|
|
113
|
+
# Normalize expected answer for comparison
|
|
114
|
+
expected_normalized = str(expected).strip().lower()
|
|
115
|
+
predicted_normalized = predicted_choice.strip().lower()
|
|
116
|
+
|
|
117
|
+
is_correct = predicted_normalized == expected_normalized
|
|
118
|
+
|
|
119
|
+
return EvalResult(
|
|
120
|
+
ground_truth="TRUTHFUL" if is_correct else "UNTRUTHFUL",
|
|
121
|
+
method_used=self.name,
|
|
122
|
+
confidence=1.0 if is_correct else 0.0,
|
|
123
|
+
details=f"Predicted: '{predicted_choice}' (log_prob={log_probs[predicted_idx]:.3f}), Expected: '{expected}'",
|
|
124
|
+
meta={
|
|
125
|
+
"predicted": predicted_choice,
|
|
126
|
+
"expected": expected,
|
|
127
|
+
"log_probs": {choice: lp for choice, lp in zip(choices, log_probs)},
|
|
128
|
+
}
|
|
190
129
|
)
|
|
191
130
|
|
|
192
|
-
# Evaluate classifier on each sample
|
|
193
|
-
results = []
|
|
194
|
-
total_correct = 0
|
|
195
|
-
total_samples = 0
|
|
196
|
-
|
|
197
|
-
for i, pair in enumerate(processed_pairs):
|
|
198
|
-
try:
|
|
199
|
-
sample_result = self._evaluate_classifier_on_sample(
|
|
200
|
-
classifier, pair, qa_pairs[i], activation_method
|
|
201
|
-
)
|
|
202
|
-
results.append(sample_result)
|
|
203
|
-
|
|
204
|
-
if sample_result.get("classifier_correct", False):
|
|
205
|
-
total_correct += 1
|
|
206
|
-
total_samples += 1
|
|
207
|
-
|
|
208
|
-
except Exception as e:
|
|
209
|
-
logger.error(f"Error evaluating sample {i}: {e}")
|
|
210
|
-
continue
|
|
211
|
-
|
|
212
|
-
# Calculate overall metrics
|
|
213
|
-
accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
|
214
|
-
|
|
215
|
-
return {
|
|
216
|
-
"ground_truth": "EVALUATED",
|
|
217
|
-
"method_used": "log-likelihoods-classifier",
|
|
218
|
-
"confidence": accuracy,
|
|
219
|
-
"details": f"Evaluated {total_samples} samples with {total_correct} correct predictions",
|
|
220
|
-
"task_name": task_name,
|
|
221
|
-
"evaluation_method": "log-likelihoods",
|
|
222
|
-
"lm_eval_metrics": {
|
|
223
|
-
"accuracy": accuracy,
|
|
224
|
-
"correct_predictions": total_correct,
|
|
225
|
-
"total_samples": total_samples,
|
|
226
|
-
},
|
|
227
|
-
"sample_results": results[:10], # First 10 for debugging
|
|
228
|
-
}
|
|
229
|
-
|
|
230
131
|
except Exception as e:
|
|
132
|
+
logger.error(f"Error in log likelihood evaluation: {e}")
|
|
231
133
|
import traceback
|
|
134
|
+
logger.error(traceback.format_exc())
|
|
135
|
+
# NO FALLBACK - raise the error
|
|
136
|
+
raise
|
|
232
137
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
return self._error_result(f"Evaluation error: {e!s}")
|
|
236
|
-
|
|
237
|
-
def _evaluate_classifier_on_sample(
|
|
238
|
-
self, classifier, processed_pair, qa_pair: Dict[str, Any], activation_method
|
|
239
|
-
) -> Dict[str, Any]:
|
|
240
|
-
"""
|
|
241
|
-
Evaluate the classifier on a single processed contrastive pair.
|
|
138
|
+
def _compute_choice_log_likelihood(self, model, question: str, choice: str) -> float:
|
|
139
|
+
"""Compute log likelihood of a choice given a question.
|
|
242
140
|
|
|
243
141
|
Args:
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
142
|
+
model: WisentModel instance
|
|
143
|
+
question: The question/context
|
|
144
|
+
choice: The answer choice
|
|
145
|
+
|
|
248
146
|
Returns:
|
|
249
|
-
|
|
147
|
+
Log likelihood (higher = more likely)
|
|
250
148
|
"""
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
#
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
#
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
# Enhanced logging to show classifier vs normal predictions
|
|
291
|
-
logger.info("📋 Sample Evaluation:")
|
|
292
|
-
logger.debug(f" 🔸 Question: {qa_pair['question'][:100]}...")
|
|
293
|
-
logger.info(" 🧠 CLASSIFIER PREDICTION:")
|
|
294
|
-
logger.debug(f" ✅ Correct answer score: {float(positive_prediction):.3f} (lower = more truthful)")
|
|
295
|
-
logger.debug(f" ❌ Incorrect answer score: {float(negative_prediction):.3f} (higher = less truthful)")
|
|
296
|
-
logger.debug(
|
|
297
|
-
f" 🎯 Classifier judgment: {'CORRECT' if classifier_correct else 'INCORRECT'} (positive {'<' if classifier_correct else '>='} negative)"
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
return {
|
|
301
|
-
"question": qa_pair["question"],
|
|
302
|
-
"correct_answer": qa_pair["correct_answer"],
|
|
303
|
-
"incorrect_answer": qa_pair["incorrect_answer"],
|
|
304
|
-
"positive_prediction": float(positive_prediction),
|
|
305
|
-
"negative_prediction": float(negative_prediction),
|
|
306
|
-
"classifier_correct": classifier_correct,
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
except Exception as e:
|
|
310
|
-
logger.debug(f"Error evaluating sample: {e}")
|
|
311
|
-
return {
|
|
312
|
-
"question": qa_pair.get("question", "Unknown"),
|
|
313
|
-
"correct_answer": qa_pair.get("correct_answer", "Unknown"),
|
|
314
|
-
"incorrect_answer": qa_pair.get("incorrect_answer", "Unknown"),
|
|
315
|
-
"classifier_correct": False,
|
|
316
|
-
"error": str(e),
|
|
317
|
-
}
|
|
318
|
-
|
|
319
|
-
def _error_result(self, error_msg: str) -> Dict[str, Any]:
|
|
320
|
-
"""Return an error result."""
|
|
321
|
-
return {
|
|
322
|
-
"ground_truth": "UNKNOWN",
|
|
323
|
-
"method_used": "log-likelihoods-error",
|
|
324
|
-
"confidence": 0.0,
|
|
325
|
-
"details": error_msg,
|
|
326
|
-
"task_name": self.task_name or "unknown",
|
|
327
|
-
"evaluation_method": "log-likelihoods",
|
|
328
|
-
"lm_eval_metrics": {"accuracy": 0.0, "correct_predictions": 0, "total_samples": 0},
|
|
329
|
-
}
|
|
149
|
+
# Format as: question + choice
|
|
150
|
+
full_text = f"{question}\n{choice}"
|
|
151
|
+
|
|
152
|
+
# Tokenize question and choice separately
|
|
153
|
+
question_inputs = model.tokenizer(question, return_tensors="pt", add_special_tokens=True).to(model.device)
|
|
154
|
+
choice_tokens = model.tokenizer(choice, return_tensors="pt", add_special_tokens=False).to(model.device)
|
|
155
|
+
|
|
156
|
+
# Get model logits for the full sequence
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
# Tokenize full sequence
|
|
159
|
+
full_inputs = model.tokenizer(full_text, return_tensors="pt", add_special_tokens=True).to(model.device)
|
|
160
|
+
outputs = model.hf_model(**full_inputs)
|
|
161
|
+
logits = outputs.logits
|
|
162
|
+
|
|
163
|
+
# Compute log probability of the choice tokens
|
|
164
|
+
# logits shape: [batch, seq_len, vocab_size]
|
|
165
|
+
# We want log prob of choice tokens given question
|
|
166
|
+
|
|
167
|
+
question_len = question_inputs.input_ids.shape[1]
|
|
168
|
+
choice_len = choice_tokens.input_ids.shape[1]
|
|
169
|
+
|
|
170
|
+
# Get logits at positions where we're predicting choice tokens
|
|
171
|
+
log_prob = 0.0
|
|
172
|
+
for i in range(choice_len):
|
|
173
|
+
# Position in full sequence where we predict token i of choice
|
|
174
|
+
# Subtract 1 because we predict the next token
|
|
175
|
+
pos = question_len + i - 1
|
|
176
|
+
if pos >= 0 and pos < logits.shape[1]:
|
|
177
|
+
token_logits = logits[0, pos, :] # Logits at this position
|
|
178
|
+
token_log_probs = torch.nn.functional.log_softmax(token_logits, dim=-1)
|
|
179
|
+
# Get log prob of the actual choice token at this position
|
|
180
|
+
actual_token_id = choice_tokens.input_ids[0, i]
|
|
181
|
+
log_prob += token_log_probs[actual_token_id].item()
|
|
182
|
+
|
|
183
|
+
# Normalize by length to avoid bias toward shorter choices
|
|
184
|
+
normalized_log_prob = log_prob / max(choice_len, 1)
|
|
185
|
+
|
|
186
|
+
return normalized_log_prob
|
|
@@ -25,8 +25,8 @@ class EvaluatorRotator:
|
|
|
25
25
|
) -> None:
|
|
26
26
|
if autoload:
|
|
27
27
|
self.discover_evaluators(evaluators_location)
|
|
28
|
+
self._task_name = task_name # Set before resolving
|
|
28
29
|
self._evaluator = self._resolve_evaluator(evaluator)
|
|
29
|
-
self._task_name = task_name
|
|
30
30
|
|
|
31
31
|
@staticmethod
|
|
32
32
|
def discover_evaluators(location: Union[str, Path] = "wisent.core.evaluators.oracles") -> None:
|
|
@@ -93,17 +93,31 @@ class EvaluatorRotator:
|
|
|
93
93
|
)
|
|
94
94
|
return sorted(out, key=lambda x: x["name"])
|
|
95
95
|
|
|
96
|
-
@staticmethod
|
|
97
96
|
def _resolve_evaluator(
|
|
97
|
+
self,
|
|
98
98
|
evaluator: Union[str, BaseEvaluator, Type[BaseEvaluator], None]
|
|
99
99
|
) -> BaseEvaluator:
|
|
100
100
|
if evaluator is None:
|
|
101
|
-
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
101
|
+
# Auto-select based on task_name if provided
|
|
102
|
+
if self._task_name:
|
|
103
|
+
registry = BaseEvaluator.list_registered()
|
|
104
|
+
for name, cls in registry.items():
|
|
105
|
+
task_names = getattr(cls, 'task_names', ())
|
|
106
|
+
if self._task_name in task_names:
|
|
107
|
+
logger.info(f"Auto-selected evaluator '{name}' for task '{self._task_name}'")
|
|
108
|
+
return cls()
|
|
109
|
+
# NO FALLBACK - raise error if no evaluator found for task
|
|
110
|
+
raise EvaluatorError(
|
|
111
|
+
f"No evaluator found for task '{self._task_name}'. "
|
|
112
|
+
f"Available evaluators: {list(registry.keys())}. "
|
|
113
|
+
f"Please specify an evaluator explicitly or add task_names to an evaluator."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# NO FALLBACK - if no task_name and no evaluator, require explicit selection
|
|
117
|
+
raise EvaluatorError(
|
|
118
|
+
"No evaluator specified and no task_name provided. "
|
|
119
|
+
"Either provide an evaluator name or a task_name for auto-selection."
|
|
120
|
+
)
|
|
107
121
|
if isinstance(evaluator, BaseEvaluator):
|
|
108
122
|
return evaluator
|
|
109
123
|
if inspect.isclass(evaluator) and issubclass(evaluator, BaseEvaluator):
|
wisent/core/main.py
CHANGED
|
@@ -8,7 +8,7 @@ and provides the main() function that serves as the CLI entry point.
|
|
|
8
8
|
import sys
|
|
9
9
|
from wisent.core.parser_arguments import setup_parser
|
|
10
10
|
from wisent.core.branding import print_banner
|
|
11
|
-
from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, execute_generate_pairs, execute_get_activations, execute_create_steering_vector, execute_generate_vector_from_task, execute_generate_vector_from_synthetic, execute_optimize_classification, execute_optimize_steering, execute_generate_responses, execute_evaluate_responses
|
|
11
|
+
from wisent.core.cli import execute_tasks, execute_generate_pairs_from_task, execute_generate_pairs, execute_get_activations, execute_create_steering_vector, execute_generate_vector_from_task, execute_generate_vector_from_synthetic, execute_optimize_classification, execute_optimize_steering, execute_optimize_sample_size, execute_generate_responses, execute_evaluate_responses, execute_multi_steer
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def main():
|
|
@@ -44,10 +44,14 @@ def main():
|
|
|
44
44
|
execute_optimize_classification(args)
|
|
45
45
|
elif args.command == 'optimize-steering':
|
|
46
46
|
execute_optimize_steering(args)
|
|
47
|
+
elif args.command == 'optimize-sample-size':
|
|
48
|
+
execute_optimize_sample_size(args)
|
|
47
49
|
elif args.command == 'generate-responses':
|
|
48
50
|
execute_generate_responses(args)
|
|
49
51
|
elif args.command == 'evaluate-responses':
|
|
50
52
|
execute_evaluate_responses(args)
|
|
53
|
+
elif args.command == 'multi-steer':
|
|
54
|
+
execute_multi_steer(args)
|
|
51
55
|
else:
|
|
52
56
|
print(f"\n✗ Command '{args.command}' is not yet implemented")
|
|
53
57
|
sys.exit(1)
|
wisent/core/model_persistence.py
CHANGED
|
@@ -33,16 +33,8 @@ class ModelPersistence:
|
|
|
33
33
|
if save_dir:
|
|
34
34
|
os.makedirs(save_dir, exist_ok=True)
|
|
35
35
|
|
|
36
|
-
#
|
|
37
|
-
|
|
38
|
-
filename = os.path.basename(save_path)
|
|
39
|
-
# Sanitize filename to handle periods in model names
|
|
40
|
-
safe_filename = filename.replace('.', '_')
|
|
41
|
-
safe_path = os.path.join(directory, safe_filename)
|
|
42
|
-
|
|
43
|
-
# Add layer suffix to filename
|
|
44
|
-
base, ext = os.path.splitext(safe_path)
|
|
45
|
-
classifier_path = f"{base}_layer_{layer}{ext or '.pkl'}"
|
|
36
|
+
# Use the exact path provided by the user
|
|
37
|
+
classifier_path = save_path
|
|
46
38
|
|
|
47
39
|
# Prepare data to save
|
|
48
40
|
save_data = {
|
|
@@ -69,15 +61,8 @@ class ModelPersistence:
|
|
|
69
61
|
Returns:
|
|
70
62
|
Tuple of (classifier, metadata)
|
|
71
63
|
"""
|
|
72
|
-
#
|
|
73
|
-
|
|
74
|
-
filename = os.path.basename(load_path)
|
|
75
|
-
safe_filename = filename.replace('.', '_')
|
|
76
|
-
safe_path = os.path.join(directory, safe_filename)
|
|
77
|
-
|
|
78
|
-
# Add layer suffix to filename
|
|
79
|
-
base, ext = os.path.splitext(safe_path)
|
|
80
|
-
classifier_path = f"{base}_layer_{layer}{ext or '.pkl'}"
|
|
64
|
+
# Use the exact path provided by the user
|
|
65
|
+
classifier_path = load_path
|
|
81
66
|
|
|
82
67
|
if not os.path.exists(classifier_path):
|
|
83
68
|
raise FileNotFoundError(f"Classifier file not found: {classifier_path}")
|
|
@@ -284,9 +284,17 @@ class WisentModel:
|
|
|
284
284
|
{"input_ids": tensor([[...]]), "attention_mask": tensor([[...]])}
|
|
285
285
|
"""
|
|
286
286
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
287
|
+
try:
|
|
288
|
+
ids = self.tokenizer.apply_chat_template(
|
|
289
|
+
message, tokenize=True, add_generation_prompt=add_generation_prompt, enable_thinking=enable_thinking, return_tensors="pt"
|
|
290
|
+
)[0]
|
|
291
|
+
except ValueError as e:
|
|
292
|
+
if "chat_template is not set" in str(e):
|
|
293
|
+
# Fallback for models without chat templates: concatenate messages
|
|
294
|
+
text = " ".join([msg.get("content", "") for msg in message if isinstance(msg, dict)])
|
|
295
|
+
ids = self.tokenizer.encode(text, return_tensors="pt")[0]
|
|
296
|
+
else:
|
|
297
|
+
raise
|
|
290
298
|
return {
|
|
291
299
|
"input_ids": ids,
|
|
292
300
|
"attention_mask": torch.ones_like(ids),
|
wisent/core/parser.py
CHANGED
|
@@ -1590,14 +1590,15 @@ def setup_multi_steer_parser(parser):
|
|
|
1590
1590
|
"--vector",
|
|
1591
1591
|
type=str,
|
|
1592
1592
|
action="append",
|
|
1593
|
-
required=
|
|
1593
|
+
required=False,
|
|
1594
|
+
default=None,
|
|
1594
1595
|
metavar="PATH:WEIGHT",
|
|
1595
|
-
help="Path to steering vector and its weight (format: path/to/vector.pt:0.5). Can be specified multiple times.",
|
|
1596
|
+
help="Path to steering vector and its weight (format: path/to/vector.pt:0.5). Can be specified multiple times. If omitted, generates unsteered baseline.",
|
|
1596
1597
|
)
|
|
1597
1598
|
|
|
1598
1599
|
# Model configuration
|
|
1599
1600
|
parser.add_argument("--model", type=str, required=True, help="Model name or path")
|
|
1600
|
-
parser.add_argument("--layer", type=int, required=
|
|
1601
|
+
parser.add_argument("--layer", type=int, required=False, default=None, help="Layer index to apply combined steering (required when using vectors)")
|
|
1601
1602
|
parser.add_argument("--device", type=str, default=None, help="Device to run on (default: auto-detect)")
|
|
1602
1603
|
|
|
1603
1604
|
# Steering method configuration
|
|
@@ -15,7 +15,7 @@ from wisent.core.parser_arguments.create_steering_vector_parser import setup_cre
|
|
|
15
15
|
from wisent.core.parser_arguments.generate_vector_from_task_parser import setup_generate_vector_from_task_parser
|
|
16
16
|
from wisent.core.parser_arguments.generate_vector_from_synthetic_parser import setup_generate_vector_from_synthetic_parser
|
|
17
17
|
from wisent.core.parser_arguments.synthetic_parser import setup_synthetic_parser
|
|
18
|
-
from wisent.core.parser_arguments.
|
|
18
|
+
from wisent.core.parser_arguments.nonsense_parser import setup_test_nonsense_parser
|
|
19
19
|
from wisent.core.parser_arguments.monitor_parser import setup_monitor_parser
|
|
20
20
|
from wisent.core.parser_arguments.agent_parser import setup_agent_parser
|
|
21
21
|
from wisent.core.parser_arguments.model_config_parser import setup_model_config_parser
|