wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,474 @@
1
+ """
2
+ Evaluation metrics for comprehensive evaluation pipeline.
3
+ """
4
+
5
+ import logging
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ import numpy as np
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
10
+
11
+ from wisent_guard.core.bigcode_extractors import MBPPExtractor
12
+
13
+ # Import LMEvalHarnessGroundTruth for intelligent evaluation (newer approach used by CLI)
14
+ from wisent_guard.core.lm_eval_harness_ground_truth import LMEvalHarnessGroundTruth
15
+ from wisent_guard.core.task_interface import get_task
16
+ from wisent_guard.core.tasks.file_task import FileTask
17
+ from wisent_guard.parameters.task_config import CODING_TASKS
18
+
19
+ from .bigcode_evaluator_wrapper import OptunaBigCodeEvaluator
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def evaluate_response_correctness(response: str, expected_answer: str, task_name: str) -> bool:
25
+ """
26
+ Evaluate if a response is correct using LMEvalHarnessGroundTruth (same approach as CLI).
27
+ Note: For coding tasks, response should already be extracted code before calling this function.
28
+
29
+ Args:
30
+ response: Model's response (pre-extracted code for coding tasks)
31
+ expected_answer: Expected correct answer
32
+ task_name: Name of the task for proper evaluation
33
+
34
+ Returns:
35
+ True if response is correct, False otherwise
36
+ """
37
+ # Check if this is a file-based task (custom dataset loaded from JSON)
38
+ # For file-based tasks, use exact string matching to avoid false positives
39
+ try:
40
+ task = get_task(task_name, limit=1)
41
+ if isinstance(task, FileTask):
42
+ logger.debug(f"Using exact match for file-based task '{task_name}'")
43
+ return response.strip().lower() == expected_answer.strip().lower()
44
+ except:
45
+ pass # Continue with normal evaluation if task lookup fails
46
+
47
+ try:
48
+ # Use the same evaluation approach as the CLI
49
+ evaluator = LMEvalHarnessGroundTruth(task_name)
50
+
51
+ # Create response data format expected by _evaluate_with_lm_eval_metrics
52
+ response_data = [
53
+ {
54
+ "generated_response": response,
55
+ "ground_truth": expected_answer,
56
+ "question": "evaluation_question", # Required field for evaluation
57
+ }
58
+ ]
59
+
60
+ # Use the same evaluation logic as CLI
61
+ eval_results = evaluator._evaluate_with_lm_eval_metrics(task_name, response_data, None)
62
+
63
+ # Extract the result - accuracy > 0 means at least one correct
64
+ return eval_results.get("accuracy", 0.0) > 0.0
65
+
66
+ except Exception as e:
67
+ logger.warning(f"LMEvalHarnessGroundTruth failed, using exact match fallback: {e}")
68
+ # Fallback to simple string matching
69
+ return response.strip().lower() == expected_answer.strip().lower()
70
+
71
+
72
+ def evaluate_benchmark_performance(
73
+ predictions: List[str],
74
+ ground_truths: List[str],
75
+ task_name: str = None,
76
+ task_docs: List[Dict] = None,
77
+ classifier_scorer: Optional[Callable[[List[str], str], List[float]]] = None,
78
+ ) -> Dict[str, float]:
79
+ """
80
+ Evaluate benchmark performance using LMEvalHarnessGroundTruth (same approach as CLI).
81
+ For coding tasks, uses BigCode execution-based evaluation instead of string comparison.
82
+
83
+ Args:
84
+ predictions: List of model predictions
85
+ ground_truths: List of correct answers
86
+ task_name: Name of the task for intelligent evaluation
87
+ task_docs: List of original task documents (required for coding tasks)
88
+ classifier_scorer: Optional function to score predictions with classifier for confidence scores
89
+
90
+ Returns:
91
+ Dictionary containing benchmark performance metrics
92
+ """
93
+ if task_name:
94
+ # Check if this is a coding task that requires code execution evaluation
95
+ is_coding_task = task_name.lower() in CODING_TASKS
96
+
97
+ # Calculate classifier confidence scores if classifier_scorer provided
98
+ classifier_confidences = None
99
+ if classifier_scorer is not None:
100
+ try:
101
+ logger.debug(f"Calculating classifier confidence scores for {len(predictions)} predictions")
102
+ classifier_confidences = classifier_scorer(predictions, f"metrics_evaluation_{task_name}")
103
+ logger.debug(f"Calculated {len(classifier_confidences)} confidence scores")
104
+ except Exception as e:
105
+ logger.warning(f"Failed to calculate classifier confidence scores: {e}")
106
+ classifier_confidences = None
107
+
108
+ if is_coding_task:
109
+ # Use BigCode execution-based evaluation for coding tasks
110
+ logger.info(f"Using BigCode execution-based evaluation for coding task: {task_name}")
111
+
112
+ try:
113
+ bigcode_evaluator = OptunaBigCodeEvaluator()
114
+
115
+ # Validate task docs are provided for coding tasks
116
+ if task_docs is None or len(task_docs) == 0:
117
+ logger.error(
118
+ f"No task docs provided for coding task {task_name}. BigCode evaluation requires original task documents with test cases."
119
+ )
120
+ raise ValueError(f"Task documents required for coding task evaluation: {task_name}")
121
+
122
+ # Ensure we have the right number of task docs
123
+ if len(task_docs) != len(predictions):
124
+ logger.error(f"Task docs length mismatch: {len(task_docs)} docs vs {len(predictions)} predictions")
125
+ raise ValueError(
126
+ f"Number of task documents ({len(task_docs)}) must match number of predictions ({len(predictions)})"
127
+ )
128
+
129
+ # Evaluate using BigCode execution
130
+ evaluation_results, accuracy_metrics = bigcode_evaluator.evaluate_and_calculate_accuracy(
131
+ predictions, task_docs, task_name
132
+ )
133
+
134
+ # Create evaluation details in the expected format
135
+ evaluation_details = []
136
+ for i, (pred, result) in enumerate(zip(predictions, evaluation_results)):
137
+ eval_detail = {
138
+ "prediction": result.get("extracted_code", pred),
139
+ "ground_truth": ground_truths[i] if i < len(ground_truths) else "unknown",
140
+ "is_correct": result.get("passed", False),
141
+ "classifier_confidence": classifier_confidences[i]
142
+ if classifier_confidences and i < len(classifier_confidences)
143
+ else 1.0,
144
+ "method": "bigcode_execution",
145
+ "original_prediction": pred,
146
+ "code_extracted": result.get("extracted_code", "") != pred,
147
+ "execution_error": result.get("error"),
148
+ }
149
+ evaluation_details.append(eval_detail)
150
+
151
+ return {
152
+ "accuracy": accuracy_metrics["accuracy"],
153
+ "total_samples": accuracy_metrics["total_samples"],
154
+ "correct": accuracy_metrics["correct"],
155
+ "incorrect": accuracy_metrics["incorrect"],
156
+ "evaluation_method": "bigcode_execution",
157
+ "task_name": task_name,
158
+ "evaluation_details": evaluation_details,
159
+ "pass_count": accuracy_metrics.get("pass_count", 0),
160
+ "fail_count": accuracy_metrics.get("fail_count", 0),
161
+ "error_count": accuracy_metrics.get("error_count", 0),
162
+ }
163
+
164
+ except Exception as e:
165
+ logger.error(f"BigCode evaluation failed for {task_name}, falling back to string-based evaluation: {e}")
166
+ # Fall through to string-based evaluation
167
+
168
+ # String-based evaluation for non-coding tasks or BigCode fallback
169
+ extracted_predictions = predictions
170
+
171
+ if is_coding_task:
172
+ # Extract code from predictions for coding tasks (fallback mode)
173
+ extractor = MBPPExtractor() # Works for all coding tasks, not just MBPP
174
+ extracted_predictions = []
175
+
176
+ for pred in predictions:
177
+ extracted_code = extractor.extract_code_from_answer(pred)
178
+ extracted_predictions.append(extracted_code)
179
+
180
+ logger.debug(f"Code extraction applied for {task_name}: {len(predictions)} predictions processed")
181
+
182
+ # Use intelligent evaluation with LMEvalHarnessGroundTruth (same as CLI)
183
+ correct_predictions = []
184
+ evaluation_details = []
185
+
186
+ for i, (orig_pred, extracted_pred, gt) in enumerate(zip(predictions, extracted_predictions, ground_truths)):
187
+ try:
188
+ # Use the extracted prediction for evaluation
189
+ is_correct = evaluate_response_correctness(extracted_pred, gt, task_name)
190
+ correct_predictions.append(is_correct)
191
+
192
+ # Include both original and extracted predictions in details for debugging
193
+ eval_detail = {
194
+ "prediction": extracted_pred,
195
+ "ground_truth": gt,
196
+ "is_correct": is_correct,
197
+ "classifier_confidence": classifier_confidences[i]
198
+ if classifier_confidences and i < len(classifier_confidences)
199
+ else 1.0,
200
+ "method": "lm_eval_harness_ground_truth",
201
+ }
202
+
203
+ # Add original prediction for coding tasks to help with debugging
204
+ if is_coding_task and orig_pred != extracted_pred:
205
+ eval_detail["original_prediction"] = orig_pred
206
+ eval_detail["code_extracted"] = True
207
+
208
+ evaluation_details.append(eval_detail)
209
+
210
+ except Exception as e:
211
+ logger.warning(f"LMEvalHarnessGroundTruth failed for prediction '{extracted_pred}' vs '{gt}': {e}")
212
+ # Fallback to simple string matching
213
+ is_correct = extracted_pred.strip().lower() == gt.strip().lower()
214
+ correct_predictions.append(is_correct)
215
+
216
+ eval_detail = {
217
+ "prediction": extracted_pred,
218
+ "ground_truth": gt,
219
+ "is_correct": is_correct,
220
+ "classifier_confidence": classifier_confidences[i]
221
+ if classifier_confidences and i < len(classifier_confidences)
222
+ else 1.0,
223
+ "method": "fallback_exact_match",
224
+ }
225
+
226
+ if is_coding_task and orig_pred != extracted_pred:
227
+ eval_detail["original_prediction"] = orig_pred
228
+ eval_detail["code_extracted"] = True
229
+
230
+ evaluation_details.append(eval_detail)
231
+
232
+ accuracy = np.mean(correct_predictions)
233
+ total_correct = sum(correct_predictions)
234
+
235
+ return {
236
+ "accuracy": accuracy,
237
+ "total_samples": len(predictions),
238
+ "correct": total_correct,
239
+ "incorrect": len(predictions) - total_correct,
240
+ "evaluation_method": "lm_eval_harness_ground_truth",
241
+ "task_name": task_name,
242
+ "evaluation_details": evaluation_details[:5], # Include first 5 for debugging
243
+ }
244
+ # Fallback to simple exact match
245
+ logger.info("No task_name provided, using simple exact match evaluation")
246
+ exact_matches = [pred.strip().lower() == gt.strip().lower() for pred, gt in zip(predictions, ground_truths)]
247
+ accuracy = np.mean(exact_matches)
248
+
249
+ return {
250
+ "accuracy": accuracy,
251
+ "total_samples": len(predictions),
252
+ "correct": sum(exact_matches),
253
+ "incorrect": len(predictions) - sum(exact_matches),
254
+ "evaluation_method": "exact_match",
255
+ }
256
+
257
+
258
+ def evaluate_probe_performance(y_true: np.ndarray, y_pred: np.ndarray, y_pred_proba: np.ndarray) -> Dict[str, float]:
259
+ """
260
+ Evaluate probe performance with comprehensive metrics.
261
+
262
+ Args:
263
+ y_true: True labels
264
+ y_pred: Predicted labels
265
+ y_pred_proba: Predicted probabilities (for positive class)
266
+
267
+ Returns:
268
+ Dictionary containing probe performance metrics
269
+ """
270
+ if len(y_true) == 0:
271
+ # Return default metrics if no data
272
+ return {"accuracy": 0.5, "precision": 0.5, "recall": 0.5, "f1": 0.5, "auc": 0.5, "total_samples": 0}
273
+
274
+ accuracy = accuracy_score(y_true, y_pred)
275
+ precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
276
+
277
+ try:
278
+ auc = roc_auc_score(y_true, y_pred_proba)
279
+ except:
280
+ auc = 0.5 # Default for cases where AUC can't be computed
281
+
282
+ return {
283
+ "accuracy": accuracy,
284
+ "precision": precision,
285
+ "recall": recall,
286
+ "f1": f1,
287
+ "auc": auc,
288
+ "total_samples": len(y_true),
289
+ }
290
+
291
+
292
+ def calculate_combined_score(
293
+ benchmark_metrics: Dict[str, float],
294
+ probe_metrics: Dict[str, float],
295
+ benchmark_weight: float = 0.7,
296
+ probe_weight: float = 0.3,
297
+ ) -> float:
298
+ """
299
+ Calculate combined score from benchmark and probe performance.
300
+
301
+ Args:
302
+ benchmark_metrics: Benchmark performance metrics
303
+ probe_metrics: Probe performance metrics
304
+ benchmark_weight: Weight for benchmark performance
305
+ probe_weight: Weight for probe performance
306
+
307
+ Returns:
308
+ Combined score (0-1)
309
+ """
310
+ benchmark_score = benchmark_metrics.get("accuracy", 0.0)
311
+ probe_score = probe_metrics.get("auc", 0.5) # Use AUC as primary probe metric
312
+
313
+ combined_score = benchmark_weight * benchmark_score + probe_weight * probe_score
314
+ return combined_score
315
+
316
+
317
+ def calculate_comprehensive_metrics(results: Dict[str, Any]) -> Dict[str, Any]:
318
+ """
319
+ Calculate comprehensive metrics from evaluation results.
320
+
321
+ Args:
322
+ results: Complete evaluation results
323
+
324
+ Returns:
325
+ Dictionary with comprehensive metrics and analysis
326
+ """
327
+ comprehensive_metrics = {}
328
+
329
+ if "test_results" in results:
330
+ test_results = results["test_results"]
331
+
332
+ # Extract key metrics
333
+ base_benchmark_acc = test_results.get("base_benchmark_metrics", {}).get("accuracy", 0.0)
334
+ steered_benchmark_acc = test_results.get("steered_benchmark_metrics", {}).get("accuracy", 0.0)
335
+ base_probe_auc = test_results.get("base_probe_metrics", {}).get("auc", 0.5)
336
+ steered_probe_auc = test_results.get("steered_probe_metrics", {}).get("auc", 0.5)
337
+
338
+ # Calculate improvements
339
+ benchmark_improvement = steered_benchmark_acc - base_benchmark_acc
340
+ probe_improvement = steered_probe_auc - base_probe_auc
341
+
342
+ comprehensive_metrics.update(
343
+ {
344
+ "base_benchmark_accuracy": base_benchmark_acc,
345
+ "steered_benchmark_accuracy": steered_benchmark_acc,
346
+ "benchmark_improvement": benchmark_improvement,
347
+ "benchmark_improvement_percent": (benchmark_improvement / max(base_benchmark_acc, 0.001)) * 100,
348
+ "base_probe_auc": base_probe_auc,
349
+ "steered_probe_auc": steered_probe_auc,
350
+ "probe_improvement": probe_improvement,
351
+ "probe_improvement_percent": (probe_improvement / max(base_probe_auc, 0.001)) * 100,
352
+ "overall_effectiveness": (benchmark_improvement + probe_improvement) / 2,
353
+ "validation_score": test_results.get("validation_combined_score", 0.0),
354
+ }
355
+ )
356
+
357
+ # Add training statistics
358
+ if "probe_training_results" in results:
359
+ training_results = results["probe_training_results"]
360
+
361
+ # Calculate training performance statistics
362
+ all_training_aucs = []
363
+ for layer_key, layer_results in training_results.items():
364
+ for c_key, metrics in layer_results.items():
365
+ if isinstance(metrics, dict) and "auc" in metrics:
366
+ all_training_aucs.append(metrics["auc"])
367
+
368
+ if all_training_aucs:
369
+ comprehensive_metrics.update(
370
+ {
371
+ "training_probe_auc_mean": np.mean(all_training_aucs),
372
+ "training_probe_auc_std": np.std(all_training_aucs),
373
+ "training_probe_auc_max": np.max(all_training_aucs),
374
+ "training_probe_auc_min": np.min(all_training_aucs),
375
+ }
376
+ )
377
+
378
+ # Add optimization statistics
379
+ if "steering_optimization_results" in results:
380
+ optimization_results = results["steering_optimization_results"]
381
+
382
+ all_configs = optimization_results.get("all_configs", [])
383
+ if all_configs:
384
+ combined_scores = [config.get("combined_score", 0.0) for config in all_configs]
385
+ benchmark_scores = [config.get("benchmark_metrics", {}).get("accuracy", 0.0) for config in all_configs]
386
+
387
+ comprehensive_metrics.update(
388
+ {
389
+ "optimization_configs_tested": len(all_configs),
390
+ "optimization_score_mean": np.mean(combined_scores),
391
+ "optimization_score_std": np.std(combined_scores),
392
+ "optimization_benchmark_mean": np.mean(benchmark_scores),
393
+ "optimization_benchmark_std": np.std(benchmark_scores),
394
+ }
395
+ )
396
+
397
+ return comprehensive_metrics
398
+
399
+
400
+ def generate_performance_summary(comprehensive_metrics: Dict[str, Any]) -> str:
401
+ """
402
+ Generate a human-readable performance summary.
403
+
404
+ Args:
405
+ comprehensive_metrics: Comprehensive metrics dictionary
406
+
407
+ Returns:
408
+ String summary of performance
409
+ """
410
+ summary = []
411
+ summary.append("=" * 60)
412
+ summary.append("COMPREHENSIVE EVALUATION PERFORMANCE SUMMARY")
413
+ summary.append("=" * 60)
414
+
415
+ # Benchmark Performance
416
+ if "base_benchmark_accuracy" in comprehensive_metrics:
417
+ base_acc = comprehensive_metrics["base_benchmark_accuracy"]
418
+ steered_acc = comprehensive_metrics["steered_benchmark_accuracy"]
419
+ improvement = comprehensive_metrics["benchmark_improvement"]
420
+
421
+ summary.append("\nšŸ“Š BENCHMARK PERFORMANCE:")
422
+ summary.append(f" Base Model Accuracy: {base_acc:.3f} ({base_acc * 100:.1f}%)")
423
+ summary.append(f" Steered Model Accuracy: {steered_acc:.3f} ({steered_acc * 100:.1f}%)")
424
+ summary.append(f" Improvement: {improvement:+.3f} ({improvement * 100:+.1f}%)")
425
+
426
+ # Probe Performance
427
+ if "base_probe_auc" in comprehensive_metrics:
428
+ base_auc = comprehensive_metrics["base_probe_auc"]
429
+ steered_auc = comprehensive_metrics["steered_probe_auc"]
430
+ improvement = comprehensive_metrics["probe_improvement"]
431
+
432
+ summary.append("\nšŸ” PROBE PERFORMANCE:")
433
+ summary.append(f" Base Model Probe AUC: {base_auc:.3f}")
434
+ summary.append(f" Steered Model Probe AUC: {steered_auc:.3f}")
435
+ summary.append(f" Improvement: {improvement:+.3f}")
436
+
437
+ # Training Statistics
438
+ if "training_probe_auc_mean" in comprehensive_metrics:
439
+ mean_auc = comprehensive_metrics["training_probe_auc_mean"]
440
+ std_auc = comprehensive_metrics["training_probe_auc_std"]
441
+ max_auc = comprehensive_metrics["training_probe_auc_max"]
442
+
443
+ summary.append("\nšŸŽÆ TRAINING STATISTICS:")
444
+ summary.append(f" Probe Training AUC: {mean_auc:.3f} ± {std_auc:.3f}")
445
+ summary.append(f" Best Training AUC: {max_auc:.3f}")
446
+
447
+ # Optimization Statistics
448
+ if "optimization_configs_tested" in comprehensive_metrics:
449
+ num_configs = comprehensive_metrics["optimization_configs_tested"]
450
+ best_score = comprehensive_metrics.get("validation_score", 0.0)
451
+
452
+ summary.append("\nāš™ļø OPTIMIZATION STATISTICS:")
453
+ summary.append(f" Configurations Tested: {num_configs}")
454
+ summary.append(f" Best Validation Score: {best_score:.3f}")
455
+
456
+ # Overall Assessment
457
+ if "overall_effectiveness" in comprehensive_metrics:
458
+ effectiveness = comprehensive_metrics["overall_effectiveness"]
459
+
460
+ summary.append("\nšŸ† OVERALL ASSESSMENT:")
461
+ if effectiveness > 0.1:
462
+ assessment = "Highly Effective"
463
+ elif effectiveness > 0.05:
464
+ assessment = "Moderately Effective"
465
+ elif effectiveness > 0.01:
466
+ assessment = "Slightly Effective"
467
+ else:
468
+ assessment = "Minimal Effect"
469
+
470
+ summary.append(f" Steering Effectiveness: {assessment} ({effectiveness:+.3f})")
471
+
472
+ summary.append("=" * 60)
473
+
474
+ return "\n".join(summary)