wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of wisent might be problematic. Click here for more details.
- wisent/__init__.py +1 -8
- wisent/benchmarks/__init__.py +0 -0
- wisent/benchmarks/coding/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
- wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
- wisent/benchmarks/coding/metrics/evaluator.py +275 -0
- wisent/benchmarks/coding/metrics/passk.py +66 -0
- wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
- wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
- wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
- wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
- wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
- wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
- wisent/benchmarks/coding/providers/__init__.py +18 -0
- wisent/benchmarks/coding/providers/core/__init__.py +0 -0
- wisent/benchmarks/coding/providers/core/atoms.py +31 -0
- wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
- wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
- wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
- wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
- wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
- wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
- wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
- wisent/classifiers/__init__.py +0 -0
- wisent/classifiers/core/__init__.py +0 -0
- wisent/classifiers/core/atoms.py +747 -0
- wisent/classifiers/models/__init__.py +0 -0
- wisent/classifiers/models/logistic.py +29 -0
- wisent/classifiers/models/mlp.py +47 -0
- wisent/cli/__init__.py +0 -0
- wisent/cli/classifiers/__init__.py +0 -0
- wisent/cli/classifiers/classifier_rotator.py +137 -0
- wisent/cli/cli_logger.py +142 -0
- wisent/cli/data_loaders/__init__.py +0 -0
- wisent/cli/data_loaders/data_loader_rotator.py +96 -0
- wisent/cli/evaluators/__init__.py +0 -0
- wisent/cli/evaluators/evaluator_rotator.py +148 -0
- wisent/cli/steering_methods/__init__.py +0 -0
- wisent/cli/steering_methods/steering_rotator.py +110 -0
- wisent/cli/wisent_cli/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/__init__.py +0 -0
- wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
- wisent/cli/wisent_cli/commands/listing.py +154 -0
- wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
- wisent/cli/wisent_cli/main.py +93 -0
- wisent/cli/wisent_cli/shell.py +80 -0
- wisent/cli/wisent_cli/ui.py +69 -0
- wisent/cli/wisent_cli/util/__init__.py +0 -0
- wisent/cli/wisent_cli/util/aggregations.py +43 -0
- wisent/cli/wisent_cli/util/parsing.py +126 -0
- wisent/cli/wisent_cli/version.py +4 -0
- wisent/core/__init__.py +27 -0
- wisent/core/activations/__init__.py +0 -0
- wisent/core/activations/activations_collector.py +338 -0
- wisent/core/activations/core/__init__.py +0 -0
- wisent/core/activations/core/atoms.py +216 -0
- wisent/core/agent/__init__.py +18 -0
- wisent/core/agent/budget.py +638 -0
- wisent/core/agent/device_benchmarks.py +685 -0
- wisent/core/agent/diagnose/__init__.py +55 -0
- wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
- wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
- wisent/core/agent/diagnose/create_classifier.py +1154 -0
- wisent/core/agent/diagnose/response_diagnostics.py +268 -0
- wisent/core/agent/diagnose/select_classifiers.py +506 -0
- wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
- wisent/core/agent/diagnose/tasks/__init__.py +33 -0
- wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
- wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
- wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
- wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
- wisent/core/agent/diagnose.py +242 -0
- wisent/core/agent/steer.py +212 -0
- wisent/core/agent/timeout.py +134 -0
- wisent/core/autonomous_agent.py +1234 -0
- wisent/core/bigcode_integration.py +583 -0
- wisent/core/contrastive_pairs/__init__.py +15 -0
- wisent/core/contrastive_pairs/core/__init__.py +0 -0
- wisent/core/contrastive_pairs/core/atoms.py +45 -0
- wisent/core/contrastive_pairs/core/buliders.py +59 -0
- wisent/core/contrastive_pairs/core/pair.py +178 -0
- wisent/core/contrastive_pairs/core/response.py +152 -0
- wisent/core/contrastive_pairs/core/serialization.py +300 -0
- wisent/core/contrastive_pairs/core/set.py +133 -0
- wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
- wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
- wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
- wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
- wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
- wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
- wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
- wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
- wisent/core/data_loaders/__init__.py +0 -0
- wisent/core/data_loaders/core/__init__.py +0 -0
- wisent/core/data_loaders/core/atoms.py +98 -0
- wisent/core/data_loaders/loaders/__init__.py +0 -0
- wisent/core/data_loaders/loaders/custom.py +120 -0
- wisent/core/data_loaders/loaders/lm_loader.py +218 -0
- wisent/core/detection_handling.py +257 -0
- wisent/core/download_full_benchmarks.py +1386 -0
- wisent/core/evaluators/__init__.py +0 -0
- wisent/core/evaluators/oracles/__init__.py +0 -0
- wisent/core/evaluators/oracles/interactive.py +73 -0
- wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
- wisent/core/evaluators/oracles/user_specified.py +67 -0
- wisent/core/hyperparameter_optimizer.py +429 -0
- wisent/core/lm_eval_harness_ground_truth.py +1396 -0
- wisent/core/log_likelihoods_evaluator.py +321 -0
- wisent/core/managed_cached_benchmarks.py +595 -0
- wisent/core/mixed_benchmark_sampler.py +364 -0
- wisent/core/model_config_manager.py +330 -0
- wisent/core/model_persistence.py +317 -0
- wisent/core/models/__init__.py +0 -0
- wisent/core/models/core/__init__.py +0 -0
- wisent/core/models/core/atoms.py +460 -0
- wisent/core/models/wisent_model.py +727 -0
- wisent/core/multi_steering.py +316 -0
- wisent/core/optuna/__init__.py +57 -0
- wisent/core/optuna/classifier/__init__.py +25 -0
- wisent/core/optuna/classifier/activation_generator.py +349 -0
- wisent/core/optuna/classifier/classifier_cache.py +509 -0
- wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
- wisent/core/optuna/steering/__init__.py +0 -0
- wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
- wisent/core/optuna/steering/data_utils.py +342 -0
- wisent/core/optuna/steering/metrics.py +474 -0
- wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
- wisent/core/optuna/steering/steering_optimization.py +1111 -0
- wisent/core/parser.py +1668 -0
- wisent/core/prompts/__init__.py +0 -0
- wisent/core/prompts/core/__init__.py +0 -0
- wisent/core/prompts/core/atom.py +57 -0
- wisent/core/prompts/core/prompt_formater.py +157 -0
- wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
- wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
- wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
- wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
- wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
- wisent/core/representation.py +5 -0
- wisent/core/sample_size_optimizer.py +648 -0
- wisent/core/sample_size_optimizer_v2.py +355 -0
- wisent/core/save_results.py +277 -0
- wisent/core/steering.py +652 -0
- wisent/core/steering_method.py +26 -0
- wisent/core/steering_methods/__init__.py +0 -0
- wisent/core/steering_methods/core/__init__.py +0 -0
- wisent/core/steering_methods/core/atoms.py +153 -0
- wisent/core/steering_methods/methods/__init__.py +0 -0
- wisent/core/steering_methods/methods/caa.py +44 -0
- wisent/core/steering_optimizer.py +1297 -0
- wisent/core/task_interface.py +132 -0
- wisent/core/task_selector.py +189 -0
- wisent/core/tasks/__init__.py +175 -0
- wisent/core/tasks/aime_task.py +141 -0
- wisent/core/tasks/file_task.py +211 -0
- wisent/core/tasks/hle_task.py +180 -0
- wisent/core/tasks/hmmt_task.py +119 -0
- wisent/core/tasks/livecodebench_task.py +201 -0
- wisent/core/tasks/livemathbench_task.py +158 -0
- wisent/core/tasks/lm_eval_task.py +455 -0
- wisent/core/tasks/math500_task.py +84 -0
- wisent/core/tasks/polymath_task.py +146 -0
- wisent/core/tasks/supergpqa_task.py +220 -0
- wisent/core/time_estimator.py +149 -0
- wisent/core/timing_calibration.py +174 -0
- wisent/core/tracking/__init__.py +54 -0
- wisent/core/tracking/latency.py +618 -0
- wisent/core/tracking/memory.py +359 -0
- wisent/core/trainers/__init__.py +0 -0
- wisent/core/trainers/core/__init__.py +11 -0
- wisent/core/trainers/core/atoms.py +45 -0
- wisent/core/trainers/steering_trainer.py +271 -0
- wisent/core/user_model_config.py +158 -0
- wisent/opti/__init__.py +0 -0
- wisent/opti/core/__init__.py +0 -0
- wisent/opti/core/atoms.py +175 -0
- wisent/opti/methods/__init__.py +0 -0
- wisent/opti/methods/opti_classificator.py +172 -0
- wisent/opti/methods/opti_steering.py +138 -0
- wisent/synthetic/__init__.py +0 -0
- wisent/synthetic/cleaners/__init__.py +0 -0
- wisent/synthetic/cleaners/core/__init__.py +0 -0
- wisent/synthetic/cleaners/core/atoms.py +58 -0
- wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
- wisent/synthetic/cleaners/methods/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
- wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
- wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
- wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
- wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
- wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
- wisent/synthetic/db_instructions/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/__init__.py +0 -0
- wisent/synthetic/db_instructions/core/atoms.py +25 -0
- wisent/synthetic/db_instructions/mini_dp.py +37 -0
- wisent/synthetic/generators/__init__.py +0 -0
- wisent/synthetic/generators/core/__init__.py +0 -0
- wisent/synthetic/generators/core/atoms.py +73 -0
- wisent/synthetic/generators/diversities/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/__init__.py +0 -0
- wisent/synthetic/generators/diversities/core/core.py +68 -0
- wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
- wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
- wisent/synthetic/generators/pairs_generator.py +179 -0
- wisent-0.5.2.dist-info/METADATA +67 -0
- wisent-0.5.2.dist-info/RECORD +218 -0
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
- wisent/activations/__init__.py +0 -9
- wisent/activations/client.py +0 -97
- wisent/activations/extractor.py +0 -251
- wisent/activations/models.py +0 -95
- wisent/client.py +0 -45
- wisent/control_vector/__init__.py +0 -9
- wisent/control_vector/client.py +0 -85
- wisent/control_vector/manager.py +0 -168
- wisent/control_vector/models.py +0 -70
- wisent/inference/__init__.py +0 -9
- wisent/inference/client.py +0 -103
- wisent/inference/inferencer.py +0 -250
- wisent/inference/models.py +0 -66
- wisent/utils/__init__.py +0 -3
- wisent/utils/auth.py +0 -30
- wisent/utils/http.py +0 -228
- wisent/version.py +0 -3
- wisent-0.1.1.dist-info/METADATA +0 -142
- wisent-0.1.1.dist-info/RECORD +0 -23
- {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory usage tracking for wisent-guard operations.
|
|
3
|
+
|
|
4
|
+
This module provides comprehensive memory monitoring capabilities including
|
|
5
|
+
GPU and CPU memory tracking, peak usage detection, and memory profiling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import gc
|
|
9
|
+
import psutil
|
|
10
|
+
import time
|
|
11
|
+
import threading
|
|
12
|
+
from typing import Dict, List, Optional, Any, Callable
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from contextlib import contextmanager
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from wisent.core.utils.device import resolve_default_device
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import nvidia_ml_py3 as nvml
|
|
21
|
+
NVML_AVAILABLE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
NVML_AVAILABLE = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class MemorySnapshot:
|
|
28
|
+
"""Snapshot of memory usage at a specific point in time."""
|
|
29
|
+
timestamp: float
|
|
30
|
+
cpu_memory_mb: float
|
|
31
|
+
cpu_memory_percent: float
|
|
32
|
+
gpu_memory_mb: Optional[float] = None
|
|
33
|
+
gpu_memory_percent: Optional[float] = None
|
|
34
|
+
allocated_tensors: Optional[int] = None
|
|
35
|
+
cached_memory_mb: Optional[float] = None
|
|
36
|
+
operation: Optional[str] = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class MemoryStats:
|
|
41
|
+
"""Aggregated memory statistics over a period."""
|
|
42
|
+
peak_cpu_mb: float
|
|
43
|
+
peak_gpu_mb: Optional[float]
|
|
44
|
+
avg_cpu_mb: float
|
|
45
|
+
avg_gpu_mb: Optional[float]
|
|
46
|
+
min_cpu_mb: float
|
|
47
|
+
min_gpu_mb: Optional[float]
|
|
48
|
+
duration_seconds: float
|
|
49
|
+
snapshots: List[MemorySnapshot] = field(default_factory=list)
|
|
50
|
+
operations: List[str] = field(default_factory=list)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MemoryTracker:
|
|
54
|
+
"""
|
|
55
|
+
Comprehensive memory usage tracker for wisent-guard operations.
|
|
56
|
+
|
|
57
|
+
Tracks both CPU and GPU memory usage with optional continuous monitoring.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
track_gpu: bool = True,
|
|
63
|
+
sampling_interval: float = 0.1,
|
|
64
|
+
auto_cleanup: bool = True
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Initialize memory tracker.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
track_gpu: Whether to track GPU memory (requires CUDA)
|
|
71
|
+
sampling_interval: How often to sample memory (seconds)
|
|
72
|
+
auto_cleanup: Whether to automatically run garbage collection
|
|
73
|
+
"""
|
|
74
|
+
self.device_kind = resolve_default_device()
|
|
75
|
+
self.track_gpu = track_gpu and self.device_kind in {"cuda", "mps"}
|
|
76
|
+
self.sampling_interval = sampling_interval
|
|
77
|
+
self.auto_cleanup = auto_cleanup
|
|
78
|
+
|
|
79
|
+
self.snapshots: List[MemorySnapshot] = []
|
|
80
|
+
self.is_monitoring = False
|
|
81
|
+
self.monitor_thread: Optional[threading.Thread] = None
|
|
82
|
+
self.start_time: Optional[float] = None
|
|
83
|
+
|
|
84
|
+
# Initialize GPU monitoring if available
|
|
85
|
+
if self.track_gpu and self.device_kind == "cuda" and NVML_AVAILABLE:
|
|
86
|
+
try:
|
|
87
|
+
nvml.nvmlInit()
|
|
88
|
+
self.gpu_handle = nvml.nvmlDeviceGetHandleByIndex(0)
|
|
89
|
+
self.gpu_available = True
|
|
90
|
+
except Exception:
|
|
91
|
+
self.gpu_available = False
|
|
92
|
+
self.gpu_handle = None
|
|
93
|
+
else:
|
|
94
|
+
self.gpu_handle = None
|
|
95
|
+
self.gpu_available = False
|
|
96
|
+
|
|
97
|
+
def take_snapshot(self, operation: Optional[str] = None) -> MemorySnapshot:
|
|
98
|
+
"""Take a single memory snapshot."""
|
|
99
|
+
timestamp = time.time()
|
|
100
|
+
|
|
101
|
+
# CPU memory
|
|
102
|
+
process = psutil.Process()
|
|
103
|
+
memory_info = process.memory_info()
|
|
104
|
+
cpu_memory_mb = memory_info.rss / 1024 / 1024
|
|
105
|
+
cpu_memory_percent = process.memory_percent()
|
|
106
|
+
|
|
107
|
+
# GPU memory
|
|
108
|
+
gpu_memory_mb = None
|
|
109
|
+
gpu_memory_percent = None
|
|
110
|
+
allocated_tensors = None
|
|
111
|
+
cached_memory_mb = None
|
|
112
|
+
|
|
113
|
+
if self.track_gpu:
|
|
114
|
+
if self.device_kind == "cuda" and torch.cuda.is_available():
|
|
115
|
+
gpu_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
|
|
116
|
+
cached_memory_mb = torch.cuda.memory_reserved() / 1024 / 1024
|
|
117
|
+
allocated_tensors = len(
|
|
118
|
+
[
|
|
119
|
+
obj
|
|
120
|
+
for obj in gc.get_objects()
|
|
121
|
+
if torch.is_tensor(obj) and getattr(obj, "is_cuda", False)
|
|
122
|
+
]
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if self.gpu_available and self.gpu_handle is not None:
|
|
126
|
+
try:
|
|
127
|
+
gpu_info = nvml.nvmlDeviceGetMemoryInfo(self.gpu_handle)
|
|
128
|
+
total_gpu_mb = gpu_info.total / 1024 / 1024
|
|
129
|
+
gpu_memory_percent = (gpu_memory_mb / total_gpu_mb) * 100
|
|
130
|
+
except Exception:
|
|
131
|
+
pass
|
|
132
|
+
elif self.device_kind == "mps" and hasattr(torch, "mps"):
|
|
133
|
+
try:
|
|
134
|
+
allocated_bytes = torch.mps.current_allocated_memory()
|
|
135
|
+
except AttributeError:
|
|
136
|
+
allocated_bytes = 0
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
cached_bytes = torch.mps.driver_allocated_memory()
|
|
140
|
+
except AttributeError:
|
|
141
|
+
cached_bytes = allocated_bytes
|
|
142
|
+
|
|
143
|
+
gpu_memory_mb = allocated_bytes / 1024 / 1024
|
|
144
|
+
cached_memory_mb = cached_bytes / 1024 / 1024
|
|
145
|
+
allocated_tensors = len(
|
|
146
|
+
[
|
|
147
|
+
obj
|
|
148
|
+
for obj in gc.get_objects()
|
|
149
|
+
if torch.is_tensor(obj) and getattr(getattr(obj, "device", None), "type", None) == "mps"
|
|
150
|
+
]
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
snapshot = MemorySnapshot(
|
|
154
|
+
timestamp=timestamp,
|
|
155
|
+
cpu_memory_mb=cpu_memory_mb,
|
|
156
|
+
cpu_memory_percent=cpu_memory_percent,
|
|
157
|
+
gpu_memory_mb=gpu_memory_mb,
|
|
158
|
+
gpu_memory_percent=gpu_memory_percent,
|
|
159
|
+
allocated_tensors=allocated_tensors,
|
|
160
|
+
cached_memory_mb=cached_memory_mb,
|
|
161
|
+
operation=operation
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
self.snapshots.append(snapshot)
|
|
165
|
+
return snapshot
|
|
166
|
+
|
|
167
|
+
def start_monitoring(self) -> None:
|
|
168
|
+
"""Start continuous memory monitoring in a background thread."""
|
|
169
|
+
if self.is_monitoring:
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
self.is_monitoring = True
|
|
173
|
+
self.start_time = time.time()
|
|
174
|
+
self.snapshots.clear()
|
|
175
|
+
|
|
176
|
+
def monitor_loop():
|
|
177
|
+
while self.is_monitoring:
|
|
178
|
+
self.take_snapshot("continuous_monitoring")
|
|
179
|
+
time.sleep(self.sampling_interval)
|
|
180
|
+
|
|
181
|
+
self.monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
|
|
182
|
+
self.monitor_thread.start()
|
|
183
|
+
|
|
184
|
+
def stop_monitoring(self) -> MemoryStats:
|
|
185
|
+
"""Stop continuous monitoring and return aggregated statistics."""
|
|
186
|
+
if not self.is_monitoring:
|
|
187
|
+
raise ValueError("Monitoring is not active")
|
|
188
|
+
|
|
189
|
+
self.is_monitoring = False
|
|
190
|
+
if self.monitor_thread:
|
|
191
|
+
self.monitor_thread.join()
|
|
192
|
+
|
|
193
|
+
return self.get_stats()
|
|
194
|
+
|
|
195
|
+
def get_stats(self) -> MemoryStats:
|
|
196
|
+
"""Get aggregated memory statistics from collected snapshots."""
|
|
197
|
+
if not self.snapshots:
|
|
198
|
+
raise ValueError("No snapshots available")
|
|
199
|
+
|
|
200
|
+
cpu_values = [s.cpu_memory_mb for s in self.snapshots]
|
|
201
|
+
gpu_values = [s.gpu_memory_mb for s in self.snapshots if s.gpu_memory_mb is not None]
|
|
202
|
+
|
|
203
|
+
duration = self.snapshots[-1].timestamp - self.snapshots[0].timestamp
|
|
204
|
+
operations = list(set(s.operation for s in self.snapshots if s.operation))
|
|
205
|
+
|
|
206
|
+
return MemoryStats(
|
|
207
|
+
peak_cpu_mb=max(cpu_values),
|
|
208
|
+
peak_gpu_mb=max(gpu_values) if gpu_values else None,
|
|
209
|
+
avg_cpu_mb=sum(cpu_values) / len(cpu_values),
|
|
210
|
+
avg_gpu_mb=sum(gpu_values) / len(gpu_values) if gpu_values else None,
|
|
211
|
+
min_cpu_mb=min(cpu_values),
|
|
212
|
+
min_gpu_mb=min(gpu_values) if gpu_values else None,
|
|
213
|
+
duration_seconds=duration,
|
|
214
|
+
snapshots=self.snapshots.copy(),
|
|
215
|
+
operations=operations
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def clear_cache(self) -> None:
|
|
219
|
+
"""Clear GPU cache and run garbage collection."""
|
|
220
|
+
if self.auto_cleanup:
|
|
221
|
+
gc.collect()
|
|
222
|
+
if self.device_kind == "cuda" and torch.cuda.is_available():
|
|
223
|
+
torch.cuda.empty_cache()
|
|
224
|
+
elif self.device_kind == "mps" and hasattr(torch, "mps"):
|
|
225
|
+
try:
|
|
226
|
+
torch.mps.empty_cache()
|
|
227
|
+
except AttributeError:
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
def reset(self) -> None:
|
|
231
|
+
"""Reset the tracker, clearing all snapshots."""
|
|
232
|
+
if self.is_monitoring:
|
|
233
|
+
self.stop_monitoring()
|
|
234
|
+
self.snapshots.clear()
|
|
235
|
+
self.start_time = None
|
|
236
|
+
|
|
237
|
+
@contextmanager
|
|
238
|
+
def track_operation(self, operation_name: str):
|
|
239
|
+
"""Context manager to track memory usage during a specific operation."""
|
|
240
|
+
self.take_snapshot(f"{operation_name}_start")
|
|
241
|
+
start_time = time.time()
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
yield self
|
|
245
|
+
finally:
|
|
246
|
+
end_time = time.time()
|
|
247
|
+
self.take_snapshot(f"{operation_name}_end")
|
|
248
|
+
|
|
249
|
+
if self.auto_cleanup:
|
|
250
|
+
self.clear_cache()
|
|
251
|
+
|
|
252
|
+
def get_current_usage(self) -> Dict[str, Any]:
|
|
253
|
+
"""Get current memory usage without storing a snapshot."""
|
|
254
|
+
snapshot = self.take_snapshot("current_check")
|
|
255
|
+
self.snapshots.pop() # Remove the snapshot we just added
|
|
256
|
+
|
|
257
|
+
usage = {
|
|
258
|
+
"cpu_memory_mb": snapshot.cpu_memory_mb,
|
|
259
|
+
"cpu_memory_percent": snapshot.cpu_memory_percent,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
if snapshot.gpu_memory_mb is not None:
|
|
263
|
+
usage.update({
|
|
264
|
+
"gpu_memory_mb": snapshot.gpu_memory_mb,
|
|
265
|
+
"gpu_memory_percent": snapshot.gpu_memory_percent,
|
|
266
|
+
"allocated_tensors": snapshot.allocated_tensors,
|
|
267
|
+
"cached_memory_mb": snapshot.cached_memory_mb,
|
|
268
|
+
})
|
|
269
|
+
|
|
270
|
+
return usage
|
|
271
|
+
|
|
272
|
+
def format_stats(self, stats: MemoryStats, detailed: bool = False) -> str:
|
|
273
|
+
"""Format memory statistics as a readable string."""
|
|
274
|
+
lines = [
|
|
275
|
+
"Memory Usage Statistics:",
|
|
276
|
+
f" Duration: {stats.duration_seconds:.2f} seconds",
|
|
277
|
+
f" CPU Memory:",
|
|
278
|
+
f" Peak: {stats.peak_cpu_mb:.1f} MB",
|
|
279
|
+
f" Average: {stats.avg_cpu_mb:.1f} MB",
|
|
280
|
+
f" Minimum: {stats.min_cpu_mb:.1f} MB",
|
|
281
|
+
]
|
|
282
|
+
|
|
283
|
+
if stats.peak_gpu_mb is not None:
|
|
284
|
+
lines.extend([
|
|
285
|
+
f" GPU Memory:",
|
|
286
|
+
f" Peak: {stats.peak_gpu_mb:.1f} MB",
|
|
287
|
+
f" Average: {stats.avg_gpu_mb:.1f} MB",
|
|
288
|
+
f" Minimum: {stats.min_gpu_mb:.1f} MB",
|
|
289
|
+
])
|
|
290
|
+
|
|
291
|
+
if stats.operations:
|
|
292
|
+
lines.append(f" Operations: {', '.join(stats.operations)}")
|
|
293
|
+
|
|
294
|
+
if detailed and stats.snapshots:
|
|
295
|
+
lines.append(f" Snapshots: {len(stats.snapshots)} collected")
|
|
296
|
+
|
|
297
|
+
# Show peak usage snapshot
|
|
298
|
+
peak_snapshot = max(stats.snapshots, key=lambda s: s.cpu_memory_mb)
|
|
299
|
+
lines.extend([
|
|
300
|
+
f" Peak Usage Snapshot:",
|
|
301
|
+
f" Time: {peak_snapshot.timestamp:.2f}",
|
|
302
|
+
f" CPU: {peak_snapshot.cpu_memory_mb:.1f} MB ({peak_snapshot.cpu_memory_percent:.1f}%)",
|
|
303
|
+
])
|
|
304
|
+
|
|
305
|
+
if peak_snapshot.gpu_memory_mb is not None:
|
|
306
|
+
lines.append(f" GPU: {peak_snapshot.gpu_memory_mb:.1f} MB")
|
|
307
|
+
if peak_snapshot.allocated_tensors is not None:
|
|
308
|
+
lines.append(f" Tensors: {peak_snapshot.allocated_tensors}")
|
|
309
|
+
|
|
310
|
+
return "\n".join(lines)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# Global memory tracker instance
|
|
314
|
+
_global_tracker: Optional[MemoryTracker] = None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def get_global_tracker() -> MemoryTracker:
|
|
318
|
+
"""Get or create the global memory tracker instance."""
|
|
319
|
+
global _global_tracker
|
|
320
|
+
if _global_tracker is None:
|
|
321
|
+
_global_tracker = MemoryTracker()
|
|
322
|
+
return _global_tracker
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def track_memory(operation_name: str):
|
|
326
|
+
"""Decorator to track memory usage of a function."""
|
|
327
|
+
def decorator(func: Callable) -> Callable:
|
|
328
|
+
def wrapper(*args, **kwargs):
|
|
329
|
+
tracker = get_global_tracker()
|
|
330
|
+
with tracker.track_operation(operation_name):
|
|
331
|
+
return func(*args, **kwargs)
|
|
332
|
+
return wrapper
|
|
333
|
+
return decorator
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def get_memory_info() -> Dict[str, Any]:
|
|
337
|
+
"""Get current memory information without tracking."""
|
|
338
|
+
tracker = MemoryTracker(auto_cleanup=False)
|
|
339
|
+
return tracker.get_current_usage()
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def format_memory_usage(usage: Dict[str, Any]) -> str:
|
|
343
|
+
"""Format memory usage dictionary as a readable string."""
|
|
344
|
+
lines = [
|
|
345
|
+
f"CPU Memory: {usage['cpu_memory_mb']:.1f} MB ({usage['cpu_memory_percent']:.1f}%)"
|
|
346
|
+
]
|
|
347
|
+
|
|
348
|
+
if 'gpu_memory_mb' in usage and usage['gpu_memory_mb'] is not None:
|
|
349
|
+
lines.append(f"GPU Memory: {usage['gpu_memory_mb']:.1f} MB")
|
|
350
|
+
if 'gpu_memory_percent' in usage and usage['gpu_memory_percent'] is not None:
|
|
351
|
+
lines[-1] += f" ({usage['gpu_memory_percent']:.1f}%)"
|
|
352
|
+
|
|
353
|
+
if 'cached_memory_mb' in usage:
|
|
354
|
+
lines.append(f"GPU Cached: {usage['cached_memory_mb']:.1f} MB")
|
|
355
|
+
|
|
356
|
+
if 'allocated_tensors' in usage:
|
|
357
|
+
lines.append(f"GPU Tensors: {usage['allocated_tensors']}")
|
|
358
|
+
|
|
359
|
+
return " | ".join(lines)
|
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#1 class WisentSteeringTrainer:
|
|
2
|
+
#2 trainer should load activation collector (sse provded file).
|
|
3
|
+
#3 shoudl load contrastive pair set (see provided file)
|
|
4
|
+
#4 should decide what type of sterring trainig method choose: caa, bipo, etc.
|
|
5
|
+
#4 should be able to but from which layer we collect activation and then use for each activanis and layer steering method to obtain steering vector.
|
|
6
|
+
#5 some method uses many actviations from layers aome only one. we need to be able to specify that. like user can say use layer 10, 20, 30 or use all layers from 10 to 30.
|
|
7
|
+
#6 after training user need to obtain contrastive piars set with collected activatioons (see provded file) and steered vectors which need to be LayerActivations class.
|
|
8
|
+
#7 we should save all the trained sterred vectors, with contrastive pairs with activations, and meta data like date, model name, layers used, method used, hyperparams used etc.
|
|
9
|
+
|
|
10
|
+
# Imporatat info: we also need to sepcyfy activation collection stategy (see LayerActivations). All provded files has good descriptions/docstrings. Plse wrtire code with that in mind. create two files: atoms.py
|
|
11
|
+
# where we defied all base structure for the trainers, all abtarct calss etc. and steering_trainer.py where we implement WisentSteeringTrainer class.
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
from wisent.core.activations.core.atoms import LayerActivations
|
|
6
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"TrainingResult",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
@dataclass(slots=True)
|
|
13
|
+
class TrainingResult:
|
|
14
|
+
"""
|
|
15
|
+
Container returned by a trainer after running the full pipeline.
|
|
16
|
+
|
|
17
|
+
attributes:
|
|
18
|
+
steered_vectors:
|
|
19
|
+
Per-layer steering vectors in a LayerActivations mapping. Each value
|
|
20
|
+
is typically a 1D tensor of shape [H].
|
|
21
|
+
pair_set_with_activations:
|
|
22
|
+
The original ContrastivePairSet, but with per-pair, per-layer activations
|
|
23
|
+
collected and stored in the Positive/NegativeResponse objects.
|
|
24
|
+
metadata:
|
|
25
|
+
A JSON-serializable dictionary with run metadata
|
|
26
|
+
(date, model_name, layers, method, hyperparams, aggregation, etc.).
|
|
27
|
+
"""
|
|
28
|
+
steered_vectors: LayerActivations
|
|
29
|
+
pair_set_with_activations: ContrastivePairSet
|
|
30
|
+
metadata: Dict[str, Any]
|
|
31
|
+
|
|
32
|
+
class BaseSteeringTrainer(ABC):
|
|
33
|
+
"""
|
|
34
|
+
Abstract interface for a trainer that orchestrates:
|
|
35
|
+
1) Collecting activations for a set of contrastive pairs
|
|
36
|
+
2) Training a steering vector(s) using a chosen method
|
|
37
|
+
3) Returning a TrainingResult and (optionally) saving artifacts
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def run(self, *args: Any, **kwargs: Any) -> TrainingResult:
|
|
42
|
+
"""
|
|
43
|
+
Execute the full pipeline and return a TrainingResult.
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Sequence
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import torch
|
|
9
|
+
import datetime as _dt
|
|
10
|
+
|
|
11
|
+
from wisent.core.activations.core.atoms import (
|
|
12
|
+
LayerActivations,
|
|
13
|
+
ActivationAggregationStrategy,
|
|
14
|
+
RawActivationMap,
|
|
15
|
+
)
|
|
16
|
+
from wisent.core.models.wisent_model import WisentModel
|
|
17
|
+
|
|
18
|
+
from wisent.core.trainers.core.atoms import (
|
|
19
|
+
TrainingResult,
|
|
20
|
+
BaseSteeringTrainer
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
|
|
24
|
+
from wisent.core.activations.activations_collector import ActivationCollector
|
|
25
|
+
from wisent.core.steering_methods.core.atoms import BaseSteeringMethod
|
|
26
|
+
from wisent.core.contrastive_pairs.diagnostics import run_control_vector_diagnostics
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"WisentSteeringTrainer",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
@dataclass(slots=True)
|
|
36
|
+
class WisentSteeringTrainer(BaseSteeringTrainer):
|
|
37
|
+
"""
|
|
38
|
+
Orchestrates activation collection + steering vector training for a given model and pair set.
|
|
39
|
+
|
|
40
|
+
Minimal usage:
|
|
41
|
+
trainer = WisentSteeringTrainer(model, pair_set, steering_method)
|
|
42
|
+
result = trainer.run(layers_spec=..., method_kwargs=..., aggregation=..., ...)
|
|
43
|
+
# result is a TrainingResult with steered vectors, enriched pair set, and metadata
|
|
44
|
+
trainer.save_result(output_dir) # optional save
|
|
45
|
+
|
|
46
|
+
arguments:
|
|
47
|
+
model: WisentModel to use for activation collection.
|
|
48
|
+
pair_set: ContrastivePairSet with pairs to use for collection and training.
|
|
49
|
+
steering_method: BaseSteeringMethod instance to use for training.
|
|
50
|
+
store_device: Device to store collected activations on (default "cpu").
|
|
51
|
+
dtype: Optional torch.dtype to cast collected activations to (default None, meaning no cast).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
model: WisentModel
|
|
55
|
+
pair_set: ContrastivePairSet
|
|
56
|
+
steering_method: BaseSteeringMethod
|
|
57
|
+
store_device: str | torch.device = "cpu"
|
|
58
|
+
dtype: torch.dtype | None = None
|
|
59
|
+
|
|
60
|
+
def __post_init__(self) -> None:
|
|
61
|
+
self.collector = ActivationCollector(model=self.model, store_device=self.store_device, dtype=self.dtype)
|
|
62
|
+
self._last_result: TrainingResult | None = None
|
|
63
|
+
|
|
64
|
+
def run(
|
|
65
|
+
self,
|
|
66
|
+
layers_spec: Sequence[str] | str | int | Sequence[int] | None,
|
|
67
|
+
method_kwargs: dict[str, Any] | None = None,
|
|
68
|
+
aggregation: ActivationAggregationStrategy = ActivationAggregationStrategy.CONTINUATION_TOKEN,
|
|
69
|
+
return_full_sequence: bool = False,
|
|
70
|
+
normalize_layers: bool = False,
|
|
71
|
+
save_dir: str | Path | None = None,
|
|
72
|
+
) -> TrainingResult:
|
|
73
|
+
"""
|
|
74
|
+
Full pipeline:
|
|
75
|
+
1) Decide which layers to use (from spec or all layers if None).
|
|
76
|
+
2) Collect activations for each pair at these layers.
|
|
77
|
+
3) Train steering vectors using the selected method.
|
|
78
|
+
4) Return a TrainingResult with vectors, enriched pair set, and metadata.
|
|
79
|
+
5) Optionally save artifacts to disk.
|
|
80
|
+
|
|
81
|
+
arguments:
|
|
82
|
+
layers_spec:
|
|
83
|
+
- list like ["10","20","30"] or [10, 20, 30]
|
|
84
|
+
- range string "10-30" / "10..30"
|
|
85
|
+
- single int "12"
|
|
86
|
+
- None → use all available layers on the model
|
|
87
|
+
method:
|
|
88
|
+
Name of steering method ("caa", "bipo", ...).
|
|
89
|
+
method_kwargs:
|
|
90
|
+
Dict of hyperparameters for the method (e.g., {"normalize": True, "scale": 1.0}).
|
|
91
|
+
aggregation:
|
|
92
|
+
ActivationAggregationStrategy to use during collection when not returning
|
|
93
|
+
full sequences. Ignored if 'return_full_sequence=True'.
|
|
94
|
+
return_full_sequence:
|
|
95
|
+
If True, store full [T,H] sequences per layer (method then must know how
|
|
96
|
+
to collapse to vectors). Default False (collect [H] vectors directly).
|
|
97
|
+
normalize_layers:
|
|
98
|
+
If True, L2-normalize activations layer-wise during collection.
|
|
99
|
+
save_dir:
|
|
100
|
+
If provided, artifacts are written there. Directory is created if missing.
|
|
101
|
+
|
|
102
|
+
returns:
|
|
103
|
+
TrainingResult
|
|
104
|
+
"""
|
|
105
|
+
method_kwargs = method_kwargs or {}
|
|
106
|
+
|
|
107
|
+
# 1) Resolve layer names
|
|
108
|
+
layers = self._resolve_layers(layers_spec)
|
|
109
|
+
|
|
110
|
+
# 2) Collect activations for each pair
|
|
111
|
+
for i, pair in enumerate(self.pair_set.pairs):
|
|
112
|
+
updated = self.collector.collect_for_pair(
|
|
113
|
+
pair,
|
|
114
|
+
layers=layers,
|
|
115
|
+
aggregation=aggregation,
|
|
116
|
+
return_full_sequence=return_full_sequence,
|
|
117
|
+
normalize_layers=normalize_layers,
|
|
118
|
+
)
|
|
119
|
+
self.pair_set.pairs[i] = updated
|
|
120
|
+
|
|
121
|
+
# 3) Train using selected method
|
|
122
|
+
raw_vectors: RawActivationMap = self.steering_method.train(self.pair_set, **(method_kwargs or {}))
|
|
123
|
+
|
|
124
|
+
steered = LayerActivations(raw_vectors)
|
|
125
|
+
|
|
126
|
+
control_vector_report = run_control_vector_diagnostics(steered)
|
|
127
|
+
for issue in control_vector_report.issues:
|
|
128
|
+
log_method = logger.error if issue.severity == "critical" else logger.warning
|
|
129
|
+
log_method(
|
|
130
|
+
"[control_vector diagnostics] %s (details=%s)",
|
|
131
|
+
issue.message,
|
|
132
|
+
issue.details,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
control_vector_summary = control_vector_report.summary.get("control_vectors", {})
|
|
136
|
+
control_vector_issues = [
|
|
137
|
+
{
|
|
138
|
+
"metric": issue.metric,
|
|
139
|
+
"severity": issue.severity,
|
|
140
|
+
"message": issue.message,
|
|
141
|
+
"details": issue.details,
|
|
142
|
+
}
|
|
143
|
+
for issue in control_vector_report.issues
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
if control_vector_report.has_critical_issues:
|
|
147
|
+
raise ValueError("Control vector diagnostics found critical issues; see logs for specifics.")
|
|
148
|
+
|
|
149
|
+
# 4) Metadata
|
|
150
|
+
now = _dt.datetime.now().astimezone()
|
|
151
|
+
metadata: dict[str, Any] = {
|
|
152
|
+
"timestamp": now.isoformat(),
|
|
153
|
+
"model_name": getattr(self.model, "model_name", getattr(self.model, "name", None)),
|
|
154
|
+
"layers_used": layers or "all",
|
|
155
|
+
"method": self.steering_method.name,
|
|
156
|
+
"method_kwargs": method_kwargs,
|
|
157
|
+
"activation_aggregation_strategy": (None if return_full_sequence else aggregation),
|
|
158
|
+
"return_full_sequence": bool(return_full_sequence),
|
|
159
|
+
"normalize_layers": bool(normalize_layers),
|
|
160
|
+
"num_pairs": len(self.pair_set.pairs),
|
|
161
|
+
"hidden_size": getattr(self.model, "hidden_size", None),
|
|
162
|
+
"control_vector_diagnostics": control_vector_summary,
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if control_vector_issues:
|
|
166
|
+
metadata["control_vector_issues"] = control_vector_issues
|
|
167
|
+
|
|
168
|
+
result = TrainingResult(steered_vectors=steered, pair_set_with_activations=self.pair_set, metadata=metadata)
|
|
169
|
+
self._last_result = result
|
|
170
|
+
|
|
171
|
+
# 5) Optional save
|
|
172
|
+
if save_dir is not None:
|
|
173
|
+
self.save_result(save_dir, result)
|
|
174
|
+
|
|
175
|
+
return result
|
|
176
|
+
|
|
177
|
+
def save_result(self, output_dir: str | Path, result: TrainingResult | None = None) -> Path:
|
|
178
|
+
"""
|
|
179
|
+
Persist vectors, metadata, and the pair set (with activations) to disk.
|
|
180
|
+
|
|
181
|
+
Files written:
|
|
182
|
+
- metadata.json (JSON)
|
|
183
|
+
- steering_vectors.pt (torch.save of dict[layer]->tensor on CPU)
|
|
184
|
+
- pairs_with_activations.pt (torch.save of the full ContrastivePairSet object)
|
|
185
|
+
- steering_vectors_summary.json (shapes/dtypes only, human-readable)
|
|
186
|
+
|
|
187
|
+
returns:
|
|
188
|
+
Path to the created directory.
|
|
189
|
+
"""
|
|
190
|
+
result = result or self._last_result
|
|
191
|
+
if result is None:
|
|
192
|
+
raise RuntimeError("No result to save. Run the trainer first.")
|
|
193
|
+
|
|
194
|
+
out = Path(output_dir)
|
|
195
|
+
out.mkdir(parents=True, exist_ok=True)
|
|
196
|
+
|
|
197
|
+
# Vectors
|
|
198
|
+
raw_map: RawActivationMap = result.steered_vectors.to_dict() # still tensors
|
|
199
|
+
cpu_map = {k: (v.detach().to("cpu") if isinstance(v, torch.Tensor) else v) for k, v in raw_map.items() if k != "_activation_aggregation_strategy"}
|
|
200
|
+
torch.save(cpu_map, out / "steering_vectors.pt")
|
|
201
|
+
|
|
202
|
+
# Summary (json-serializable)
|
|
203
|
+
vec_summary = {
|
|
204
|
+
k: None if v is None else {
|
|
205
|
+
"shape": tuple(v.shape),
|
|
206
|
+
"dtype": str(v.dtype),
|
|
207
|
+
}
|
|
208
|
+
for k, v in cpu_map.items()
|
|
209
|
+
}
|
|
210
|
+
(out / "steering_vectors_summary.json").write_text(json.dumps(vec_summary, indent=2))
|
|
211
|
+
|
|
212
|
+
# Metadata
|
|
213
|
+
(out / "metadata.json").write_text(json.dumps(result.metadata, indent=2))
|
|
214
|
+
|
|
215
|
+
# Full pair set with activations (Python pickle via torch.save)
|
|
216
|
+
torch.save(result.pair_set_with_activations, out / "pairs_with_activations.pt")
|
|
217
|
+
|
|
218
|
+
return out
|
|
219
|
+
|
|
220
|
+
def _resolve_layers(self, spec: Sequence[str] | str | int | Sequence[int] | None) -> list[str] | None:
|
|
221
|
+
"""
|
|
222
|
+
Convert a user-facing spec into canonical layer names ("1","2",...).
|
|
223
|
+
If None, return None (meaning: use all layers in the collector/model).
|
|
224
|
+
|
|
225
|
+
arguments:
|
|
226
|
+
spec: See 'layers_spec' argument in run().
|
|
227
|
+
|
|
228
|
+
returns:
|
|
229
|
+
Sorted list of layer names as strings, or None.
|
|
230
|
+
|
|
231
|
+
examples:
|
|
232
|
+
None -> None
|
|
233
|
+
"10-12" -> ["10","11","12"]
|
|
234
|
+
[5,10,15] -> ["5","10","15"]
|
|
235
|
+
"3,7,10..12" -> ["3","7","10","11","12"]
|
|
236
|
+
8 -> ["8"]
|
|
237
|
+
"""
|
|
238
|
+
if spec is None:
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
if isinstance(spec, (list, tuple)):
|
|
242
|
+
names: list[str] = []
|
|
243
|
+
for item in spec:
|
|
244
|
+
if isinstance(item, int):
|
|
245
|
+
names.append(str(item))
|
|
246
|
+
else:
|
|
247
|
+
names.extend(self._parse_layer_token(item))
|
|
248
|
+
return sorted(set(names), key=lambda s: (len(s), s))
|
|
249
|
+
|
|
250
|
+
if isinstance(spec, int):
|
|
251
|
+
return [str(spec)]
|
|
252
|
+
|
|
253
|
+
names: list[str] = []
|
|
254
|
+
for token in str(spec).replace(" ", "").split(","):
|
|
255
|
+
names.extend(self._parse_layer_token(token))
|
|
256
|
+
return sorted(set(names), key=lambda s: (len(s), s))
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def _parse_layer_token(token: str) -> list[str]:
|
|
260
|
+
"""
|
|
261
|
+
Parse a token like "5", "10-20", "10..20" into a list of names.
|
|
262
|
+
"""
|
|
263
|
+
if not token:
|
|
264
|
+
return []
|
|
265
|
+
if "-" in token or ".." in token:
|
|
266
|
+
a, b = token.replace("..", "-").split("-")
|
|
267
|
+
a_i, b_i = int(a), int(b)
|
|
268
|
+
lo, hi = (a_i, b_i) if a_i <= b_i else (b_i, a_i)
|
|
269
|
+
return [str(i) for i in range(lo, hi + 1)]
|
|
270
|
+
else:
|
|
271
|
+
return [str(int(token))]
|