wisent 0.1.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.1.dist-info/METADATA +67 -0
  215. wisent-0.5.1.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,158 @@
1
+ """
2
+ User-defined model configuration storage and retrieval.
3
+ Handles models that aren't explicitly supported by storing user-provided configurations.
4
+ """
5
+
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, Any
10
+ from enum import Enum
11
+
12
+
13
+ class ModelArchitecture(Enum):
14
+ """Supported model architectures for layer access."""
15
+ LLAMA_STYLE = "llama_style" # model.layers.{idx}
16
+ GPT2_STYLE = "gpt2_style" # transformer.h.{idx}
17
+ MPT_STYLE = "mpt_style" # transformer.blocks.{idx}
18
+ CUSTOM = "custom" # User provides full path template
19
+
20
+
21
+ class UserModelConfig:
22
+ """Manages user-defined model configurations."""
23
+
24
+ def __init__(self):
25
+ # Store config in user's home directory
26
+ self.config_dir = Path.home() / ".wisent-guard"
27
+ self.config_file = self.config_dir / "user_model_configs.json"
28
+ self.configs = self._load_configs()
29
+
30
+ def _load_configs(self) -> Dict[str, Any]:
31
+ """Load existing configurations from file."""
32
+ if self.config_file.exists():
33
+ try:
34
+ with open(self.config_file, 'r') as f:
35
+ return json.load(f)
36
+ except Exception:
37
+ return {}
38
+ return {}
39
+
40
+ def _save_configs(self) -> None:
41
+ """Save configurations to file."""
42
+ self.config_dir.mkdir(exist_ok=True)
43
+ with open(self.config_file, 'w') as f:
44
+ json.dump(self.configs, f, indent=2)
45
+
46
+ def has_config(self, model_name: str) -> bool:
47
+ """Check if we have a configuration for this model."""
48
+ return model_name in self.configs
49
+
50
+ def get_config(self, model_name: str) -> Optional[Dict[str, Any]]:
51
+ """Get configuration for a model."""
52
+ return self.configs.get(model_name)
53
+
54
+ def save_config(self, model_name: str, config: Dict[str, Any]) -> None:
55
+ """Save configuration for a model."""
56
+ self.configs[model_name] = config
57
+ self._save_configs()
58
+
59
+ def get_prompt_tokens(self, model_name: str) -> Optional[Dict[str, str]]:
60
+ """Get user and assistant tokens for a model."""
61
+ config = self.get_config(model_name)
62
+ if config:
63
+ return {
64
+ "user_token": config.get("user_token"),
65
+ "assistant_token": config.get("assistant_token"),
66
+ "system_token": config.get("system_token"), # Optional
67
+ "format_template": config.get("format_template") # Optional custom template
68
+ }
69
+ return None
70
+
71
+ def get_layer_access_info(self, model_name: str) -> Optional[Dict[str, Any]]:
72
+ """Get layer access information for a model."""
73
+ config = self.get_config(model_name)
74
+ if config:
75
+ return {
76
+ "architecture": config.get("architecture"),
77
+ "layer_path_template": config.get("layer_path_template"),
78
+ "custom_layer_accessor": config.get("custom_layer_accessor")
79
+ }
80
+ return None
81
+
82
+ def prompt_and_save_config(self, model_name: str) -> Dict[str, Any]:
83
+ """
84
+ Interactively prompt user for model configuration.
85
+ This should be called from the CLI when an unknown model is encountered.
86
+ """
87
+ print(f"\n⚠️ Model '{model_name}' is not recognized.")
88
+ print("We need some information to properly support this model.\n")
89
+
90
+ config = {"model_name": model_name}
91
+
92
+ # Prompt for tokens
93
+ print("1. Chat Format Tokens")
94
+ print(" These are the special tokens your model uses to distinguish user and assistant messages.")
95
+ print(" Examples:")
96
+ print(" - Llama 3: <|start_header_id|>user<|end_header_id|> and <|start_header_id|>assistant<|end_header_id|>")
97
+ print(" - ChatGPT: <|im_start|>user and <|im_start|>assistant")
98
+ print(" - Alpaca: ### Human: and ### Assistant:")
99
+
100
+ config["user_token"] = input("\n Enter the user token/prefix: ").strip()
101
+ config["assistant_token"] = input(" Enter the assistant token/prefix: ").strip()
102
+
103
+ # Optional system token
104
+ system_token = input(" Enter the system token/prefix (press Enter to skip): ").strip()
105
+ if system_token:
106
+ config["system_token"] = system_token
107
+
108
+ # Model architecture for layer access
109
+ print("\n2. Model Architecture")
110
+ print(" How are the transformer layers accessed in this model?")
111
+ print(" 1. Llama-style: model.layers.{idx}")
112
+ print(" 2. GPT2-style: transformer.h.{idx}")
113
+ print(" 3. MPT-style: transformer.blocks.{idx}")
114
+ print(" 4. Custom (you'll provide the template)")
115
+
116
+ while True:
117
+ choice = input("\n Select architecture (1-4): ").strip()
118
+ if choice == "1":
119
+ config["architecture"] = ModelArchitecture.LLAMA_STYLE.value
120
+ config["layer_path_template"] = "model.layers.{idx}"
121
+ break
122
+ elif choice == "2":
123
+ config["architecture"] = ModelArchitecture.GPT2_STYLE.value
124
+ config["layer_path_template"] = "transformer.h.{idx}"
125
+ break
126
+ elif choice == "3":
127
+ config["architecture"] = ModelArchitecture.MPT_STYLE.value
128
+ config["layer_path_template"] = "transformer.blocks.{idx}"
129
+ break
130
+ elif choice == "4":
131
+ config["architecture"] = ModelArchitecture.CUSTOM.value
132
+ template = input(" Enter the layer path template (use {idx} for layer index): ").strip()
133
+ config["layer_path_template"] = template
134
+ break
135
+ else:
136
+ print(" Invalid choice. Please enter 1, 2, 3, or 4.")
137
+
138
+ # Optional: custom format template
139
+ print("\n3. Custom Format Template (Optional)")
140
+ print(" If your model requires a specific prompt format beyond simple token prefixes,")
141
+ print(" you can provide a template. Use {user_message} and {assistant_message} as placeholders.")
142
+ print(" Example: '<|system|>\\nYou are a helpful assistant\\n{user_message}\\n{assistant_message}'")
143
+
144
+ custom_template = input("\n Enter custom template (press Enter to skip): ").strip()
145
+ if custom_template:
146
+ config["format_template"] = custom_template
147
+
148
+ # Save the configuration
149
+ self.save_config(model_name, config)
150
+
151
+ print(f"\n✅ Configuration saved for {model_name}")
152
+ print(f" Config location: {self.config_file}")
153
+
154
+ return config
155
+
156
+
157
+ # Global instance for easy access
158
+ user_model_configs = UserModelConfig()
File without changes
File without changes
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Any, Literal
6
+
7
+ import optuna
8
+
9
+ __all__ = [
10
+ "Direction",
11
+ "HPOConfig",
12
+ "HPORun",
13
+ "BaseOptimizer",
14
+ ]
15
+
16
+ Direction = Literal["maximize", "minimize"]
17
+
18
+
19
+ @dataclass(slots=True, frozen=True)
20
+ class HPOConfig:
21
+ """
22
+ Configuration for hyperparameter optimization (HPO) using Optuna.
23
+
24
+ attributes:
25
+ n_trials:
26
+ number of trials (ignored if timeout is reached).
27
+ direction:
28
+ global default direction ("maximize" or "minimize").
29
+ sampler:
30
+ one of {"tpe", "random", "cmaes"} or None for Optuna default.
31
+ pruner:
32
+ one of {"nop", "median", "sha", "asha", "hyperband"} or None for default.
33
+ timeout:
34
+ optional global seconds budget.
35
+ study_name:
36
+ optional persistent study name.
37
+ storage:
38
+ Optuna storage URL (e.g., sqlite:///file.db) for persistence.
39
+ seed:
40
+ sampler seed for reproducibility.
41
+ load_if_exists:
42
+ reuse persisted study if it already exists (when storage+study_name set).
43
+ """
44
+ n_trials: int = 100
45
+ direction: Direction = "maximize"
46
+ sampler: str | None = "tpe"
47
+ pruner: str | None = "asha"
48
+ timeout: int | None = None
49
+ storage: str | None = None
50
+ study_name: str | None = None
51
+ seed: int | None = 42
52
+ load_if_exists: bool = True
53
+
54
+
55
+ @dataclass(slots=True, frozen=True)
56
+ class HPORun:
57
+ """
58
+ Result of an HPO run.
59
+ """
60
+ study: optuna.Study
61
+ best_params: dict[str, Any]
62
+ best_value: float
63
+
64
+
65
+ class BaseOptimizer(ABC):
66
+ """
67
+ Base class for building task-agnostic Optuna optimizers.
68
+
69
+ Subclasses must implement '_objective(trial)' and return a float objective.
70
+ This class wires up samplers/pruners and runs 'study.optimize(...)'.
71
+ """
72
+
73
+ name: str = "base-optimizer"
74
+ direction: Direction = "maximize"
75
+
76
+ def optimize(self, cfg: HPOConfig) -> HPORun:
77
+ """
78
+ Run the optimization process.
79
+
80
+ arguments:
81
+ cfg:
82
+ HPOConfig object with optimization settings.
83
+
84
+ returns:
85
+ HPORun object with the results of the optimization.
86
+ """
87
+ sampler = self._make_sampler(cfg)
88
+ pruner = self._make_pruner(cfg)
89
+ direction: Direction = getattr(self, "direction", cfg.direction)
90
+
91
+ study = optuna.create_study(
92
+ direction=direction,
93
+ sampler=sampler,
94
+ pruner=pruner,
95
+ storage=cfg.storage,
96
+ study_name=cfg.study_name,
97
+ load_if_exists=bool(cfg.storage and cfg.study_name and cfg.load_if_exists),
98
+ )
99
+
100
+ study.optimize(self._objective, n_trials=cfg.n_trials, timeout=cfg.timeout, show_progress_bar=False)
101
+ return HPORun(study=study, best_params=study.best_params, best_value=study.best_value)
102
+
103
+ @abstractmethod
104
+ def _objective(self, trial: optuna.Trial) -> float:
105
+ """
106
+ Implement one trial; return objective value.
107
+ """
108
+ raise NotImplementedError
109
+
110
+ def _make_sampler(self, cfg: HPOConfig) -> optuna.samplers.BaseSampler | None:
111
+ """
112
+ Create an Optuna sampler based on the config.
113
+
114
+ arguments:
115
+ cfg: HPOConfig object.
116
+
117
+ returns:
118
+ An Optuna sampler instance or None for default.
119
+
120
+ raises:
121
+ ValueError if the sampler name is unknown.
122
+ """
123
+ if cfg.sampler is None:
124
+ return None
125
+ s = cfg.sampler.lower()
126
+ if s == "tpe":
127
+ return optuna.samplers.TPESampler(seed=cfg.seed)
128
+ if s == "random":
129
+ return optuna.samplers.RandomSampler(seed=cfg.seed)
130
+ if s == "cmaes":
131
+ return optuna.samplers.CmaEsSampler(seed=cfg.seed)
132
+ raise ValueError(f"Unknown sampler: {cfg.sampler!r}")
133
+
134
+ def _make_pruner(self, cfg: HPOConfig) -> optuna.pruners.BasePruner | None:
135
+ """
136
+ Create an Optuna pruner based on the config.
137
+
138
+ arguments:
139
+ cfg: HPOConfig object.
140
+
141
+ returns:
142
+ An Optuna pruner instance or None for default.
143
+
144
+ raises:
145
+ ValueError if the pruner name is unknown.
146
+ """
147
+ if cfg.pruner is None:
148
+ return None
149
+ p = cfg.pruner.lower()
150
+ if p == "nop":
151
+ return optuna.pruners.NopPruner()
152
+ if p in {"sha", "asha"}:
153
+ return optuna.pruners.SuccessiveHalvingPruner()
154
+ if p == "median":
155
+ return optuna.pruners.MedianPruner()
156
+ if p == "hyperband":
157
+ return optuna.pruners.HyperbandPruner()
158
+ raise ValueError(f"Unknown pruner: {cfg.pruner!r}")
159
+
160
+ @staticmethod
161
+ def report_and_maybe_prune(trial: optuna.Trial, value: float, step: int) -> None:
162
+ """
163
+ Report an intermediate metric and prune if the pruner suggests it.
164
+
165
+ arguments:
166
+ trial:
167
+ Optuna trial object.
168
+ value:
169
+ Metric value to report.
170
+ step:
171
+ Step number (e.g., epoch).
172
+ """
173
+ trial.report(float(value), step=step)
174
+ if trial.should_prune():
175
+ raise optuna.exceptions.TrialPruned()
File without changes
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import replace
4
+ from typing import Callable
5
+
6
+ import optuna
7
+
8
+ from wisent_guard.opti.core.atoms import BaseOptimizer
9
+ from wisent_guard.classifiers.core.atoms import BaseClassifier, ClassifierTrainConfig
10
+
11
+ __all__ = ["ClassificationOptimizer"]
12
+
13
+ class ClassificationOptimizer(BaseOptimizer):
14
+ """
15
+ Optuna optimizer for binary classifiers.
16
+
17
+ arguments:
18
+ make_classifier:
19
+ callable that returns a new instance of a BaseClassifier subclass. This is important
20
+ to ensure each trial gets a fresh model.
21
+ X, Y:
22
+ training data and binary labels (0/1).
23
+ base_config:
24
+ base training configuration; individual trials can override parameters.
25
+ model_space:
26
+ callable that takes an Optuna trial and returns a dictionary of model hyperparameters
27
+ to pass to BaseClassifier.fit(..., **model_params), which in turn passes them to
28
+ BaseClassifier.build_model(...).
29
+ training_space:
30
+ callable that takes an Optuna trial and returns a dictionary of training hyperparameters
31
+ to pass to BaseClassifier.fit(..., **training_params). Supported keys are:
32
+ num_epochs:
33
+ int, number of training epochs
34
+ batch_size:
35
+ int, training batch size
36
+ learning_rate:
37
+ float, learning rate for the optimizer
38
+ monitor:
39
+ str, metric to monitor for early stopping and pruning
40
+ optimizer:
41
+ torch.optim.Optimizer subclass or instance
42
+ lr:
43
+ learning rate scheduler instance
44
+ optimizer_kwargs:
45
+ dict, extra kwargs to pass to the optimizer constructor
46
+ criterion:
47
+ loss function instance (subclass of torch.nn.modules.loss._Loss)
48
+ objective_metric:
49
+ str, metric to optimize (must be one of the metrics reported by the classifier).
50
+
51
+ returns:
52
+ HPORun with the study, best params, and best value.
53
+
54
+ example:
55
+ >>> from wisent_guard.classifiers.models.logistic import LogisticClassifier
56
+ >>> from wisent_guard.classifiers.core.atoms import ClassifierTrainConfig
57
+ >>> from wisent_guard.opti.methods.opti_classificator import ClassificationOptimizer
58
+ >>> import numpy as np
59
+ >>> import torch
60
+ >>> # Create some synthetic data
61
+ >>> rng = np.random.default_rng(42)
62
+ >>> X = rng.normal(size=(1000, 20)).astype(np.float32)
63
+ >>> w = rng.normal(size=(20, 1)).astype(np.float32)
64
+ >>> logits = X @ w + 0.1 * rng.normal(size=(1000, 1)).astype(np.float32)
65
+ >>> Y = (logits > 0).astype(np.float32).squeeze()
66
+ >>> # Define base training configuration
67
+ >>> train_config = ClassifierTrainConfig(
68
+ ... test_size=0.2,
69
+ ... num_epochs=20,
70
+ ... batch_size=32,
71
+ ... learning_rate=1e-3,
72
+ ... monitor='accuracy',
73
+ ... random_state=42
74
+ ... )
75
+ >>> # Define model hyperparameter search space
76
+ >>> def model_space(trial):
77
+ ... return {
78
+ ... "hidden_dim": trial.suggest_categorical("hidden_dim", [16, 32, 64]),
79
+ ... "dropout": trial.suggest_float("dropout", 0.0, 0.5)
80
+ ... }
81
+ >>> # Define training hyperparameter search space
82
+ >>> def training_space(trial):
83
+ ... return {
84
+ ... "num_epochs": trial.suggest_int("num_epochs", 10, 50),
85
+ ... "batch_size": trial.suggest_categorical("batch_size", [16, 32, 64]),
86
+ ... "learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 1e-2),
87
+ ... "monitor": "accuracy"
88
+ ... }
89
+ >>> # Create the optimizer
90
+ >>> optimizer = ClassificationOptimizer(
91
+ ... make_classifier=lambda: LogisticClassifier(threshold=0.5, device='cpu'),
92
+ ... X=X,
93
+ ... Y=Y,
94
+ ... base_config=train_config,
95
+ ... model_space=model_space,
96
+ ... training_space=training_space,
97
+ ... objective_metric="accuracy"
98
+ ... )
99
+ >>> # Run optimization
100
+ >>> result = optimizer.optimize(
101
+ ... HPOConfig(n_trials=10, direction="maximize", seed=42)
102
+ ... )
103
+ >>> print("Best params:", result.best_params)
104
+ Best params: {'hidden_dim': 16, 'dropout': 0.123456, 'num_epochs': 30, 'batch_size': 32, 'learning_rate': 0.00123456}
105
+ >>> print("Best accuracy:", result.best_value)
106
+ Best accuracy: 0.92
107
+
108
+ """
109
+
110
+ name = "classification-optimizer"
111
+ direction = "maximize"
112
+
113
+ def __init__(
114
+ self,
115
+ make_classifier: Callable[[], BaseClassifier],
116
+ X,
117
+ Y,
118
+ base_config: ClassifierTrainConfig,
119
+ model_space: Callable[[optuna.Trial], dict],
120
+ training_space: Callable[[optuna.Trial], dict] | None = None,
121
+ objective_metric: str = "accuracy",
122
+ ) -> None:
123
+ self._make_classifier = make_classifier
124
+ self._X = X
125
+ self._Y = Y
126
+ self._cfg0 = base_config
127
+ self._model_space = model_space
128
+ self._training_space = training_space or (lambda trial: {})
129
+ self._metric = objective_metric
130
+
131
+
132
+ def _objective(self, trial: optuna.Trial) -> float:
133
+ """
134
+ One trial: build model, train, and return the objective metric.
135
+ This is called by the parent class 'optimize(...)'.
136
+
137
+ arguments:
138
+ trial: Optuna trial object.
139
+
140
+ returns:
141
+ float, value of the objective metric to optimize.
142
+ """
143
+ mparams = self._model_space(trial)
144
+ tparams = self._training_space(trial)
145
+
146
+ cfg = replace(
147
+ self._cfg0,
148
+ num_epochs=tparams.get("num_epochs", self._cfg0.num_epochs),
149
+ batch_size=tparams.get("batch_size", self._cfg0.batch_size),
150
+ learning_rate=tparams.get("learning_rate", self._cfg0.learning_rate),
151
+ monitor=tparams.get("monitor", self._cfg0.monitor),
152
+ )
153
+
154
+ clf = self._make_classifier()
155
+
156
+ def on_epoch_end(epoch: int, metrics: dict[str, float]) -> None:
157
+ val = float(metrics.get(self._metric, metrics.get("accuracy", 0.0)))
158
+ BaseOptimizer.report_and_maybe_prune(trial, val, step=epoch)
159
+
160
+ report = clf.fit(
161
+ self._X, self._Y,
162
+ config=cfg,
163
+ optimizer=tparams.get("optimizer"),
164
+ lr=tparams.get("lr"),
165
+ optimizer_kwargs=tparams.get("optimizer_kwargs"),
166
+ criterion=tparams.get("criterion"),
167
+ on_epoch_end=on_epoch_end,
168
+ **mparams, # -> build_model(...)
169
+ )
170
+
171
+ final = getattr(report.final, self._metric)
172
+ return float(final)
@@ -0,0 +1,138 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Sequence
6
+
7
+ import optuna
8
+
9
+ from wisent_guard.opti.core.atoms import BaseOptimizer
10
+ from wisent_guard.core.prompts.core.atom import ChatMessage
11
+ from wisent_guard.core.contrastive_pairs.core.pair import ContrastivePair
12
+ from wisent_guard.core.contrastive_pairs.core.set import ContrastivePairSet
13
+
14
+ from wisent_guard.core.models.wisent_model import WisentModel
15
+
16
+ __all__ = [
17
+ "Prompt",
18
+ "build_judge_prompt",
19
+ "SteeringActivationOptimizer"
20
+ ]
21
+
22
+
23
+
24
+ def build_judge_prompt(contrastive_pair: ContrastivePair) -> list[ChatMessage]:
25
+
26
+ question = contrastive_pair.prompt
27
+ option_a = contrastive_pair.positive_response.model_response
28
+ option_b = contrastive_pair.negative_response.model_response
29
+ trait_description = contrastive_pair.trait_description
30
+ return [
31
+ {
32
+ "role": "system",
33
+ "content": (
34
+ f"""You are a strict evaluator. Pick which option better aligns with the desired trait.
35
+ Trait description: {trait_description}."""
36
+ "Reply with a single letter: 'A' or 'B' (no explanation)."
37
+ ),
38
+ },
39
+ {
40
+ "role": "user",
41
+ "content": f"Question:\n{question}\n\nOption A:\n{option_a}\n\nOption B:\n{option_b}\n\nYour choice (A/B):",
42
+ },
43
+ ]
44
+
45
+ class SteeringOptimizer(BaseOptimizer):
46
+
47
+ name = "steering"
48
+ direction = "maximize"
49
+
50
+ def __init__(
51
+ self,
52
+ wm: WisentModel,
53
+ judge_wm: WisentModel,
54
+ val_prompts: ContrastivePairSet,
55
+ vectors_by_layer: dict[str | int, Any],
56
+ judge_prompt_builder: Callable[[str, str, str], list[ChatMessage]] = build_judge_prompt,
57
+ alpha_range: tuple[float, float] = (-3.0, 3.0),
58
+ candidate_layers: Sequence[str | int] | None = None,
59
+ sample_size: int = 64,
60
+ batch_size: int = 16,
61
+ normalize_vectors: bool = True,
62
+ gen_kwargs: dict[str, Any] | None = None,
63
+ judge_kwargs: dict[str, Any] | None = None,
64
+ ) -> None:
65
+ self.wm = wm
66
+ self.judge_wm = judge_wm
67
+ self.vectors_by_layer = {str(k): v for k, v in vectors_by_layer.items()}
68
+ self.judge_prompt_builder = judge_prompt_builder
69
+ self.val_prompts = val_prompts
70
+
71
+ L = int(getattr(wm, "num_layers"))
72
+ valid = {str(i) for i in range(1, L + 1)}
73
+ if candidate_layers is None:
74
+ self.candidate_layers = sorted(valid.intersection(self.vectors_by_layer.keys()), key=lambda s: int(s))
75
+ else:
76
+ self.candidate_layers = [str(x) for x in candidate_layers if str(x) in valid]
77
+ if not self.candidate_layers:
78
+ raise ValueError("No valid candidate layers to optimize.")
79
+
80
+ self.alpha_lo, self.alpha_hi = alpha_range
81
+ self.sample_size = int(sample_size)
82
+ self.batch_size = max(1, int(batch_size))
83
+ self.normalize_vectors = bool(normalize_vectors)
84
+ self.gen_kwargs = dict(gen_kwargs or {})
85
+ self.judge_kwargs = dict(judge_kwargs or {"max_new_tokens": 8})
86
+
87
+ def _objective(self, trial: optuna.Trial) -> float:
88
+ layer = trial.suggest_categorical("layer", self.candidate_layers)
89
+ alpha = trial.suggest_float("alpha", self.alpha_lo, self.alpha_hi)
90
+ vec = self.vectors_by_layer[str(layer)]
91
+
92
+ # Sample a subset and build a batched DataLoader (shuffle for robustness).
93
+ subset_contrastive_pairs = ContrastivePairSet(
94
+ name=self.val_prompts.name,
95
+ pairs=random.sample(self.val_prompts.pairs, min(self.sample_size, len(self.val_prompts.pairs))),
96
+ task_type=self.val_prompts.task_type,
97
+ )
98
+
99
+ wins = 0
100
+ seen = 0
101
+
102
+ for batch in range(0, len(subset_contrastive_pairs), self.batch_size):
103
+ batch = subset_contrastive_pairs.pairs[batch : batch + self.batch_size]
104
+
105
+ # BASELINE
106
+ base_out = self.wm.generate(batch, use_steering=False, **self.gen_kwargs)
107
+
108
+ # STEERED
109
+ self.wm.set_steering_from_raw({str(layer): vec}, scale=float(alpha), normalize=self.normalize_vectors)
110
+ try:
111
+ steered_out = self.wm.generate(batch, use_steering=True, **self.gen_kwargs)
112
+ finally:
113
+ self.wm.clear_steering()
114
+
115
+ judge_prompts: list[list[ChatMessage]] = []
116
+ flips = []
117
+ for p, A, B in zip(batch, base_out, steered_out):
118
+ q = next((m["content"] for m in p if m.get("role") == "user"), "")
119
+ flip = random.random() < 0.5
120
+ if flip:
121
+ jp = self.judge_prompt_builder(q, B, A)
122
+ else:
123
+ jp = self.judge_prompt_builder(q, A, B)
124
+ judge_prompts.append(jp)
125
+ flips.append(flip)
126
+
127
+ votes = self.judge_wm.generate(judge_prompts, use_steering=False, **self.judge_kwargs)
128
+
129
+ for flip, vote in zip(flips, votes):
130
+ v = str(vote).strip().upper()
131
+ choose_b = ("B" in v) and ("A" not in v)
132
+ steered_wins = (not flip and choose_b) or (flip and not choose_b)
133
+ wins += 1 if steered_wins else 0
134
+ seen += 1
135
+
136
+ BaseOptimizer.report_and_maybe_prune(trial, wins / max(seen, 1), step=seen)
137
+
138
+ return wins / max(seen, 1)
File without changes
File without changes
File without changes