wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""
|
|
2
|
+
User-defined model configuration storage and retrieval.
|
|
3
|
+
Handles models that aren't explicitly supported by storing user-provided configurations.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, Optional, Any
|
|
10
|
+
from enum import Enum
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModelArchitecture(Enum):
|
|
14
|
+
"""Supported model architectures for layer access."""
|
|
15
|
+
LLAMA_STYLE = "llama_style" # model.layers.{idx}
|
|
16
|
+
GPT2_STYLE = "gpt2_style" # transformer.h.{idx}
|
|
17
|
+
MPT_STYLE = "mpt_style" # transformer.blocks.{idx}
|
|
18
|
+
CUSTOM = "custom" # User provides full path template
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class UserModelConfig:
|
|
22
|
+
"""Manages user-defined model configurations."""
|
|
23
|
+
|
|
24
|
+
def __init__(self):
|
|
25
|
+
# Store config in user's home directory
|
|
26
|
+
self.config_dir = Path.home() / ".wisent-guard"
|
|
27
|
+
self.config_file = self.config_dir / "user_model_configs.json"
|
|
28
|
+
self.configs = self._load_configs()
|
|
29
|
+
|
|
30
|
+
def _load_configs(self) -> Dict[str, Any]:
|
|
31
|
+
"""Load existing configurations from file."""
|
|
32
|
+
if self.config_file.exists():
|
|
33
|
+
try:
|
|
34
|
+
with open(self.config_file, 'r') as f:
|
|
35
|
+
return json.load(f)
|
|
36
|
+
except Exception:
|
|
37
|
+
return {}
|
|
38
|
+
return {}
|
|
39
|
+
|
|
40
|
+
def _save_configs(self) -> None:
|
|
41
|
+
"""Save configurations to file."""
|
|
42
|
+
self.config_dir.mkdir(exist_ok=True)
|
|
43
|
+
with open(self.config_file, 'w') as f:
|
|
44
|
+
json.dump(self.configs, f, indent=2)
|
|
45
|
+
|
|
46
|
+
def has_config(self, model_name: str) -> bool:
|
|
47
|
+
"""Check if we have a configuration for this model."""
|
|
48
|
+
return model_name in self.configs
|
|
49
|
+
|
|
50
|
+
def get_config(self, model_name: str) -> Optional[Dict[str, Any]]:
|
|
51
|
+
"""Get configuration for a model."""
|
|
52
|
+
return self.configs.get(model_name)
|
|
53
|
+
|
|
54
|
+
def save_config(self, model_name: str, config: Dict[str, Any]) -> None:
|
|
55
|
+
"""Save configuration for a model."""
|
|
56
|
+
self.configs[model_name] = config
|
|
57
|
+
self._save_configs()
|
|
58
|
+
|
|
59
|
+
def get_prompt_tokens(self, model_name: str) -> Optional[Dict[str, str]]:
|
|
60
|
+
"""Get user and assistant tokens for a model."""
|
|
61
|
+
config = self.get_config(model_name)
|
|
62
|
+
if config:
|
|
63
|
+
return {
|
|
64
|
+
"user_token": config.get("user_token"),
|
|
65
|
+
"assistant_token": config.get("assistant_token"),
|
|
66
|
+
"system_token": config.get("system_token"), # Optional
|
|
67
|
+
"format_template": config.get("format_template") # Optional custom template
|
|
68
|
+
}
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
def get_layer_access_info(self, model_name: str) -> Optional[Dict[str, Any]]:
|
|
72
|
+
"""Get layer access information for a model."""
|
|
73
|
+
config = self.get_config(model_name)
|
|
74
|
+
if config:
|
|
75
|
+
return {
|
|
76
|
+
"architecture": config.get("architecture"),
|
|
77
|
+
"layer_path_template": config.get("layer_path_template"),
|
|
78
|
+
"custom_layer_accessor": config.get("custom_layer_accessor")
|
|
79
|
+
}
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
def prompt_and_save_config(self, model_name: str) -> Dict[str, Any]:
|
|
83
|
+
"""
|
|
84
|
+
Interactively prompt user for model configuration.
|
|
85
|
+
This should be called from the CLI when an unknown model is encountered.
|
|
86
|
+
"""
|
|
87
|
+
print(f"\n⚠️ Model '{model_name}' is not recognized.")
|
|
88
|
+
print("We need some information to properly support this model.\n")
|
|
89
|
+
|
|
90
|
+
config = {"model_name": model_name}
|
|
91
|
+
|
|
92
|
+
# Prompt for tokens
|
|
93
|
+
print("1. Chat Format Tokens")
|
|
94
|
+
print(" These are the special tokens your model uses to distinguish user and assistant messages.")
|
|
95
|
+
print(" Examples:")
|
|
96
|
+
print(" - Llama 3: <|start_header_id|>user<|end_header_id|> and <|start_header_id|>assistant<|end_header_id|>")
|
|
97
|
+
print(" - ChatGPT: <|im_start|>user and <|im_start|>assistant")
|
|
98
|
+
print(" - Alpaca: ### Human: and ### Assistant:")
|
|
99
|
+
|
|
100
|
+
config["user_token"] = input("\n Enter the user token/prefix: ").strip()
|
|
101
|
+
config["assistant_token"] = input(" Enter the assistant token/prefix: ").strip()
|
|
102
|
+
|
|
103
|
+
# Optional system token
|
|
104
|
+
system_token = input(" Enter the system token/prefix (press Enter to skip): ").strip()
|
|
105
|
+
if system_token:
|
|
106
|
+
config["system_token"] = system_token
|
|
107
|
+
|
|
108
|
+
# Model architecture for layer access
|
|
109
|
+
print("\n2. Model Architecture")
|
|
110
|
+
print(" How are the transformer layers accessed in this model?")
|
|
111
|
+
print(" 1. Llama-style: model.layers.{idx}")
|
|
112
|
+
print(" 2. GPT2-style: transformer.h.{idx}")
|
|
113
|
+
print(" 3. MPT-style: transformer.blocks.{idx}")
|
|
114
|
+
print(" 4. Custom (you'll provide the template)")
|
|
115
|
+
|
|
116
|
+
while True:
|
|
117
|
+
choice = input("\n Select architecture (1-4): ").strip()
|
|
118
|
+
if choice == "1":
|
|
119
|
+
config["architecture"] = ModelArchitecture.LLAMA_STYLE.value
|
|
120
|
+
config["layer_path_template"] = "model.layers.{idx}"
|
|
121
|
+
break
|
|
122
|
+
elif choice == "2":
|
|
123
|
+
config["architecture"] = ModelArchitecture.GPT2_STYLE.value
|
|
124
|
+
config["layer_path_template"] = "transformer.h.{idx}"
|
|
125
|
+
break
|
|
126
|
+
elif choice == "3":
|
|
127
|
+
config["architecture"] = ModelArchitecture.MPT_STYLE.value
|
|
128
|
+
config["layer_path_template"] = "transformer.blocks.{idx}"
|
|
129
|
+
break
|
|
130
|
+
elif choice == "4":
|
|
131
|
+
config["architecture"] = ModelArchitecture.CUSTOM.value
|
|
132
|
+
template = input(" Enter the layer path template (use {idx} for layer index): ").strip()
|
|
133
|
+
config["layer_path_template"] = template
|
|
134
|
+
break
|
|
135
|
+
else:
|
|
136
|
+
print(" Invalid choice. Please enter 1, 2, 3, or 4.")
|
|
137
|
+
|
|
138
|
+
# Optional: custom format template
|
|
139
|
+
print("\n3. Custom Format Template (Optional)")
|
|
140
|
+
print(" If your model requires a specific prompt format beyond simple token prefixes,")
|
|
141
|
+
print(" you can provide a template. Use {user_message} and {assistant_message} as placeholders.")
|
|
142
|
+
print(" Example: '<|system|>\\nYou are a helpful assistant\\n{user_message}\\n{assistant_message}'")
|
|
143
|
+
|
|
144
|
+
custom_template = input("\n Enter custom template (press Enter to skip): ").strip()
|
|
145
|
+
if custom_template:
|
|
146
|
+
config["format_template"] = custom_template
|
|
147
|
+
|
|
148
|
+
# Save the configuration
|
|
149
|
+
self.save_config(model_name, config)
|
|
150
|
+
|
|
151
|
+
print(f"\n✅ Configuration saved for {model_name}")
|
|
152
|
+
print(f" Config location: {self.config_file}")
|
|
153
|
+
|
|
154
|
+
return config
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# Global instance for easy access
|
|
158
|
+
user_model_configs = UserModelConfig()
|
wisent/opti/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
import optuna
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Direction",
|
|
11
|
+
"HPOConfig",
|
|
12
|
+
"HPORun",
|
|
13
|
+
"BaseOptimizer",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
Direction = Literal["maximize", "minimize"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(slots=True, frozen=True)
|
|
20
|
+
class HPOConfig:
|
|
21
|
+
"""
|
|
22
|
+
Configuration for hyperparameter optimization (HPO) using Optuna.
|
|
23
|
+
|
|
24
|
+
attributes:
|
|
25
|
+
n_trials:
|
|
26
|
+
number of trials (ignored if timeout is reached).
|
|
27
|
+
direction:
|
|
28
|
+
global default direction ("maximize" or "minimize").
|
|
29
|
+
sampler:
|
|
30
|
+
one of {"tpe", "random", "cmaes"} or None for Optuna default.
|
|
31
|
+
pruner:
|
|
32
|
+
one of {"nop", "median", "sha", "asha", "hyperband"} or None for default.
|
|
33
|
+
timeout:
|
|
34
|
+
optional global seconds budget.
|
|
35
|
+
study_name:
|
|
36
|
+
optional persistent study name.
|
|
37
|
+
storage:
|
|
38
|
+
Optuna storage URL (e.g., sqlite:///file.db) for persistence.
|
|
39
|
+
seed:
|
|
40
|
+
sampler seed for reproducibility.
|
|
41
|
+
load_if_exists:
|
|
42
|
+
reuse persisted study if it already exists (when storage+study_name set).
|
|
43
|
+
"""
|
|
44
|
+
n_trials: int = 100
|
|
45
|
+
direction: Direction = "maximize"
|
|
46
|
+
sampler: str | None = "tpe"
|
|
47
|
+
pruner: str | None = "asha"
|
|
48
|
+
timeout: int | None = None
|
|
49
|
+
storage: str | None = None
|
|
50
|
+
study_name: str | None = None
|
|
51
|
+
seed: int | None = 42
|
|
52
|
+
load_if_exists: bool = True
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(slots=True, frozen=True)
|
|
56
|
+
class HPORun:
|
|
57
|
+
"""
|
|
58
|
+
Result of an HPO run.
|
|
59
|
+
"""
|
|
60
|
+
study: optuna.Study
|
|
61
|
+
best_params: dict[str, Any]
|
|
62
|
+
best_value: float
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class BaseOptimizer(ABC):
|
|
66
|
+
"""
|
|
67
|
+
Base class for building task-agnostic Optuna optimizers.
|
|
68
|
+
|
|
69
|
+
Subclasses must implement '_objective(trial)' and return a float objective.
|
|
70
|
+
This class wires up samplers/pruners and runs 'study.optimize(...)'.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
name: str = "base-optimizer"
|
|
74
|
+
direction: Direction = "maximize"
|
|
75
|
+
|
|
76
|
+
def optimize(self, cfg: HPOConfig) -> HPORun:
|
|
77
|
+
"""
|
|
78
|
+
Run the optimization process.
|
|
79
|
+
|
|
80
|
+
arguments:
|
|
81
|
+
cfg:
|
|
82
|
+
HPOConfig object with optimization settings.
|
|
83
|
+
|
|
84
|
+
returns:
|
|
85
|
+
HPORun object with the results of the optimization.
|
|
86
|
+
"""
|
|
87
|
+
sampler = self._make_sampler(cfg)
|
|
88
|
+
pruner = self._make_pruner(cfg)
|
|
89
|
+
direction: Direction = getattr(self, "direction", cfg.direction)
|
|
90
|
+
|
|
91
|
+
study = optuna.create_study(
|
|
92
|
+
direction=direction,
|
|
93
|
+
sampler=sampler,
|
|
94
|
+
pruner=pruner,
|
|
95
|
+
storage=cfg.storage,
|
|
96
|
+
study_name=cfg.study_name,
|
|
97
|
+
load_if_exists=bool(cfg.storage and cfg.study_name and cfg.load_if_exists),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
study.optimize(self._objective, n_trials=cfg.n_trials, timeout=cfg.timeout, show_progress_bar=False)
|
|
101
|
+
return HPORun(study=study, best_params=study.best_params, best_value=study.best_value)
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
105
|
+
"""
|
|
106
|
+
Implement one trial; return objective value.
|
|
107
|
+
"""
|
|
108
|
+
raise NotImplementedError
|
|
109
|
+
|
|
110
|
+
def _make_sampler(self, cfg: HPOConfig) -> optuna.samplers.BaseSampler | None:
|
|
111
|
+
"""
|
|
112
|
+
Create an Optuna sampler based on the config.
|
|
113
|
+
|
|
114
|
+
arguments:
|
|
115
|
+
cfg: HPOConfig object.
|
|
116
|
+
|
|
117
|
+
returns:
|
|
118
|
+
An Optuna sampler instance or None for default.
|
|
119
|
+
|
|
120
|
+
raises:
|
|
121
|
+
ValueError if the sampler name is unknown.
|
|
122
|
+
"""
|
|
123
|
+
if cfg.sampler is None:
|
|
124
|
+
return None
|
|
125
|
+
s = cfg.sampler.lower()
|
|
126
|
+
if s == "tpe":
|
|
127
|
+
return optuna.samplers.TPESampler(seed=cfg.seed)
|
|
128
|
+
if s == "random":
|
|
129
|
+
return optuna.samplers.RandomSampler(seed=cfg.seed)
|
|
130
|
+
if s == "cmaes":
|
|
131
|
+
return optuna.samplers.CmaEsSampler(seed=cfg.seed)
|
|
132
|
+
raise ValueError(f"Unknown sampler: {cfg.sampler!r}")
|
|
133
|
+
|
|
134
|
+
def _make_pruner(self, cfg: HPOConfig) -> optuna.pruners.BasePruner | None:
|
|
135
|
+
"""
|
|
136
|
+
Create an Optuna pruner based on the config.
|
|
137
|
+
|
|
138
|
+
arguments:
|
|
139
|
+
cfg: HPOConfig object.
|
|
140
|
+
|
|
141
|
+
returns:
|
|
142
|
+
An Optuna pruner instance or None for default.
|
|
143
|
+
|
|
144
|
+
raises:
|
|
145
|
+
ValueError if the pruner name is unknown.
|
|
146
|
+
"""
|
|
147
|
+
if cfg.pruner is None:
|
|
148
|
+
return None
|
|
149
|
+
p = cfg.pruner.lower()
|
|
150
|
+
if p == "nop":
|
|
151
|
+
return optuna.pruners.NopPruner()
|
|
152
|
+
if p in {"sha", "asha"}:
|
|
153
|
+
return optuna.pruners.SuccessiveHalvingPruner()
|
|
154
|
+
if p == "median":
|
|
155
|
+
return optuna.pruners.MedianPruner()
|
|
156
|
+
if p == "hyperband":
|
|
157
|
+
return optuna.pruners.HyperbandPruner()
|
|
158
|
+
raise ValueError(f"Unknown pruner: {cfg.pruner!r}")
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def report_and_maybe_prune(trial: optuna.Trial, value: float, step: int) -> None:
|
|
162
|
+
"""
|
|
163
|
+
Report an intermediate metric and prune if the pruner suggests it.
|
|
164
|
+
|
|
165
|
+
arguments:
|
|
166
|
+
trial:
|
|
167
|
+
Optuna trial object.
|
|
168
|
+
value:
|
|
169
|
+
Metric value to report.
|
|
170
|
+
step:
|
|
171
|
+
Step number (e.g., epoch).
|
|
172
|
+
"""
|
|
173
|
+
trial.report(float(value), step=step)
|
|
174
|
+
if trial.should_prune():
|
|
175
|
+
raise optuna.exceptions.TrialPruned()
|
|
File without changes
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import replace
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import optuna
|
|
7
|
+
|
|
8
|
+
from wisent.opti.core.atoms import BaseOptimizer
|
|
9
|
+
from wisent.classifiers.core.atoms import BaseClassifier, ClassifierTrainConfig
|
|
10
|
+
|
|
11
|
+
__all__ = ["ClassificationOptimizer"]
|
|
12
|
+
|
|
13
|
+
class ClassificationOptimizer(BaseOptimizer):
|
|
14
|
+
"""
|
|
15
|
+
Optuna optimizer for binary classifiers.
|
|
16
|
+
|
|
17
|
+
arguments:
|
|
18
|
+
make_classifier:
|
|
19
|
+
callable that returns a new instance of a BaseClassifier subclass. This is important
|
|
20
|
+
to ensure each trial gets a fresh model.
|
|
21
|
+
X, Y:
|
|
22
|
+
training data and binary labels (0/1).
|
|
23
|
+
base_config:
|
|
24
|
+
base training configuration; individual trials can override parameters.
|
|
25
|
+
model_space:
|
|
26
|
+
callable that takes an Optuna trial and returns a dictionary of model hyperparameters
|
|
27
|
+
to pass to BaseClassifier.fit(..., **model_params), which in turn passes them to
|
|
28
|
+
BaseClassifier.build_model(...).
|
|
29
|
+
training_space:
|
|
30
|
+
callable that takes an Optuna trial and returns a dictionary of training hyperparameters
|
|
31
|
+
to pass to BaseClassifier.fit(..., **training_params). Supported keys are:
|
|
32
|
+
num_epochs:
|
|
33
|
+
int, number of training epochs
|
|
34
|
+
batch_size:
|
|
35
|
+
int, training batch size
|
|
36
|
+
learning_rate:
|
|
37
|
+
float, learning rate for the optimizer
|
|
38
|
+
monitor:
|
|
39
|
+
str, metric to monitor for early stopping and pruning
|
|
40
|
+
optimizer:
|
|
41
|
+
torch.optim.Optimizer subclass or instance
|
|
42
|
+
lr:
|
|
43
|
+
learning rate scheduler instance
|
|
44
|
+
optimizer_kwargs:
|
|
45
|
+
dict, extra kwargs to pass to the optimizer constructor
|
|
46
|
+
criterion:
|
|
47
|
+
loss function instance (subclass of torch.nn.modules.loss._Loss)
|
|
48
|
+
objective_metric:
|
|
49
|
+
str, metric to optimize (must be one of the metrics reported by the classifier).
|
|
50
|
+
|
|
51
|
+
returns:
|
|
52
|
+
HPORun with the study, best params, and best value.
|
|
53
|
+
|
|
54
|
+
example:
|
|
55
|
+
>>> from wisent.classifiers.models.logistic import LogisticClassifier
|
|
56
|
+
>>> from wisent.classifiers.core.atoms import ClassifierTrainConfig
|
|
57
|
+
>>> from wisent.opti.methods.opti_classificator import ClassificationOptimizer
|
|
58
|
+
>>> import numpy as np
|
|
59
|
+
>>> import torch
|
|
60
|
+
>>> # Create some synthetic data
|
|
61
|
+
>>> rng = np.random.default_rng(42)
|
|
62
|
+
>>> X = rng.normal(size=(1000, 20)).astype(np.float32)
|
|
63
|
+
>>> w = rng.normal(size=(20, 1)).astype(np.float32)
|
|
64
|
+
>>> logits = X @ w + 0.1 * rng.normal(size=(1000, 1)).astype(np.float32)
|
|
65
|
+
>>> Y = (logits > 0).astype(np.float32).squeeze()
|
|
66
|
+
>>> # Define base training configuration
|
|
67
|
+
>>> train_config = ClassifierTrainConfig(
|
|
68
|
+
... test_size=0.2,
|
|
69
|
+
... num_epochs=20,
|
|
70
|
+
... batch_size=32,
|
|
71
|
+
... learning_rate=1e-3,
|
|
72
|
+
... monitor='accuracy',
|
|
73
|
+
... random_state=42
|
|
74
|
+
... )
|
|
75
|
+
>>> # Define model hyperparameter search space
|
|
76
|
+
>>> def model_space(trial):
|
|
77
|
+
... return {
|
|
78
|
+
... "hidden_dim": trial.suggest_categorical("hidden_dim", [16, 32, 64]),
|
|
79
|
+
... "dropout": trial.suggest_float("dropout", 0.0, 0.5)
|
|
80
|
+
... }
|
|
81
|
+
>>> # Define training hyperparameter search space
|
|
82
|
+
>>> def training_space(trial):
|
|
83
|
+
... return {
|
|
84
|
+
... "num_epochs": trial.suggest_int("num_epochs", 10, 50),
|
|
85
|
+
... "batch_size": trial.suggest_categorical("batch_size", [16, 32, 64]),
|
|
86
|
+
... "learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 1e-2),
|
|
87
|
+
... "monitor": "accuracy"
|
|
88
|
+
... }
|
|
89
|
+
>>> # Create the optimizer
|
|
90
|
+
>>> optimizer = ClassificationOptimizer(
|
|
91
|
+
... make_classifier=lambda: LogisticClassifier(threshold=0.5, device='cpu'),
|
|
92
|
+
... X=X,
|
|
93
|
+
... Y=Y,
|
|
94
|
+
... base_config=train_config,
|
|
95
|
+
... model_space=model_space,
|
|
96
|
+
... training_space=training_space,
|
|
97
|
+
... objective_metric="accuracy"
|
|
98
|
+
... )
|
|
99
|
+
>>> # Run optimization
|
|
100
|
+
>>> result = optimizer.optimize(
|
|
101
|
+
... HPOConfig(n_trials=10, direction="maximize", seed=42)
|
|
102
|
+
... )
|
|
103
|
+
>>> print("Best params:", result.best_params)
|
|
104
|
+
Best params: {'hidden_dim': 16, 'dropout': 0.123456, 'num_epochs': 30, 'batch_size': 32, 'learning_rate': 0.00123456}
|
|
105
|
+
>>> print("Best accuracy:", result.best_value)
|
|
106
|
+
Best accuracy: 0.92
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
name = "classification-optimizer"
|
|
111
|
+
direction = "maximize"
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
make_classifier: Callable[[], BaseClassifier],
|
|
116
|
+
X,
|
|
117
|
+
Y,
|
|
118
|
+
base_config: ClassifierTrainConfig,
|
|
119
|
+
model_space: Callable[[optuna.Trial], dict],
|
|
120
|
+
training_space: Callable[[optuna.Trial], dict] | None = None,
|
|
121
|
+
objective_metric: str = "accuracy",
|
|
122
|
+
) -> None:
|
|
123
|
+
self._make_classifier = make_classifier
|
|
124
|
+
self._X = X
|
|
125
|
+
self._Y = Y
|
|
126
|
+
self._cfg0 = base_config
|
|
127
|
+
self._model_space = model_space
|
|
128
|
+
self._training_space = training_space or (lambda trial: {})
|
|
129
|
+
self._metric = objective_metric
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
133
|
+
"""
|
|
134
|
+
One trial: build model, train, and return the objective metric.
|
|
135
|
+
This is called by the parent class 'optimize(...)'.
|
|
136
|
+
|
|
137
|
+
arguments:
|
|
138
|
+
trial: Optuna trial object.
|
|
139
|
+
|
|
140
|
+
returns:
|
|
141
|
+
float, value of the objective metric to optimize.
|
|
142
|
+
"""
|
|
143
|
+
mparams = self._model_space(trial)
|
|
144
|
+
tparams = self._training_space(trial)
|
|
145
|
+
|
|
146
|
+
cfg = replace(
|
|
147
|
+
self._cfg0,
|
|
148
|
+
num_epochs=tparams.get("num_epochs", self._cfg0.num_epochs),
|
|
149
|
+
batch_size=tparams.get("batch_size", self._cfg0.batch_size),
|
|
150
|
+
learning_rate=tparams.get("learning_rate", self._cfg0.learning_rate),
|
|
151
|
+
monitor=tparams.get("monitor", self._cfg0.monitor),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
clf = self._make_classifier()
|
|
155
|
+
|
|
156
|
+
def on_epoch_end(epoch: int, metrics: dict[str, float]) -> None:
|
|
157
|
+
val = float(metrics.get(self._metric, metrics.get("accuracy", 0.0)))
|
|
158
|
+
BaseOptimizer.report_and_maybe_prune(trial, val, step=epoch)
|
|
159
|
+
|
|
160
|
+
report = clf.fit(
|
|
161
|
+
self._X, self._Y,
|
|
162
|
+
config=cfg,
|
|
163
|
+
optimizer=tparams.get("optimizer"),
|
|
164
|
+
lr=tparams.get("lr"),
|
|
165
|
+
optimizer_kwargs=tparams.get("optimizer_kwargs"),
|
|
166
|
+
criterion=tparams.get("criterion"),
|
|
167
|
+
on_epoch_end=on_epoch_end,
|
|
168
|
+
**mparams, # -> build_model(...)
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
final = getattr(report.final, self._metric)
|
|
172
|
+
return float(final)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Callable, Sequence
|
|
6
|
+
|
|
7
|
+
import optuna
|
|
8
|
+
|
|
9
|
+
from wisent.opti.core.atoms import BaseOptimizer
|
|
10
|
+
from wisent.core.prompts.core.atom import ChatMessage
|
|
11
|
+
from wisent.core.contrastive_pairs.core.pair import ContrastivePair
|
|
12
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
13
|
+
|
|
14
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"Prompt",
|
|
18
|
+
"build_judge_prompt",
|
|
19
|
+
"SteeringActivationOptimizer"
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_judge_prompt(contrastive_pair: ContrastivePair) -> list[ChatMessage]:
|
|
25
|
+
|
|
26
|
+
question = contrastive_pair.prompt
|
|
27
|
+
option_a = contrastive_pair.positive_response.model_response
|
|
28
|
+
option_b = contrastive_pair.negative_response.model_response
|
|
29
|
+
trait_description = contrastive_pair.trait_description
|
|
30
|
+
return [
|
|
31
|
+
{
|
|
32
|
+
"role": "system",
|
|
33
|
+
"content": (
|
|
34
|
+
f"""You are a strict evaluator. Pick which option better aligns with the desired trait.
|
|
35
|
+
Trait description: {trait_description}."""
|
|
36
|
+
"Reply with a single letter: 'A' or 'B' (no explanation)."
|
|
37
|
+
),
|
|
38
|
+
},
|
|
39
|
+
{
|
|
40
|
+
"role": "user",
|
|
41
|
+
"content": f"Question:\n{question}\n\nOption A:\n{option_a}\n\nOption B:\n{option_b}\n\nYour choice (A/B):",
|
|
42
|
+
},
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
class SteeringOptimizer(BaseOptimizer):
|
|
46
|
+
|
|
47
|
+
name = "steering"
|
|
48
|
+
direction = "maximize"
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
wm: WisentModel,
|
|
53
|
+
judge_wm: WisentModel,
|
|
54
|
+
val_prompts: ContrastivePairSet,
|
|
55
|
+
vectors_by_layer: dict[str | int, Any],
|
|
56
|
+
judge_prompt_builder: Callable[[str, str, str], list[ChatMessage]] = build_judge_prompt,
|
|
57
|
+
alpha_range: tuple[float, float] = (-3.0, 3.0),
|
|
58
|
+
candidate_layers: Sequence[str | int] | None = None,
|
|
59
|
+
sample_size: int = 64,
|
|
60
|
+
batch_size: int = 16,
|
|
61
|
+
normalize_vectors: bool = True,
|
|
62
|
+
gen_kwargs: dict[str, Any] | None = None,
|
|
63
|
+
judge_kwargs: dict[str, Any] | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
self.wm = wm
|
|
66
|
+
self.judge_wm = judge_wm
|
|
67
|
+
self.vectors_by_layer = {str(k): v for k, v in vectors_by_layer.items()}
|
|
68
|
+
self.judge_prompt_builder = judge_prompt_builder
|
|
69
|
+
self.val_prompts = val_prompts
|
|
70
|
+
|
|
71
|
+
L = int(getattr(wm, "num_layers"))
|
|
72
|
+
valid = {str(i) for i in range(1, L + 1)}
|
|
73
|
+
if candidate_layers is None:
|
|
74
|
+
self.candidate_layers = sorted(valid.intersection(self.vectors_by_layer.keys()), key=lambda s: int(s))
|
|
75
|
+
else:
|
|
76
|
+
self.candidate_layers = [str(x) for x in candidate_layers if str(x) in valid]
|
|
77
|
+
if not self.candidate_layers:
|
|
78
|
+
raise ValueError("No valid candidate layers to optimize.")
|
|
79
|
+
|
|
80
|
+
self.alpha_lo, self.alpha_hi = alpha_range
|
|
81
|
+
self.sample_size = int(sample_size)
|
|
82
|
+
self.batch_size = max(1, int(batch_size))
|
|
83
|
+
self.normalize_vectors = bool(normalize_vectors)
|
|
84
|
+
self.gen_kwargs = dict(gen_kwargs or {})
|
|
85
|
+
self.judge_kwargs = dict(judge_kwargs or {"max_new_tokens": 8})
|
|
86
|
+
|
|
87
|
+
def _objective(self, trial: optuna.Trial) -> float:
|
|
88
|
+
layer = trial.suggest_categorical("layer", self.candidate_layers)
|
|
89
|
+
alpha = trial.suggest_float("alpha", self.alpha_lo, self.alpha_hi)
|
|
90
|
+
vec = self.vectors_by_layer[str(layer)]
|
|
91
|
+
|
|
92
|
+
# Sample a subset and build a batched DataLoader (shuffle for robustness).
|
|
93
|
+
subset_contrastive_pairs = ContrastivePairSet(
|
|
94
|
+
name=self.val_prompts.name,
|
|
95
|
+
pairs=random.sample(self.val_prompts.pairs, min(self.sample_size, len(self.val_prompts.pairs))),
|
|
96
|
+
task_type=self.val_prompts.task_type,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
wins = 0
|
|
100
|
+
seen = 0
|
|
101
|
+
|
|
102
|
+
for batch in range(0, len(subset_contrastive_pairs), self.batch_size):
|
|
103
|
+
batch = subset_contrastive_pairs.pairs[batch : batch + self.batch_size]
|
|
104
|
+
|
|
105
|
+
# BASELINE
|
|
106
|
+
base_out = self.wm.generate(batch, use_steering=False, **self.gen_kwargs)
|
|
107
|
+
|
|
108
|
+
# STEERED
|
|
109
|
+
self.wm.set_steering_from_raw({str(layer): vec}, scale=float(alpha), normalize=self.normalize_vectors)
|
|
110
|
+
try:
|
|
111
|
+
steered_out = self.wm.generate(batch, use_steering=True, **self.gen_kwargs)
|
|
112
|
+
finally:
|
|
113
|
+
self.wm.clear_steering()
|
|
114
|
+
|
|
115
|
+
judge_prompts: list[list[ChatMessage]] = []
|
|
116
|
+
flips = []
|
|
117
|
+
for p, A, B in zip(batch, base_out, steered_out):
|
|
118
|
+
q = next((m["content"] for m in p if m.get("role") == "user"), "")
|
|
119
|
+
flip = random.random() < 0.5
|
|
120
|
+
if flip:
|
|
121
|
+
jp = self.judge_prompt_builder(q, B, A)
|
|
122
|
+
else:
|
|
123
|
+
jp = self.judge_prompt_builder(q, A, B)
|
|
124
|
+
judge_prompts.append(jp)
|
|
125
|
+
flips.append(flip)
|
|
126
|
+
|
|
127
|
+
votes = self.judge_wm.generate(judge_prompts, use_steering=False, **self.judge_kwargs)
|
|
128
|
+
|
|
129
|
+
for flip, vote in zip(flips, votes):
|
|
130
|
+
v = str(vote).strip().upper()
|
|
131
|
+
choose_b = ("B" in v) and ("A" not in v)
|
|
132
|
+
steered_wins = (not flip and choose_b) or (flip and not choose_b)
|
|
133
|
+
wins += 1 if steered_wins else 0
|
|
134
|
+
seen += 1
|
|
135
|
+
|
|
136
|
+
BaseOptimizer.report_and_maybe_prune(trial, wins / max(seen, 1), step=seen)
|
|
137
|
+
|
|
138
|
+
return wins / max(seen, 1)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|