wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
1
+ """
2
+ Simplified Sample Size Optimizer using training-limit and testing-limit flags.
3
+ Supports both classification and steering methods.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import logging
10
+ from typing import Dict, List, Optional, Any, Tuple
11
+ from datetime import datetime
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+
15
+ from ..cli import run_task_pipeline
16
+ from .model_config_manager import ModelConfigManager
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class SimplifiedSampleSizeOptimizer:
22
+ """Simplified sample size optimizer that leverages CLI training/testing limits."""
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: str,
27
+ task_name: str,
28
+ layer: int,
29
+ method_type: str = "classification", # "classification" or "steering"
30
+ sample_sizes: Optional[List[int]] = None,
31
+ test_size: int = 200,
32
+ seed: int = 42,
33
+ verbose: bool = False,
34
+ **method_kwargs
35
+ ):
36
+ """
37
+ Initialize the optimizer.
38
+
39
+ Args:
40
+ model_name: Model to optimize
41
+ task_name: Task to optimize for
42
+ layer: Layer to use
43
+ method_type: "classification" or "steering"
44
+ sample_sizes: List of training sample sizes to test
45
+ test_size: Fixed test set size
46
+ seed: Random seed for reproducibility
47
+ verbose: Verbose output
48
+ **method_kwargs: Additional arguments for the method
49
+ For classification: token_aggregation, threshold, classifier_type
50
+ For steering: steering_method, steering_strength, token_targeting_strategy
51
+ """
52
+ self.model_name = model_name
53
+ self.task_name = task_name
54
+ self.layer = layer
55
+ self.method_type = method_type
56
+ self.sample_sizes = sample_sizes or [5, 10, 20, 50, 100, 200, 500]
57
+ self.test_size = test_size
58
+ self.seed = seed
59
+ self.verbose = verbose
60
+ self.method_kwargs = method_kwargs
61
+
62
+ # Results storage
63
+ self.results = {
64
+ "sample_sizes": [],
65
+ "accuracies": [],
66
+ "f1_scores": [],
67
+ "training_times": [],
68
+ "evaluation_times": []
69
+ }
70
+
71
+ def run_single_experiment(self, training_size: int) -> Dict[str, Any]:
72
+ """
73
+ Run a single experiment with a specific training size.
74
+
75
+ Args:
76
+ training_size: Number of training samples
77
+
78
+ Returns:
79
+ Dictionary with results
80
+ """
81
+ if self.verbose:
82
+ print(f"\n{'='*60}")
83
+ print(f"Testing {self.method_type} with {training_size} training samples")
84
+ print(f"{'='*60}")
85
+
86
+ start_time = time.time()
87
+
88
+ # Build arguments for run_task_pipeline
89
+ pipeline_args = {
90
+ "task_name": self.task_name,
91
+ "model_name": self.model_name,
92
+ "layer": str(self.layer),
93
+ "training_limit": training_size,
94
+ "testing_limit": self.test_size,
95
+ "seed": self.seed,
96
+ "verbose": self.verbose,
97
+ "split_ratio": 0.8, # Standard split
98
+ "limit": training_size + self.test_size + 100, # Ensure enough data
99
+ }
100
+
101
+ # Add method-specific arguments
102
+ if self.method_type == "classification":
103
+ pipeline_args.update({
104
+ "token_aggregation": self.method_kwargs.get("token_aggregation", "average"),
105
+ "detection_threshold": self.method_kwargs.get("threshold", 0.5),
106
+ "classifier_type": self.method_kwargs.get("classifier_type", "logistic"),
107
+ "steering_mode": False
108
+ })
109
+ else: # steering
110
+ pipeline_args.update({
111
+ "steering_mode": True,
112
+ "steering_method": self.method_kwargs.get("steering_method", "CAA"),
113
+ "steering_strength": self.method_kwargs.get("steering_strength", 1.0),
114
+ "token_targeting_strategy": self.method_kwargs.get("token_targeting_strategy", "LAST_TOKEN"),
115
+ "token_aggregation": self.method_kwargs.get("token_aggregation", "average"),
116
+ })
117
+
118
+ try:
119
+ # Run the pipeline
120
+ result = run_task_pipeline(**pipeline_args)
121
+
122
+ end_time = time.time()
123
+ total_time = end_time - start_time
124
+
125
+ # Extract metrics based on method type
126
+ if self.method_type == "classification":
127
+ accuracy = result.get("test_accuracy", 0.0)
128
+ f1_score = result.get("test_f1_score", 0.0)
129
+ else: # steering
130
+ # For steering, we look at the evaluation results
131
+ eval_results = result.get("evaluation_results", {})
132
+ accuracy = eval_results.get("accuracy", 0.0)
133
+ # Convert to float if it's a string percentage
134
+ if isinstance(accuracy, str) and accuracy.endswith('%'):
135
+ accuracy = float(accuracy.rstrip('%')) / 100.0
136
+ f1_score = accuracy # Use accuracy as proxy for F1 in steering
137
+
138
+ return {
139
+ "accuracy": accuracy,
140
+ "f1_score": f1_score,
141
+ "training_time": result.get("training_time", total_time * 0.8),
142
+ "evaluation_time": total_time * 0.2,
143
+ "total_time": total_time,
144
+ "success": True
145
+ }
146
+
147
+ except Exception as e:
148
+ logger.error(f"Failed to run experiment with {training_size} samples: {e}")
149
+ return {
150
+ "accuracy": 0.0,
151
+ "f1_score": 0.0,
152
+ "training_time": 0.0,
153
+ "evaluation_time": 0.0,
154
+ "total_time": 0.0,
155
+ "success": False,
156
+ "error": str(e)
157
+ }
158
+
159
+ def run_optimization(self) -> Dict[str, Any]:
160
+ """
161
+ Run the complete optimization process.
162
+
163
+ Returns:
164
+ Dictionary with optimization results
165
+ """
166
+ logger.info(f"Starting {self.method_type} sample size optimization...")
167
+ logger.info(f"Model: {self.model_name}, Task: {self.task_name}, Layer: {self.layer}")
168
+ logger.info(f"Testing sample sizes: {self.sample_sizes}")
169
+ logger.info(f"Fixed test size: {self.test_size}")
170
+
171
+ # Run experiments for each sample size
172
+ for sample_size in self.sample_sizes:
173
+ result = self.run_single_experiment(sample_size)
174
+
175
+ if result["success"]:
176
+ self.results["sample_sizes"].append(sample_size)
177
+ self.results["accuracies"].append(result["accuracy"])
178
+ self.results["f1_scores"].append(result["f1_score"])
179
+ self.results["training_times"].append(result["training_time"])
180
+ self.results["evaluation_times"].append(result["evaluation_time"])
181
+
182
+ if self.verbose:
183
+ print(f"\nāœ“ Tested {sample_size} samples: accuracy={result['accuracy']:.3f}, f1={result['f1_score']:.3f}")
184
+
185
+ # Find optimal sample size
186
+ optimal_idx, optimal_size = self.find_optimal_sample_size()
187
+
188
+ return {
189
+ "optimal_sample_size": optimal_size,
190
+ "optimal_accuracy": self.results["accuracies"][optimal_idx] if optimal_idx >= 0 else None,
191
+ "optimal_f1_score": self.results["f1_scores"][optimal_idx] if optimal_idx >= 0 else None,
192
+ "all_results": self.results,
193
+ "method_type": self.method_type,
194
+ "method_kwargs": self.method_kwargs
195
+ }
196
+
197
+ def find_optimal_sample_size(self) -> Tuple[int, int]:
198
+ """
199
+ Find the optimal sample size based on accuracy and efficiency.
200
+
201
+ Returns:
202
+ Tuple of (optimal_index, optimal_sample_size)
203
+ """
204
+ if not self.results["accuracies"]:
205
+ return -1, 0
206
+
207
+ accuracies = np.array(self.results["accuracies"])
208
+ sample_sizes = np.array(self.results["sample_sizes"])
209
+
210
+ # Find the point of diminishing returns
211
+ # We want the smallest sample size that achieves near-optimal accuracy
212
+ max_accuracy = np.max(accuracies)
213
+ threshold = max_accuracy * 0.95 # Within 95% of best accuracy
214
+
215
+ # Find indices where accuracy is above threshold
216
+ good_indices = np.where(accuracies >= threshold)[0]
217
+
218
+ if len(good_indices) > 0:
219
+ # Choose the smallest sample size among good ones
220
+ optimal_idx = good_indices[0]
221
+ else:
222
+ # If no good indices, choose the best accuracy
223
+ optimal_idx = np.argmax(accuracies)
224
+
225
+ return optimal_idx, sample_sizes[optimal_idx]
226
+
227
+ def plot_results(self, save_path: Optional[str] = None) -> None:
228
+ """
229
+ Plot the optimization results.
230
+
231
+ Args:
232
+ save_path: Path to save the plot
233
+ """
234
+ if not self.results["sample_sizes"]:
235
+ logger.warning("No results to plot")
236
+ return
237
+
238
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
239
+
240
+ # Plot accuracy and F1 score
241
+ ax1.plot(self.results["sample_sizes"], self.results["accuracies"],
242
+ 'b-o', label='Accuracy', markersize=8)
243
+ ax1.plot(self.results["sample_sizes"], self.results["f1_scores"],
244
+ 'r--s', label='F1 Score', markersize=8)
245
+
246
+ # Mark optimal point
247
+ optimal_idx, optimal_size = self.find_optimal_sample_size()
248
+ if optimal_idx >= 0:
249
+ ax1.axvline(x=optimal_size, color='g', linestyle=':', alpha=0.7,
250
+ label=f'Optimal: {optimal_size}')
251
+
252
+ ax1.set_xlabel('Training Sample Size')
253
+ ax1.set_ylabel('Score')
254
+ ax1.set_title(f'{self.method_type.capitalize()} Performance vs Sample Size')
255
+ ax1.legend()
256
+ ax1.grid(True, alpha=0.3)
257
+ ax1.set_xscale('log')
258
+
259
+ # Plot training time
260
+ ax2.plot(self.results["sample_sizes"], self.results["training_times"],
261
+ 'g-^', label='Training Time', markersize=8)
262
+ ax2.set_xlabel('Training Sample Size')
263
+ ax2.set_ylabel('Time (seconds)')
264
+ ax2.set_title('Training Time vs Sample Size')
265
+ ax2.legend()
266
+ ax2.grid(True, alpha=0.3)
267
+ ax2.set_xscale('log')
268
+
269
+ plt.suptitle(f'Sample Size Optimization: {self.model_name} on {self.task_name}')
270
+ plt.tight_layout()
271
+
272
+ if save_path:
273
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
274
+ logger.info(f"Plot saved to {save_path}")
275
+ else:
276
+ plt.show()
277
+
278
+ plt.close()
279
+
280
+
281
+ def optimize_sample_size(
282
+ model_name: str,
283
+ task_name: str,
284
+ layer: int,
285
+ method_type: str = "classification",
286
+ sample_sizes: Optional[List[int]] = None,
287
+ test_size: int = 200,
288
+ seed: int = 42,
289
+ verbose: bool = False,
290
+ save_plot: bool = False,
291
+ save_to_config: bool = True,
292
+ **method_kwargs
293
+ ) -> Dict[str, Any]:
294
+ """
295
+ Convenience function to run sample size optimization.
296
+
297
+ Args:
298
+ model_name: Model to optimize
299
+ task_name: Task to optimize for
300
+ layer: Layer to use
301
+ method_type: "classification" or "steering"
302
+ sample_sizes: Sample sizes to test
303
+ test_size: Fixed test set size
304
+ seed: Random seed
305
+ verbose: Verbose output
306
+ save_plot: Whether to save the plot
307
+ save_to_config: Whether to save results to model config
308
+ **method_kwargs: Method-specific arguments
309
+
310
+ Returns:
311
+ Optimization results
312
+ """
313
+ optimizer = SimplifiedSampleSizeOptimizer(
314
+ model_name=model_name,
315
+ task_name=task_name,
316
+ layer=layer,
317
+ method_type=method_type,
318
+ sample_sizes=sample_sizes,
319
+ test_size=test_size,
320
+ seed=seed,
321
+ verbose=verbose,
322
+ **method_kwargs
323
+ )
324
+
325
+ results = optimizer.run_optimization()
326
+
327
+ # Create plot if requested
328
+ if save_plot:
329
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
330
+ plot_dir = f"sample_size_optimization/{model_name}"
331
+ os.makedirs(plot_dir, exist_ok=True)
332
+ plot_path = os.path.join(
333
+ plot_dir,
334
+ f"{task_name}_{method_type}_layer{layer}_{timestamp}.png"
335
+ )
336
+ optimizer.plot_results(plot_path)
337
+
338
+ # Save to config if requested
339
+ if save_to_config and results["optimal_sample_size"] > 0:
340
+ try:
341
+ config_manager = ModelConfigManager()
342
+
343
+ # For now, just log the optimal sample size
344
+ # TODO: Implement save_optimal_sample_size in ModelConfigManager
345
+ logger.info(
346
+ f"Optimal {method_type} sample size for {model_name} on {task_name}: "
347
+ f"{results['optimal_sample_size']} (accuracy: {results.get('optimal_accuracy', 'N/A')})"
348
+ )
349
+
350
+ if verbose:
351
+ print(f"\nšŸ’” Note: To use this optimal sample size, add --limit {results['optimal_sample_size']} to your commands")
352
+ except Exception as e:
353
+ logger.warning(f"Could not save to config: {e}")
354
+
355
+ return results
@@ -0,0 +1,277 @@
1
+ """
2
+ Functions for saving wisent-guard evaluation results in various formats.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import csv
8
+ import logging
9
+ from typing import Dict, Any
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def save_results_json(results: Dict[str, Any], output_path: str) -> None:
15
+ """Save results to JSON file."""
16
+ try:
17
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
18
+
19
+ with open(output_path, 'w') as f:
20
+ json.dump(results, f, indent=2, default=str)
21
+
22
+ logger.info(f"Results saved to {output_path}")
23
+
24
+ except Exception as e:
25
+ logger.error(f"Failed to save results to {output_path}: {e}")
26
+
27
+
28
+ def save_results_csv(results: Dict[str, Any], output_path: str) -> None:
29
+ """Save results to CSV file."""
30
+ try:
31
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
32
+
33
+ # Flatten results for CSV
34
+ rows = []
35
+ for task_name, task_results in results.items():
36
+ if isinstance(task_results, dict):
37
+ row = {"task": task_name}
38
+ row.update(task_results)
39
+ rows.append(row)
40
+
41
+ if rows:
42
+ with open(output_path, 'w', newline='') as f:
43
+ writer = csv.DictWriter(f, fieldnames=rows[0].keys())
44
+ writer.writeheader()
45
+ writer.writerows(rows)
46
+
47
+ logger.info(f"CSV results saved to {output_path}")
48
+
49
+ except Exception as e:
50
+ logger.error(f"Failed to save CSV to {output_path}: {e}")
51
+
52
+
53
+ def save_classification_results_csv(results: Dict[str, Any], output_path: str) -> None:
54
+ """
55
+ Save detailed classification results to CSV file for manual evaluation.
56
+
57
+ Exports one row per response with:
58
+ - For single-layer: question, response, token_scores, overall_prediction, ground_truth
59
+ - For multi-layer: question, response, token_scores_layer_X (for each layer),
60
+ aggregated_score_layer_X (for each layer), overall_prediction_layer_X (for each layer), ground_truth
61
+ """
62
+ try:
63
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
64
+
65
+ csv_rows = []
66
+ all_layers = set() # Track all layers for multi-layer mode
67
+ is_multi_layer = False
68
+
69
+ # First pass: determine if we have multi-layer data and collect all layers
70
+ for task_name, task_results in results.items():
71
+ if not isinstance(task_results, dict) or 'sample_responses' not in task_results:
72
+ continue
73
+
74
+ # Skip steering mode results (they don't have classification data)
75
+ if task_results.get('steering_mode', False):
76
+ continue
77
+
78
+ sample_responses = task_results['sample_responses']
79
+
80
+ for response_data in sample_responses:
81
+ layer_results = response_data.get('layer_results', {})
82
+ if layer_results:
83
+ is_multi_layer = True
84
+ all_layers.update(layer_results.keys())
85
+
86
+ # Sort layers for consistent column ordering
87
+ sorted_layers = sorted(all_layers) if all_layers else []
88
+
89
+ # Second pass: create CSV rows
90
+ for task_name, task_results in results.items():
91
+ if not isinstance(task_results, dict) or 'sample_responses' not in task_results:
92
+ continue
93
+
94
+ # Skip steering mode results (they don't have classification data)
95
+ if task_results.get('steering_mode', False):
96
+ continue
97
+
98
+ sample_responses = task_results['sample_responses']
99
+
100
+ for response_data in sample_responses:
101
+ layer_results = response_data.get('layer_results', {})
102
+
103
+ # Create base row
104
+ csv_row = {
105
+ 'question': response_data.get('question', ''),
106
+ 'response': response_data.get('response', ''),
107
+ 'ground_truth': '' # Empty for user to fill
108
+ }
109
+
110
+ if is_multi_layer and layer_results:
111
+ # Multi-layer mode: create columns for each layer
112
+ for layer in sorted_layers:
113
+ layer_data = layer_results.get(layer, {})
114
+
115
+ # Format token scores as pipe-separated values
116
+ token_scores_str = ""
117
+ if layer_data.get('token_scores'):
118
+ token_scores_formatted = [f"{score:.6f}" for score in layer_data['token_scores']]
119
+ token_scores_str = "|".join(token_scores_formatted)
120
+
121
+ # Add layer-specific columns
122
+ csv_row[f'token_scores_layer_{layer}'] = token_scores_str
123
+ csv_row[f'aggregated_score_layer_{layer}'] = f"{layer_data.get('aggregated_score', 0.0):.6f}"
124
+ csv_row[f'overall_prediction_layer_{layer}'] = layer_data.get('classification', 'UNKNOWN')
125
+
126
+ elif not is_multi_layer:
127
+ # Single-layer mode: use original format
128
+ token_scores_str = ""
129
+ if response_data.get('token_scores'):
130
+ token_scores_formatted = [f"{score:.6f}" for score in response_data['token_scores']]
131
+ token_scores_str = "|".join(token_scores_formatted)
132
+
133
+ csv_row['token_scores'] = token_scores_str
134
+ csv_row['aggregated_score'] = f"{response_data.get('aggregated_score', 0.0):.6f}"
135
+ csv_row['overall_prediction'] = response_data.get('classification', 'UNKNOWN')
136
+
137
+ csv_rows.append(csv_row)
138
+
139
+ # Only save if we have classification data
140
+ if csv_rows:
141
+ # Determine fieldnames based on mode
142
+ if is_multi_layer:
143
+ fieldnames = ['question', 'response']
144
+ for layer in sorted_layers:
145
+ fieldnames.extend([
146
+ f'token_scores_layer_{layer}',
147
+ f'aggregated_score_layer_{layer}',
148
+ f'overall_prediction_layer_{layer}'
149
+ ])
150
+ fieldnames.append('ground_truth')
151
+ else:
152
+ fieldnames = ['question', 'response', 'token_scores', 'aggregated_score', 'overall_prediction', 'ground_truth']
153
+
154
+ with open(output_path, 'w', newline='', encoding='utf-8') as f:
155
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
156
+ writer.writeheader()
157
+ writer.writerows(csv_rows)
158
+
159
+ logger.info(f"Classification results CSV saved to {output_path}")
160
+ print(f"\nšŸ“Š Classification results saved to: {output_path}")
161
+ print(f" • {len(csv_rows)} responses exported")
162
+ if is_multi_layer:
163
+ print(f" • Multi-layer format with columns for layers: {sorted_layers}")
164
+ print(f" • Token scores, aggregated scores, and predictions saved per layer")
165
+ else:
166
+ print(f" • Single-layer format")
167
+ print(f" • Fill in the 'ground_truth' column with: 'TRUTHFUL' or 'HALLUCINATION'")
168
+ print(f" • Use for manual evaluation and classifier optimization")
169
+ else:
170
+ logger.info("No classification results to export (steering mode or empty results)")
171
+
172
+ except Exception as e:
173
+ logger.error(f"Failed to save classification CSV to {output_path}: {e}")
174
+
175
+
176
+ def create_evaluation_report(results: Dict[str, Any], output_path: str) -> None:
177
+ """Create a comprehensive evaluation report in markdown format."""
178
+ try:
179
+ with open(output_path, 'w') as f:
180
+ f.write("# Wisent-Guard Evaluation Report\n\n")
181
+
182
+ # Summary table
183
+ f.write("## Summary\n\n")
184
+ f.write("| Task | Training Accuracy | Evaluation Accuracy | Optimization |\n")
185
+ f.write("|------|------------------|--------------------|--------------|\n")
186
+
187
+ for task_name, task_results in results.items():
188
+ if task_results is None:
189
+ f.write(f"| {task_name} | NULL | NULL | N/A |\n")
190
+ elif isinstance(task_results, dict) and "error" in task_results:
191
+ f.write(f"| {task_name} | ERROR | ERROR | N/A |\n")
192
+ elif isinstance(task_results, dict):
193
+ train_acc = task_results.get("training_results", {}).get("accuracy", "N/A")
194
+ eval_acc = task_results.get("evaluation_results", {}).get("accuracy", "N/A")
195
+ optimized = "Yes" if task_results.get("optimization_performed", False) else "No"
196
+
197
+ if isinstance(train_acc, float):
198
+ train_acc = f"{train_acc:.2%}"
199
+ if isinstance(eval_acc, float):
200
+ eval_acc = f"{eval_acc:.2%}"
201
+
202
+ f.write(f"| {task_name} | {train_acc} | {eval_acc} | {optimized} |\n")
203
+
204
+ # Detailed results for each task
205
+ for task_name, task_results in results.items():
206
+ f.write(f"\n## {task_name}\n\n")
207
+
208
+ if task_results is None:
209
+ f.write(f"**Error**: Task results are None\n")
210
+ elif isinstance(task_results, dict) and "error" in task_results:
211
+ f.write(f"**Error**: {task_results['error']}\n")
212
+ elif isinstance(task_results, dict):
213
+ # Configuration
214
+ f.write("### Configuration\n")
215
+ f.write(f"- **Model**: {task_results.get('model_name', 'Unknown')}\n")
216
+ f.write(f"- **Layer**: {task_results.get('layer', 'Unknown')}\n")
217
+ f.write(f"- **Classifier**: {task_results.get('classifier_type', 'Unknown')}\n")
218
+ f.write(f"- **Token Aggregation**: {task_results.get('token_aggregation', 'Unknown')}\n")
219
+ f.write(f"- **Ground Truth Method**: {task_results.get('ground_truth_method', 'Unknown')}\n")
220
+
221
+ # Training results
222
+ if "training_results" in task_results:
223
+ train_results = task_results["training_results"]
224
+ f.write("\n### Training Results\n")
225
+ train_acc = train_results.get('accuracy', 'N/A')
226
+ if isinstance(train_acc, float):
227
+ f.write(f"- **Accuracy**: {train_acc:.2%}\n")
228
+ else:
229
+ f.write(f"- **Accuracy**: {train_acc}\n")
230
+
231
+ train_prec = train_results.get('precision', 'N/A')
232
+ if isinstance(train_prec, float):
233
+ f.write(f"- **Precision**: {train_prec:.2f}\n")
234
+ else:
235
+ f.write(f"- **Precision**: {train_prec}\n")
236
+
237
+ train_recall = train_results.get('recall', 'N/A')
238
+ if isinstance(train_recall, float):
239
+ f.write(f"- **Recall**: {train_recall:.2f}\n")
240
+ else:
241
+ f.write(f"- **Recall**: {train_recall}\n")
242
+
243
+ train_f1 = train_results.get('f1', 'N/A')
244
+ if isinstance(train_f1, float):
245
+ f.write(f"- **F1 Score**: {train_f1:.2f}\n")
246
+ else:
247
+ f.write(f"- **F1 Score**: {train_f1}\n")
248
+
249
+ # Evaluation results
250
+ if "evaluation_results" in task_results:
251
+ eval_results = task_results["evaluation_results"]
252
+ f.write("\n### Evaluation Results\n")
253
+ eval_acc = eval_results.get('accuracy', 'N/A')
254
+ if isinstance(eval_acc, float):
255
+ f.write(f"- **Accuracy**: {eval_acc:.2%}\n")
256
+ else:
257
+ f.write(f"- **Accuracy**: {eval_acc}\n")
258
+ f.write(f"- **Total Predictions**: {eval_results.get('total_predictions', 'N/A')}\n")
259
+ f.write(f"- **Correct Predictions**: {eval_results.get('correct_predictions', 'N/A')}\n")
260
+
261
+ # Optimization results
262
+ if task_results.get("optimization_performed", False):
263
+ f.write("\n### Optimization Results\n")
264
+ f.write(f"- **Best Layer**: {task_results.get('best_layer', 'Unknown')}\n")
265
+ f.write(f"- **Best Aggregation**: {task_results.get('best_aggregation', 'Unknown')}\n")
266
+ best_acc = task_results.get('best_accuracy', 'Unknown')
267
+ if isinstance(best_acc, float):
268
+ f.write(f"- **Best Accuracy**: {best_acc:.2%}\n")
269
+ else:
270
+ f.write(f"- **Best Accuracy**: {best_acc}\n")
271
+
272
+ f.write(f"\n---\n\n*Report generated on {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*\n")
273
+
274
+ logger.info(f"Evaluation report saved to {output_path}")
275
+
276
+ except Exception as e:
277
+ logger.error(f"Failed to create report at {output_path}: {e}")