wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
On-the-Fly Classifier Creation System for Autonomous Agent
|
|
3
|
+
|
|
4
|
+
This module handles:
|
|
5
|
+
- Dynamic training of new classifiers for specific issue types
|
|
6
|
+
- Automatic training data generation for different problem domains
|
|
7
|
+
- Classifier optimization and validation
|
|
8
|
+
- Integration with the autonomous agent system
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
from wisent.core.classifier.classifier import ActivationClassifier, Classifier
|
|
16
|
+
|
|
17
|
+
from ...activations import Activations
|
|
18
|
+
from ...layer import Layer
|
|
19
|
+
from ...model import Model
|
|
20
|
+
from ...model_persistence import ModelPersistence, create_classifier_metadata
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class TrainingConfig:
|
|
25
|
+
"""Configuration for classifier training."""
|
|
26
|
+
|
|
27
|
+
issue_type: str
|
|
28
|
+
layer: int
|
|
29
|
+
classifier_type: str = "logistic"
|
|
30
|
+
threshold: float = 0.5
|
|
31
|
+
model_name: str = ""
|
|
32
|
+
training_samples: int = 100
|
|
33
|
+
test_split: float = 0.2
|
|
34
|
+
optimization_metric: str = "f1"
|
|
35
|
+
save_path: Optional[str] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class TrainingResult:
|
|
40
|
+
"""Result of classifier training."""
|
|
41
|
+
|
|
42
|
+
classifier: Classifier
|
|
43
|
+
config: TrainingConfig
|
|
44
|
+
performance_metrics: Dict[str, float]
|
|
45
|
+
training_time: float
|
|
46
|
+
save_path: Optional[str] = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ClassifierCreator:
|
|
50
|
+
"""Creates new classifiers on demand for the autonomous agent."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, model: Model):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the classifier creator.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model: The language model to use for training
|
|
58
|
+
"""
|
|
59
|
+
self.model = model
|
|
60
|
+
|
|
61
|
+
def create_classifier_for_issue_type(
|
|
62
|
+
self, issue_type: str, layer: int, config: Optional[TrainingConfig] = None
|
|
63
|
+
) -> TrainingResult:
|
|
64
|
+
"""
|
|
65
|
+
Create a new classifier for a specific issue type.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
issue_type: Type of issue to detect (e.g., "hallucination", "quality")
|
|
69
|
+
layer: Model layer to use for activation extraction
|
|
70
|
+
config: Optional training configuration
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
TrainingResult with the trained classifier and metrics
|
|
74
|
+
"""
|
|
75
|
+
print(f"🏋️ Creating classifier for {issue_type} at layer {layer}...")
|
|
76
|
+
|
|
77
|
+
# Use provided config or create default
|
|
78
|
+
if config is None:
|
|
79
|
+
config = TrainingConfig(issue_type=issue_type, layer=layer, model_name=self.model.name)
|
|
80
|
+
|
|
81
|
+
start_time = time.time()
|
|
82
|
+
|
|
83
|
+
# Generate training data
|
|
84
|
+
print(" 📊 Generating training data...")
|
|
85
|
+
training_data = self._generate_training_data(issue_type, config.training_samples)
|
|
86
|
+
|
|
87
|
+
# Extract activations
|
|
88
|
+
print(" 🧠 Extracting activations...")
|
|
89
|
+
harmful_activations, harmless_activations = self._extract_activations_from_data(training_data, layer)
|
|
90
|
+
|
|
91
|
+
# Train classifier
|
|
92
|
+
print(" 🎯 Training classifier...")
|
|
93
|
+
classifier = self._train_classifier(harmful_activations, harmless_activations, config)
|
|
94
|
+
|
|
95
|
+
# Evaluate performance
|
|
96
|
+
print(" 📈 Evaluating performance...")
|
|
97
|
+
metrics = self._evaluate_classifier(classifier, harmful_activations, harmless_activations)
|
|
98
|
+
|
|
99
|
+
training_time = time.time() - start_time
|
|
100
|
+
|
|
101
|
+
# Save classifier if path provided
|
|
102
|
+
save_path = None
|
|
103
|
+
if config.save_path:
|
|
104
|
+
print(" 💾 Saving classifier...")
|
|
105
|
+
save_path = self._save_classifier(classifier, config, metrics)
|
|
106
|
+
|
|
107
|
+
result = TrainingResult(
|
|
108
|
+
classifier=classifier.classifier, # Return the base classifier
|
|
109
|
+
config=config,
|
|
110
|
+
performance_metrics=metrics,
|
|
111
|
+
training_time=training_time,
|
|
112
|
+
save_path=save_path,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
print(
|
|
116
|
+
f" ✅ Classifier created in {training_time:.2f}s "
|
|
117
|
+
f"(F1: {metrics.get('f1', 0):.3f}, Accuracy: {metrics.get('accuracy', 0):.3f})"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
def create_multi_layer_classifiers(
|
|
123
|
+
self, issue_type: str, layers: List[int], save_base_path: Optional[str] = None
|
|
124
|
+
) -> Dict[int, TrainingResult]:
|
|
125
|
+
"""
|
|
126
|
+
Create classifiers for multiple layers for the same issue type.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
issue_type: Type of issue to detect
|
|
130
|
+
layers: List of layers to create classifiers for
|
|
131
|
+
save_base_path: Base path for saving classifiers
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Dictionary mapping layer indices to training results
|
|
135
|
+
"""
|
|
136
|
+
print(f"🔄 Creating multi-layer classifiers for {issue_type}...")
|
|
137
|
+
|
|
138
|
+
results = {}
|
|
139
|
+
|
|
140
|
+
for layer in layers:
|
|
141
|
+
config = TrainingConfig(
|
|
142
|
+
issue_type=issue_type,
|
|
143
|
+
layer=layer,
|
|
144
|
+
model_name=self.model.name,
|
|
145
|
+
save_path=f"{save_base_path}_layer_{layer}.pkl" if save_base_path else None,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
result = self.create_classifier_for_issue_type(issue_type, layer, config)
|
|
149
|
+
results[layer] = result
|
|
150
|
+
|
|
151
|
+
print(f" ✅ Created {len(results)} classifiers across layers {layers}")
|
|
152
|
+
return results
|
|
153
|
+
|
|
154
|
+
def optimize_classifier_for_performance(
|
|
155
|
+
self,
|
|
156
|
+
issue_type: str,
|
|
157
|
+
layer_range: Tuple[int, int] = None,
|
|
158
|
+
classifier_types: List[str] = None,
|
|
159
|
+
target_metric: str = "f1",
|
|
160
|
+
min_target_score: float = 0.7,
|
|
161
|
+
) -> TrainingResult:
|
|
162
|
+
"""
|
|
163
|
+
Optimize classifier by testing different configurations.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
issue_type: Type of issue to detect
|
|
167
|
+
layer_range: Range of layers to test (start, end). If None, auto-detect all model layers
|
|
168
|
+
classifier_types: Types of classifiers to test
|
|
169
|
+
target_metric: Metric to optimize for
|
|
170
|
+
min_target_score: Minimum acceptable score
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Best performing classifier configuration
|
|
174
|
+
"""
|
|
175
|
+
print(f"🎯 Optimizing classifier for {issue_type}...")
|
|
176
|
+
|
|
177
|
+
if classifier_types is None:
|
|
178
|
+
classifier_types = ["logistic", "mlp"]
|
|
179
|
+
|
|
180
|
+
# Auto-detect layer range if not provided
|
|
181
|
+
if layer_range is None:
|
|
182
|
+
from ..hyperparameter_optimizer import detect_model_layers
|
|
183
|
+
|
|
184
|
+
total_layers = detect_model_layers(self.model)
|
|
185
|
+
layer_range = (0, total_layers - 1)
|
|
186
|
+
print(f" 📊 Auto-detected {total_layers} layers, testing range {layer_range[0]}-{layer_range[1]}")
|
|
187
|
+
|
|
188
|
+
best_result = None
|
|
189
|
+
best_score = 0.0
|
|
190
|
+
|
|
191
|
+
layers_to_test = range(layer_range[0], layer_range[1] + 1, 2) # Test every 2nd layer
|
|
192
|
+
|
|
193
|
+
for layer in layers_to_test:
|
|
194
|
+
for classifier_type in classifier_types:
|
|
195
|
+
config = TrainingConfig(
|
|
196
|
+
issue_type=issue_type, layer=layer, classifier_type=classifier_type, model_name=self.model.name
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
result = self.create_classifier_for_issue_type(issue_type, layer, config)
|
|
201
|
+
score = result.performance_metrics.get(target_metric, 0.0)
|
|
202
|
+
|
|
203
|
+
print(f" Layer {layer}, {classifier_type}: {target_metric}={score:.3f}")
|
|
204
|
+
|
|
205
|
+
if score > best_score:
|
|
206
|
+
best_score = score
|
|
207
|
+
best_result = result
|
|
208
|
+
|
|
209
|
+
# Early stopping if we hit the target
|
|
210
|
+
if score >= min_target_score:
|
|
211
|
+
print(f" 🎉 Target score reached: {score:.3f}")
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
except Exception as e:
|
|
215
|
+
print(f" ❌ Failed layer {layer}, {classifier_type}: {e}")
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
# Break outer loop if target reached
|
|
219
|
+
if best_score >= min_target_score:
|
|
220
|
+
break
|
|
221
|
+
|
|
222
|
+
if best_result is None:
|
|
223
|
+
raise RuntimeError(f"Failed to create any working classifier for {issue_type}")
|
|
224
|
+
|
|
225
|
+
print(
|
|
226
|
+
f" ✅ Best configuration: Layer {best_result.config.layer}, "
|
|
227
|
+
f"{best_result.config.classifier_type}, {target_metric}={best_score:.3f}"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return best_result
|
|
231
|
+
|
|
232
|
+
async def create_classifier_for_issue_with_benchmarks(
|
|
233
|
+
self,
|
|
234
|
+
issue_type: str,
|
|
235
|
+
relevant_benchmarks: List[str],
|
|
236
|
+
layer: int = 15,
|
|
237
|
+
num_samples: int = 50,
|
|
238
|
+
config: Optional[TrainingConfig] = None,
|
|
239
|
+
) -> TrainingResult:
|
|
240
|
+
"""
|
|
241
|
+
Create a classifier using specific benchmarks for better contrastive pairs.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
issue_type: Type of issue to detect (e.g., "hallucination", "quality")
|
|
245
|
+
relevant_benchmarks: List of benchmark names to use for training data
|
|
246
|
+
layer: Model layer to use for activation extraction (default: 15)
|
|
247
|
+
num_samples: Number of training samples to generate
|
|
248
|
+
config: Optional training configuration
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
TrainingResult with the trained classifier and metrics
|
|
252
|
+
"""
|
|
253
|
+
print(f"🎯 Creating {issue_type} classifier using benchmarks: {relevant_benchmarks}")
|
|
254
|
+
|
|
255
|
+
# Use provided config or create default
|
|
256
|
+
if config is None:
|
|
257
|
+
config = TrainingConfig(
|
|
258
|
+
issue_type=issue_type, layer=layer, model_name=self.model.name, training_samples=num_samples
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
start_time = time.time()
|
|
262
|
+
|
|
263
|
+
# Generate training data using the provided benchmarks
|
|
264
|
+
print(" 📊 Loading benchmark-specific training data...")
|
|
265
|
+
training_data = []
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
# Load data from the relevant benchmarks
|
|
269
|
+
benchmark_data = self._load_benchmark_data(relevant_benchmarks, num_samples)
|
|
270
|
+
training_data.extend(benchmark_data)
|
|
271
|
+
print(f" ✅ Loaded {len(benchmark_data)} examples from benchmarks")
|
|
272
|
+
except Exception as e:
|
|
273
|
+
print(f" ⚠️ Failed to load benchmark data: {e}")
|
|
274
|
+
|
|
275
|
+
# If we don't have enough data from benchmarks, supplement with synthetic data
|
|
276
|
+
if len(training_data) < num_samples // 2:
|
|
277
|
+
print(" 🧪 Supplementing with synthetic training data...")
|
|
278
|
+
try:
|
|
279
|
+
synthetic_data = self._generate_synthetic_training_data(issue_type, num_samples - len(training_data))
|
|
280
|
+
training_data.extend(synthetic_data)
|
|
281
|
+
print(f" ✅ Added {len(synthetic_data)} synthetic examples")
|
|
282
|
+
except Exception as e:
|
|
283
|
+
print(f" ⚠️ Failed to generate synthetic data: {e}")
|
|
284
|
+
|
|
285
|
+
if not training_data:
|
|
286
|
+
raise ValueError(f"No training data available for {issue_type}")
|
|
287
|
+
|
|
288
|
+
print(f" 📈 Total training examples: {len(training_data)}")
|
|
289
|
+
|
|
290
|
+
# Extract activations
|
|
291
|
+
print(" 🧠 Extracting activations...")
|
|
292
|
+
harmful_activations, harmless_activations = self._extract_activations_from_data(training_data, layer)
|
|
293
|
+
|
|
294
|
+
# Train classifier
|
|
295
|
+
print(" 🎯 Training classifier...")
|
|
296
|
+
classifier = self._train_classifier(harmful_activations, harmless_activations, config)
|
|
297
|
+
|
|
298
|
+
# Evaluate performance
|
|
299
|
+
print(" 📈 Evaluating performance...")
|
|
300
|
+
metrics = self._evaluate_classifier(classifier, harmful_activations, harmless_activations)
|
|
301
|
+
|
|
302
|
+
training_time = time.time() - start_time
|
|
303
|
+
|
|
304
|
+
# Save classifier if path provided
|
|
305
|
+
save_path = None
|
|
306
|
+
if config.save_path:
|
|
307
|
+
print(" 💾 Saving classifier...")
|
|
308
|
+
save_path = self._save_classifier(classifier, config, metrics)
|
|
309
|
+
|
|
310
|
+
result = TrainingResult(
|
|
311
|
+
classifier=classifier.classifier, # Return the base classifier
|
|
312
|
+
config=config,
|
|
313
|
+
performance_metrics=metrics,
|
|
314
|
+
training_time=training_time,
|
|
315
|
+
save_path=save_path,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
print(
|
|
319
|
+
f" ✅ Benchmark-based classifier created in {training_time:.2f}s "
|
|
320
|
+
f"(F1: {metrics.get('f1', 0):.3f}, Accuracy: {metrics.get('accuracy', 0):.3f})"
|
|
321
|
+
)
|
|
322
|
+
print(f" 📊 Used benchmarks: {relevant_benchmarks}")
|
|
323
|
+
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
async def create_combined_benchmark_classifier(
|
|
327
|
+
self, benchmark_names: List[str], classifier_params: "ClassifierParams", config: Optional[TrainingConfig] = None
|
|
328
|
+
) -> TrainingResult:
|
|
329
|
+
"""
|
|
330
|
+
Create a classifier trained on combined data from multiple benchmarks.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
benchmark_names: List of benchmark names to combine training data from
|
|
334
|
+
classifier_params: Model-determined classifier parameters
|
|
335
|
+
config: Optional training configuration
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
TrainingResult with the trained combined classifier
|
|
339
|
+
"""
|
|
340
|
+
print(f"🏗️ Creating combined classifier from {len(benchmark_names)} benchmarks...")
|
|
341
|
+
print(f" 📊 Benchmarks: {benchmark_names}")
|
|
342
|
+
print(f" 🧠 Using layer {classifier_params.optimal_layer}, {classifier_params.training_samples} samples")
|
|
343
|
+
|
|
344
|
+
# Create config from classifier_params
|
|
345
|
+
if config is None:
|
|
346
|
+
config = TrainingConfig(
|
|
347
|
+
issue_type=f"quality_combined_{'_'.join(sorted(benchmark_names))}",
|
|
348
|
+
layer=classifier_params.optimal_layer,
|
|
349
|
+
classifier_type=classifier_params.classifier_type,
|
|
350
|
+
threshold=classifier_params.classification_threshold,
|
|
351
|
+
training_samples=classifier_params.training_samples,
|
|
352
|
+
model_name=self.model.name,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
start_time = time.time()
|
|
356
|
+
|
|
357
|
+
# Generate combined training data from all benchmarks
|
|
358
|
+
print(" 📊 Loading and combining benchmark training data...")
|
|
359
|
+
combined_training_data = await self._load_combined_benchmark_data(
|
|
360
|
+
benchmark_names, classifier_params.training_samples
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
print(f" 📈 Loaded {len(combined_training_data)} combined training examples")
|
|
364
|
+
|
|
365
|
+
# Extract activations
|
|
366
|
+
print(" 🧠 Extracting activations...")
|
|
367
|
+
harmful_activations, harmless_activations = self._extract_activations_from_data(
|
|
368
|
+
combined_training_data, classifier_params.optimal_layer
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Train classifier
|
|
372
|
+
print(" 🎯 Training combined classifier...")
|
|
373
|
+
classifier = self._train_classifier(harmful_activations, harmless_activations, config)
|
|
374
|
+
|
|
375
|
+
# Evaluate performance
|
|
376
|
+
print(" 📈 Evaluating performance...")
|
|
377
|
+
metrics = self._evaluate_classifier(classifier, harmful_activations, harmless_activations)
|
|
378
|
+
|
|
379
|
+
training_time = time.time() - start_time
|
|
380
|
+
|
|
381
|
+
# Save classifier if path provided
|
|
382
|
+
save_path = None
|
|
383
|
+
if config.save_path:
|
|
384
|
+
print(" 💾 Saving combined classifier...")
|
|
385
|
+
save_path = self._save_classifier(classifier, config, metrics)
|
|
386
|
+
|
|
387
|
+
result = TrainingResult(
|
|
388
|
+
classifier=classifier.classifier,
|
|
389
|
+
config=config,
|
|
390
|
+
performance_metrics=metrics,
|
|
391
|
+
training_time=training_time,
|
|
392
|
+
save_path=save_path,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
print(
|
|
396
|
+
f" ✅ Combined classifier created in {training_time:.2f}s "
|
|
397
|
+
f"(F1: {metrics.get('f1', 0):.3f}, Accuracy: {metrics.get('accuracy', 0):.3f})"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
return result
|
|
401
|
+
|
|
402
|
+
async def _load_combined_benchmark_data(
|
|
403
|
+
self, benchmark_names: List[str], total_samples: int
|
|
404
|
+
) -> List[Dict[str, Any]]:
|
|
405
|
+
"""
|
|
406
|
+
Load and combine training data from multiple benchmarks.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
benchmark_names: List of benchmark names to load data from
|
|
410
|
+
total_samples: Total number of training samples to create
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
Combined list of training examples with balanced sampling
|
|
414
|
+
"""
|
|
415
|
+
combined_data = []
|
|
416
|
+
samples_per_benchmark = max(1, total_samples // len(benchmark_names))
|
|
417
|
+
|
|
418
|
+
print(f" 📊 Loading ~{samples_per_benchmark} samples per benchmark")
|
|
419
|
+
|
|
420
|
+
for benchmark_name in benchmark_names:
|
|
421
|
+
try:
|
|
422
|
+
print(f" 🔄 Loading data from {benchmark_name}...")
|
|
423
|
+
benchmark_data = self._load_benchmark_data([benchmark_name], samples_per_benchmark)
|
|
424
|
+
combined_data.extend(benchmark_data)
|
|
425
|
+
print(f" ✅ Loaded {len(benchmark_data)} samples from {benchmark_name}")
|
|
426
|
+
|
|
427
|
+
except Exception as e:
|
|
428
|
+
print(f" ⚠️ Failed to load {benchmark_name}: {e}")
|
|
429
|
+
# Continue with other benchmarks
|
|
430
|
+
continue
|
|
431
|
+
|
|
432
|
+
# If we don't have enough samples, pad with synthetic data
|
|
433
|
+
if len(combined_data) < total_samples:
|
|
434
|
+
remaining_samples = total_samples - len(combined_data)
|
|
435
|
+
print(f" 🔧 Generating {remaining_samples} synthetic samples to reach target")
|
|
436
|
+
synthetic_data = self._generate_synthetic_training_data("quality", remaining_samples)
|
|
437
|
+
combined_data.extend(synthetic_data)
|
|
438
|
+
|
|
439
|
+
# Shuffle the combined data to ensure good mixing
|
|
440
|
+
import random
|
|
441
|
+
|
|
442
|
+
random.shuffle(combined_data)
|
|
443
|
+
|
|
444
|
+
# Trim to exact target if we have too many
|
|
445
|
+
combined_data = combined_data[:total_samples]
|
|
446
|
+
|
|
447
|
+
print(f" ✅ Final combined dataset: {len(combined_data)} samples")
|
|
448
|
+
return combined_data
|
|
449
|
+
|
|
450
|
+
async def create_classifier_for_issue(self, issue_type: str, layer: int = 15) -> TrainingResult:
|
|
451
|
+
"""
|
|
452
|
+
Create a classifier for an issue type (async version for compatibility).
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
issue_type: Type of issue to detect
|
|
456
|
+
layer: Model layer to use for activation extraction
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
TrainingResult with the trained classifier
|
|
460
|
+
"""
|
|
461
|
+
return self.create_classifier_for_issue_type(issue_type, layer)
|
|
462
|
+
|
|
463
|
+
def _generate_training_data(self, issue_type: str, num_samples: int) -> List[Dict[str, Any]]:
|
|
464
|
+
"""
|
|
465
|
+
Generate training data dynamically for a specific issue type using relevant benchmarks.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
issue_type: Type of issue to generate data for
|
|
469
|
+
num_samples: Number of training samples to generate
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
List of training examples with harmful/harmless pairs
|
|
473
|
+
"""
|
|
474
|
+
print(f" 📊 Loading dynamic training data for {issue_type}...")
|
|
475
|
+
|
|
476
|
+
# Try to find relevant benchmarks for the issue type (using default 5-minute budget)
|
|
477
|
+
relevant_benchmarks = self._find_relevant_benchmarks(issue_type)
|
|
478
|
+
|
|
479
|
+
if relevant_benchmarks:
|
|
480
|
+
print(f" 🎯 Found {len(relevant_benchmarks)} relevant benchmarks: {relevant_benchmarks[:3]}...")
|
|
481
|
+
return self._load_benchmark_data(relevant_benchmarks, num_samples)
|
|
482
|
+
print(" 🤖 No specific benchmarks found, using synthetic generation...")
|
|
483
|
+
return self._generate_synthetic_training_data(issue_type, num_samples)
|
|
484
|
+
|
|
485
|
+
def _find_relevant_benchmarks(self, issue_type: str, time_budget_minutes: float = 5.0) -> List[str]:
|
|
486
|
+
"""Find relevant benchmarks for the given issue type based on time budget with priority-aware selection."""
|
|
487
|
+
from ..budget import calculate_max_tasks_for_time_budget
|
|
488
|
+
from .tasks.task_relevance import find_relevant_tasks
|
|
489
|
+
|
|
490
|
+
try:
|
|
491
|
+
# Calculate max tasks using budget system
|
|
492
|
+
max_tasks = calculate_max_tasks_for_time_budget(
|
|
493
|
+
task_type="benchmark_evaluation", time_budget_minutes=time_budget_minutes
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
print(f" 🕐 Time budget: {time_budget_minutes:.1f}min → max {max_tasks} tasks")
|
|
497
|
+
|
|
498
|
+
# Use priority-aware intelligent benchmark selection
|
|
499
|
+
try:
|
|
500
|
+
# Import priority-aware selection function
|
|
501
|
+
import os
|
|
502
|
+
import sys
|
|
503
|
+
|
|
504
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "lm-harness-integration"))
|
|
505
|
+
from only_benchmarks import find_most_relevant_benchmarks
|
|
506
|
+
|
|
507
|
+
# Use priority-aware selection with time budget
|
|
508
|
+
relevant_results = find_most_relevant_benchmarks(
|
|
509
|
+
prompt=issue_type,
|
|
510
|
+
top_k=max_tasks,
|
|
511
|
+
priority="all",
|
|
512
|
+
fast_only=False,
|
|
513
|
+
time_budget_minutes=time_budget_minutes,
|
|
514
|
+
prefer_fast=True, # Prefer fast benchmarks for agent use
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# Extract benchmark names
|
|
518
|
+
relevant_benchmarks = [result["benchmark"] for result in relevant_results]
|
|
519
|
+
|
|
520
|
+
if relevant_benchmarks:
|
|
521
|
+
print(f" 📊 Found {len(relevant_benchmarks)} priority-aware benchmarks for '{issue_type}':")
|
|
522
|
+
for i, result in enumerate(relevant_results[:3]):
|
|
523
|
+
priority_str = f" (priority: {result.get('priority', 'unknown')})"
|
|
524
|
+
loading_time_str = f" (loading time: {result.get('loading_time', 60.0):.1f}s)"
|
|
525
|
+
print(f" {i + 1}. {result['benchmark']}{priority_str}{loading_time_str}")
|
|
526
|
+
if len(relevant_benchmarks) > 3:
|
|
527
|
+
print(f" ... and {len(relevant_benchmarks) - 3} more")
|
|
528
|
+
|
|
529
|
+
return relevant_benchmarks
|
|
530
|
+
|
|
531
|
+
except Exception as priority_error:
|
|
532
|
+
print(f" ⚠️ Priority-aware selection failed: {priority_error}")
|
|
533
|
+
print(" 🔄 Falling back to legacy task relevance...")
|
|
534
|
+
|
|
535
|
+
# Fallback to legacy system
|
|
536
|
+
relevant_task_results = find_relevant_tasks(
|
|
537
|
+
query=issue_type, max_results=max_tasks, min_relevance_score=0.1
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Extract just the task names
|
|
541
|
+
candidate_benchmarks = [task_name for task_name, score in relevant_task_results]
|
|
542
|
+
|
|
543
|
+
# Use priority-aware budget optimization
|
|
544
|
+
from ..budget import optimize_benchmarks_for_budget
|
|
545
|
+
|
|
546
|
+
relevant_benchmarks = optimize_benchmarks_for_budget(
|
|
547
|
+
task_candidates=candidate_benchmarks,
|
|
548
|
+
time_budget_minutes=time_budget_minutes,
|
|
549
|
+
max_tasks=max_tasks,
|
|
550
|
+
prefer_fast=True, # Agent prefers fast benchmarks
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
if relevant_benchmarks:
|
|
554
|
+
print(f" 📊 Found {len(relevant_benchmarks)} relevant benchmarks for '{issue_type}':")
|
|
555
|
+
# Show the scores for the selected benchmarks
|
|
556
|
+
for i, (task_name, score) in enumerate(relevant_task_results[:3]):
|
|
557
|
+
if task_name in relevant_benchmarks:
|
|
558
|
+
print(f" {i + 1}. {task_name} (relevance: {score:.3f})")
|
|
559
|
+
if len(relevant_benchmarks) > 3:
|
|
560
|
+
print(f" ... and {len(relevant_benchmarks) - 3} more")
|
|
561
|
+
|
|
562
|
+
return relevant_benchmarks
|
|
563
|
+
|
|
564
|
+
except Exception as e:
|
|
565
|
+
print(f" ⚠️ Error finding relevant benchmarks: {e}")
|
|
566
|
+
print(" ⚠️ Using fallback tasks")
|
|
567
|
+
# Minimal fallback to high priority fast benchmarks
|
|
568
|
+
return ["mmlu", "truthfulqa_mc1", "hellaswag"]
|
|
569
|
+
|
|
570
|
+
def _extract_benchmark_concepts(self, benchmark_names: List[str]) -> Dict[str, List[str]]:
|
|
571
|
+
"""Extract semantic concepts from benchmark names."""
|
|
572
|
+
concepts = {}
|
|
573
|
+
|
|
574
|
+
for name in benchmark_names:
|
|
575
|
+
# Extract concepts from benchmark name
|
|
576
|
+
name_concepts = []
|
|
577
|
+
name_lower = name.lower()
|
|
578
|
+
|
|
579
|
+
# Split on common separators and extract meaningful tokens
|
|
580
|
+
tokens = name_lower.replace("_", " ").replace("-", " ").split()
|
|
581
|
+
|
|
582
|
+
# Filter out common non-semantic tokens
|
|
583
|
+
semantic_tokens = []
|
|
584
|
+
skip_tokens = {
|
|
585
|
+
"the",
|
|
586
|
+
"and",
|
|
587
|
+
"or",
|
|
588
|
+
"of",
|
|
589
|
+
"in",
|
|
590
|
+
"on",
|
|
591
|
+
"at",
|
|
592
|
+
"to",
|
|
593
|
+
"for",
|
|
594
|
+
"with",
|
|
595
|
+
"by",
|
|
596
|
+
"from",
|
|
597
|
+
"as",
|
|
598
|
+
"is",
|
|
599
|
+
"are",
|
|
600
|
+
"was",
|
|
601
|
+
"were",
|
|
602
|
+
"be",
|
|
603
|
+
"been",
|
|
604
|
+
"being",
|
|
605
|
+
"have",
|
|
606
|
+
"has",
|
|
607
|
+
"had",
|
|
608
|
+
"do",
|
|
609
|
+
"does",
|
|
610
|
+
"did",
|
|
611
|
+
"will",
|
|
612
|
+
"would",
|
|
613
|
+
"could",
|
|
614
|
+
"should",
|
|
615
|
+
"may",
|
|
616
|
+
"might",
|
|
617
|
+
"can",
|
|
618
|
+
"light",
|
|
619
|
+
"full",
|
|
620
|
+
"val",
|
|
621
|
+
"test",
|
|
622
|
+
"dev",
|
|
623
|
+
"mc1",
|
|
624
|
+
"mc2",
|
|
625
|
+
"mt",
|
|
626
|
+
"cot",
|
|
627
|
+
"fewshot",
|
|
628
|
+
"zeroshot",
|
|
629
|
+
"generate",
|
|
630
|
+
"until",
|
|
631
|
+
"multiple",
|
|
632
|
+
"choice",
|
|
633
|
+
"group",
|
|
634
|
+
"subset",
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
for token in tokens:
|
|
638
|
+
if len(token) > 2 and token not in skip_tokens and token.isalpha():
|
|
639
|
+
semantic_tokens.append(token)
|
|
640
|
+
|
|
641
|
+
# Extract domain-specific concepts
|
|
642
|
+
domain_concepts = self._extract_domain_concepts(name_lower, semantic_tokens)
|
|
643
|
+
name_concepts.extend(domain_concepts)
|
|
644
|
+
|
|
645
|
+
concepts[name] = list(set(name_concepts)) # Remove duplicates
|
|
646
|
+
|
|
647
|
+
return concepts
|
|
648
|
+
|
|
649
|
+
def _extract_domain_concepts(self, benchmark_name: str, tokens: List[str]) -> List[str]:
|
|
650
|
+
"""Extract domain-specific concepts directly from benchmark name components."""
|
|
651
|
+
concepts = []
|
|
652
|
+
|
|
653
|
+
# Add all meaningful tokens as concepts
|
|
654
|
+
for token in tokens:
|
|
655
|
+
if len(token) > 2:
|
|
656
|
+
concepts.append(token)
|
|
657
|
+
|
|
658
|
+
# Extract compound concept meanings from token combinations
|
|
659
|
+
name_parts = benchmark_name.lower().split("_")
|
|
660
|
+
|
|
661
|
+
# Generate concept combinations
|
|
662
|
+
for i, part in enumerate(name_parts):
|
|
663
|
+
if len(part) > 2:
|
|
664
|
+
concepts.append(part)
|
|
665
|
+
|
|
666
|
+
# Look for meaningful compound concepts
|
|
667
|
+
if i < len(name_parts) - 1:
|
|
668
|
+
next_part = name_parts[i + 1]
|
|
669
|
+
if len(next_part) > 2:
|
|
670
|
+
compound = f"{part}_{next_part}"
|
|
671
|
+
concepts.append(compound)
|
|
672
|
+
|
|
673
|
+
# Extract semantic root words
|
|
674
|
+
for token in tokens:
|
|
675
|
+
root_concepts = self._extract_semantic_roots(token)
|
|
676
|
+
concepts.extend(root_concepts)
|
|
677
|
+
|
|
678
|
+
return list(set(concepts)) # Remove duplicates
|
|
679
|
+
|
|
680
|
+
def _extract_semantic_roots(self, word: str) -> List[str]:
|
|
681
|
+
"""Extract semantic root concepts from a word."""
|
|
682
|
+
roots = []
|
|
683
|
+
|
|
684
|
+
# Simple morphological analysis
|
|
685
|
+
# Remove common suffixes to find roots
|
|
686
|
+
suffixes = [
|
|
687
|
+
"ing",
|
|
688
|
+
"tion",
|
|
689
|
+
"sion",
|
|
690
|
+
"ness",
|
|
691
|
+
"ment",
|
|
692
|
+
"able",
|
|
693
|
+
"ible",
|
|
694
|
+
"ful",
|
|
695
|
+
"less",
|
|
696
|
+
"ly",
|
|
697
|
+
"al",
|
|
698
|
+
"ic",
|
|
699
|
+
"ous",
|
|
700
|
+
"ive",
|
|
701
|
+
]
|
|
702
|
+
|
|
703
|
+
root = word
|
|
704
|
+
for suffix in suffixes:
|
|
705
|
+
if word.endswith(suffix) and len(word) > len(suffix) + 2:
|
|
706
|
+
root = word[: -len(suffix)]
|
|
707
|
+
break
|
|
708
|
+
|
|
709
|
+
if root != word and len(root) > 2:
|
|
710
|
+
roots.append(root)
|
|
711
|
+
|
|
712
|
+
# Add the original word
|
|
713
|
+
roots.append(word)
|
|
714
|
+
|
|
715
|
+
return roots
|
|
716
|
+
|
|
717
|
+
def _calculate_benchmark_relevance(self, issue_type: str, benchmark_concepts: Dict[str, List[str]]) -> List[str]:
|
|
718
|
+
"""Calculate relevance scores using semantic similarity."""
|
|
719
|
+
# Calculate relevance scores
|
|
720
|
+
benchmark_scores = []
|
|
721
|
+
|
|
722
|
+
for benchmark_name, concepts in benchmark_concepts.items():
|
|
723
|
+
score = self._calculate_semantic_similarity(issue_type, benchmark_name, concepts)
|
|
724
|
+
|
|
725
|
+
if score > 0:
|
|
726
|
+
benchmark_scores.append((benchmark_name, score))
|
|
727
|
+
|
|
728
|
+
# Sort by relevance score
|
|
729
|
+
benchmark_scores.sort(key=lambda x: x[1], reverse=True)
|
|
730
|
+
|
|
731
|
+
return [name for name, score in benchmark_scores]
|
|
732
|
+
|
|
733
|
+
def _calculate_semantic_similarity(self, issue_type: str, benchmark_name: str, concepts: List[str]) -> float:
|
|
734
|
+
"""Calculate semantic similarity between issue type and benchmark."""
|
|
735
|
+
issue_lower = issue_type.lower()
|
|
736
|
+
benchmark_lower = benchmark_name.lower()
|
|
737
|
+
|
|
738
|
+
score = 0.0
|
|
739
|
+
|
|
740
|
+
# Direct name matching (highest weight)
|
|
741
|
+
if issue_lower in benchmark_lower or benchmark_lower in issue_lower:
|
|
742
|
+
score += 5.0
|
|
743
|
+
|
|
744
|
+
# Concept matching
|
|
745
|
+
for concept in concepts:
|
|
746
|
+
concept_lower = concept.lower()
|
|
747
|
+
|
|
748
|
+
# Exact concept match
|
|
749
|
+
if issue_lower == concept_lower:
|
|
750
|
+
score += 4.0
|
|
751
|
+
# Partial concept match
|
|
752
|
+
elif issue_lower in concept_lower or concept_lower in issue_lower:
|
|
753
|
+
score += 2.0
|
|
754
|
+
# Semantic similarity check
|
|
755
|
+
elif self._are_semantically_similar(issue_lower, concept_lower):
|
|
756
|
+
score += 1.5
|
|
757
|
+
|
|
758
|
+
# Token-level similarity in benchmark name
|
|
759
|
+
benchmark_tokens = benchmark_lower.replace("_", " ").replace("-", " ").split()
|
|
760
|
+
issue_tokens = issue_lower.replace("_", " ").replace("-", " ").split()
|
|
761
|
+
|
|
762
|
+
for issue_token in issue_tokens:
|
|
763
|
+
for benchmark_token in benchmark_tokens:
|
|
764
|
+
if len(issue_token) > 2 and len(benchmark_token) > 2:
|
|
765
|
+
if issue_token == benchmark_token:
|
|
766
|
+
score += 3.0
|
|
767
|
+
elif issue_token in benchmark_token or benchmark_token in issue_token:
|
|
768
|
+
score += 1.0
|
|
769
|
+
elif self._are_semantically_similar(issue_token, benchmark_token):
|
|
770
|
+
score += 0.5
|
|
771
|
+
|
|
772
|
+
return score
|
|
773
|
+
|
|
774
|
+
def _are_semantically_similar(self, term1: str, term2: str) -> bool:
|
|
775
|
+
"""Check if two terms are semantically similar using algorithmic methods."""
|
|
776
|
+
if len(term1) < 3 or len(term2) < 3:
|
|
777
|
+
return False
|
|
778
|
+
|
|
779
|
+
# Character-level similarity
|
|
780
|
+
overlap = len(set(term1) & set(term2))
|
|
781
|
+
min_len = min(len(term1), len(term2))
|
|
782
|
+
char_similarity = overlap / min_len
|
|
783
|
+
|
|
784
|
+
# Substring similarity
|
|
785
|
+
longer, shorter = (term1, term2) if len(term1) > len(term2) else (term2, term1)
|
|
786
|
+
substring_match = shorter in longer
|
|
787
|
+
|
|
788
|
+
# Prefix/suffix similarity
|
|
789
|
+
prefix_len = 0
|
|
790
|
+
suffix_len = 0
|
|
791
|
+
|
|
792
|
+
for i in range(min(len(term1), len(term2))):
|
|
793
|
+
if term1[i] == term2[i]:
|
|
794
|
+
prefix_len += 1
|
|
795
|
+
else:
|
|
796
|
+
break
|
|
797
|
+
|
|
798
|
+
for i in range(1, min(len(term1), len(term2)) + 1):
|
|
799
|
+
if term1[-i] == term2[-i]:
|
|
800
|
+
suffix_len += 1
|
|
801
|
+
else:
|
|
802
|
+
break
|
|
803
|
+
|
|
804
|
+
affix_similarity = (prefix_len + suffix_len) / max(len(term1), len(term2))
|
|
805
|
+
|
|
806
|
+
# Combined similarity score
|
|
807
|
+
return char_similarity > 0.6 or substring_match or affix_similarity > 0.4 or prefix_len >= 3 or suffix_len >= 3
|
|
808
|
+
|
|
809
|
+
def _prioritize_benchmarks(self, relevant_benchmarks: List[str]) -> List[str]:
|
|
810
|
+
"""Prioritize benchmarks algorithmically based on naming patterns and characteristics."""
|
|
811
|
+
benchmark_scores = []
|
|
812
|
+
|
|
813
|
+
for benchmark in relevant_benchmarks:
|
|
814
|
+
score = self._calculate_benchmark_quality_score(benchmark)
|
|
815
|
+
benchmark_scores.append((benchmark, score))
|
|
816
|
+
|
|
817
|
+
# Sort by quality score (higher is better)
|
|
818
|
+
benchmark_scores.sort(key=lambda x: x[1], reverse=True)
|
|
819
|
+
return [benchmark for benchmark, score in benchmark_scores]
|
|
820
|
+
|
|
821
|
+
def _calculate_benchmark_quality_score(self, benchmark_name: str) -> float:
|
|
822
|
+
"""Calculate quality score for a benchmark based on naming patterns and characteristics."""
|
|
823
|
+
score = 0.0
|
|
824
|
+
benchmark_lower = benchmark_name.lower()
|
|
825
|
+
|
|
826
|
+
# Length heuristic - moderate length names tend to be well-established
|
|
827
|
+
name_length = len(benchmark_name)
|
|
828
|
+
if 8 <= name_length <= 25:
|
|
829
|
+
score += 2.0
|
|
830
|
+
elif name_length < 8:
|
|
831
|
+
score += 0.5 # Very short names might be too simple
|
|
832
|
+
else:
|
|
833
|
+
score += 1.0 # Very long names might be overly specific
|
|
834
|
+
|
|
835
|
+
# Component analysis
|
|
836
|
+
parts = benchmark_lower.split("_")
|
|
837
|
+
num_parts = len(parts)
|
|
838
|
+
|
|
839
|
+
# Well-structured benchmarks often have 2-3 parts
|
|
840
|
+
if 2 <= num_parts <= 3:
|
|
841
|
+
score += 2.0
|
|
842
|
+
elif num_parts == 1:
|
|
843
|
+
score += 1.5 # Simple names can be good too
|
|
844
|
+
else:
|
|
845
|
+
score += 0.5 # Too many parts might indicate over-specification
|
|
846
|
+
|
|
847
|
+
# Indicator of established benchmarks (avoid hardcoding specific names)
|
|
848
|
+
quality_indicators = [
|
|
849
|
+
# Multiple choice indicators (often well-validated)
|
|
850
|
+
("mc1", 1.5),
|
|
851
|
+
("mc2", 1.5),
|
|
852
|
+
("multiple_choice", 1.5),
|
|
853
|
+
# Evaluation methodology indicators
|
|
854
|
+
("eval", 1.0),
|
|
855
|
+
("test", 1.0),
|
|
856
|
+
("benchmark", 1.0),
|
|
857
|
+
# Language understanding indicators
|
|
858
|
+
("language", 1.0),
|
|
859
|
+
("understanding", 1.0),
|
|
860
|
+
("comprehension", 1.0),
|
|
861
|
+
# Logic and reasoning indicators
|
|
862
|
+
("logic", 1.0),
|
|
863
|
+
("reasoning", 1.0),
|
|
864
|
+
("deduction", 1.0),
|
|
865
|
+
# Knowledge assessment indicators
|
|
866
|
+
("knowledge", 1.0),
|
|
867
|
+
("question", 1.0),
|
|
868
|
+
("answer", 1.0),
|
|
869
|
+
]
|
|
870
|
+
|
|
871
|
+
for indicator, points in quality_indicators:
|
|
872
|
+
if indicator in benchmark_lower:
|
|
873
|
+
score += points
|
|
874
|
+
|
|
875
|
+
# Penalize very specialized or experimental indicators
|
|
876
|
+
experimental_indicators = [
|
|
877
|
+
"experimental",
|
|
878
|
+
"pilot",
|
|
879
|
+
"demo",
|
|
880
|
+
"sample",
|
|
881
|
+
"tiny",
|
|
882
|
+
"mini",
|
|
883
|
+
"subset",
|
|
884
|
+
"light",
|
|
885
|
+
"debug",
|
|
886
|
+
"test_only",
|
|
887
|
+
]
|
|
888
|
+
|
|
889
|
+
for indicator in experimental_indicators:
|
|
890
|
+
if indicator in benchmark_lower:
|
|
891
|
+
score -= 1.0
|
|
892
|
+
|
|
893
|
+
# Bonus for domain diversity indicators
|
|
894
|
+
domain_indicators = ["multilingual", "global", "cross", "multi", "diverse"]
|
|
895
|
+
|
|
896
|
+
for indicator in domain_indicators:
|
|
897
|
+
if indicator in benchmark_lower:
|
|
898
|
+
score += 0.5
|
|
899
|
+
|
|
900
|
+
return max(0.0, score) # Ensure non-negative score
|
|
901
|
+
|
|
902
|
+
def _load_benchmark_data(self, benchmarks: List[str], num_samples: int) -> List[Dict[str, Any]]:
|
|
903
|
+
"""Load training data from multiple relevant benchmarks."""
|
|
904
|
+
from .tasks import TaskManager
|
|
905
|
+
|
|
906
|
+
training_data = []
|
|
907
|
+
samples_per_benchmark = max(1, num_samples // len(benchmarks))
|
|
908
|
+
|
|
909
|
+
# Create task manager instance
|
|
910
|
+
task_manager = TaskManager()
|
|
911
|
+
|
|
912
|
+
for benchmark in benchmarks:
|
|
913
|
+
try:
|
|
914
|
+
print(f" 🔄 Loading from {benchmark}...")
|
|
915
|
+
|
|
916
|
+
# Load benchmark task using TaskManager
|
|
917
|
+
task_data = task_manager.load_task(benchmark, limit=samples_per_benchmark * 3)
|
|
918
|
+
docs = task_manager.split_task_data(task_data, split_ratio=1.0)[0]
|
|
919
|
+
|
|
920
|
+
# Extract QA pairs using existing system
|
|
921
|
+
from ...contrastive_pairs.contrastive_pair_set import ContrastivePairSet
|
|
922
|
+
|
|
923
|
+
qa_pairs = ContrastivePairSet.extract_qa_pairs_from_task_docs(benchmark, task_data, docs)
|
|
924
|
+
|
|
925
|
+
# Convert to training format
|
|
926
|
+
for pair in qa_pairs[:samples_per_benchmark]:
|
|
927
|
+
if self._is_valid_pair(pair):
|
|
928
|
+
training_data.append(
|
|
929
|
+
{
|
|
930
|
+
"prompt": pair.get("question", f"Context from {benchmark}"),
|
|
931
|
+
"harmful_response": pair.get("incorrect_answer", ""),
|
|
932
|
+
"harmless_response": pair.get("correct_answer", ""),
|
|
933
|
+
"source": benchmark,
|
|
934
|
+
}
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
print(f" ✅ Loaded {len(qa_pairs[:samples_per_benchmark])} examples from {benchmark}")
|
|
938
|
+
|
|
939
|
+
# Stop if we have enough data
|
|
940
|
+
if len(training_data) >= num_samples:
|
|
941
|
+
break
|
|
942
|
+
|
|
943
|
+
except Exception as e:
|
|
944
|
+
print(f" ⚠️ Failed to load {benchmark}: {e}")
|
|
945
|
+
continue
|
|
946
|
+
|
|
947
|
+
if not training_data:
|
|
948
|
+
print(" ❌ Failed to load from any benchmarks, falling back to synthetic...")
|
|
949
|
+
return self._generate_synthetic_training_data("unknown", num_samples)
|
|
950
|
+
|
|
951
|
+
print(
|
|
952
|
+
f" ✅ Loaded {len(training_data)} examples from {len([b for b in benchmarks if any(b in item.get('source', '') for item in training_data)])} benchmarks"
|
|
953
|
+
)
|
|
954
|
+
return training_data[:num_samples]
|
|
955
|
+
|
|
956
|
+
def _is_valid_pair(self, pair: Dict[str, Any]) -> bool:
|
|
957
|
+
"""Check if a QA pair is valid for training."""
|
|
958
|
+
return (
|
|
959
|
+
pair.get("correct_answer")
|
|
960
|
+
and pair.get("incorrect_answer")
|
|
961
|
+
and len(pair.get("correct_answer", "")) > 5
|
|
962
|
+
and len(pair.get("incorrect_answer", "")) > 5
|
|
963
|
+
and pair.get("correct_answer") != pair.get("incorrect_answer")
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
def _generate_synthetic_training_data(self, issue_type: str, num_samples: int) -> List[Dict[str, Any]]:
|
|
967
|
+
"""Generate synthetic training data for issue types."""
|
|
968
|
+
from ...contrastive_pairs.generate_synthetically import SyntheticContrastivePairGenerator
|
|
969
|
+
|
|
970
|
+
try:
|
|
971
|
+
generator = SyntheticContrastivePairGenerator(self.model)
|
|
972
|
+
|
|
973
|
+
# Create trait description based on issue type
|
|
974
|
+
trait_descriptions = {
|
|
975
|
+
"quality": "providing clear, accurate, and well-structured responses",
|
|
976
|
+
"harmful": "giving safe, helpful, and constructive advice",
|
|
977
|
+
"bias": "responding with fairness and avoiding stereotypes",
|
|
978
|
+
"coherence": "maintaining logical flow and clear communication",
|
|
979
|
+
"hallucination": "providing factually accurate information",
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
trait_description = trait_descriptions.get(issue_type, f"avoiding {issue_type} issues in responses")
|
|
983
|
+
|
|
984
|
+
# Generate synthetic pairs
|
|
985
|
+
synthetic_pairs = generator.generate_contrastive_pair_set(
|
|
986
|
+
trait_description=trait_description, num_pairs=num_samples, name=f"synthetic_{issue_type}"
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
# Convert to training format
|
|
990
|
+
training_data = []
|
|
991
|
+
for pair in synthetic_pairs.pairs[:num_samples]:
|
|
992
|
+
training_data.append(
|
|
993
|
+
{
|
|
994
|
+
"prompt": pair.prompt or f"Context for {issue_type} detection",
|
|
995
|
+
"harmful_response": pair.negative_response,
|
|
996
|
+
"harmless_response": pair.positive_response,
|
|
997
|
+
}
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
print(f" ✅ Generated {len(training_data)} synthetic examples for {issue_type}")
|
|
1001
|
+
return training_data
|
|
1002
|
+
|
|
1003
|
+
except Exception as e:
|
|
1004
|
+
print(f" ❌ Failed to generate synthetic data: {e}")
|
|
1005
|
+
raise ValueError(f"Cannot generate training data for issue type: {issue_type}")
|
|
1006
|
+
|
|
1007
|
+
def _extract_activations_from_data(
|
|
1008
|
+
self, training_data: List[Dict[str, Any]], layer: int
|
|
1009
|
+
) -> Tuple[List[Activations], List[Activations]]:
|
|
1010
|
+
"""
|
|
1011
|
+
Extract activations from training data.
|
|
1012
|
+
|
|
1013
|
+
Args:
|
|
1014
|
+
training_data: List of training examples
|
|
1015
|
+
layer: Layer to extract activations from
|
|
1016
|
+
|
|
1017
|
+
Returns:
|
|
1018
|
+
Tuple of (harmful_activations, harmless_activations)
|
|
1019
|
+
"""
|
|
1020
|
+
harmful_activations = []
|
|
1021
|
+
harmless_activations = []
|
|
1022
|
+
|
|
1023
|
+
layer_obj = Layer(index=layer, type="transformer")
|
|
1024
|
+
|
|
1025
|
+
for example in training_data:
|
|
1026
|
+
# Extract harmful activation
|
|
1027
|
+
harmful_tensor = self.model.extract_activations(example["harmful_response"], layer_obj)
|
|
1028
|
+
harmful_activation = Activations(tensor=harmful_tensor, layer=layer_obj)
|
|
1029
|
+
harmful_activations.append(harmful_activation)
|
|
1030
|
+
|
|
1031
|
+
# Extract harmless activation
|
|
1032
|
+
harmless_tensor = self.model.extract_activations(example["harmless_response"], layer_obj)
|
|
1033
|
+
harmless_activation = Activations(tensor=harmless_tensor, layer=layer_obj)
|
|
1034
|
+
harmless_activations.append(harmless_activation)
|
|
1035
|
+
|
|
1036
|
+
return harmful_activations, harmless_activations
|
|
1037
|
+
|
|
1038
|
+
def _train_classifier(
|
|
1039
|
+
self, harmful_activations: List[Activations], harmless_activations: List[Activations], config: TrainingConfig
|
|
1040
|
+
) -> ActivationClassifier:
|
|
1041
|
+
"""
|
|
1042
|
+
Train a classifier on the activation data.
|
|
1043
|
+
|
|
1044
|
+
Args:
|
|
1045
|
+
harmful_activations: List of harmful activations
|
|
1046
|
+
harmless_activations: List of harmless activations
|
|
1047
|
+
config: Training configuration
|
|
1048
|
+
|
|
1049
|
+
Returns:
|
|
1050
|
+
Trained ActivationClassifier
|
|
1051
|
+
"""
|
|
1052
|
+
classifier = ActivationClassifier(
|
|
1053
|
+
model_type=config.classifier_type, threshold=config.threshold, device=self.model.device
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
classifier.train_on_activations(harmful_activations, harmless_activations)
|
|
1057
|
+
|
|
1058
|
+
return classifier
|
|
1059
|
+
|
|
1060
|
+
def _evaluate_classifier(
|
|
1061
|
+
self,
|
|
1062
|
+
classifier: ActivationClassifier,
|
|
1063
|
+
harmful_activations: List[Activations],
|
|
1064
|
+
harmless_activations: List[Activations],
|
|
1065
|
+
) -> Dict[str, float]:
|
|
1066
|
+
"""
|
|
1067
|
+
Evaluate classifier performance.
|
|
1068
|
+
|
|
1069
|
+
Args:
|
|
1070
|
+
classifier: Trained classifier
|
|
1071
|
+
harmful_activations: Test harmful activations
|
|
1072
|
+
harmless_activations: Test harmless activations
|
|
1073
|
+
|
|
1074
|
+
Returns:
|
|
1075
|
+
Dictionary of performance metrics
|
|
1076
|
+
"""
|
|
1077
|
+
# Use a portion of data for testing
|
|
1078
|
+
test_size = min(10, len(harmful_activations) // 5) # 20% or at least 10
|
|
1079
|
+
|
|
1080
|
+
test_harmful = harmful_activations[-test_size:]
|
|
1081
|
+
test_harmless = harmless_activations[-test_size:]
|
|
1082
|
+
|
|
1083
|
+
return classifier.evaluate_on_activations(test_harmful, test_harmless)
|
|
1084
|
+
|
|
1085
|
+
def _save_classifier(
|
|
1086
|
+
self, classifier: ActivationClassifier, config: TrainingConfig, metrics: Dict[str, float]
|
|
1087
|
+
) -> str:
|
|
1088
|
+
"""
|
|
1089
|
+
Save classifier with metadata.
|
|
1090
|
+
|
|
1091
|
+
Args:
|
|
1092
|
+
classifier: Trained classifier
|
|
1093
|
+
config: Training configuration
|
|
1094
|
+
metrics: Performance metrics
|
|
1095
|
+
|
|
1096
|
+
Returns:
|
|
1097
|
+
Path where classifier was saved
|
|
1098
|
+
"""
|
|
1099
|
+
# Create metadata
|
|
1100
|
+
metadata = create_classifier_metadata(
|
|
1101
|
+
model_name=config.model_name,
|
|
1102
|
+
task_name=config.issue_type,
|
|
1103
|
+
layer=config.layer,
|
|
1104
|
+
classifier_type=config.classifier_type,
|
|
1105
|
+
training_accuracy=metrics.get("accuracy", 0.0),
|
|
1106
|
+
training_samples=config.training_samples,
|
|
1107
|
+
token_aggregation="final", # Default for our system
|
|
1108
|
+
detection_threshold=config.threshold,
|
|
1109
|
+
f1=metrics.get("f1", 0.0),
|
|
1110
|
+
precision=metrics.get("precision", 0.0),
|
|
1111
|
+
recall=metrics.get("recall", 0.0),
|
|
1112
|
+
auc=metrics.get("auc", 0.0),
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
# Save using ModelPersistence
|
|
1116
|
+
save_path = ModelPersistence.save_classifier(classifier.classifier, config.layer, config.save_path, metadata)
|
|
1117
|
+
|
|
1118
|
+
return save_path
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
def create_classifier_on_demand(
|
|
1122
|
+
model: Model, issue_type: str, layer: int = None, save_path: str = None, optimize: bool = False
|
|
1123
|
+
) -> TrainingResult:
|
|
1124
|
+
"""
|
|
1125
|
+
Convenience function to create a classifier on demand.
|
|
1126
|
+
|
|
1127
|
+
Args:
|
|
1128
|
+
model: Language model to use
|
|
1129
|
+
issue_type: Type of issue to detect
|
|
1130
|
+
layer: Specific layer to use (auto-optimized if None)
|
|
1131
|
+
save_path: Path to save the classifier
|
|
1132
|
+
optimize: Whether to optimize for best performance
|
|
1133
|
+
|
|
1134
|
+
Returns:
|
|
1135
|
+
TrainingResult with the created classifier
|
|
1136
|
+
"""
|
|
1137
|
+
creator = ClassifierCreator(model)
|
|
1138
|
+
|
|
1139
|
+
if optimize or layer is None:
|
|
1140
|
+
# Optimize to find best configuration
|
|
1141
|
+
result = creator.optimize_classifier_for_performance(issue_type)
|
|
1142
|
+
|
|
1143
|
+
# Save if path provided
|
|
1144
|
+
if save_path:
|
|
1145
|
+
result.config.save_path = save_path
|
|
1146
|
+
result.save_path = creator._save_classifier(
|
|
1147
|
+
ActivationClassifier(device=model.device), result.config, result.performance_metrics
|
|
1148
|
+
)
|
|
1149
|
+
|
|
1150
|
+
return result
|
|
1151
|
+
# Use specified layer
|
|
1152
|
+
config = TrainingConfig(issue_type=issue_type, layer=layer, save_path=save_path, model_name=model.name)
|
|
1153
|
+
|
|
1154
|
+
return creator.create_classifier_for_issue_type(issue_type, layer, config)
|