wisent 0.1.1__py3-none-any.whl ā 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info ā wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info ā wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info ā wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,754 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Synthetic Classifier Option System
|
|
3
|
+
|
|
4
|
+
Creates custom classifiers from automatically discovered traits using synthetic contrastive pairs.
|
|
5
|
+
The model analyzes prompts to determine relevant traits for responses, then creates classifiers for those traits.
|
|
6
|
+
The actual response is NEVER analyzed as text - only its activations are classified.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import time
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import List, Tuple
|
|
13
|
+
|
|
14
|
+
from wisent.core.classifier.classifier import ActivationClassifier
|
|
15
|
+
|
|
16
|
+
from ....core.agent.budget import ResourceType, calculate_max_tasks_for_time_budget, get_budget_manager
|
|
17
|
+
from ....core.contrastive_pairs.generate_synthetically import SyntheticContrastivePairGenerator
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class TraitDiscoveryResult:
|
|
22
|
+
"""Result of automatic trait discovery."""
|
|
23
|
+
|
|
24
|
+
traits_discovered: List[str]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SyntheticClassifierResult:
|
|
29
|
+
"""Result of synthetic classifier creation and diagnosis."""
|
|
30
|
+
|
|
31
|
+
trait_description: str
|
|
32
|
+
classifier_confidence: float
|
|
33
|
+
prediction: int
|
|
34
|
+
confidence_score: float
|
|
35
|
+
training_pairs_count: int
|
|
36
|
+
generation_time: float
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AutomaticTraitDiscovery:
|
|
40
|
+
"""Automatically discovers relevant traits for prompt response analysis."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, model):
|
|
43
|
+
self.model = model
|
|
44
|
+
|
|
45
|
+
def discover_relevant_traits(self, prompt: str, time_budget_minutes: float) -> TraitDiscoveryResult:
|
|
46
|
+
"""
|
|
47
|
+
Analyze a prompt to automatically discover relevant quality traits for responses.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
prompt: The prompt/question to analyze for trait discovery
|
|
51
|
+
time_budget_minutes: Time budget for classifier creation in minutes
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
TraitDiscoveryResult with discovered traits
|
|
55
|
+
"""
|
|
56
|
+
# Calculate max traits based on time budget
|
|
57
|
+
max_traits = calculate_max_tasks_for_time_budget("classifier_training", time_budget_minutes)
|
|
58
|
+
max_traits = max(1, min(max_traits, 5)) # Cap between 1-5 traits
|
|
59
|
+
logging.info(f"Budget system: {time_budget_minutes:.1f} min budget ā max {max_traits} traits")
|
|
60
|
+
|
|
61
|
+
# Generate dynamic trait prompt based on budget
|
|
62
|
+
trait_lines = "\n".join([f"TRAIT_{i + 1}:" for i in range(max_traits)])
|
|
63
|
+
|
|
64
|
+
discovery_prompt = f"""USER PROMPT: {prompt}
|
|
65
|
+
|
|
66
|
+
List {max_traits} quality traits for responses:
|
|
67
|
+
{trait_lines}"""
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
analysis, _ = self.model.generate(
|
|
71
|
+
discovery_prompt, layer_index=15, max_new_tokens=200, temperature=0.7, do_sample=True
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
logging.info(f"Model generated analysis: {analysis[:200]}...")
|
|
75
|
+
return self._parse_discovery_result(analysis)
|
|
76
|
+
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logging.info(f"Error in trait discovery: {e}")
|
|
79
|
+
# Fallback to general traits
|
|
80
|
+
return TraitDiscoveryResult(traits_discovered=["accuracy and truthfulness", "helpfulness", "safety"])
|
|
81
|
+
|
|
82
|
+
def _parse_discovery_result(self, analysis: str) -> TraitDiscoveryResult:
|
|
83
|
+
"""Parse the model's trait discovery response."""
|
|
84
|
+
traits = []
|
|
85
|
+
|
|
86
|
+
lines = analysis.split("\n")
|
|
87
|
+
|
|
88
|
+
for line in lines:
|
|
89
|
+
line = line.strip()
|
|
90
|
+
|
|
91
|
+
if line.startswith("TRAIT_"):
|
|
92
|
+
# Extract trait description
|
|
93
|
+
if ":" in line:
|
|
94
|
+
trait = line.split(":", 1)[1].strip()
|
|
95
|
+
if len(trait) > 3:
|
|
96
|
+
traits.append(trait)
|
|
97
|
+
|
|
98
|
+
return TraitDiscoveryResult(traits_discovered=traits)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SyntheticClassifierFactory:
|
|
102
|
+
"""Creates custom classifiers from trait descriptions using synthetic contrastive pairs."""
|
|
103
|
+
|
|
104
|
+
def __init__(self, model):
|
|
105
|
+
self.model = model
|
|
106
|
+
self.pair_generator = SyntheticContrastivePairGenerator(model)
|
|
107
|
+
|
|
108
|
+
def create_classifier_from_trait(
|
|
109
|
+
self, trait_description: str, num_pairs: int = 15
|
|
110
|
+
) -> Tuple[ActivationClassifier, int]:
|
|
111
|
+
"""
|
|
112
|
+
Create a classifier for a specific trait using synthetic contrastive pairs.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
trait_description: Natural language description of the trait
|
|
116
|
+
num_pairs: Number of contrastive pairs to generate
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Tuple of (trained classifier, number of training pairs)
|
|
120
|
+
"""
|
|
121
|
+
try:
|
|
122
|
+
# Generate synthetic contrastive pairs for this trait
|
|
123
|
+
pair_set = self.pair_generator.generate_contrastive_pair_set(
|
|
124
|
+
trait_description=trait_description,
|
|
125
|
+
num_pairs=num_pairs,
|
|
126
|
+
name=f"synthetic_{trait_description[:20].replace(' ', '_')}",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if len(pair_set.pairs) < 3:
|
|
130
|
+
raise ValueError(f"Insufficient training pairs generated: {len(pair_set.pairs)}")
|
|
131
|
+
|
|
132
|
+
# Extract activations for training
|
|
133
|
+
positive_activations = []
|
|
134
|
+
negative_activations = []
|
|
135
|
+
|
|
136
|
+
logging.info(f"Extracting activations from {len(pair_set.pairs)} pairs...")
|
|
137
|
+
|
|
138
|
+
# Create Layer object for activation extraction
|
|
139
|
+
from wisent.core.layer import Layer
|
|
140
|
+
|
|
141
|
+
layer_obj = Layer(index=15, type="transformer")
|
|
142
|
+
logging.info(f"Created Layer object: index={layer_obj.index}, type={layer_obj.type}")
|
|
143
|
+
|
|
144
|
+
for i, pair in enumerate(pair_set.pairs):
|
|
145
|
+
logging.debug(f"Processing pair {i + 1}/{len(pair_set.pairs)}...")
|
|
146
|
+
try:
|
|
147
|
+
# Get activations for positive response
|
|
148
|
+
logging.debug(f"Extracting positive activations for: {pair.positive_response.text[:100]!r}")
|
|
149
|
+
pos_activations = self.model.extract_activations(pair.positive_response.text, layer_obj)
|
|
150
|
+
logging.debug(
|
|
151
|
+
f"Positive activations shape: {pos_activations.shape if hasattr(pos_activations, 'shape') else 'N/A'}"
|
|
152
|
+
)
|
|
153
|
+
positive_activations.append(pos_activations)
|
|
154
|
+
|
|
155
|
+
# Get activations for negative response
|
|
156
|
+
logging.debug(f"Extracting negative activations for: {pair.negative_response.text[:100]!r}")
|
|
157
|
+
neg_activations = self.model.extract_activations(pair.negative_response.text, layer_obj)
|
|
158
|
+
logging.debug(
|
|
159
|
+
f"Negative activations shape: {neg_activations.shape if hasattr(neg_activations, 'shape') else 'N/A'}"
|
|
160
|
+
)
|
|
161
|
+
negative_activations.append(neg_activations)
|
|
162
|
+
|
|
163
|
+
logging.debug(f"Successfully processed pair {i + 1}")
|
|
164
|
+
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logging.debug(f"Error extracting activations for pair {i + 1}: {e}")
|
|
167
|
+
import traceback
|
|
168
|
+
|
|
169
|
+
error_details = traceback.format_exc()
|
|
170
|
+
logging.debug(f"Full error traceback:\n{error_details}")
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
logging.info("ACTIVATION EXTRACTION SUMMARY:")
|
|
174
|
+
logging.info(f"Positive activations collected: {len(positive_activations)}")
|
|
175
|
+
logging.info(f"Negative activations collected: {len(negative_activations)}")
|
|
176
|
+
logging.info(f"Total pairs processed: {len(pair_set.pairs)}")
|
|
177
|
+
logging.info(f"Success rate: {(len(positive_activations) / len(pair_set.pairs) * 100):.1f}%")
|
|
178
|
+
|
|
179
|
+
if len(positive_activations) < 2 or len(negative_activations) < 2:
|
|
180
|
+
error_msg = f"Insufficient activation data for training: {len(positive_activations)} positive, {len(negative_activations)} negative"
|
|
181
|
+
logging.info(f"ERROR: {error_msg}")
|
|
182
|
+
raise ValueError(error_msg)
|
|
183
|
+
|
|
184
|
+
# Train classifier on activations
|
|
185
|
+
logging.info(
|
|
186
|
+
f"Training classifier on {len(positive_activations)} positive, {len(negative_activations)} negative activations..."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
logging.info("Creating ActivationClassifier instance...")
|
|
190
|
+
classifier = ActivationClassifier()
|
|
191
|
+
logging.info("ActivationClassifier created")
|
|
192
|
+
|
|
193
|
+
logging.info("Starting classifier training...")
|
|
194
|
+
try:
|
|
195
|
+
# Convert activations to the format expected by train_on_activations method
|
|
196
|
+
from wisent.core.activations import Activations
|
|
197
|
+
|
|
198
|
+
# Convert torch tensors to Activations objects if needed
|
|
199
|
+
harmful_activations = []
|
|
200
|
+
harmless_activations = []
|
|
201
|
+
|
|
202
|
+
from wisent.core.layer import Layer
|
|
203
|
+
|
|
204
|
+
layer_obj = Layer(index=15, type="transformer")
|
|
205
|
+
|
|
206
|
+
for pos_act in positive_activations:
|
|
207
|
+
if hasattr(pos_act, "shape"): # It's a torch tensor
|
|
208
|
+
# Create Activations object from tensor
|
|
209
|
+
act_obj = Activations(pos_act, layer_obj)
|
|
210
|
+
harmful_activations.append(act_obj)
|
|
211
|
+
else:
|
|
212
|
+
harmful_activations.append(pos_act)
|
|
213
|
+
|
|
214
|
+
for neg_act in negative_activations:
|
|
215
|
+
if hasattr(neg_act, "shape"): # It's a torch tensor
|
|
216
|
+
# Create Activations object from tensor
|
|
217
|
+
act_obj = Activations(neg_act, layer_obj)
|
|
218
|
+
harmless_activations.append(act_obj)
|
|
219
|
+
else:
|
|
220
|
+
harmless_activations.append(neg_act)
|
|
221
|
+
|
|
222
|
+
logging.info(
|
|
223
|
+
f"Converted to Activations objects: {len(harmful_activations)} harmful, {len(harmless_activations)} harmless"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Train using the correct method
|
|
227
|
+
training_result = classifier.train_on_activations(harmful_activations, harmless_activations)
|
|
228
|
+
logging.info(f"Classifier training completed successfully! Result: {training_result}")
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logging.info(f"ERROR during classifier training: {e}")
|
|
231
|
+
import traceback
|
|
232
|
+
|
|
233
|
+
error_details = traceback.format_exc()
|
|
234
|
+
logging.info(f"Full training error traceback:\n{error_details}")
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
return classifier, len(pair_set.pairs)
|
|
238
|
+
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logging.info(f"Error creating classifier for trait '{trait_description}': {e}")
|
|
241
|
+
raise
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class SyntheticClassifierSystem:
|
|
245
|
+
"""
|
|
246
|
+
Creates synthetic classifiers based on prompt analysis and applies them to response activations.
|
|
247
|
+
|
|
248
|
+
Analyzes prompts to discover relevant traits, creates classifiers using synthetic
|
|
249
|
+
contrastive pairs, and applies them to response activations only.
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
def __init__(self, model):
|
|
253
|
+
self.model = model
|
|
254
|
+
self.trait_discovery = AutomaticTraitDiscovery(model)
|
|
255
|
+
self.classifier_factory = SyntheticClassifierFactory(model)
|
|
256
|
+
|
|
257
|
+
def create_classifiers_for_prompt(
|
|
258
|
+
self, prompt: str, time_budget_minutes: float, pairs_per_trait: int = 12
|
|
259
|
+
) -> Tuple[List[ActivationClassifier], TraitDiscoveryResult]:
|
|
260
|
+
"""
|
|
261
|
+
Create synthetic classifiers for a prompt by discovering relevant traits.
|
|
262
|
+
Uses budget-aware planning to make intelligent decisions about what operations to perform.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
prompt: The prompt to analyze and create classifiers for
|
|
266
|
+
time_budget_minutes: Time budget for classifier creation in minutes
|
|
267
|
+
pairs_per_trait: Number of contrastive pairs per trait
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Tuple of (list of trained classifiers, trait discovery result)
|
|
271
|
+
"""
|
|
272
|
+
logging.info(f"Creating synthetic classifiers for prompt (budget: {time_budget_minutes:.1f} minutes)...")
|
|
273
|
+
|
|
274
|
+
# Get cost estimates from device benchmarks
|
|
275
|
+
try:
|
|
276
|
+
from ..budget import estimate_task_time_direct
|
|
277
|
+
|
|
278
|
+
# Estimate costs for different operations (in seconds)
|
|
279
|
+
model_loading_cost = estimate_task_time_direct("model_loading", 1) # Already loaded, minimal cost
|
|
280
|
+
trait_discovery_cost = 10.0 # Estimate: simple text generation ~10s
|
|
281
|
+
data_generation_cost = estimate_task_time_direct("data_generation", 1) # Per pair
|
|
282
|
+
classifier_training_cost = (
|
|
283
|
+
estimate_task_time_direct("classifier_training", 100) / 100
|
|
284
|
+
) # Per classifier (benchmark is per 100)
|
|
285
|
+
|
|
286
|
+
logging.info("Cost estimates per unit:")
|
|
287
|
+
logging.info(f"⢠Trait discovery: ~{trait_discovery_cost:.0f}s")
|
|
288
|
+
logging.info(f"⢠Data generation: ~{data_generation_cost:.0f}s per pair")
|
|
289
|
+
logging.info(f"⢠Classifier training: ~{classifier_training_cost:.0f}s per classifier")
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
logging.info(f"Could not get benchmark data: {e}")
|
|
293
|
+
logging.info("Using fallback estimates")
|
|
294
|
+
# Fallback estimates if benchmarks aren't available
|
|
295
|
+
trait_discovery_cost = 10.0
|
|
296
|
+
data_generation_cost = 30.0 # Per pair
|
|
297
|
+
classifier_training_cost = 180.0 # Per classifier (3 minutes)
|
|
298
|
+
|
|
299
|
+
budget_seconds = time_budget_minutes * 60.0
|
|
300
|
+
|
|
301
|
+
# Step 1: Budget-aware trait discovery
|
|
302
|
+
logging.info("Discovering relevant traits for this prompt...")
|
|
303
|
+
|
|
304
|
+
# Estimate if we have enough budget for even basic operations
|
|
305
|
+
min_required_time = trait_discovery_cost + (data_generation_cost * 3) + classifier_training_cost
|
|
306
|
+
|
|
307
|
+
if budget_seconds < min_required_time:
|
|
308
|
+
logging.info(f"Budget ({budget_seconds:.0f}s) too small for full classifier training")
|
|
309
|
+
logging.info(f"Minimum required: {min_required_time:.0f}s")
|
|
310
|
+
logging.info("Falling back to simple trait analysis only...")
|
|
311
|
+
|
|
312
|
+
# Just do trait discovery without training classifiers
|
|
313
|
+
discovery_result = self.trait_discovery.discover_relevant_traits(prompt, time_budget_minutes)
|
|
314
|
+
logging.info(
|
|
315
|
+
f"Discovered {len(discovery_result.traits_discovered)} traits: {discovery_result.traits_discovered}"
|
|
316
|
+
)
|
|
317
|
+
logging.info("Skipping classifier training due to budget constraints")
|
|
318
|
+
return [], discovery_result
|
|
319
|
+
|
|
320
|
+
# Calculate how many traits we can afford
|
|
321
|
+
cost_per_trait = (data_generation_cost * pairs_per_trait) + classifier_training_cost
|
|
322
|
+
available_for_traits = budget_seconds - trait_discovery_cost
|
|
323
|
+
max_affordable_traits = max(1, int(available_for_traits / cost_per_trait))
|
|
324
|
+
|
|
325
|
+
logging.info("Budget analysis:")
|
|
326
|
+
logging.info(f"⢠Available time: {budget_seconds:.0f}s")
|
|
327
|
+
logging.info(f"⢠Cost per trait ({pairs_per_trait} pairs): {cost_per_trait:.0f}s")
|
|
328
|
+
logging.info(f"⢠Max affordable traits: {max_affordable_traits}")
|
|
329
|
+
|
|
330
|
+
discovery_result = self.trait_discovery.discover_relevant_traits(prompt, time_budget_minutes)
|
|
331
|
+
|
|
332
|
+
if not discovery_result.traits_discovered:
|
|
333
|
+
logging.info("No traits discovered, cannot create classifiers")
|
|
334
|
+
return [], discovery_result
|
|
335
|
+
|
|
336
|
+
# Limit traits to what we can afford
|
|
337
|
+
affordable_traits = discovery_result.traits_discovered[:max_affordable_traits]
|
|
338
|
+
if len(affordable_traits) < len(discovery_result.traits_discovered):
|
|
339
|
+
logging.info(
|
|
340
|
+
f"Budget limiting to {len(affordable_traits)}/{len(discovery_result.traits_discovered)} traits"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
logging.info(f"Processing {len(affordable_traits)} traits: {affordable_traits}")
|
|
344
|
+
|
|
345
|
+
# Step 2: Create classifiers for affordable traits with smart resource allocation
|
|
346
|
+
classifiers = []
|
|
347
|
+
remaining_budget = budget_seconds - trait_discovery_cost
|
|
348
|
+
|
|
349
|
+
for i, trait_description in enumerate(affordable_traits):
|
|
350
|
+
logging.info(f"Creating classifier {i + 1}/{len(affordable_traits)}: {trait_description}")
|
|
351
|
+
logging.info(f"Remaining budget: {remaining_budget:.0f}s")
|
|
352
|
+
|
|
353
|
+
# Estimate cost for this specific classifier
|
|
354
|
+
estimated_pairs_cost = data_generation_cost * pairs_per_trait
|
|
355
|
+
estimated_training_cost = classifier_training_cost
|
|
356
|
+
total_estimated_cost = estimated_pairs_cost + estimated_training_cost
|
|
357
|
+
|
|
358
|
+
if total_estimated_cost > remaining_budget:
|
|
359
|
+
# Try with fewer pairs
|
|
360
|
+
max_affordable_pairs = max(3, int((remaining_budget - classifier_training_cost) / data_generation_cost))
|
|
361
|
+
if max_affordable_pairs < 3:
|
|
362
|
+
logging.info(f"Insufficient budget ({remaining_budget:.0f}s) for training, skipping")
|
|
363
|
+
continue
|
|
364
|
+
logging.info(f"Reducing pairs from {pairs_per_trait} to {max_affordable_pairs} to fit budget")
|
|
365
|
+
actual_pairs = max_affordable_pairs
|
|
366
|
+
else:
|
|
367
|
+
actual_pairs = pairs_per_trait
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
start_time = time.time()
|
|
371
|
+
|
|
372
|
+
# Create classifier for this trait
|
|
373
|
+
logging.info(f"Creating classifier with {actual_pairs} pairs...")
|
|
374
|
+
classifier, pairs_count = self.classifier_factory.create_classifier_from_trait(
|
|
375
|
+
trait_description=trait_description, num_pairs=actual_pairs
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
actual_time = time.time() - start_time
|
|
379
|
+
remaining_budget -= actual_time
|
|
380
|
+
|
|
381
|
+
# Store trait info in classifier for later reference
|
|
382
|
+
classifier._trait_description = trait_description
|
|
383
|
+
classifier._pairs_count = pairs_count
|
|
384
|
+
|
|
385
|
+
classifiers.append(classifier)
|
|
386
|
+
|
|
387
|
+
logging.info(f"Classifier created with {pairs_count} training pairs ({actual_time:.0f}s)")
|
|
388
|
+
|
|
389
|
+
except Exception as e:
|
|
390
|
+
logging.info(f"Error creating classifier for trait '{trait_description}': {e}")
|
|
391
|
+
continue
|
|
392
|
+
|
|
393
|
+
logging.info(f"Created {len(classifiers)} synthetic classifiers within budget")
|
|
394
|
+
|
|
395
|
+
# Update discovery result to reflect what we actually processed
|
|
396
|
+
final_discovery = TraitDiscoveryResult(traits_discovered=affordable_traits)
|
|
397
|
+
return classifiers, final_discovery
|
|
398
|
+
|
|
399
|
+
def apply_classifiers_to_response(
|
|
400
|
+
self, response_text: str, classifiers: List[ActivationClassifier], trait_discovery: TraitDiscoveryResult
|
|
401
|
+
) -> List[SyntheticClassifierResult]:
|
|
402
|
+
"""
|
|
403
|
+
Apply pre-trained synthetic classifiers to a response.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
response_text: The response to analyze (only used for activation extraction)
|
|
407
|
+
classifiers: List of trained classifiers to apply
|
|
408
|
+
trait_discovery: Original trait discovery result for context
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
List of classification results
|
|
412
|
+
"""
|
|
413
|
+
logging.info(f"Applying {len(classifiers)} synthetic classifiers to response...")
|
|
414
|
+
|
|
415
|
+
# Extract activations from the response ONCE
|
|
416
|
+
logging.info("Extracting activations from response...")
|
|
417
|
+
try:
|
|
418
|
+
response_activations, _ = self.model.extract_activations(response_text, layer=15)
|
|
419
|
+
except Exception as e:
|
|
420
|
+
logging.info(f"Error extracting response activations: {e}")
|
|
421
|
+
return []
|
|
422
|
+
|
|
423
|
+
results = []
|
|
424
|
+
|
|
425
|
+
for i, classifier in enumerate(classifiers):
|
|
426
|
+
trait_description = getattr(classifier, "_trait_description", f"trait_{i}")
|
|
427
|
+
pairs_count = getattr(classifier, "_pairs_count", 0)
|
|
428
|
+
|
|
429
|
+
logging.info(f"Applying classifier {i + 1}/{len(classifiers)}: {trait_description}")
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
start_time = time.time()
|
|
433
|
+
|
|
434
|
+
# Apply classifier to response activations
|
|
435
|
+
prediction = classifier.predict(response_activations)
|
|
436
|
+
confidence = classifier.predict_proba(response_activations)
|
|
437
|
+
|
|
438
|
+
# Handle confidence score (could be array or scalar)
|
|
439
|
+
if hasattr(confidence, "__iter__") and len(confidence) > 1:
|
|
440
|
+
confidence_score = float(max(confidence))
|
|
441
|
+
else:
|
|
442
|
+
confidence_score = float(confidence)
|
|
443
|
+
|
|
444
|
+
generation_time = time.time() - start_time
|
|
445
|
+
|
|
446
|
+
result = SyntheticClassifierResult(
|
|
447
|
+
trait_description=trait_description,
|
|
448
|
+
classifier_confidence=confidence_score,
|
|
449
|
+
prediction=int(prediction),
|
|
450
|
+
confidence_score=confidence_score,
|
|
451
|
+
training_pairs_count=pairs_count,
|
|
452
|
+
generation_time=generation_time,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
results.append(result)
|
|
456
|
+
|
|
457
|
+
logging.info(f"Result: prediction={prediction}, confidence={confidence_score:.3f}")
|
|
458
|
+
|
|
459
|
+
except Exception as e:
|
|
460
|
+
logging.info(f"Error applying classifier for trait '{trait_description}': {e}")
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
logging.info(f"Applied {len(results)} classifiers successfully")
|
|
464
|
+
return results
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def get_time_budget_from_manager() -> float:
|
|
468
|
+
"""Get time budget from the global budget manager."""
|
|
469
|
+
budget_manager = get_budget_manager()
|
|
470
|
+
time_budget = budget_manager.get_budget(ResourceType.TIME)
|
|
471
|
+
if not time_budget:
|
|
472
|
+
raise ValueError("No time budget set in budget manager. Call set_time_budget(minutes) first.")
|
|
473
|
+
return time_budget.remaining_budget / 60.0 # Convert to minutes
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
# Main interface functions
|
|
477
|
+
def create_synthetic_classifier_system(model) -> SyntheticClassifierSystem:
|
|
478
|
+
"""Create a synthetic classifier system instance."""
|
|
479
|
+
return SyntheticClassifierSystem(model)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def create_classifiers_for_prompt(
|
|
483
|
+
model, prompt: str, pairs_per_trait: int = 12
|
|
484
|
+
) -> Tuple[List[ActivationClassifier], TraitDiscoveryResult]:
|
|
485
|
+
"""
|
|
486
|
+
Convenience function to create synthetic classifiers for a prompt.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
model: The language model instance
|
|
490
|
+
prompt: Prompt to analyze and create classifiers for
|
|
491
|
+
pairs_per_trait: Number of contrastive pairs per trait
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
Tuple of (trained classifiers, trait discovery result)
|
|
495
|
+
"""
|
|
496
|
+
time_budget_minutes = get_time_budget_from_manager()
|
|
497
|
+
system = create_synthetic_classifier_system(model)
|
|
498
|
+
return system.create_classifiers_for_prompt(prompt, time_budget_minutes, pairs_per_trait)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def apply_classifiers_to_response(
|
|
502
|
+
model, response_text: str, classifiers: List[ActivationClassifier], trait_discovery: TraitDiscoveryResult
|
|
503
|
+
) -> List[SyntheticClassifierResult]:
|
|
504
|
+
"""
|
|
505
|
+
Convenience function to apply classifiers to a response.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
model: The language model instance
|
|
509
|
+
response_text: Response to analyze
|
|
510
|
+
classifiers: Pre-trained classifiers
|
|
511
|
+
trait_discovery: Original trait discovery result
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
List of classification results
|
|
515
|
+
"""
|
|
516
|
+
system = create_synthetic_classifier_system(model)
|
|
517
|
+
return system.apply_classifiers_to_response(response_text, classifiers, trait_discovery)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def create_classifier_from_trait_description(
|
|
521
|
+
model, trait_description: str, num_pairs: int = 15
|
|
522
|
+
) -> ActivationClassifier:
|
|
523
|
+
"""
|
|
524
|
+
Direct function to create a classifier from a trait description.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
model: The language model instance
|
|
528
|
+
trait_description: Natural language description of the trait (e.g., "accuracy and truthfulness")
|
|
529
|
+
num_pairs: Number of contrastive pairs to generate for training
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
Trained ActivationClassifier
|
|
533
|
+
"""
|
|
534
|
+
import datetime
|
|
535
|
+
|
|
536
|
+
# Setup logging to file
|
|
537
|
+
log_file = f"synthetic_classifier_debug_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
538
|
+
|
|
539
|
+
def log_and_print(message):
|
|
540
|
+
print(message)
|
|
541
|
+
with open(log_file, "a") as f:
|
|
542
|
+
f.write(f"{datetime.datetime.now().isoformat()}: {message}\n")
|
|
543
|
+
|
|
544
|
+
log_and_print(f"šÆ Creating classifier for trait: '{trait_description}'")
|
|
545
|
+
log_and_print(f"š Parameters: num_pairs={num_pairs}")
|
|
546
|
+
|
|
547
|
+
# Create synthetic contrastive pair generator
|
|
548
|
+
log_and_print("š Creating SyntheticContrastivePairGenerator...")
|
|
549
|
+
pair_generator = SyntheticContrastivePairGenerator(model)
|
|
550
|
+
log_and_print("ā
SyntheticContrastivePairGenerator created successfully")
|
|
551
|
+
|
|
552
|
+
# Generate contrastive pairs for this trait
|
|
553
|
+
log_and_print(f"š Generating {num_pairs} contrastive pairs...")
|
|
554
|
+
pair_set = pair_generator.generate_contrastive_pair_set(
|
|
555
|
+
trait_description=trait_description,
|
|
556
|
+
num_pairs=num_pairs,
|
|
557
|
+
name=f"synthetic_{trait_description[:20].replace(' ', '_')}",
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
log_and_print(f"ā
Generated {len(pair_set.pairs)} pairs total")
|
|
561
|
+
|
|
562
|
+
# Log all generated pairs in detail
|
|
563
|
+
log_and_print("=" * 80)
|
|
564
|
+
log_and_print("DETAILED PAIR ANALYSIS:")
|
|
565
|
+
log_and_print("=" * 80)
|
|
566
|
+
|
|
567
|
+
for i, pair in enumerate(pair_set.pairs):
|
|
568
|
+
log_and_print(f"\n--- PAIR {i + 1}/{len(pair_set.pairs)} ---")
|
|
569
|
+
log_and_print(f"Prompt: {pair.prompt!r}")
|
|
570
|
+
log_and_print(f"Positive Response: {pair.positive_response.text!r}")
|
|
571
|
+
log_and_print(f"Negative Response: {pair.negative_response.text!r}")
|
|
572
|
+
log_and_print(f"Positive Response Type: {type(pair.positive_response)}")
|
|
573
|
+
log_and_print(f"Negative Response Type: {type(pair.negative_response)}")
|
|
574
|
+
log_and_print(
|
|
575
|
+
f"Positive Response Length: {len(pair.positive_response.text) if hasattr(pair.positive_response, 'text') else 'N/A'}"
|
|
576
|
+
)
|
|
577
|
+
log_and_print(
|
|
578
|
+
f"Negative Response Length: {len(pair.negative_response.text) if hasattr(pair.negative_response, 'text') else 'N/A'}"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Check for any special attributes
|
|
582
|
+
if hasattr(pair, "_prompt_pair"):
|
|
583
|
+
log_and_print(f"Has _prompt_pair: {pair._prompt_pair}")
|
|
584
|
+
if hasattr(pair, "_prompt_strategy"):
|
|
585
|
+
log_and_print(f"Has _prompt_strategy: {pair._prompt_strategy}")
|
|
586
|
+
|
|
587
|
+
log_and_print("=" * 80)
|
|
588
|
+
|
|
589
|
+
if len(pair_set.pairs) < 3:
|
|
590
|
+
error_msg = f"Insufficient training pairs generated: {len(pair_set.pairs)}"
|
|
591
|
+
log_and_print(f"ā ERROR: {error_msg}")
|
|
592
|
+
raise ValueError(error_msg)
|
|
593
|
+
|
|
594
|
+
# Extract activations for training
|
|
595
|
+
positive_activations = []
|
|
596
|
+
negative_activations = []
|
|
597
|
+
|
|
598
|
+
log_and_print(f"š§ Extracting activations from {len(pair_set.pairs)} pairs...")
|
|
599
|
+
|
|
600
|
+
# Create Layer object for activation extraction
|
|
601
|
+
from wisent.core.layer import Layer
|
|
602
|
+
|
|
603
|
+
layer_obj = Layer(index=15, type="transformer")
|
|
604
|
+
log_and_print(f"š§ Created Layer object: index={layer_obj.index}, type={layer_obj.type}")
|
|
605
|
+
|
|
606
|
+
for i, pair in enumerate(pair_set.pairs):
|
|
607
|
+
log_and_print(f"\nš Processing pair {i + 1}/{len(pair_set.pairs)}...")
|
|
608
|
+
try:
|
|
609
|
+
# Get activations for positive response
|
|
610
|
+
log_and_print(f" š Extracting positive activations for: {pair.positive_response.text[:100]!r}")
|
|
611
|
+
pos_activations = model.extract_activations(pair.positive_response.text, layer_obj)
|
|
612
|
+
log_and_print(
|
|
613
|
+
f" ā
Positive activations shape: {pos_activations.shape if hasattr(pos_activations, 'shape') else 'N/A'}"
|
|
614
|
+
)
|
|
615
|
+
positive_activations.append(pos_activations)
|
|
616
|
+
|
|
617
|
+
# Get activations for negative response
|
|
618
|
+
log_and_print(f" š Extracting negative activations for: {pair.negative_response.text[:100]!r}")
|
|
619
|
+
neg_activations = model.extract_activations(pair.negative_response.text, layer_obj)
|
|
620
|
+
log_and_print(
|
|
621
|
+
f" ā
Negative activations shape: {neg_activations.shape if hasattr(neg_activations, 'shape') else 'N/A'}"
|
|
622
|
+
)
|
|
623
|
+
negative_activations.append(neg_activations)
|
|
624
|
+
|
|
625
|
+
log_and_print(f" ā
Successfully processed pair {i + 1}")
|
|
626
|
+
|
|
627
|
+
except Exception as e:
|
|
628
|
+
log_and_print(f" ā ļø Error extracting activations for pair {i + 1}: {e}")
|
|
629
|
+
import traceback
|
|
630
|
+
|
|
631
|
+
error_details = traceback.format_exc()
|
|
632
|
+
log_and_print(f" š Full error traceback:\n{error_details}")
|
|
633
|
+
continue
|
|
634
|
+
|
|
635
|
+
log_and_print("\nš ACTIVATION EXTRACTION SUMMARY:")
|
|
636
|
+
log_and_print(f" Positive activations collected: {len(positive_activations)}")
|
|
637
|
+
log_and_print(f" Negative activations collected: {len(negative_activations)}")
|
|
638
|
+
log_and_print(f" Total pairs processed: {len(pair_set.pairs)}")
|
|
639
|
+
log_and_print(f" Success rate: {(len(positive_activations) / len(pair_set.pairs) * 100):.1f}%")
|
|
640
|
+
|
|
641
|
+
if len(positive_activations) < 2 or len(negative_activations) < 2:
|
|
642
|
+
error_msg = f"Insufficient activation data for training: {len(positive_activations)} positive, {len(negative_activations)} negative"
|
|
643
|
+
log_and_print(f"ā ERROR: {error_msg}")
|
|
644
|
+
raise ValueError(error_msg)
|
|
645
|
+
|
|
646
|
+
# Train classifier on activations
|
|
647
|
+
log_and_print(
|
|
648
|
+
f"šļø Training classifier on {len(positive_activations)} positive, {len(negative_activations)} negative activations..."
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
log_and_print("š§ Creating ActivationClassifier instance...")
|
|
652
|
+
classifier = ActivationClassifier()
|
|
653
|
+
log_and_print("ā
ActivationClassifier created")
|
|
654
|
+
|
|
655
|
+
log_and_print("šÆ Starting classifier training...")
|
|
656
|
+
try:
|
|
657
|
+
# Convert activations to the format expected by train_on_activations method
|
|
658
|
+
from wisent.core.activations import Activations
|
|
659
|
+
|
|
660
|
+
# Convert torch tensors to Activations objects if needed
|
|
661
|
+
harmful_activations = []
|
|
662
|
+
harmless_activations = []
|
|
663
|
+
|
|
664
|
+
for pos_act in positive_activations:
|
|
665
|
+
if hasattr(pos_act, "shape"): # It's a torch tensor
|
|
666
|
+
# Create Activations object from tensor
|
|
667
|
+
act_obj = Activations(pos_act, layer_obj)
|
|
668
|
+
harmful_activations.append(act_obj)
|
|
669
|
+
else:
|
|
670
|
+
harmful_activations.append(pos_act)
|
|
671
|
+
|
|
672
|
+
for neg_act in negative_activations:
|
|
673
|
+
if hasattr(neg_act, "shape"): # It's a torch tensor
|
|
674
|
+
# Create Activations object from tensor
|
|
675
|
+
act_obj = Activations(neg_act, layer_obj)
|
|
676
|
+
harmless_activations.append(act_obj)
|
|
677
|
+
else:
|
|
678
|
+
harmless_activations.append(neg_act)
|
|
679
|
+
|
|
680
|
+
log_and_print(
|
|
681
|
+
f"š§ Converted to Activations objects: {len(harmful_activations)} harmful, {len(harmless_activations)} harmless"
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# Train using the correct method
|
|
685
|
+
training_result = classifier.train_on_activations(harmful_activations, harmless_activations)
|
|
686
|
+
log_and_print(f"ā
Classifier training completed successfully! Result: {training_result}")
|
|
687
|
+
except Exception as e:
|
|
688
|
+
log_and_print(f"ā ERROR during classifier training: {e}")
|
|
689
|
+
import traceback
|
|
690
|
+
|
|
691
|
+
error_details = traceback.format_exc()
|
|
692
|
+
log_and_print(f"š Full training error traceback:\n{error_details}")
|
|
693
|
+
raise
|
|
694
|
+
|
|
695
|
+
# Store metadata
|
|
696
|
+
classifier._trait_description = trait_description
|
|
697
|
+
classifier._pairs_count = len(pair_set.pairs)
|
|
698
|
+
log_and_print(f"š Stored metadata: trait='{trait_description}', pairs_count={len(pair_set.pairs)}")
|
|
699
|
+
|
|
700
|
+
log_and_print("š Classifier creation completed successfully!")
|
|
701
|
+
log_and_print(f"š Debug log saved to: {log_file}")
|
|
702
|
+
|
|
703
|
+
return classifier
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def evaluate_response_with_trait_classifier(
|
|
707
|
+
model, response_text: str, trait_classifier: ActivationClassifier
|
|
708
|
+
) -> SyntheticClassifierResult:
|
|
709
|
+
"""
|
|
710
|
+
Evaluate a response using a trait-specific classifier.
|
|
711
|
+
|
|
712
|
+
Args:
|
|
713
|
+
model: The language model instance
|
|
714
|
+
response_text: Response to analyze
|
|
715
|
+
trait_classifier: Pre-trained classifier for a specific trait
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
Classification result
|
|
719
|
+
"""
|
|
720
|
+
trait_description = getattr(trait_classifier, "_trait_description", "unknown_trait")
|
|
721
|
+
pairs_count = getattr(trait_classifier, "_pairs_count", 0)
|
|
722
|
+
|
|
723
|
+
logging.info(f"Evaluating response with '{trait_description}' classifier...")
|
|
724
|
+
|
|
725
|
+
# Extract activations from response
|
|
726
|
+
try:
|
|
727
|
+
response_activations, _ = model.extract_activations(response_text, layer=15)
|
|
728
|
+
except Exception as e:
|
|
729
|
+
raise ValueError(f"Error extracting response activations: {e}")
|
|
730
|
+
|
|
731
|
+
# Apply classifier
|
|
732
|
+
start_time = time.time()
|
|
733
|
+
prediction = trait_classifier.predict(response_activations)
|
|
734
|
+
confidence = trait_classifier.predict_proba(response_activations)
|
|
735
|
+
|
|
736
|
+
# Handle confidence score
|
|
737
|
+
if hasattr(confidence, "__iter__") and len(confidence) > 1:
|
|
738
|
+
confidence_score = float(max(confidence))
|
|
739
|
+
else:
|
|
740
|
+
confidence_score = float(confidence)
|
|
741
|
+
|
|
742
|
+
generation_time = time.time() - start_time
|
|
743
|
+
|
|
744
|
+
result = SyntheticClassifierResult(
|
|
745
|
+
trait_description=trait_description,
|
|
746
|
+
classifier_confidence=confidence_score,
|
|
747
|
+
prediction=int(prediction),
|
|
748
|
+
confidence_score=confidence_score,
|
|
749
|
+
training_pairs_count=pairs_count,
|
|
750
|
+
generation_time=generation_time,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
logging.info(f"Result: prediction={prediction}, confidence={confidence_score:.3f}")
|
|
754
|
+
return result
|