wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,188 @@
1
+ """
2
+ BigCode evaluator wrapper for optuna pipeline integration.
3
+
4
+ This module provides a clean interface for integrating BigCode code execution
5
+ evaluation with the optuna optimization pipeline.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Dict, List, Tuple
10
+
11
+ from wisent.core.bigcode_extractors import get_bigcode_extractor
12
+ from wisent.core.bigcode_integration import BigCodeEvaluator, is_bigcode_task
13
+ from wisent.parameters.task_config import CODING_TASKS
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class OptunaBigCodeEvaluator:
19
+ """
20
+ Wrapper for BigCode evaluation in optuna pipeline.
21
+
22
+ This class provides a clean interface for evaluating coding tasks
23
+ using actual code execution instead of string comparison.
24
+ """
25
+
26
+ def __init__(self, docker_executor=None):
27
+ """
28
+ Initialize the evaluator.
29
+
30
+ Args:
31
+ docker_executor: Optional Docker executor for secure code execution
32
+ """
33
+ self.bigcode_evaluator = BigCodeEvaluator(docker_executor)
34
+ self.code_extractor = None # Will be set per task
35
+
36
+ def is_coding_task(self, task_name: str) -> bool:
37
+ """
38
+ Check if a task requires code execution evaluation.
39
+
40
+ Args:
41
+ task_name: Name of the task
42
+
43
+ Returns:
44
+ True if task requires code execution, False otherwise
45
+ """
46
+ if not task_name:
47
+ return False
48
+ return task_name.lower() in CODING_TASKS or is_bigcode_task(task_name)
49
+
50
+ def evaluate_predictions(
51
+ self, predictions: List[str], task_docs: List[Dict[str, Any]], task_name: str
52
+ ) -> List[Dict[str, Any]]:
53
+ """
54
+ Evaluate model predictions using code execution.
55
+
56
+ Args:
57
+ predictions: List of model-generated code predictions
58
+ task_docs: List of original task documents with test cases
59
+ task_name: Name of the task
60
+
61
+ Returns:
62
+ List of evaluation results for each prediction
63
+ """
64
+ if not self.is_coding_task(task_name):
65
+ raise ValueError(f"Task {task_name} is not a coding task")
66
+
67
+ if len(predictions) != len(task_docs):
68
+ raise ValueError(f"Mismatch: {len(predictions)} predictions vs {len(task_docs)} task docs")
69
+
70
+ results = []
71
+
72
+ for i, (prediction, task_doc) in enumerate(zip(predictions, task_docs)):
73
+ try:
74
+ # Get the appropriate extractor for this task
75
+ code_extractor = get_bigcode_extractor(task_name)
76
+
77
+ # Extract code from the prediction
78
+ extracted_code = code_extractor.extract_code_from_answer(prediction)
79
+
80
+ if not extracted_code.strip():
81
+ logger.warning(f"No code extracted from prediction {i}: {prediction[:100]}...")
82
+ result = {
83
+ "passed": False,
84
+ "error": "No code extracted from prediction",
85
+ "extracted_code": extracted_code,
86
+ "original_prediction": prediction,
87
+ }
88
+ else:
89
+ # Execute the code against test cases
90
+ result = self.bigcode_evaluator._execute_and_test(task_doc, extracted_code, task_name)
91
+ result["extracted_code"] = extracted_code
92
+ result["original_prediction"] = prediction
93
+
94
+ results.append(result)
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Error evaluating prediction {i}: {e}")
98
+ results.append(
99
+ {"passed": False, "error": str(e), "extracted_code": "", "original_prediction": prediction}
100
+ )
101
+
102
+ return results
103
+
104
+ def calculate_accuracy(self, evaluation_results: List[Dict[str, Any]]) -> Dict[str, Any]:
105
+ """
106
+ Calculate accuracy metrics from evaluation results.
107
+
108
+ Args:
109
+ evaluation_results: List of evaluation results from evaluate_predictions
110
+
111
+ Returns:
112
+ Dictionary with accuracy metrics
113
+ """
114
+ total_samples = len(evaluation_results)
115
+ if total_samples == 0:
116
+ return {
117
+ "accuracy": 0.0,
118
+ "total_samples": 0,
119
+ "correct": 0,
120
+ "incorrect": 0,
121
+ "pass_count": 0,
122
+ "fail_count": 0,
123
+ "error_count": 0,
124
+ }
125
+
126
+ pass_count = sum(1 for result in evaluation_results if result.get("passed", False))
127
+ fail_count = sum(
128
+ 1 for result in evaluation_results if result.get("passed", False) == False and not result.get("error")
129
+ )
130
+ error_count = sum(1 for result in evaluation_results if result.get("error"))
131
+
132
+ accuracy = pass_count / total_samples
133
+
134
+ return {
135
+ "accuracy": accuracy,
136
+ "total_samples": total_samples,
137
+ "correct": pass_count,
138
+ "incorrect": fail_count + error_count,
139
+ "pass_count": pass_count,
140
+ "fail_count": fail_count,
141
+ "error_count": error_count,
142
+ "evaluation_method": "bigcode_execution",
143
+ }
144
+
145
+ def evaluate_and_calculate_accuracy(
146
+ self, predictions: List[str], task_docs: List[Dict[str, Any]], task_name: str
147
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
148
+ """
149
+ Convenience method to evaluate predictions and calculate accuracy.
150
+
151
+ Args:
152
+ predictions: List of model-generated code predictions
153
+ task_docs: List of original task documents with test cases
154
+ task_name: Name of the task
155
+
156
+ Returns:
157
+ Tuple of (evaluation_results, accuracy_metrics)
158
+ """
159
+ evaluation_results = self.evaluate_predictions(predictions, task_docs, task_name)
160
+ accuracy_metrics = self.calculate_accuracy(evaluation_results)
161
+
162
+ logger.info(
163
+ f"BigCode evaluation for {task_name}: "
164
+ f"{accuracy_metrics['pass_count']}/{accuracy_metrics['total_samples']} passed "
165
+ f"({accuracy_metrics['accuracy']:.3f} accuracy)"
166
+ )
167
+
168
+ return evaluation_results, accuracy_metrics
169
+
170
+
171
+ # Global instance for easy access
172
+ _optuna_bigcode_evaluator = None
173
+
174
+
175
+ def get_optuna_bigcode_evaluator(docker_executor=None) -> OptunaBigCodeEvaluator:
176
+ """
177
+ Get global OptunaBigCodeEvaluator instance.
178
+
179
+ Args:
180
+ docker_executor: Optional Docker executor
181
+
182
+ Returns:
183
+ OptunaBigCodeEvaluator instance
184
+ """
185
+ global _optuna_bigcode_evaluator
186
+ if _optuna_bigcode_evaluator is None:
187
+ _optuna_bigcode_evaluator = OptunaBigCodeEvaluator(docker_executor)
188
+ return _optuna_bigcode_evaluator
@@ -0,0 +1,342 @@
1
+ """
2
+ Data loading and processing utilities for comprehensive evaluation.
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, List, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+
13
+ # Import LMEvalHarnessGroundTruth for intelligent evaluation (same approach as CLI)
14
+ from wisent.core.lm_eval_harness_ground_truth import LMEvalHarnessGroundTruth
15
+
16
+ # Import task interface for dynamic task loading
17
+ from wisent.core.task_interface import get_task
18
+ from wisent.core.utils.device import empty_device_cache, preferred_dtype, resolve_default_device
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def load_dataset_samples(dataset_name: str, limit: int) -> List[Dict]:
24
+ """Load samples from a dataset using the unified task interface."""
25
+ logger.info(f"Loading {limit} samples from {dataset_name}...")
26
+
27
+ try:
28
+ # Use the unified task interface to get any registered task
29
+ task = get_task(dataset_name, limit=limit)
30
+ samples = task.load_data(limit=limit)
31
+
32
+ logger.info(f"Loaded {len(samples)} samples from {dataset_name} via {task.__class__.__name__}")
33
+ return samples
34
+
35
+ except Exception as e:
36
+ logger.error(f"Failed to load {dataset_name}: {e}")
37
+ # Provide helpful error message with available tasks
38
+ try:
39
+ from ...task_interface import list_tasks
40
+
41
+ available_tasks = list_tasks()
42
+ logger.error(f"Available tasks: {available_tasks}")
43
+ except:
44
+ pass
45
+ raise
46
+
47
+
48
+ def extract_activations_with_hook(
49
+ model, tokenizer, texts: List[str], layer: int, batch_size: int, max_length: int, device: torch.device
50
+ ) -> np.ndarray:
51
+ """Extract activations from a specific layer using hooks."""
52
+ activations = []
53
+
54
+ def hook_fn(module, input, output):
55
+ # Handle different output formats (some layers return tuples)
56
+ if isinstance(output, tuple):
57
+ hidden_states = output[0] # First element is usually hidden states
58
+ else:
59
+ hidden_states = output
60
+
61
+ # Extract last token activations (typical for causal LM)
62
+ if len(hidden_states.shape) == 3: # [batch, seq, hidden]
63
+ last_token_acts = hidden_states[:, -1, :].detach().cpu().numpy()
64
+ activations.extend(last_token_acts)
65
+
66
+ # Register hook
67
+ if hasattr(model, "transformer"): # GPT-style models
68
+ target_layer = model.transformer.h[layer]
69
+ elif hasattr(model, "model"): # Some other architectures
70
+ target_layer = model.model.layers[layer]
71
+ else:
72
+ raise ValueError("Unknown model architecture")
73
+
74
+ handle = target_layer.register_forward_hook(hook_fn)
75
+
76
+ try:
77
+ # Process texts in batches
78
+ for i in tqdm(range(0, len(texts), batch_size), desc=f"Extracting activations (layer {layer})"):
79
+ batch_texts = texts[i : i + batch_size]
80
+
81
+ inputs = tokenizer(
82
+ batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length
83
+ ).to(device)
84
+
85
+ with torch.no_grad():
86
+ _ = model(**inputs)
87
+
88
+ finally:
89
+ handle.remove()
90
+
91
+ return np.array(activations)
92
+
93
+
94
+ def generate_benchmark_predictions(
95
+ model,
96
+ tokenizer,
97
+ samples: List[Dict],
98
+ batch_size: int,
99
+ max_length: int,
100
+ device: torch.device,
101
+ task_name: str,
102
+ max_new_tokens: int,
103
+ preserve_task_docs: bool = False,
104
+ ) -> Tuple[List[str], List[str], List[Dict]]:
105
+ """Generate model predictions for benchmark evaluation using task extractor with batching.
106
+
107
+ Args:
108
+ preserve_task_docs: If True, returns original task documents alongside predictions
109
+
110
+ Returns:
111
+ Tuple of (predictions, ground_truths, task_docs) if preserve_task_docs=True
112
+ Tuple of (predictions, ground_truths, []) if preserve_task_docs=False
113
+ """
114
+ predictions = []
115
+ ground_truths = []
116
+ task_docs = [] if preserve_task_docs else []
117
+
118
+ # Get the task and its extractor
119
+ task = get_task(task_name)
120
+ extractor = task.get_extractor()
121
+
122
+ # First, extract all questions and answers
123
+ questions = []
124
+ answers = []
125
+
126
+ valid_samples = [] # Keep track of samples that produce valid QA pairs
127
+
128
+ for sample in samples:
129
+ qa_pair = extractor.extract_qa_pair(sample, task)
130
+ if not qa_pair:
131
+ logger.warning(f"Skipping sample - extractor couldn't extract QA pair: {sample.keys()}")
132
+ continue
133
+ questions.append(qa_pair["formatted_question"])
134
+ answers.append(qa_pair["correct_answer"])
135
+
136
+ if preserve_task_docs:
137
+ valid_samples.append(sample)
138
+
139
+ # Process in batches
140
+ for i in tqdm(range(0, len(questions), batch_size), desc="Generating benchmark predictions"):
141
+ batch_questions = questions[i : i + batch_size]
142
+ batch_answers = answers[i : i + batch_size]
143
+
144
+ # Tokenize batch
145
+ inputs = tokenizer(
146
+ batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=max_length
147
+ ).to(device)
148
+
149
+ with torch.no_grad():
150
+ outputs = model.generate(
151
+ **inputs, max_new_tokens=max_new_tokens, do_sample=True, pad_token_id=tokenizer.eos_token_id
152
+ )
153
+
154
+ # Extract generated text for each item in batch
155
+ for j, output in enumerate(outputs):
156
+ input_length = inputs["input_ids"][j].shape[0]
157
+ generated = tokenizer.decode(output[input_length:], skip_special_tokens=True)
158
+ generated = generated.strip()
159
+ predictions.append(generated)
160
+
161
+ # Add ground truths
162
+ ground_truths.extend(batch_answers)
163
+
164
+ # Add task docs if requested
165
+ if preserve_task_docs:
166
+ task_docs = valid_samples[: len(predictions)] # Ensure same length as predictions
167
+
168
+ return predictions, ground_truths, task_docs
169
+
170
+
171
+ def create_probe_training_data(
172
+ model,
173
+ tokenizer,
174
+ samples: List[Dict],
175
+ layer: int,
176
+ batch_size: int,
177
+ max_length: int,
178
+ device: torch.device,
179
+ task_name: str,
180
+ max_new_tokens: int = 200,
181
+ ) -> Tuple[np.ndarray, np.ndarray]:
182
+ """Create training data for probes: activations -> correctness labels using task extractor with batched generation."""
183
+ texts = []
184
+ labels = []
185
+
186
+ # Get the task and its extractor
187
+ task = get_task(task_name)
188
+ extractor = task.get_extractor()
189
+
190
+ # Pre-extract all questions and answers for batched generation
191
+ questions = []
192
+ correct_answers = []
193
+
194
+ for sample in samples:
195
+ qa_pair = extractor.extract_qa_pair(sample, task)
196
+ if not qa_pair:
197
+ continue
198
+ questions.append(qa_pair["formatted_question"])
199
+ correct_answers.append(qa_pair["correct_answer"])
200
+
201
+ # Generate predictions in batches
202
+ generated_answers = []
203
+
204
+ for i in tqdm(range(0, len(questions), batch_size), desc=f"Generating probe data (layer {layer})"):
205
+ batch_questions = questions[i : i + batch_size]
206
+
207
+ # Tokenize batch
208
+ inputs = tokenizer(
209
+ batch_questions, return_tensors="pt", padding=True, truncation=True, max_length=max_length
210
+ ).to(device)
211
+
212
+ with torch.no_grad():
213
+ outputs = model.generate(
214
+ **inputs, max_new_tokens=max_new_tokens, do_sample=True, pad_token_id=tokenizer.eos_token_id
215
+ )
216
+
217
+ # Extract generated text for each item in batch
218
+ for j, output in enumerate(outputs):
219
+ input_length = inputs["input_ids"][j].shape[0]
220
+ generated = tokenizer.decode(output[input_length:], skip_special_tokens=True)
221
+ generated = generated.strip()
222
+ generated_answers.append(generated)
223
+
224
+ # Now process each question-answer pair for probe training data
225
+ evaluator = LMEvalHarnessGroundTruth(task_name)
226
+
227
+ for question, correct_answer, generated in zip(questions, correct_answers, generated_answers):
228
+ # Create examples with model's actual prediction
229
+ correct_text = f"{question} {correct_answer}"
230
+ incorrect_text = f"{question} {generated}"
231
+
232
+ texts.extend([correct_text, incorrect_text])
233
+
234
+ # Evaluate if prediction is correct using LMEvalHarnessGroundTruth
235
+ try:
236
+ # Create response data format expected by _evaluate_with_lm_eval_metrics
237
+ response_data = [
238
+ {
239
+ "generated_response": generated,
240
+ "ground_truth": correct_answer,
241
+ "question": "evaluation_question", # Required field for evaluation
242
+ }
243
+ ]
244
+
245
+ # Use the same evaluation logic as CLI
246
+ eval_results = evaluator._evaluate_with_lm_eval_metrics(task_name, response_data, None)
247
+
248
+ # Extract the result - accuracy > 0 means at least one correct
249
+ is_correct = eval_results.get("accuracy", 0.0) > 0.0
250
+
251
+ except Exception as e:
252
+ logger.warning(f"LMEvalHarnessGroundTruth failed, using exact match fallback: {e}")
253
+ is_correct = generated.strip().lower() == correct_answer.strip().lower()
254
+
255
+ labels.extend([1, 1 if is_correct else 0])
256
+
257
+ # Extract activations
258
+ activations = extract_activations_with_hook(model, tokenizer, texts, layer, batch_size, max_length, device)
259
+
260
+ return activations, np.array(labels)
261
+
262
+
263
+ def load_model_and_tokenizer(model_name: str, device: torch.device):
264
+ """Load model and tokenizer with proper configuration."""
265
+ logger.info(f"Loading model {model_name} (ONCE)...")
266
+ device_kind = device.type
267
+
268
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
269
+
270
+ # Configure tokenizer for decoder-only models
271
+ if tokenizer.pad_token is None:
272
+ tokenizer.pad_token = tokenizer.eos_token
273
+
274
+ # Set left padding for decoder-only models (required for correct generation)
275
+ tokenizer.padding_side = "left"
276
+
277
+ torch_dtype = preferred_dtype(device_kind)
278
+ model = AutoModelForCausalLM.from_pretrained(
279
+ model_name,
280
+ torch_dtype=torch_dtype,
281
+ low_cpu_mem_usage=True,
282
+ )
283
+ model.to(device)
284
+ model.eval()
285
+
286
+ # Log memory usage
287
+ if device_kind == "cuda" and torch.cuda.is_available():
288
+ memory_gb = torch.cuda.memory_allocated() / 1024**3
289
+ logger.info(f"✓ Model loaded on {device}, GPU memory: {memory_gb:.2f} GB")
290
+ elif device_kind == "mps" and hasattr(torch, "mps"):
291
+ try:
292
+ memory_gb = torch.mps.current_allocated_memory() / 1024**3
293
+ logger.info(f"✓ Model loaded on {device}, MPS memory: {memory_gb:.2f} GB")
294
+ except AttributeError:
295
+ logger.info(f"✓ Model loaded on {device}")
296
+
297
+ return model, tokenizer
298
+
299
+
300
+ def free_model_memory(model, tokenizer):
301
+ """Free model memory after activation extraction."""
302
+ logger.info("🧹 Freeing model memory...")
303
+ device_kind = None
304
+ if hasattr(model, "parameters"):
305
+ try:
306
+ device_kind = next(model.parameters()).device.type
307
+ except StopIteration:
308
+ pass
309
+ del model
310
+ del tokenizer
311
+ import gc
312
+
313
+ gc.collect()
314
+ kind_for_cleanup = device_kind or resolve_default_device()
315
+ empty_device_cache(kind_for_cleanup)
316
+ if kind_for_cleanup == "cuda" and torch.cuda.is_available():
317
+ memory_gb = torch.cuda.memory_allocated() / 1024**3
318
+ logger.info(f"GPU memory after cleanup: {memory_gb:.2f} GB")
319
+ elif kind_for_cleanup == "mps" and hasattr(torch, "mps"):
320
+ try:
321
+ memory_gb = torch.mps.current_allocated_memory() / 1024**3
322
+ logger.info(f"MPS memory after cleanup: {memory_gb:.2f} GB")
323
+ except AttributeError:
324
+ pass
325
+
326
+
327
+ def get_task_contrastive_pairs(samples: List[Dict], task_name: str) -> List[Dict]:
328
+ """Extract contrastive pairs from samples using the task's extractor."""
329
+ contrastive_pairs = []
330
+
331
+ # Get the task and its extractor
332
+ task = get_task(task_name)
333
+ extractor = task.get_extractor()
334
+
335
+ for sample in samples:
336
+ # Use the task's extractor to get contrastive pair
337
+ pair = extractor.extract_contrastive_pair(sample, task)
338
+ if pair:
339
+ contrastive_pairs.append(pair)
340
+
341
+ logger.info(f"Extracted {len(contrastive_pairs)} contrastive pairs from {len(samples)} samples")
342
+ return contrastive_pairs