wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,648 @@
1
+ """
2
+ Sample Size Optimizer for finding the optimal training sample size for classifiers.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import time
9
+ from datetime import datetime
10
+ from typing import Any, Dict, List, Optional, Tuple
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
15
+
16
+ from wisent_guard.core.classifier.classifier import Classifier
17
+
18
+ from .activations import ActivationAggregationStrategy
19
+ from .contrastive_pairs import ContrastivePairSet
20
+ from .model import Model
21
+ from .model_config_manager import ModelConfigManager
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SampleSizeOptimizer:
27
+ """Optimizes training sample size for classifiers."""
28
+
29
+ def __init__(
30
+ self,
31
+ model_name: str,
32
+ task_name: str = "truthfulqa_mc1",
33
+ layer: int = 0,
34
+ token_aggregation: str = "average",
35
+ threshold: float = 0.5,
36
+ test_split: float = 0.2,
37
+ sample_sizes: Optional[List[int]] = None,
38
+ device: Optional[str] = None,
39
+ verbose: bool = False,
40
+ ):
41
+ """
42
+ Initialize the sample size optimizer.
43
+
44
+ Args:
45
+ model_name: Name of the model to optimize
46
+ task_name: Task to optimize for
47
+ layer: Layer index to optimize
48
+ token_aggregation: Token aggregation method (average, final, first, max, min)
49
+ threshold: Detection threshold for classification
50
+ test_split: Fraction of data to use for testing
51
+ sample_sizes: List of sample sizes to test
52
+ device: Device to use for computation
53
+ verbose: Enable verbose output
54
+ """
55
+ self.model_name = model_name
56
+ self.task_name = task_name
57
+ self.layer = layer
58
+ self.token_aggregation = token_aggregation
59
+ self.threshold = threshold
60
+ self.test_split = test_split
61
+ self.verbose = verbose
62
+
63
+ # Default sample sizes if not provided
64
+ if sample_sizes is None:
65
+ self.sample_sizes = [1, 2, 5, 10, 20, 50, 100, 200, 500]
66
+ else:
67
+ self.sample_sizes = sorted(sample_sizes)
68
+
69
+ # Initialize model
70
+ self.model = Model(name=model_name, device=device)
71
+ self.device = self.model.device
72
+
73
+ # Storage for results
74
+ self.results = []
75
+ self.optimal_sample_size = None
76
+
77
+ logger.info(f"Initialized SampleSizeOptimizer for {model_name}")
78
+ logger.info(f"Task: {task_name}, Layer: {layer}")
79
+ logger.info(f"Sample sizes to test: {self.sample_sizes}")
80
+
81
+ def load_and_split_data(self, limit: Optional[int] = None) -> Tuple[ContrastivePairSet, ContrastivePairSet]:
82
+ """
83
+ Load task data and split into train/test sets.
84
+
85
+ Args:
86
+ limit: Maximum number of samples to load (None for all)
87
+
88
+ Returns:
89
+ Tuple of (train_pairs, test_pairs)
90
+ """
91
+ logger.info(f"Loading data for task: {self.task_name}")
92
+
93
+ # Load task data using the model
94
+ max_samples = limit or 1000 # Default to 1000 if not specified
95
+
96
+ # Try to use cached benchmark data first
97
+ qa_pairs = None
98
+ try:
99
+ from .managed_cached_benchmarks import get_managed_cache
100
+
101
+ cache = get_managed_cache()
102
+ logger.info(f"Attempting to load from cache with limit={max_samples}")
103
+
104
+ # Load samples from cache (it will download if needed)
105
+ samples = cache.get_task_samples(self.task_name, limit=max_samples)
106
+
107
+ if samples:
108
+ logger.info(f"Loaded {len(samples)} samples from cache")
109
+ # Convert cached samples to QA pairs format
110
+ qa_pairs = []
111
+ for sample in samples:
112
+ # The cached sample has 'normalized' field with the QA pair
113
+ normalized = sample.get("normalized", {})
114
+ # Handle both formats: good_response/bad_response and correct_answer
115
+ if "good_response" in normalized and "bad_response" in normalized:
116
+ qa_pair = {
117
+ "question": normalized.get("context", normalized.get("question", "")),
118
+ "correct_answer": normalized.get("good_response", ""),
119
+ "incorrect_answer": normalized.get("bad_response", ""),
120
+ "metadata": normalized.get("metadata", {}),
121
+ }
122
+ else:
123
+ # For truthfulqa_mc1, we need to get incorrect answers from mc1_targets
124
+ raw_data = sample.get("raw_data", {})
125
+ mc1_targets = raw_data.get("mc1_targets", {})
126
+ choices = mc1_targets.get("choices", [])
127
+ labels = mc1_targets.get("labels", [])
128
+
129
+ # Find first incorrect answer
130
+ incorrect_answer = None
131
+ for i, label in enumerate(labels):
132
+ if label == 0 and i < len(choices):
133
+ incorrect_answer = choices[i]
134
+ break
135
+
136
+ if not incorrect_answer:
137
+ incorrect_answer = "This is incorrect"
138
+
139
+ qa_pair = {
140
+ "question": normalized.get("question", ""),
141
+ "correct_answer": normalized.get("correct_answer", ""),
142
+ "incorrect_answer": incorrect_answer,
143
+ "metadata": normalized.get("metadata", {}),
144
+ }
145
+ qa_pairs.append(qa_pair)
146
+ logger.info(f"Converted {len(qa_pairs)} cached samples to QA pairs")
147
+ except Exception as e:
148
+ logger.warning(f"Failed to load from cache: {e}")
149
+ qa_pairs = None
150
+
151
+ # Fallback to loading from lm-eval if cache failed
152
+ if not qa_pairs:
153
+ logger.info("Loading from lm-eval harness...")
154
+ # Load lm-eval task
155
+ task_data = self.model.load_lm_eval_task(self.task_name, shots=0, limit=max_samples)
156
+
157
+ # Split into train/test docs
158
+ docs, _ = self.model.split_task_data(task_data, split_ratio=1.0) # Use all for now
159
+
160
+ if not docs:
161
+ raise ValueError(f"No documents loaded for task {self.task_name}")
162
+
163
+ logger.info(f"Loaded {len(docs)} documents from {self.task_name}")
164
+
165
+ # Extract QA pairs from task docs
166
+ qa_pairs = ContrastivePairSet.extract_qa_pairs_from_task_docs(self.task_name, task_data, docs)
167
+
168
+ if not qa_pairs:
169
+ raise ValueError(f"No QA pairs could be extracted from task {self.task_name}")
170
+
171
+ logger.info(f"Extracted {len(qa_pairs)} QA pairs")
172
+
173
+ # Create contrastive pairs from QA pairs
174
+ from wisent_guard.core.activations.activation_collection_method import ActivationCollectionLogic
175
+
176
+ collector = ActivationCollectionLogic(model=self.model)
177
+
178
+ # Import token aggregation function
179
+
180
+ # Create contrastive pairs
181
+ all_pairs = []
182
+ for qa_pair in qa_pairs:
183
+ # Create prompts for positive and negative cases
184
+ question = qa_pair["question"]
185
+ correct_answer = qa_pair["correct_answer"]
186
+ incorrect_answer = qa_pair["incorrect_answer"]
187
+
188
+ # Generate with model to get activations
189
+ # Positive case (correct answer)
190
+ pos_prompt = self.model.format_prompt(question)
191
+ pos_response = correct_answer
192
+
193
+ # Negative case (incorrect answer)
194
+ neg_prompt = self.model.format_prompt(question)
195
+ neg_response = incorrect_answer
196
+
197
+ # Create contrastive pair
198
+ from .contrastive_pairs import ContrastivePair
199
+ from .response import NegativeResponse, PositiveResponse
200
+
201
+ pair = ContrastivePair(
202
+ prompt=question,
203
+ positive_response=PositiveResponse(text=pos_response),
204
+ negative_response=NegativeResponse(text=neg_response),
205
+ )
206
+ all_pairs.append(pair)
207
+
208
+ if not all_pairs:
209
+ raise ValueError(f"No contrastive pairs created for task {self.task_name}")
210
+
211
+ # Extract activations for all pairs at the specified layer
212
+ logger.info(f"Extracting activations at layer {self.layer}")
213
+
214
+ # Use the collector to extract activations
215
+ # For MULTIPLE_CHOICE, we use CHOICE_TOKEN targeting
216
+ all_pairs = collector.collect_activations_batch(
217
+ all_pairs,
218
+ layer_index=self.layer,
219
+ device=self.device,
220
+ token_targeting_strategy=ActivationAggregationStrategy.CHOICE_TOKEN,
221
+ )
222
+
223
+ # Filter out any pairs without activations
224
+ all_pairs = [p for p in all_pairs if p.positive_activations is not None and p.negative_activations is not None]
225
+
226
+ logger.info(f"Loaded {len(all_pairs)} contrastive pairs")
227
+
228
+ # Calculate split index
229
+ n_test = int(len(all_pairs) * self.test_split)
230
+ n_train = len(all_pairs) - n_test
231
+
232
+ # Create train and test sets
233
+ # Use a fixed seed for reproducibility
234
+ np.random.seed(42)
235
+ indices = np.random.permutation(len(all_pairs))
236
+
237
+ train_indices = indices[:n_train]
238
+ test_indices = indices[n_train:]
239
+
240
+ train_pairs = [all_pairs[i] for i in train_indices]
241
+ test_pairs = [all_pairs[i] for i in test_indices]
242
+
243
+ # Create ContrastivePairSet objects
244
+ train_set = ContrastivePairSet(name=f"{self.task_name}_train", pairs=train_pairs)
245
+ test_set = ContrastivePairSet(name=f"{self.task_name}_test", pairs=test_pairs)
246
+
247
+ logger.info(f"Split data: {len(train_pairs)} train, {len(test_pairs)} test")
248
+
249
+ return train_set, test_set
250
+
251
+ def _aggregate_activations(self, activations):
252
+ """
253
+ Apply token aggregation to activations based on configured method.
254
+
255
+ Since we're using CHOICE_TOKEN strategy, activations should be a single vector.
256
+ This method is here for consistency with the main CLI approach.
257
+
258
+ Args:
259
+ activations: Activation vector or tensor
260
+
261
+ Returns:
262
+ Aggregated activation vector
263
+ """
264
+ # For CHOICE_TOKEN strategy, activations are already a single vector
265
+ # No aggregation needed
266
+ return activations
267
+
268
+ def train_classifier_with_sample_size(
269
+ self, train_set: ContrastivePairSet, sample_size: int
270
+ ) -> Tuple[Classifier, float]:
271
+ """
272
+ Train a classifier with a specific sample size.
273
+
274
+ Args:
275
+ train_set: Full training set
276
+ sample_size: Number of samples to use for training
277
+
278
+ Returns:
279
+ Tuple of (trained_classifier, training_time)
280
+ """
281
+ # Limit training set to sample_size
282
+ if sample_size >= len(train_set.pairs):
283
+ train_pairs = train_set.pairs
284
+ else:
285
+ # Use first sample_size pairs (already shuffled)
286
+ train_pairs = train_set.pairs[:sample_size]
287
+
288
+ logger.info(f"Training classifier with {len(train_pairs)} samples")
289
+
290
+ # Ensure we have enough samples for training
291
+ if len(train_pairs) < 2:
292
+ logger.warning(f"Not enough training samples ({len(train_pairs)}). Skipping.")
293
+ return None, 0.0
294
+
295
+ # Extract activations
296
+ X_train = []
297
+ y_train = []
298
+
299
+ for pair in train_pairs:
300
+ # Positive example (correct answer)
301
+ X_train.append(pair.positive_activations)
302
+ y_train.append(0) # 0 for correct/truthful
303
+
304
+ # Negative example (incorrect answer)
305
+ X_train.append(pair.negative_activations)
306
+ y_train.append(1) # 1 for incorrect/untruthful
307
+
308
+ # Create and train classifier
309
+ classifier = Classifier(model_type="logistic", device=self.device)
310
+
311
+ start_time = time.time()
312
+ classifier.fit(X_train, y_train)
313
+ training_time = time.time() - start_time
314
+
315
+ return classifier, training_time
316
+
317
+ def evaluate_classifier(self, classifier: Classifier, test_set: ContrastivePairSet) -> Dict[str, float]:
318
+ """
319
+ Evaluate a classifier on the test set.
320
+
321
+ Args:
322
+ classifier: Trained classifier
323
+ test_set: Test set to evaluate on
324
+
325
+ Returns:
326
+ Dictionary of metrics
327
+ """
328
+ X_test = []
329
+ y_test = []
330
+
331
+ for pair in test_set.pairs:
332
+ # Positive example
333
+ X_test.append(pair.positive_activations)
334
+ y_test.append(0)
335
+
336
+ # Negative example
337
+ X_test.append(pair.negative_activations)
338
+ y_test.append(1)
339
+
340
+ # Get predictions
341
+ y_pred = []
342
+ for x in X_test:
343
+ pred = classifier.predict(x)
344
+ y_pred.append(1 if pred > 0.5 else 0)
345
+
346
+ # Calculate metrics
347
+ metrics = {
348
+ "accuracy": accuracy_score(y_test, y_pred),
349
+ "precision": precision_score(y_test, y_pred, zero_division=0),
350
+ "recall": recall_score(y_test, y_pred, zero_division=0),
351
+ "f1": f1_score(y_test, y_pred, zero_division=0),
352
+ }
353
+
354
+ return metrics
355
+
356
+ def find_optimal_sample_size(self) -> int:
357
+ """
358
+ Determine the optimal sample size based on diminishing returns.
359
+
360
+ Returns:
361
+ Optimal sample size
362
+ """
363
+ if len(self.results) < 2:
364
+ return self.sample_sizes[-1]
365
+
366
+ # Extract accuracies and times
367
+ accuracies = [r["metrics"]["accuracy"] for r in self.results]
368
+ times = [r["training_time"] for r in self.results]
369
+ sizes = [r["sample_size"] for r in self.results]
370
+
371
+ # Calculate accuracy gains
372
+ gains = []
373
+ for i in range(1, len(accuracies)):
374
+ gain = accuracies[i] - accuracies[i - 1]
375
+ gains.append(gain)
376
+
377
+ # Find where gain drops below threshold (2% improvement)
378
+ threshold = 0.02
379
+ optimal_idx = len(sizes) - 1 # Default to largest
380
+
381
+ for i, gain in enumerate(gains):
382
+ if gain < threshold and accuracies[i + 1] > 0.7: # Ensure reasonable accuracy
383
+ optimal_idx = i + 1
384
+ break
385
+
386
+ # Also consider training time - if time increases dramatically, prefer smaller
387
+ if optimal_idx < len(sizes) - 1 and times[optimal_idx] > 0:
388
+ time_ratio = times[optimal_idx + 1] / times[optimal_idx]
389
+ if time_ratio > 2.0 and gains[optimal_idx] < 0.01:
390
+ # Training time doubled for < 1% gain, stick with current
391
+ pass
392
+ elif accuracies[optimal_idx + 1] - accuracies[optimal_idx] > 0.05:
393
+ # Significant accuracy improvement, use larger size
394
+ optimal_idx += 1
395
+
396
+ return sizes[optimal_idx]
397
+
398
+ def run_optimization(self) -> Dict[str, Any]:
399
+ """
400
+ Run the complete sample size optimization process.
401
+
402
+ Returns:
403
+ Dictionary containing results and optimal sample size
404
+ """
405
+ logger.info("Starting sample size optimization...")
406
+
407
+ # Load and split data
408
+ dataset_limit = getattr(self, "dataset_limit", None)
409
+ train_set, test_set = self.load_and_split_data(limit=dataset_limit)
410
+
411
+ # Ensure we don't test sample sizes larger than training set
412
+ max_train_size = len(train_set.pairs)
413
+ valid_sample_sizes = [s for s in self.sample_sizes if s <= max_train_size]
414
+
415
+ if not valid_sample_sizes:
416
+ raise ValueError(f"No valid sample sizes. Training set has only {max_train_size} samples.")
417
+
418
+ logger.info(f"Testing sample sizes: {valid_sample_sizes}")
419
+
420
+ # Test each sample size
421
+ for sample_size in valid_sample_sizes:
422
+ logger.info(f"\n{'=' * 50}")
423
+ logger.info(f"Testing sample size: {sample_size}")
424
+
425
+ # Train classifier
426
+ classifier, training_time = self.train_classifier_with_sample_size(train_set, sample_size)
427
+
428
+ # Skip if classifier training failed
429
+ if classifier is None:
430
+ logger.warning(f"Skipping sample size {sample_size} - not enough samples for training")
431
+ continue
432
+
433
+ # Evaluate on test set
434
+ metrics = self.evaluate_classifier(classifier, test_set)
435
+
436
+ # Store results
437
+ result = {"sample_size": sample_size, "training_time": training_time, "metrics": metrics}
438
+ self.results.append(result)
439
+
440
+ logger.info(f"Accuracy: {metrics['accuracy']:.3f}")
441
+ logger.info(f"F1 Score: {metrics['f1']:.3f}")
442
+ logger.info(f"Training time: {training_time:.3f}s")
443
+
444
+ # Find optimal sample size
445
+ self.optimal_sample_size = self.find_optimal_sample_size()
446
+
447
+ logger.info(f"\n{'=' * 50}")
448
+ logger.info(f"Optimal sample size: {self.optimal_sample_size}")
449
+
450
+ # Create summary
451
+ summary = {
452
+ "model": self.model_name,
453
+ "task": self.task_name,
454
+ "layer": self.layer,
455
+ "test_split": self.test_split,
456
+ "results": self.results,
457
+ "optimal_sample_size": self.optimal_sample_size,
458
+ "timestamp": datetime.now().isoformat(),
459
+ }
460
+
461
+ return summary
462
+
463
+ def save_results(self, output_dir: Optional[str] = None) -> str:
464
+ """
465
+ Save optimization results to file.
466
+
467
+ Args:
468
+ output_dir: Directory to save results (uses default if None)
469
+
470
+ Returns:
471
+ Path to saved results file
472
+ """
473
+ if output_dir is None:
474
+ output_dir = "./sample_size_optimization_results"
475
+
476
+ os.makedirs(output_dir, exist_ok=True)
477
+
478
+ # Create filename
479
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
480
+ model_safe = self.model_name.replace("/", "_")
481
+ filename = f"sample_size_{model_safe}_{self.task_name}_layer{self.layer}_{timestamp}.json"
482
+ filepath = os.path.join(output_dir, filename)
483
+
484
+ # Prepare data for saving
485
+ save_data = {
486
+ "model": self.model_name,
487
+ "task": self.task_name,
488
+ "layer": self.layer,
489
+ "test_split": self.test_split,
490
+ "results": self.results,
491
+ "optimal_sample_size": self.optimal_sample_size,
492
+ "timestamp": datetime.now().isoformat(),
493
+ }
494
+
495
+ # Save to file
496
+ with open(filepath, "w") as f:
497
+ json.dump(save_data, f, indent=2)
498
+
499
+ logger.info(f"Results saved to: {filepath}")
500
+ return filepath
501
+
502
+ def plot_results(self, save_path: Optional[str] = None, show: bool = True) -> None:
503
+ """
504
+ Plot accuracy vs sample size curve.
505
+
506
+ Args:
507
+ save_path: Path to save plot (optional)
508
+ show: Whether to display the plot
509
+ """
510
+ if not self.results:
511
+ logger.warning("No results to plot")
512
+ return
513
+
514
+ # Extract data
515
+ sizes = [r["sample_size"] for r in self.results]
516
+ accuracies = [r["metrics"]["accuracy"] for r in self.results]
517
+ f1_scores = [r["metrics"]["f1"] for r in self.results]
518
+ times = [r["training_time"] for r in self.results]
519
+
520
+ # Create figure with subplots
521
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
522
+
523
+ # Plot 1: Accuracy and F1 vs Sample Size
524
+ ax1.plot(sizes, accuracies, "b-o", label="Accuracy", linewidth=2, markersize=8)
525
+ ax1.plot(sizes, f1_scores, "g--s", label="F1 Score", linewidth=2, markersize=8)
526
+
527
+ # Mark optimal sample size
528
+ if self.optimal_sample_size:
529
+ ax1.axvline(
530
+ self.optimal_sample_size, color="r", linestyle=":", label=f"Optimal: {self.optimal_sample_size}"
531
+ )
532
+
533
+ ax1.set_xlabel("Sample Size")
534
+ ax1.set_ylabel("Score")
535
+ ax1.set_title(
536
+ f"Classifier Performance vs Sample Size\n{self.model_name} - {self.task_name} - Layer {self.layer}"
537
+ )
538
+ ax1.legend()
539
+ ax1.grid(True, alpha=0.3)
540
+ # Use linear scale for x-axis
541
+ ax1.set_xticks(sizes)
542
+ ax1.set_xticklabels([str(s) for s in sizes])
543
+
544
+ # Plot 2: Training Time vs Sample Size
545
+ ax2.plot(sizes, times, "r-^", linewidth=2, markersize=8)
546
+ ax2.set_xlabel("Sample Size")
547
+ ax2.set_ylabel("Training Time (seconds)")
548
+ ax2.set_title("Training Time vs Sample Size")
549
+ ax2.grid(True, alpha=0.3)
550
+ # Use linear scale for x-axis
551
+ ax2.set_xticks(sizes)
552
+ ax2.set_xticklabels([str(s) for s in sizes])
553
+
554
+ plt.tight_layout()
555
+
556
+ if save_path:
557
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
558
+ logger.info(f"Plot saved to: {save_path}")
559
+
560
+ if show:
561
+ plt.show()
562
+
563
+ plt.close()
564
+
565
+
566
+ def run_sample_size_optimization(
567
+ model_name: str,
568
+ task_name: str = "truthfulqa_mc1",
569
+ layer: int = 0,
570
+ token_aggregation: str = "average",
571
+ threshold: float = 0.5,
572
+ test_split: float = 0.2,
573
+ sample_sizes: Optional[List[int]] = None,
574
+ dataset_limit: Optional[int] = None,
575
+ device: Optional[str] = None,
576
+ verbose: bool = False,
577
+ save_plot: bool = True,
578
+ save_to_config: bool = True,
579
+ ) -> Dict[str, Any]:
580
+ """
581
+ Run sample size optimization and optionally save to model config.
582
+
583
+ Args:
584
+ model_name: Name of the model
585
+ task_name: Task to optimize for
586
+ layer: Layer index
587
+ token_aggregation: Token aggregation method
588
+ threshold: Detection threshold
589
+ test_split: Test split ratio
590
+ sample_sizes: Sample sizes to test
591
+ dataset_limit: Maximum number of samples to load from dataset
592
+ device: Computation device
593
+ verbose: Verbose output
594
+ save_plot: Whether to save the plot
595
+ save_to_config: Whether to save to model config
596
+
597
+ Returns:
598
+ Optimization results dictionary
599
+ """
600
+ # Create optimizer
601
+ optimizer = SampleSizeOptimizer(
602
+ model_name=model_name,
603
+ task_name=task_name,
604
+ layer=layer,
605
+ token_aggregation=token_aggregation,
606
+ threshold=threshold,
607
+ test_split=test_split,
608
+ sample_sizes=sample_sizes,
609
+ device=device,
610
+ verbose=verbose,
611
+ )
612
+
613
+ # Run optimization with dataset limit
614
+ optimizer.dataset_limit = dataset_limit
615
+ results = optimizer.run_optimization()
616
+
617
+ # Save results
618
+ results_path = optimizer.save_results()
619
+
620
+ # Create plot
621
+ if save_plot:
622
+ plot_path = results_path.replace(".json", ".png")
623
+ optimizer.plot_results(save_path=plot_path, show=False)
624
+
625
+ # Save to model config if requested
626
+ if save_to_config and optimizer.optimal_sample_size:
627
+ config_manager = ModelConfigManager()
628
+
629
+ # Load existing config or create new
630
+ existing_config = config_manager.load_model_config(model_name)
631
+
632
+ if existing_config:
633
+ # Update existing config
634
+ if "optimal_sample_sizes" not in existing_config:
635
+ existing_config["optimal_sample_sizes"] = {}
636
+
637
+ if task_name not in existing_config["optimal_sample_sizes"]:
638
+ existing_config["optimal_sample_sizes"][task_name] = {}
639
+
640
+ existing_config["optimal_sample_sizes"][task_name][str(layer)] = optimizer.optimal_sample_size
641
+
642
+ # Save updated config
643
+ config_manager.update_model_config(model_name, existing_config)
644
+ logger.info(f"Updated model config with optimal sample size: {optimizer.optimal_sample_size}")
645
+ else:
646
+ logger.warning("No existing model config found. Run optimize-classification first.")
647
+
648
+ return results