wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import itertools
|
|
3
|
+
from typing import Dict, List, Tuple, Any, Optional
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
import numpy as np
|
|
6
|
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
|
|
7
|
+
|
|
8
|
+
from .contrastive_pairs import ContrastivePairSet
|
|
9
|
+
from .steering import SteeringMethod, SteeringType
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def detect_model_layers(model) -> int:
|
|
15
|
+
"""
|
|
16
|
+
Detect the number of layers in a model.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model: The model object to inspect
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Number of layers in the model
|
|
23
|
+
"""
|
|
24
|
+
try:
|
|
25
|
+
# Try different ways to get layer count based on model architecture
|
|
26
|
+
if hasattr(model, 'hf_model'):
|
|
27
|
+
hf_model = model.hf_model
|
|
28
|
+
else:
|
|
29
|
+
hf_model = model
|
|
30
|
+
|
|
31
|
+
# Method 1: Check config for common layer count attributes
|
|
32
|
+
if hasattr(hf_model, 'config'):
|
|
33
|
+
config = hf_model.config
|
|
34
|
+
|
|
35
|
+
# Different models use different names for layer count
|
|
36
|
+
layer_attrs = ['num_hidden_layers', 'n_layer', 'num_layers', 'n_layers']
|
|
37
|
+
for attr in layer_attrs:
|
|
38
|
+
if hasattr(config, attr):
|
|
39
|
+
layer_count = getattr(config, attr)
|
|
40
|
+
if isinstance(layer_count, int) and layer_count > 0:
|
|
41
|
+
logger.info(f"Detected {layer_count} layers from config.{attr}")
|
|
42
|
+
return layer_count
|
|
43
|
+
|
|
44
|
+
# Method 2: Count actual layer modules
|
|
45
|
+
if hasattr(hf_model, 'model') and hasattr(hf_model.model, 'layers'):
|
|
46
|
+
# Llama/Mistral style: model.layers
|
|
47
|
+
layer_count = len(hf_model.model.layers)
|
|
48
|
+
logger.info(f"Detected {layer_count} layers from model.layers")
|
|
49
|
+
return layer_count
|
|
50
|
+
elif hasattr(hf_model, 'transformer') and hasattr(hf_model.transformer, 'h'):
|
|
51
|
+
# GPT style: transformer.h
|
|
52
|
+
layer_count = len(hf_model.transformer.h)
|
|
53
|
+
logger.info(f"Detected {layer_count} layers from transformer.h")
|
|
54
|
+
return layer_count
|
|
55
|
+
elif hasattr(hf_model, 'encoder') and hasattr(hf_model.encoder, 'layer'):
|
|
56
|
+
# BERT style: encoder.layer
|
|
57
|
+
layer_count = len(hf_model.encoder.layer)
|
|
58
|
+
logger.info(f"Detected {layer_count} layers from encoder.layer")
|
|
59
|
+
return layer_count
|
|
60
|
+
|
|
61
|
+
# Method 3: Try to count by iterating through named modules
|
|
62
|
+
layer_count = 0
|
|
63
|
+
for name, _ in hf_model.named_modules():
|
|
64
|
+
# Look for patterns like "layers.0", "h.0", "layer.0", etc.
|
|
65
|
+
if any(pattern in name for pattern in ['.layers.', '.h.', '.layer.']):
|
|
66
|
+
# Extract layer number
|
|
67
|
+
for part in name.split('.'):
|
|
68
|
+
if part.isdigit():
|
|
69
|
+
layer_num = int(part)
|
|
70
|
+
layer_count = max(layer_count, layer_num + 1)
|
|
71
|
+
|
|
72
|
+
if layer_count > 0:
|
|
73
|
+
logger.info(f"Detected {layer_count} layers from module names")
|
|
74
|
+
return layer_count
|
|
75
|
+
|
|
76
|
+
# Fallback: Conservative default
|
|
77
|
+
logger.warning("Could not detect layer count, using default of 32")
|
|
78
|
+
return 32
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.warning(f"Error detecting layer count: {e}, using default of 32")
|
|
82
|
+
return 32
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_default_layer_range(total_layers: int, use_all: bool = True) -> List[int]:
|
|
86
|
+
"""
|
|
87
|
+
Get a reasonable default layer range for optimization.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
total_layers: Total number of layers in the model
|
|
91
|
+
use_all: If True, use all layers; if False, use middle layers only
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
List of layer indices to optimize over
|
|
95
|
+
"""
|
|
96
|
+
if use_all:
|
|
97
|
+
# Use all layers (0-indexed)
|
|
98
|
+
return list(range(total_layers))
|
|
99
|
+
else:
|
|
100
|
+
# Use middle layers (skip first and last quarter)
|
|
101
|
+
start_layer = max(0, total_layers // 4)
|
|
102
|
+
end_layer = min(total_layers, (3 * total_layers) // 4)
|
|
103
|
+
return list(range(start_layer, end_layer))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class OptimizationConfig:
|
|
108
|
+
"""Configuration for hyperparameter optimization."""
|
|
109
|
+
|
|
110
|
+
# Layer range to search (will be auto-detected if None)
|
|
111
|
+
layer_range: List[int] = None
|
|
112
|
+
|
|
113
|
+
# Token aggregation methods to try
|
|
114
|
+
aggregation_methods: List[str] = field(default_factory=lambda: ["average", "final", "first", "max", "min"])
|
|
115
|
+
|
|
116
|
+
# Threshold range to search (for classification)
|
|
117
|
+
threshold_range: List[float] = field(default_factory=lambda: [0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
|
|
118
|
+
|
|
119
|
+
# Classifier types to try
|
|
120
|
+
classifier_types: List[str] = field(default_factory=lambda: ["logistic"])
|
|
121
|
+
|
|
122
|
+
# Performance metric to optimize
|
|
123
|
+
metric: str = "f1" # Options: "accuracy", "f1", "precision", "recall", "auc"
|
|
124
|
+
|
|
125
|
+
# Cross-validation folds (if 0, uses simple train/val split)
|
|
126
|
+
cv_folds: int = 0
|
|
127
|
+
|
|
128
|
+
# Validation split ratio (used when cv_folds=0)
|
|
129
|
+
val_split: float = 0.2
|
|
130
|
+
|
|
131
|
+
# Maximum number of combinations to try (for performance)
|
|
132
|
+
max_combinations: int = 100
|
|
133
|
+
|
|
134
|
+
# Random seed for reproducibility
|
|
135
|
+
seed: int = 42
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class OptimizationResult:
|
|
140
|
+
"""Result of hyperparameter optimization."""
|
|
141
|
+
|
|
142
|
+
best_layer: int
|
|
143
|
+
best_aggregation: str
|
|
144
|
+
best_threshold: float
|
|
145
|
+
best_classifier_type: str
|
|
146
|
+
best_score: float
|
|
147
|
+
best_metrics: Dict[str, float]
|
|
148
|
+
|
|
149
|
+
# All tested combinations and their scores
|
|
150
|
+
all_results: List[Dict[str, Any]] = field(default_factory=list)
|
|
151
|
+
|
|
152
|
+
# Configuration used for optimization
|
|
153
|
+
config: OptimizationConfig = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class HyperparameterOptimizer:
|
|
157
|
+
"""Optimizes hyperparameters for the guard system."""
|
|
158
|
+
|
|
159
|
+
def __init__(self, config: OptimizationConfig = None):
|
|
160
|
+
self.config = config or OptimizationConfig()
|
|
161
|
+
np.random.seed(self.config.seed)
|
|
162
|
+
|
|
163
|
+
def optimize(
|
|
164
|
+
self,
|
|
165
|
+
model,
|
|
166
|
+
train_pair_set: ContrastivePairSet,
|
|
167
|
+
test_pair_set: ContrastivePairSet,
|
|
168
|
+
device: str = None,
|
|
169
|
+
verbose: bool = False
|
|
170
|
+
) -> OptimizationResult:
|
|
171
|
+
"""
|
|
172
|
+
Optimize hyperparameters for the guard system.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
model: The model to use for training
|
|
176
|
+
train_pair_set: Training contrastive pairs
|
|
177
|
+
test_pair_set: Test contrastive pairs for evaluation
|
|
178
|
+
device: Device to run on
|
|
179
|
+
verbose: Whether to print progress
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
OptimizationResult with best hyperparameters and performance
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
# Auto-detect layer range if not provided
|
|
186
|
+
layer_range = self.config.layer_range
|
|
187
|
+
if layer_range is None:
|
|
188
|
+
total_layers = detect_model_layers(model)
|
|
189
|
+
layer_range = get_default_layer_range(total_layers, use_all=True)
|
|
190
|
+
if verbose:
|
|
191
|
+
print(f" • Auto-detected {total_layers} model layers")
|
|
192
|
+
print(f" • Using all layers for optimization: {layer_range[0]}-{layer_range[-1]}")
|
|
193
|
+
|
|
194
|
+
if verbose:
|
|
195
|
+
print(f"\n🔍 Starting hyperparameter optimization...")
|
|
196
|
+
print(f" • Layers to test: {len(layer_range)} (range: {layer_range[0]}-{layer_range[-1]})")
|
|
197
|
+
print(f" • Aggregation methods: {len(self.config.aggregation_methods)}")
|
|
198
|
+
print(f" • Thresholds: {len(self.config.threshold_range)}")
|
|
199
|
+
print(f" • Classifier types: {len(self.config.classifier_types)}")
|
|
200
|
+
print(f" • Optimization metric: {self.config.metric}")
|
|
201
|
+
|
|
202
|
+
# Generate all combinations of hyperparameters
|
|
203
|
+
combinations = list(itertools.product(
|
|
204
|
+
layer_range,
|
|
205
|
+
self.config.aggregation_methods,
|
|
206
|
+
self.config.threshold_range,
|
|
207
|
+
self.config.classifier_types
|
|
208
|
+
))
|
|
209
|
+
|
|
210
|
+
# Limit combinations if too many
|
|
211
|
+
if len(combinations) > self.config.max_combinations:
|
|
212
|
+
if verbose:
|
|
213
|
+
print(f" • Too many combinations ({len(combinations)}), sampling {self.config.max_combinations}")
|
|
214
|
+
combinations = np.random.choice(
|
|
215
|
+
combinations,
|
|
216
|
+
size=self.config.max_combinations,
|
|
217
|
+
replace=False
|
|
218
|
+
).tolist()
|
|
219
|
+
|
|
220
|
+
if verbose:
|
|
221
|
+
print(f" • Testing {len(combinations)} combinations...")
|
|
222
|
+
|
|
223
|
+
best_score = -np.inf
|
|
224
|
+
best_result = None
|
|
225
|
+
all_results = []
|
|
226
|
+
|
|
227
|
+
for i, (layer, aggregation, threshold, classifier_type) in enumerate(combinations):
|
|
228
|
+
try:
|
|
229
|
+
if verbose and (i + 1) % 10 == 0:
|
|
230
|
+
print(f" • Progress: {i + 1}/{len(combinations)} combinations tested")
|
|
231
|
+
|
|
232
|
+
# Train and evaluate this combination
|
|
233
|
+
result = self._evaluate_combination(
|
|
234
|
+
model=model,
|
|
235
|
+
train_pair_set=train_pair_set,
|
|
236
|
+
test_pair_set=test_pair_set,
|
|
237
|
+
layer=layer,
|
|
238
|
+
aggregation=aggregation,
|
|
239
|
+
threshold=threshold,
|
|
240
|
+
classifier_type=classifier_type,
|
|
241
|
+
device=device
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
all_results.append(result)
|
|
245
|
+
|
|
246
|
+
# Check if this is the best so far
|
|
247
|
+
score = result[self.config.metric]
|
|
248
|
+
if score > best_score:
|
|
249
|
+
best_score = score
|
|
250
|
+
best_result = result
|
|
251
|
+
|
|
252
|
+
if verbose:
|
|
253
|
+
print(f" • New best: layer={layer}, agg={aggregation}, thresh={threshold:.2f}, {self.config.metric}={score:.3f}")
|
|
254
|
+
|
|
255
|
+
except Exception as e:
|
|
256
|
+
logger.warning(f"Failed to evaluate combination (layer={layer}, agg={aggregation}, thresh={threshold}, type={classifier_type}): {e}")
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
if best_result is None:
|
|
260
|
+
raise ValueError("No valid combinations found during optimization")
|
|
261
|
+
|
|
262
|
+
# Create optimization result
|
|
263
|
+
optimization_result = OptimizationResult(
|
|
264
|
+
best_layer=best_result['layer'],
|
|
265
|
+
best_aggregation=best_result['aggregation'],
|
|
266
|
+
best_threshold=best_result['threshold'],
|
|
267
|
+
best_classifier_type=best_result['classifier_type'],
|
|
268
|
+
best_score=best_result[self.config.metric],
|
|
269
|
+
best_metrics={
|
|
270
|
+
'accuracy': best_result['accuracy'],
|
|
271
|
+
'f1': best_result['f1'],
|
|
272
|
+
'precision': best_result['precision'],
|
|
273
|
+
'recall': best_result['recall'],
|
|
274
|
+
'auc': best_result.get('auc', 0.0)
|
|
275
|
+
},
|
|
276
|
+
all_results=all_results,
|
|
277
|
+
config=self.config
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if verbose:
|
|
281
|
+
print(f"\n✅ Optimization complete!")
|
|
282
|
+
print(f" • Best layer: {optimization_result.best_layer}")
|
|
283
|
+
print(f" • Best aggregation: {optimization_result.best_aggregation}")
|
|
284
|
+
print(f" • Best threshold: {optimization_result.best_threshold:.2f}")
|
|
285
|
+
print(f" • Best classifier: {optimization_result.best_classifier_type}")
|
|
286
|
+
print(f" • Best {self.config.metric}: {optimization_result.best_score:.3f}")
|
|
287
|
+
print(f" • Tested {len(all_results)} valid combinations")
|
|
288
|
+
|
|
289
|
+
return optimization_result
|
|
290
|
+
|
|
291
|
+
def _evaluate_combination(
|
|
292
|
+
self,
|
|
293
|
+
model,
|
|
294
|
+
train_pair_set: ContrastivePairSet,
|
|
295
|
+
test_pair_set: ContrastivePairSet,
|
|
296
|
+
layer: int,
|
|
297
|
+
aggregation: str,
|
|
298
|
+
threshold: float,
|
|
299
|
+
classifier_type: str,
|
|
300
|
+
device: str = None
|
|
301
|
+
) -> Dict[str, Any]:
|
|
302
|
+
"""
|
|
303
|
+
Evaluate a single hyperparameter combination.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
model: The model to use
|
|
307
|
+
train_pair_set: Training data
|
|
308
|
+
test_pair_set: Test data
|
|
309
|
+
layer: Layer index to use
|
|
310
|
+
aggregation: Token aggregation method
|
|
311
|
+
threshold: Classification threshold
|
|
312
|
+
classifier_type: Type of classifier
|
|
313
|
+
device: Device to run on
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Dictionary with evaluation metrics
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
# Train classifier with this combination
|
|
320
|
+
steering_type = SteeringType.LOGISTIC if classifier_type == "logistic" else SteeringType.MLP
|
|
321
|
+
steering_method = SteeringMethod(method_type=steering_type, device=device)
|
|
322
|
+
|
|
323
|
+
# Extract activations for training (this should be done by the activation collector)
|
|
324
|
+
# For now, assume the pair set already has activations
|
|
325
|
+
|
|
326
|
+
training_results = steering_method.train(train_pair_set)
|
|
327
|
+
|
|
328
|
+
# Evaluate on test set
|
|
329
|
+
test_results = steering_method.evaluate(test_pair_set)
|
|
330
|
+
|
|
331
|
+
# Get predictions with the specified threshold
|
|
332
|
+
predictions = []
|
|
333
|
+
true_labels = []
|
|
334
|
+
|
|
335
|
+
for pair in test_pair_set.pairs:
|
|
336
|
+
if hasattr(pair.positive_response, 'activations') and hasattr(pair.negative_response, 'activations'):
|
|
337
|
+
# Get classifier predictions for both responses
|
|
338
|
+
pos_features = pair.positive_response.activations.extract_features_for_classifier()
|
|
339
|
+
neg_features = pair.negative_response.activations.extract_features_for_classifier()
|
|
340
|
+
|
|
341
|
+
# Predict probabilities
|
|
342
|
+
pos_prob = steering_method.classifier.predict_proba([pos_features.numpy()])[0]
|
|
343
|
+
neg_prob = steering_method.classifier.predict_proba([neg_features.numpy()])[0]
|
|
344
|
+
|
|
345
|
+
# Apply threshold and aggregation
|
|
346
|
+
pos_pred = 1 if pos_prob > threshold else 0
|
|
347
|
+
neg_pred = 1 if neg_prob > threshold else 0
|
|
348
|
+
|
|
349
|
+
# Positive response should be classified as 0 (harmless)
|
|
350
|
+
# Negative response should be classified as 1 (harmful)
|
|
351
|
+
predictions.extend([pos_pred, neg_pred])
|
|
352
|
+
true_labels.extend([0, 1])
|
|
353
|
+
|
|
354
|
+
if len(predictions) == 0:
|
|
355
|
+
raise ValueError("No valid predictions generated")
|
|
356
|
+
|
|
357
|
+
# Calculate metrics
|
|
358
|
+
accuracy = accuracy_score(true_labels, predictions)
|
|
359
|
+
f1 = f1_score(true_labels, predictions, zero_division=0)
|
|
360
|
+
precision = precision_score(true_labels, predictions, zero_division=0)
|
|
361
|
+
recall = recall_score(true_labels, predictions, zero_division=0)
|
|
362
|
+
|
|
363
|
+
# Calculate AUC if possible
|
|
364
|
+
try:
|
|
365
|
+
# Get probability scores for positive class
|
|
366
|
+
prob_scores = []
|
|
367
|
+
for pair in test_pair_set.pairs:
|
|
368
|
+
if hasattr(pair.positive_response, 'activations') and hasattr(pair.negative_response, 'activations'):
|
|
369
|
+
pos_features = pair.positive_response.activations.extract_features_for_classifier()
|
|
370
|
+
neg_features = pair.negative_response.activations.extract_features_for_classifier()
|
|
371
|
+
|
|
372
|
+
pos_prob = steering_method.classifier.predict_proba([pos_features.numpy()])[0]
|
|
373
|
+
neg_prob = steering_method.classifier.predict_proba([neg_features.numpy()])[0]
|
|
374
|
+
|
|
375
|
+
prob_scores.extend([pos_prob, neg_prob])
|
|
376
|
+
|
|
377
|
+
auc = roc_auc_score(true_labels, prob_scores) if len(set(true_labels)) > 1 else 0.0
|
|
378
|
+
except:
|
|
379
|
+
auc = 0.0
|
|
380
|
+
|
|
381
|
+
return {
|
|
382
|
+
'layer': layer,
|
|
383
|
+
'aggregation': aggregation,
|
|
384
|
+
'threshold': threshold,
|
|
385
|
+
'classifier_type': classifier_type,
|
|
386
|
+
'accuracy': accuracy,
|
|
387
|
+
'f1': f1,
|
|
388
|
+
'precision': precision,
|
|
389
|
+
'recall': recall,
|
|
390
|
+
'auc': auc,
|
|
391
|
+
'training_results': training_results,
|
|
392
|
+
'test_results': test_results
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
@staticmethod
|
|
396
|
+
def from_config_dict(config_dict: Dict[str, Any]) -> 'HyperparameterOptimizer':
|
|
397
|
+
"""Create optimizer from configuration dictionary."""
|
|
398
|
+
config = OptimizationConfig(**config_dict)
|
|
399
|
+
return HyperparameterOptimizer(config)
|
|
400
|
+
|
|
401
|
+
def save_results(self, result: OptimizationResult, filepath: str):
|
|
402
|
+
"""Save optimization results to file."""
|
|
403
|
+
import json
|
|
404
|
+
|
|
405
|
+
# Convert result to serializable format
|
|
406
|
+
result_dict = {
|
|
407
|
+
'best_hyperparameters': {
|
|
408
|
+
'layer': result.best_layer,
|
|
409
|
+
'aggregation': result.best_aggregation,
|
|
410
|
+
'threshold': result.best_threshold,
|
|
411
|
+
'classifier_type': result.best_classifier_type
|
|
412
|
+
},
|
|
413
|
+
'best_score': result.best_score,
|
|
414
|
+
'best_metrics': result.best_metrics,
|
|
415
|
+
'optimization_config': {
|
|
416
|
+
'layer_range': self.config.layer_range,
|
|
417
|
+
'aggregation_methods': self.config.aggregation_methods,
|
|
418
|
+
'threshold_range': self.config.threshold_range,
|
|
419
|
+
'classifier_types': self.config.classifier_types,
|
|
420
|
+
'metric': self.config.metric,
|
|
421
|
+
'max_combinations': self.config.max_combinations
|
|
422
|
+
},
|
|
423
|
+
'all_results': result.all_results
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
with open(filepath, 'w') as f:
|
|
427
|
+
json.dump(result_dict, f, indent=2)
|
|
428
|
+
|
|
429
|
+
logger.info(f"Optimization results saved to {filepath}")
|