wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,364 @@
1
+ """
2
+ Mixed Benchmark Sampler for tag-based random sampling across multiple benchmarks.
3
+
4
+ This module enables training and evaluation on random samples from multiple benchmarks
5
+ that share common tags (e.g., 'coding', 'reasoning', 'math').
6
+ """
7
+
8
+ import random
9
+ import logging
10
+ from typing import List, Dict, Any, Optional, Set, Tuple
11
+ from dataclasses import dataclass
12
+ from collections import defaultdict
13
+
14
+ # Suppress BigCode debug output
15
+ import builtins
16
+ _original_print = getattr(builtins, '_original_print', builtins.print)
17
+
18
+ def _quiet_print(*args, **kwargs):
19
+ """Filter out BigCode debug messages."""
20
+ message = ' '.join(str(arg) for arg in args)
21
+ if any(x in message for x in ['DEBUG', 'Available tasks:', 'ERROR extracting', 'bigcode_eval']):
22
+ return
23
+ _original_print(*args, **kwargs)
24
+
25
+ # Store original print and patch
26
+ builtins._original_print = builtins.print
27
+ builtins.print = _quiet_print
28
+
29
+ try:
30
+ from .lm_harness_integration.only_benchmarks import CORE_BENCHMARKS
31
+ except ImportError:
32
+ # Try alternative import path
33
+ import sys
34
+ import os
35
+ current_dir = os.path.dirname(os.path.abspath(__file__))
36
+ sys.path.insert(0, os.path.join(current_dir, "lm-harness-integration"))
37
+ from only_benchmarks import CORE_BENCHMARKS
38
+
39
+ from .contrastive_pairs import ContrastivePairSet
40
+ from .managed_cached_benchmarks import ManagedCachedBenchmarks, get_managed_cache
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @dataclass
46
+ class BenchmarkSample:
47
+ """A single sample from a benchmark."""
48
+ benchmark_name: str
49
+ sample_data: Dict[str, Any]
50
+ tags: List[str]
51
+
52
+
53
+ class MixedBenchmarkSampler:
54
+ """
55
+ Samples randomly from multiple benchmarks based on tags.
56
+
57
+ This creates more robust classifiers by training on diverse data
58
+ from multiple sources rather than a single benchmark.
59
+ """
60
+
61
+ def __init__(self, cache_dir: str = "./benchmark_cache"):
62
+ """
63
+ Initialize the mixed benchmark sampler.
64
+
65
+ Args:
66
+ cache_dir: Directory for cached benchmark data
67
+ """
68
+ self.cache_dir = cache_dir
69
+ self.managed_cache = get_managed_cache(cache_dir)
70
+ self._benchmark_registry = self._build_benchmark_registry()
71
+
72
+ def _build_benchmark_registry(self) -> Dict[str, List[str]]:
73
+ """Build a registry mapping tags to benchmark names."""
74
+ tag_to_benchmarks = defaultdict(list)
75
+
76
+ for benchmark_name, config in CORE_BENCHMARKS.items():
77
+ tags = config.get("tags", [])
78
+ for tag in tags:
79
+ tag_to_benchmarks[tag].append(benchmark_name)
80
+
81
+ return dict(tag_to_benchmarks)
82
+
83
+ def get_benchmarks_by_tag(self, tag: str) -> List[str]:
84
+ """Get all benchmarks that have a specific tag."""
85
+ return self._benchmark_registry.get(tag, [])
86
+
87
+ def get_benchmarks_by_tags(self, tags: List[str], mode: str = "any") -> List[str]:
88
+ """
89
+ Get benchmarks that match the given tags.
90
+
91
+ Args:
92
+ tags: List of tags to match
93
+ mode: "any" (benchmark has at least one tag) or "all" (benchmark has all tags)
94
+
95
+ Returns:
96
+ List of benchmark names matching the criteria
97
+ """
98
+ if mode == "any":
99
+ # Get benchmarks that have ANY of the specified tags
100
+ matching_benchmarks = set()
101
+ for tag in tags:
102
+ matching_benchmarks.update(self.get_benchmarks_by_tag(tag))
103
+ return list(matching_benchmarks)
104
+
105
+ elif mode == "all":
106
+ # Get benchmarks that have ALL of the specified tags
107
+ if not tags:
108
+ return []
109
+
110
+ # Start with benchmarks that have the first tag
111
+ matching_benchmarks = set(self.get_benchmarks_by_tag(tags[0]))
112
+
113
+ # Intersect with benchmarks for each additional tag
114
+ for tag in tags[1:]:
115
+ matching_benchmarks &= set(self.get_benchmarks_by_tag(tag))
116
+
117
+ return list(matching_benchmarks)
118
+
119
+ else:
120
+ raise ValueError(f"Invalid mode: {mode}. Use 'any' or 'all'")
121
+
122
+ def sample_mixed_dataset(
123
+ self,
124
+ tags: List[str],
125
+ total_samples: int,
126
+ split_ratio: float = 0.8,
127
+ random_seed: Optional[int] = None,
128
+ tag_mode: str = "any",
129
+ benchmark_weights: Optional[Dict[str, float]] = None
130
+ ) -> Tuple[List[BenchmarkSample], List[BenchmarkSample]]:
131
+ """
132
+ Sample a mixed dataset from benchmarks matching the given tags.
133
+
134
+ Args:
135
+ tags: Tags to filter benchmarks (e.g., ["coding", "python"])
136
+ total_samples: Total number of samples to collect
137
+ split_ratio: Train/test split ratio
138
+ random_seed: Random seed for reproducibility
139
+ tag_mode: "any" or "all" for tag matching
140
+ benchmark_weights: Optional weights for sampling probability per benchmark
141
+
142
+ Returns:
143
+ Tuple of (train_samples, test_samples)
144
+ """
145
+ if random_seed is not None:
146
+ random.seed(random_seed)
147
+
148
+ # Get matching benchmarks
149
+ matching_benchmarks = self.get_benchmarks_by_tags(tags, mode=tag_mode)
150
+
151
+ if not matching_benchmarks:
152
+ raise ValueError(f"No benchmarks found with tags {tags} (mode={tag_mode})")
153
+
154
+ logger.info(f"Found {len(matching_benchmarks)} benchmarks matching tags {tags}")
155
+ logger.info(f"Matching benchmarks: {matching_benchmarks[:10]}...") # Show first 10
156
+
157
+ # Collect all available samples from matching benchmarks
158
+ all_samples = []
159
+ benchmark_sample_counts = {}
160
+
161
+ # Skip benchmarks that require code execution permission
162
+ code_execution_benchmarks = {"apps", "ds1000", "mercury"}
163
+
164
+ for benchmark_name in matching_benchmarks:
165
+ # Skip benchmarks that require code execution for safety
166
+ if benchmark_name in code_execution_benchmarks:
167
+ logger.info(f"Skipping {benchmark_name} (requires code execution permission)")
168
+ continue
169
+
170
+ try:
171
+ # Get samples from this benchmark
172
+ samples_per_benchmark = max(10, total_samples // len(matching_benchmarks))
173
+
174
+ cached_samples = self.managed_cache.get_task_samples(
175
+ task_name=benchmark_name,
176
+ limit=samples_per_benchmark,
177
+ force_fresh=False
178
+ )
179
+
180
+ # Convert to BenchmarkSample objects
181
+ for sample in cached_samples:
182
+ benchmark_sample = BenchmarkSample(
183
+ benchmark_name=benchmark_name,
184
+ sample_data=sample,
185
+ tags=CORE_BENCHMARKS[benchmark_name].get("tags", [])
186
+ )
187
+ all_samples.append(benchmark_sample)
188
+
189
+ benchmark_sample_counts[benchmark_name] = len(cached_samples)
190
+
191
+ except Exception as e:
192
+ logger.warning(f"Failed to load samples from {benchmark_name}: {e}")
193
+ continue
194
+
195
+ if not all_samples:
196
+ raise ValueError(f"No samples could be loaded from any benchmark with tags {tags}")
197
+
198
+ logger.info(f"Collected {len(all_samples)} total samples from {len(benchmark_sample_counts)} benchmarks")
199
+ for benchmark, count in benchmark_sample_counts.items():
200
+ logger.debug(f" {benchmark}: {count} samples")
201
+
202
+ # Apply benchmark weights if provided
203
+ if benchmark_weights:
204
+ weighted_samples = []
205
+ for sample in all_samples:
206
+ weight = benchmark_weights.get(sample.benchmark_name, 1.0)
207
+ # Duplicate samples based on weight (simple approach)
208
+ weighted_samples.extend([sample] * int(weight))
209
+ all_samples = weighted_samples
210
+
211
+ # Randomly sample and shuffle
212
+ if len(all_samples) > total_samples:
213
+ all_samples = random.sample(all_samples, total_samples)
214
+ else:
215
+ # If we have fewer samples than requested, use all and log warning
216
+ logger.warning(f"Only {len(all_samples)} samples available, requested {total_samples}")
217
+
218
+ random.shuffle(all_samples)
219
+
220
+ # Split into train/test
221
+ split_point = int(len(all_samples) * split_ratio)
222
+ train_samples = all_samples[:split_point]
223
+ test_samples = all_samples[split_point:]
224
+
225
+ # Log distribution
226
+ train_dist = defaultdict(int)
227
+ test_dist = defaultdict(int)
228
+
229
+ for sample in train_samples:
230
+ train_dist[sample.benchmark_name] += 1
231
+
232
+ for sample in test_samples:
233
+ test_dist[sample.benchmark_name] += 1
234
+
235
+ logger.info(f"Train set: {len(train_samples)} samples from {len(train_dist)} benchmarks")
236
+ logger.info(f"Test set: {len(test_samples)} samples from {len(test_dist)} benchmarks")
237
+
238
+ return train_samples, test_samples
239
+
240
+ def extract_contrastive_pairs_from_mixed_samples(
241
+ self,
242
+ samples: List[BenchmarkSample]
243
+ ) -> List[Dict[str, Any]]:
244
+ """
245
+ Extract contrastive pairs from mixed benchmark samples.
246
+
247
+ Args:
248
+ samples: List of BenchmarkSample objects
249
+
250
+ Returns:
251
+ List of contrastive pairs with question, correct_answer, incorrect_answer
252
+ """
253
+ contrastive_pairs = []
254
+
255
+ for sample in samples:
256
+ try:
257
+ # Each sample already has normalized QA pair from managed cache
258
+ qa_pair = sample.sample_data.get("normalized", {})
259
+
260
+ if qa_pair and all(k in qa_pair for k in ["question", "correct_answer", "incorrect_answer"]):
261
+ # Add benchmark source info
262
+ qa_pair["source_benchmark"] = sample.benchmark_name
263
+ qa_pair["tags"] = sample.tags
264
+ contrastive_pairs.append(qa_pair)
265
+ else:
266
+ logger.warning(f"Invalid QA pair from {sample.benchmark_name}")
267
+
268
+ except Exception as e:
269
+ logger.warning(f"Failed to extract pair from {sample.benchmark_name}: {e}")
270
+ continue
271
+
272
+ logger.info(f"Extracted {len(contrastive_pairs)} contrastive pairs from mixed samples")
273
+
274
+ return contrastive_pairs
275
+
276
+ def create_mixed_contrastive_pair_set(
277
+ self,
278
+ tags: List[str],
279
+ total_samples: int,
280
+ name: Optional[str] = None,
281
+ **kwargs
282
+ ) -> ContrastivePairSet:
283
+ """
284
+ Create a ContrastivePairSet from mixed benchmark samples.
285
+
286
+ Args:
287
+ tags: Tags to filter benchmarks
288
+ total_samples: Number of samples to include
289
+ name: Name for the pair set (auto-generated if not provided)
290
+ **kwargs: Additional arguments for sample_mixed_dataset
291
+
292
+ Returns:
293
+ ContrastivePairSet ready for training
294
+ """
295
+ # Sample mixed dataset
296
+ train_samples, test_samples = self.sample_mixed_dataset(
297
+ tags=tags,
298
+ total_samples=total_samples,
299
+ **kwargs
300
+ )
301
+
302
+ # Extract contrastive pairs
303
+ all_samples = train_samples + test_samples
304
+ contrastive_pairs = self.extract_contrastive_pairs_from_mixed_samples(all_samples)
305
+
306
+ # Create name if not provided
307
+ if name is None:
308
+ name = f"mixed_{'_'.join(tags)}_{total_samples}_samples"
309
+
310
+ # Create ContrastivePairSet
311
+ return ContrastivePairSet.from_contrastive_pairs(
312
+ name=name,
313
+ contrastive_pairs=contrastive_pairs,
314
+ task_type="mixed_benchmark"
315
+ )
316
+
317
+
318
+ def sample_benchmarks_by_tag(
319
+ tag: str,
320
+ samples_per_benchmark: int = 10,
321
+ max_benchmarks: Optional[int] = None,
322
+ random_seed: Optional[int] = None
323
+ ) -> Dict[str, List[Dict[str, Any]]]:
324
+ """
325
+ Convenience function to sample from all benchmarks with a specific tag.
326
+
327
+ Args:
328
+ tag: Tag to filter benchmarks (e.g., "coding")
329
+ samples_per_benchmark: Number of samples from each benchmark
330
+ max_benchmarks: Maximum number of benchmarks to sample from
331
+ random_seed: Random seed for reproducibility
332
+
333
+ Returns:
334
+ Dictionary mapping benchmark names to their samples
335
+ """
336
+ sampler = MixedBenchmarkSampler()
337
+
338
+ # Get all benchmarks with the tag
339
+ benchmarks = sampler.get_benchmarks_by_tag(tag)
340
+
341
+ if max_benchmarks and len(benchmarks) > max_benchmarks:
342
+ if random_seed is not None:
343
+ random.seed(random_seed)
344
+ benchmarks = random.sample(benchmarks, max_benchmarks)
345
+
346
+ # Sample from each benchmark
347
+ results = {}
348
+ cache = get_managed_cache()
349
+
350
+ for benchmark_name in benchmarks:
351
+ try:
352
+ samples = cache.get_task_samples(
353
+ task_name=benchmark_name,
354
+ limit=samples_per_benchmark,
355
+ force_fresh=False
356
+ )
357
+ results[benchmark_name] = samples
358
+ logger.info(f"Sampled {len(samples)} from {benchmark_name}")
359
+
360
+ except Exception as e:
361
+ logger.warning(f"Failed to sample from {benchmark_name}: {e}")
362
+ continue
363
+
364
+ return results
@@ -0,0 +1,330 @@
1
+ """
2
+ Model Configuration Manager for storing and retrieving optimal parameters per model.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import logging
8
+ from typing import Dict, Any, Optional, List
9
+ from datetime import datetime
10
+ import hashlib
11
+ import numpy as np
12
+
13
+
14
+ class NumpyEncoder(json.JSONEncoder):
15
+ """Custom JSON encoder to handle numpy types."""
16
+ def default(self, obj):
17
+ if isinstance(obj, (np.integer, np.int64)):
18
+ return int(obj)
19
+ if isinstance(obj, (np.floating, np.float64)):
20
+ return float(obj)
21
+ if isinstance(obj, np.ndarray):
22
+ return obj.tolist()
23
+ return super().default(obj)
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class ModelConfigManager:
29
+ """Manages model-specific configuration files for optimal parameters."""
30
+
31
+ def __init__(self, config_dir: Optional[str] = None):
32
+ """
33
+ Initialize the ModelConfigManager.
34
+
35
+ Args:
36
+ config_dir: Directory to store config files. If None, uses default location.
37
+ """
38
+ if config_dir is None:
39
+ # Use ~/.wisent-guard/model_configs/ as default
40
+ home_dir = os.path.expanduser("~")
41
+ self.config_dir = os.path.join(home_dir, ".wisent-guard", "model_configs")
42
+ else:
43
+ self.config_dir = config_dir
44
+
45
+ # Create directory if it doesn't exist
46
+ os.makedirs(self.config_dir, exist_ok=True)
47
+
48
+ def _sanitize_model_name(self, model_name: str) -> str:
49
+ """
50
+ Convert model name to a safe filename.
51
+
52
+ Args:
53
+ model_name: Original model name (e.g., "meta-llama/Llama-3.1-8B-Instruct")
54
+
55
+ Returns:
56
+ Sanitized filename (e.g., "meta-llama_Llama-3.1-8B-Instruct")
57
+ """
58
+ # Replace problematic characters
59
+ sanitized = model_name.replace("/", "_").replace("\\", "_").replace(":", "_")
60
+ # Remove any other problematic characters
61
+ sanitized = "".join(c for c in sanitized if c.isalnum() or c in "._-")
62
+ return sanitized
63
+
64
+ def _get_config_path(self, model_name: str) -> str:
65
+ """Get the full path to the config file for a model."""
66
+ sanitized_name = self._sanitize_model_name(model_name)
67
+ return os.path.join(self.config_dir, f"{sanitized_name}.json")
68
+
69
+ def save_model_config(
70
+ self,
71
+ model_name: str,
72
+ classification_layer: int,
73
+ steering_layer: Optional[int] = None,
74
+ token_aggregation: str = "average",
75
+ detection_threshold: float = 0.6,
76
+ optimization_method: str = "manual",
77
+ optimization_metrics: Optional[Dict[str, Any]] = None,
78
+ task_specific_overrides: Optional[Dict[str, Dict[str, Any]]] = None
79
+ ) -> str:
80
+ """
81
+ Save optimal parameters for a model.
82
+
83
+ Args:
84
+ model_name: Name/path of the model
85
+ classification_layer: Optimal layer for classification
86
+ steering_layer: Optimal layer for steering (defaults to classification_layer)
87
+ token_aggregation: Token aggregation method
88
+ detection_threshold: Detection threshold
89
+ optimization_method: How these parameters were determined
90
+ optimization_metrics: Metrics from optimization process
91
+ task_specific_overrides: Task-specific parameter overrides
92
+
93
+ Returns:
94
+ Path to the saved config file
95
+ """
96
+ if steering_layer is None:
97
+ steering_layer = classification_layer
98
+
99
+ config_data = {
100
+ "model_name": model_name,
101
+ "created_date": datetime.now().isoformat(),
102
+ "optimization_method": optimization_method,
103
+ "optimal_parameters": {
104
+ "classification_layer": classification_layer,
105
+ "steering_layer": steering_layer,
106
+ "token_aggregation": token_aggregation,
107
+ "detection_threshold": detection_threshold
108
+ },
109
+ "task_specific_overrides": task_specific_overrides or {},
110
+ "optimization_metrics": optimization_metrics or {},
111
+ "config_version": "1.0"
112
+ }
113
+
114
+ config_path = self._get_config_path(model_name)
115
+
116
+ try:
117
+ with open(config_path, 'w') as f:
118
+ json.dump(config_data, f, indent=2, cls=NumpyEncoder)
119
+
120
+ logger.info(f"✅ Model configuration saved: {config_path}")
121
+ logger.info(f" • Classification layer: {classification_layer}")
122
+ logger.info(f" • Steering layer: {steering_layer}")
123
+ logger.info(f" • Token aggregation: {token_aggregation}")
124
+ logger.info(f" • Detection threshold: {detection_threshold}")
125
+
126
+ return config_path
127
+
128
+ except Exception as e:
129
+ logger.error(f"❌ Failed to save model configuration: {e}")
130
+ raise
131
+
132
+ def load_model_config(self, model_name: str) -> Optional[Dict[str, Any]]:
133
+ """
134
+ Load optimal parameters for a model.
135
+
136
+ Args:
137
+ model_name: Name/path of the model
138
+
139
+ Returns:
140
+ Configuration dictionary if found, None otherwise
141
+ """
142
+ config_path = self._get_config_path(model_name)
143
+
144
+ if not os.path.exists(config_path):
145
+ return None
146
+
147
+ try:
148
+ with open(config_path, 'r') as f:
149
+ config_data = json.load(f)
150
+
151
+ logger.debug(f"📄 Loaded model configuration: {config_path}")
152
+ return config_data
153
+
154
+ except Exception as e:
155
+ logger.warning(f"⚠️ Failed to load model configuration: {e}")
156
+ return None
157
+
158
+ def has_model_config(self, model_name: str) -> bool:
159
+ """Check if a model has a saved configuration."""
160
+ config_path = self._get_config_path(model_name)
161
+ return os.path.exists(config_path)
162
+
163
+ def update_model_config(self, model_name: str, config_data: Dict[str, Any]) -> str:
164
+ """
165
+ Update an existing model configuration.
166
+
167
+ Args:
168
+ model_name: Name/path of the model
169
+ config_data: Updated configuration dictionary
170
+
171
+ Returns:
172
+ Path to the saved config file
173
+ """
174
+ config_path = self._get_config_path(model_name)
175
+
176
+ # Update timestamp
177
+ config_data["updated_date"] = datetime.now().isoformat()
178
+
179
+ try:
180
+ with open(config_path, 'w') as f:
181
+ json.dump(config_data, f, indent=2, cls=NumpyEncoder)
182
+
183
+ logger.info(f"✅ Model configuration updated: {config_path}")
184
+ return config_path
185
+
186
+ except Exception as e:
187
+ logger.error(f"❌ Failed to update model configuration: {e}")
188
+ raise
189
+
190
+ def get_optimal_parameters(
191
+ self,
192
+ model_name: str,
193
+ task_name: Optional[str] = None
194
+ ) -> Optional[Dict[str, Any]]:
195
+ """
196
+ Get optimal parameters for a model, with optional task-specific overrides.
197
+
198
+ Args:
199
+ model_name: Name/path of the model
200
+ task_name: Specific task name for overrides
201
+
202
+ Returns:
203
+ Dictionary of optimal parameters or None if no config exists
204
+ """
205
+ config = self.load_model_config(model_name)
206
+ if not config:
207
+ return None
208
+
209
+ # Start with base optimal parameters
210
+ optimal_params = config.get("optimal_parameters", {}).copy()
211
+
212
+ # Apply task-specific overrides if available
213
+ if task_name and "task_specific_overrides" in config:
214
+ task_overrides = config["task_specific_overrides"].get(task_name, {})
215
+ optimal_params.update(task_overrides)
216
+
217
+ return optimal_params
218
+
219
+ def get_optimal_sample_size(
220
+ self,
221
+ model_name: str,
222
+ task_name: str,
223
+ layer: int
224
+ ) -> Optional[int]:
225
+ """
226
+ Get optimal sample size for a specific task and layer.
227
+
228
+ Args:
229
+ model_name: Name/path of the model
230
+ task_name: Task name
231
+ layer: Layer index
232
+
233
+ Returns:
234
+ Optimal sample size or None if not found
235
+ """
236
+ config = self.load_model_config(model_name)
237
+ if not config:
238
+ return None
239
+
240
+ # Check if optimal_sample_sizes exists
241
+ if "optimal_sample_sizes" not in config:
242
+ return None
243
+
244
+ # Navigate the nested structure: optimal_sample_sizes[task][layer]
245
+ task_sizes = config["optimal_sample_sizes"].get(task_name, {})
246
+ sample_size = task_sizes.get(str(layer), None)
247
+
248
+ return sample_size
249
+
250
+ def list_model_configs(self) -> List[Dict[str, Any]]:
251
+ """
252
+ List all available model configurations.
253
+
254
+ Returns:
255
+ List of configuration summaries
256
+ """
257
+ configs = []
258
+
259
+ if not os.path.exists(self.config_dir):
260
+ return configs
261
+
262
+ for filename in os.listdir(self.config_dir):
263
+ if filename.endswith('.json'):
264
+ try:
265
+ config_path = os.path.join(self.config_dir, filename)
266
+ with open(config_path, 'r') as f:
267
+ config_data = json.load(f)
268
+
269
+ summary = {
270
+ "model_name": config_data.get("model_name", "unknown"),
271
+ "created_date": config_data.get("created_date", "unknown"),
272
+ "optimization_method": config_data.get("optimization_method", "unknown"),
273
+ "classification_layer": config_data.get("optimal_parameters", {}).get("classification_layer"),
274
+ "steering_layer": config_data.get("optimal_parameters", {}).get("steering_layer"),
275
+ "config_file": filename
276
+ }
277
+ configs.append(summary)
278
+
279
+ except Exception as e:
280
+ logger.warning(f"⚠️ Failed to read config file {filename}: {e}")
281
+
282
+ return configs
283
+
284
+ def remove_model_config(self, model_name: str) -> bool:
285
+ """
286
+ Remove a model configuration.
287
+
288
+ Args:
289
+ model_name: Name/path of the model
290
+
291
+ Returns:
292
+ True if removed successfully, False otherwise
293
+ """
294
+ config_path = self._get_config_path(model_name)
295
+
296
+ if not os.path.exists(config_path):
297
+ logger.warning(f"⚠️ No configuration found for model: {model_name}")
298
+ return False
299
+
300
+ try:
301
+ os.remove(config_path)
302
+ logger.info(f"✅ Removed model configuration: {config_path}")
303
+ return True
304
+
305
+ except Exception as e:
306
+ logger.error(f"❌ Failed to remove model configuration: {e}")
307
+ return False
308
+
309
+
310
+ # Convenience functions for easy access
311
+ _default_manager = None
312
+
313
+ def get_default_manager() -> ModelConfigManager:
314
+ """Get the default ModelConfigManager instance."""
315
+ global _default_manager
316
+ if _default_manager is None:
317
+ _default_manager = ModelConfigManager()
318
+ return _default_manager
319
+
320
+ def save_model_config(model_name: str, **kwargs) -> str:
321
+ """Save model configuration using default manager."""
322
+ return get_default_manager().save_model_config(model_name, **kwargs)
323
+
324
+ def load_model_config(model_name: str) -> Optional[Dict[str, Any]]:
325
+ """Load model configuration using default manager."""
326
+ return get_default_manager().load_model_config(model_name)
327
+
328
+ def get_optimal_parameters(model_name: str, task_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
329
+ """Get optimal parameters using default manager."""
330
+ return get_default_manager().get_optimal_parameters(model_name, task_name)