wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model persistence utilities for saving and loading trained classifiers and steering vectors.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import pickle
|
|
7
|
+
import json
|
|
8
|
+
from typing import Dict, Any, List, Optional
|
|
9
|
+
import torch
|
|
10
|
+
import numpy as np
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelPersistence:
|
|
15
|
+
"""Utilities for saving and loading trained models."""
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def save_classifier(classifier, layer: int, save_path: str, metadata: Dict[str, Any] = None) -> str:
|
|
19
|
+
"""
|
|
20
|
+
Save a trained classifier to disk.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
classifier: Trained classifier object
|
|
24
|
+
layer: Layer index this classifier was trained for
|
|
25
|
+
save_path: Base path for saving (will add layer suffix)
|
|
26
|
+
metadata: Additional metadata to save with the classifier
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Actual path where the classifier was saved
|
|
30
|
+
"""
|
|
31
|
+
# Create directory if it doesn't exist
|
|
32
|
+
save_dir = os.path.dirname(save_path)
|
|
33
|
+
if save_dir:
|
|
34
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
35
|
+
|
|
36
|
+
# Split path and sanitize only the filename part
|
|
37
|
+
directory = os.path.dirname(save_path)
|
|
38
|
+
filename = os.path.basename(save_path)
|
|
39
|
+
# Sanitize filename to handle periods in model names
|
|
40
|
+
safe_filename = filename.replace('.', '_')
|
|
41
|
+
safe_path = os.path.join(directory, safe_filename)
|
|
42
|
+
|
|
43
|
+
# Add layer suffix to filename
|
|
44
|
+
base, ext = os.path.splitext(safe_path)
|
|
45
|
+
classifier_path = f"{base}_layer_{layer}{ext or '.pkl'}"
|
|
46
|
+
|
|
47
|
+
# Prepare data to save
|
|
48
|
+
save_data = {
|
|
49
|
+
'classifier': classifier,
|
|
50
|
+
'layer': layer,
|
|
51
|
+
'metadata': metadata or {}
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Save classifier
|
|
55
|
+
with open(classifier_path, 'wb') as f:
|
|
56
|
+
pickle.dump(save_data, f)
|
|
57
|
+
|
|
58
|
+
return classifier_path
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def load_classifier(load_path: str, layer: int) -> tuple:
|
|
62
|
+
"""
|
|
63
|
+
Load a trained classifier from disk.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
load_path: Base path for loading (will add layer suffix)
|
|
67
|
+
layer: Layer index to load classifier for
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple of (classifier, metadata)
|
|
71
|
+
"""
|
|
72
|
+
# Split path and sanitize only the filename part to match save format
|
|
73
|
+
directory = os.path.dirname(load_path)
|
|
74
|
+
filename = os.path.basename(load_path)
|
|
75
|
+
safe_filename = filename.replace('.', '_')
|
|
76
|
+
safe_path = os.path.join(directory, safe_filename)
|
|
77
|
+
|
|
78
|
+
# Add layer suffix to filename
|
|
79
|
+
base, ext = os.path.splitext(safe_path)
|
|
80
|
+
classifier_path = f"{base}_layer_{layer}{ext or '.pkl'}"
|
|
81
|
+
|
|
82
|
+
if not os.path.exists(classifier_path):
|
|
83
|
+
raise FileNotFoundError(f"Classifier file not found: {classifier_path}")
|
|
84
|
+
|
|
85
|
+
# Load classifier
|
|
86
|
+
with open(classifier_path, 'rb') as f:
|
|
87
|
+
save_data = pickle.load(f)
|
|
88
|
+
|
|
89
|
+
return save_data['classifier'], save_data.get('metadata', {})
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def save_multi_layer_classifiers(classifiers: Dict[int, Any], save_path: str, metadata: Dict[str, Any] = None) -> List[str]:
|
|
93
|
+
"""
|
|
94
|
+
Save multiple classifiers for different layers.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
classifiers: Dictionary mapping layer indices to trained classifiers
|
|
98
|
+
save_path: Base path for saving
|
|
99
|
+
metadata: Additional metadata to save with classifiers
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of actual paths where classifiers were saved
|
|
103
|
+
"""
|
|
104
|
+
saved_paths = []
|
|
105
|
+
for layer, classifier in classifiers.items():
|
|
106
|
+
path = ModelPersistence.save_classifier(classifier, layer, save_path, metadata)
|
|
107
|
+
saved_paths.append(path)
|
|
108
|
+
return saved_paths
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def load_multi_layer_classifiers(load_path: str, layers: List[int]) -> Dict[int, tuple]:
|
|
112
|
+
"""
|
|
113
|
+
Load multiple classifiers for different layers.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
load_path: Base path for loading
|
|
117
|
+
layers: List of layer indices to load classifiers for
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Dictionary mapping layer indices to (classifier, metadata) tuples
|
|
121
|
+
"""
|
|
122
|
+
classifiers = {}
|
|
123
|
+
for layer in layers:
|
|
124
|
+
try:
|
|
125
|
+
classifier, metadata = ModelPersistence.load_classifier(load_path, layer)
|
|
126
|
+
classifiers[layer] = (classifier, metadata)
|
|
127
|
+
except FileNotFoundError:
|
|
128
|
+
print(f"⚠️ Warning: Classifier for layer {layer} not found at {load_path}")
|
|
129
|
+
continue
|
|
130
|
+
return classifiers
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def save_steering_vector(vector: torch.Tensor, layer: int, save_path: str, metadata: Dict[str, Any] = None) -> str:
|
|
134
|
+
"""
|
|
135
|
+
Save a steering vector to disk.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
vector: Steering vector tensor
|
|
139
|
+
layer: Layer index this vector was computed for
|
|
140
|
+
save_path: Base path for saving (will add layer suffix)
|
|
141
|
+
metadata: Additional metadata to save with the vector
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Actual path where the vector was saved
|
|
145
|
+
"""
|
|
146
|
+
# Create directory if it doesn't exist
|
|
147
|
+
save_dir = os.path.dirname(save_path)
|
|
148
|
+
if save_dir:
|
|
149
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
150
|
+
|
|
151
|
+
# Add layer suffix to filename
|
|
152
|
+
base, ext = os.path.splitext(save_path)
|
|
153
|
+
vector_path = f"{base}_layer_{layer}{ext or '.pt'}"
|
|
154
|
+
|
|
155
|
+
# Prepare data to save
|
|
156
|
+
save_data = {
|
|
157
|
+
'vector': vector.cpu() if isinstance(vector, torch.Tensor) else vector,
|
|
158
|
+
'layer': layer,
|
|
159
|
+
'metadata': metadata or {}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# Save vector
|
|
163
|
+
torch.save(save_data, vector_path)
|
|
164
|
+
|
|
165
|
+
return vector_path
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def load_steering_vector(load_path: str, layer: int, device: str = None) -> tuple:
|
|
169
|
+
"""
|
|
170
|
+
Load a steering vector from disk.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
load_path: Base path for loading (will add layer suffix)
|
|
174
|
+
layer: Layer index to load vector for
|
|
175
|
+
device: Device to load tensor to
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Tuple of (vector, metadata)
|
|
179
|
+
"""
|
|
180
|
+
# Add layer suffix to filename
|
|
181
|
+
base, ext = os.path.splitext(load_path)
|
|
182
|
+
vector_path = f"{base}_layer_{layer}{ext or '.pt'}"
|
|
183
|
+
|
|
184
|
+
if not os.path.exists(vector_path):
|
|
185
|
+
raise FileNotFoundError(f"Steering vector file not found: {vector_path}")
|
|
186
|
+
|
|
187
|
+
# Load vector
|
|
188
|
+
save_data = torch.load(vector_path, map_location=device)
|
|
189
|
+
|
|
190
|
+
return save_data['vector'], save_data.get('metadata', {})
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def list_available_models(model_dir: str, model_type: str = "classifier") -> Dict[str, List[int]]:
|
|
194
|
+
"""
|
|
195
|
+
List available saved models in a directory.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
model_dir: Directory to search
|
|
199
|
+
model_type: Type of model ("classifier" or "steering_vector")
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Dictionary mapping base model names to lists of available layers
|
|
203
|
+
"""
|
|
204
|
+
if not os.path.exists(model_dir):
|
|
205
|
+
return {}
|
|
206
|
+
|
|
207
|
+
extension = ".pkl" if model_type == "classifier" else ".pt"
|
|
208
|
+
models = {}
|
|
209
|
+
|
|
210
|
+
for filename in os.listdir(model_dir):
|
|
211
|
+
if filename.endswith(extension) and f"_layer_" in filename:
|
|
212
|
+
# Extract base name and layer
|
|
213
|
+
parts = filename.replace(extension, "").split("_layer_")
|
|
214
|
+
if len(parts) == 2:
|
|
215
|
+
base_name = parts[0]
|
|
216
|
+
try:
|
|
217
|
+
layer = int(parts[1])
|
|
218
|
+
if base_name not in models:
|
|
219
|
+
models[base_name] = []
|
|
220
|
+
models[base_name].append(layer)
|
|
221
|
+
except ValueError:
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
# Sort layers for each model
|
|
225
|
+
for base_name in models:
|
|
226
|
+
models[base_name].sort()
|
|
227
|
+
|
|
228
|
+
return models
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def create_classifier_metadata(
|
|
232
|
+
model_name: str,
|
|
233
|
+
task_name: str,
|
|
234
|
+
layer: int,
|
|
235
|
+
classifier_type: str,
|
|
236
|
+
training_accuracy: float,
|
|
237
|
+
training_samples: int,
|
|
238
|
+
token_aggregation: str,
|
|
239
|
+
detection_threshold: float,
|
|
240
|
+
**kwargs
|
|
241
|
+
) -> Dict[str, Any]:
|
|
242
|
+
"""
|
|
243
|
+
Create standardized metadata for a trained classifier.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
model_name: Name of the language model
|
|
247
|
+
task_name: Name of the training task
|
|
248
|
+
layer: Layer index
|
|
249
|
+
classifier_type: Type of classifier (logistic, mlp, etc.)
|
|
250
|
+
training_accuracy: Accuracy achieved during training
|
|
251
|
+
training_samples: Number of training samples used
|
|
252
|
+
token_aggregation: Token aggregation method used
|
|
253
|
+
detection_threshold: Classification threshold used
|
|
254
|
+
**kwargs: Additional metadata fields
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Metadata dictionary
|
|
258
|
+
"""
|
|
259
|
+
import datetime
|
|
260
|
+
|
|
261
|
+
metadata = {
|
|
262
|
+
'model_name': model_name,
|
|
263
|
+
'task_name': task_name,
|
|
264
|
+
'layer': layer,
|
|
265
|
+
'classifier_type': classifier_type,
|
|
266
|
+
'training_accuracy': training_accuracy,
|
|
267
|
+
'training_samples': training_samples,
|
|
268
|
+
'token_aggregation': token_aggregation,
|
|
269
|
+
'detection_threshold': detection_threshold,
|
|
270
|
+
'created_at': datetime.datetime.now().isoformat(),
|
|
271
|
+
'wisent_guard_version': '1.0.0' # Could be dynamically determined
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
# Add any additional metadata
|
|
275
|
+
metadata.update(kwargs)
|
|
276
|
+
|
|
277
|
+
return metadata
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def create_steering_vector_metadata(
|
|
281
|
+
model_name: str,
|
|
282
|
+
task_name: str,
|
|
283
|
+
layer: int,
|
|
284
|
+
vector_strength: float,
|
|
285
|
+
training_samples: int,
|
|
286
|
+
**kwargs
|
|
287
|
+
) -> Dict[str, Any]:
|
|
288
|
+
"""
|
|
289
|
+
Create standardized metadata for a steering vector.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
model_name: Name of the language model
|
|
293
|
+
task_name: Name of the training task
|
|
294
|
+
layer: Layer index
|
|
295
|
+
vector_strength: Strength/magnitude of the steering vector
|
|
296
|
+
training_samples: Number of training samples used
|
|
297
|
+
**kwargs: Additional metadata fields
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
Metadata dictionary
|
|
301
|
+
"""
|
|
302
|
+
import datetime
|
|
303
|
+
|
|
304
|
+
metadata = {
|
|
305
|
+
'model_name': model_name,
|
|
306
|
+
'task_name': task_name,
|
|
307
|
+
'layer': layer,
|
|
308
|
+
'vector_strength': vector_strength,
|
|
309
|
+
'training_samples': training_samples,
|
|
310
|
+
'created_at': datetime.datetime.now().isoformat(),
|
|
311
|
+
'wisent_guard_version': '1.0.0'
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
# Add any additional metadata
|
|
315
|
+
metadata.update(kwargs)
|
|
316
|
+
|
|
317
|
+
return metadata
|
|
File without changes
|
|
File without changes
|