wisent 0.1.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of wisent might be problematic. Click here for more details.

Files changed (237) hide show
  1. wisent/__init__.py +1 -8
  2. wisent/benchmarks/__init__.py +0 -0
  3. wisent/benchmarks/coding/__init__.py +0 -0
  4. wisent/benchmarks/coding/metrics/__init__.py +0 -0
  5. wisent/benchmarks/coding/metrics/core/__init__.py +0 -0
  6. wisent/benchmarks/coding/metrics/core/atoms.py +36 -0
  7. wisent/benchmarks/coding/metrics/evaluator.py +275 -0
  8. wisent/benchmarks/coding/metrics/passk.py +66 -0
  9. wisent/benchmarks/coding/output_sanitizer/__init__.py +0 -0
  10. wisent/benchmarks/coding/output_sanitizer/core/__init__.py +0 -0
  11. wisent/benchmarks/coding/output_sanitizer/core/atoms.py +27 -0
  12. wisent/benchmarks/coding/output_sanitizer/cpp_sanitizer.py +62 -0
  13. wisent/benchmarks/coding/output_sanitizer/java_sanitizer.py +78 -0
  14. wisent/benchmarks/coding/output_sanitizer/python_sanitizer.py +94 -0
  15. wisent/benchmarks/coding/output_sanitizer/utils.py +107 -0
  16. wisent/benchmarks/coding/providers/__init__.py +18 -0
  17. wisent/benchmarks/coding/providers/core/__init__.py +0 -0
  18. wisent/benchmarks/coding/providers/core/atoms.py +31 -0
  19. wisent/benchmarks/coding/providers/livecodebench/__init__.py +0 -0
  20. wisent/benchmarks/coding/providers/livecodebench/provider.py +53 -0
  21. wisent/benchmarks/coding/safe_docker/__init__.py +0 -0
  22. wisent/benchmarks/coding/safe_docker/core/__init__.py +0 -0
  23. wisent/benchmarks/coding/safe_docker/core/atoms.py +105 -0
  24. wisent/benchmarks/coding/safe_docker/core/runtime.py +118 -0
  25. wisent/benchmarks/coding/safe_docker/entrypoint.py +123 -0
  26. wisent/benchmarks/coding/safe_docker/recipes.py +60 -0
  27. wisent/classifiers/__init__.py +0 -0
  28. wisent/classifiers/core/__init__.py +0 -0
  29. wisent/classifiers/core/atoms.py +747 -0
  30. wisent/classifiers/models/__init__.py +0 -0
  31. wisent/classifiers/models/logistic.py +29 -0
  32. wisent/classifiers/models/mlp.py +47 -0
  33. wisent/cli/__init__.py +0 -0
  34. wisent/cli/classifiers/__init__.py +0 -0
  35. wisent/cli/classifiers/classifier_rotator.py +137 -0
  36. wisent/cli/cli_logger.py +142 -0
  37. wisent/cli/data_loaders/__init__.py +0 -0
  38. wisent/cli/data_loaders/data_loader_rotator.py +96 -0
  39. wisent/cli/evaluators/__init__.py +0 -0
  40. wisent/cli/evaluators/evaluator_rotator.py +148 -0
  41. wisent/cli/steering_methods/__init__.py +0 -0
  42. wisent/cli/steering_methods/steering_rotator.py +110 -0
  43. wisent/cli/wisent_cli/__init__.py +0 -0
  44. wisent/cli/wisent_cli/commands/__init__.py +0 -0
  45. wisent/cli/wisent_cli/commands/help_cmd.py +52 -0
  46. wisent/cli/wisent_cli/commands/listing.py +154 -0
  47. wisent/cli/wisent_cli/commands/train_cmd.py +322 -0
  48. wisent/cli/wisent_cli/main.py +93 -0
  49. wisent/cli/wisent_cli/shell.py +80 -0
  50. wisent/cli/wisent_cli/ui.py +69 -0
  51. wisent/cli/wisent_cli/util/__init__.py +0 -0
  52. wisent/cli/wisent_cli/util/aggregations.py +43 -0
  53. wisent/cli/wisent_cli/util/parsing.py +126 -0
  54. wisent/cli/wisent_cli/version.py +4 -0
  55. wisent/core/__init__.py +27 -0
  56. wisent/core/activations/__init__.py +0 -0
  57. wisent/core/activations/activations_collector.py +338 -0
  58. wisent/core/activations/core/__init__.py +0 -0
  59. wisent/core/activations/core/atoms.py +216 -0
  60. wisent/core/agent/__init__.py +18 -0
  61. wisent/core/agent/budget.py +638 -0
  62. wisent/core/agent/device_benchmarks.py +685 -0
  63. wisent/core/agent/diagnose/__init__.py +55 -0
  64. wisent/core/agent/diagnose/agent_classifier_decision.py +641 -0
  65. wisent/core/agent/diagnose/classifier_marketplace.py +554 -0
  66. wisent/core/agent/diagnose/create_classifier.py +1154 -0
  67. wisent/core/agent/diagnose/response_diagnostics.py +268 -0
  68. wisent/core/agent/diagnose/select_classifiers.py +506 -0
  69. wisent/core/agent/diagnose/synthetic_classifier_option.py +754 -0
  70. wisent/core/agent/diagnose/tasks/__init__.py +33 -0
  71. wisent/core/agent/diagnose/tasks/task_manager.py +1456 -0
  72. wisent/core/agent/diagnose/tasks/task_relevance.py +94 -0
  73. wisent/core/agent/diagnose/tasks/task_selector.py +151 -0
  74. wisent/core/agent/diagnose/test_synthetic_classifier.py +71 -0
  75. wisent/core/agent/diagnose.py +242 -0
  76. wisent/core/agent/steer.py +212 -0
  77. wisent/core/agent/timeout.py +134 -0
  78. wisent/core/autonomous_agent.py +1234 -0
  79. wisent/core/bigcode_integration.py +583 -0
  80. wisent/core/contrastive_pairs/__init__.py +15 -0
  81. wisent/core/contrastive_pairs/core/__init__.py +0 -0
  82. wisent/core/contrastive_pairs/core/atoms.py +45 -0
  83. wisent/core/contrastive_pairs/core/buliders.py +59 -0
  84. wisent/core/contrastive_pairs/core/pair.py +178 -0
  85. wisent/core/contrastive_pairs/core/response.py +152 -0
  86. wisent/core/contrastive_pairs/core/serialization.py +300 -0
  87. wisent/core/contrastive_pairs/core/set.py +133 -0
  88. wisent/core/contrastive_pairs/diagnostics/__init__.py +45 -0
  89. wisent/core/contrastive_pairs/diagnostics/activations.py +53 -0
  90. wisent/core/contrastive_pairs/diagnostics/base.py +73 -0
  91. wisent/core/contrastive_pairs/diagnostics/control_vectors.py +169 -0
  92. wisent/core/contrastive_pairs/diagnostics/coverage.py +79 -0
  93. wisent/core/contrastive_pairs/diagnostics/divergence.py +98 -0
  94. wisent/core/contrastive_pairs/diagnostics/duplicates.py +116 -0
  95. wisent/core/contrastive_pairs/lm_eval_pairs/__init__.py +0 -0
  96. wisent/core/contrastive_pairs/lm_eval_pairs/atoms.py +238 -0
  97. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_manifest.py +8 -0
  98. wisent/core/contrastive_pairs/lm_eval_pairs/lm_extractor_registry.py +132 -0
  99. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/__init__.py +0 -0
  100. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_extractors/winogrande.py +115 -0
  101. wisent/core/contrastive_pairs/lm_eval_pairs/lm_task_pairs_generation.py +50 -0
  102. wisent/core/data_loaders/__init__.py +0 -0
  103. wisent/core/data_loaders/core/__init__.py +0 -0
  104. wisent/core/data_loaders/core/atoms.py +98 -0
  105. wisent/core/data_loaders/loaders/__init__.py +0 -0
  106. wisent/core/data_loaders/loaders/custom.py +120 -0
  107. wisent/core/data_loaders/loaders/lm_loader.py +218 -0
  108. wisent/core/detection_handling.py +257 -0
  109. wisent/core/download_full_benchmarks.py +1386 -0
  110. wisent/core/evaluators/__init__.py +0 -0
  111. wisent/core/evaluators/oracles/__init__.py +0 -0
  112. wisent/core/evaluators/oracles/interactive.py +73 -0
  113. wisent/core/evaluators/oracles/nlp_evaluator.py +440 -0
  114. wisent/core/evaluators/oracles/user_specified.py +67 -0
  115. wisent/core/hyperparameter_optimizer.py +429 -0
  116. wisent/core/lm_eval_harness_ground_truth.py +1396 -0
  117. wisent/core/log_likelihoods_evaluator.py +321 -0
  118. wisent/core/managed_cached_benchmarks.py +595 -0
  119. wisent/core/mixed_benchmark_sampler.py +364 -0
  120. wisent/core/model_config_manager.py +330 -0
  121. wisent/core/model_persistence.py +317 -0
  122. wisent/core/models/__init__.py +0 -0
  123. wisent/core/models/core/__init__.py +0 -0
  124. wisent/core/models/core/atoms.py +460 -0
  125. wisent/core/models/wisent_model.py +727 -0
  126. wisent/core/multi_steering.py +316 -0
  127. wisent/core/optuna/__init__.py +57 -0
  128. wisent/core/optuna/classifier/__init__.py +25 -0
  129. wisent/core/optuna/classifier/activation_generator.py +349 -0
  130. wisent/core/optuna/classifier/classifier_cache.py +509 -0
  131. wisent/core/optuna/classifier/optuna_classifier_optimizer.py +606 -0
  132. wisent/core/optuna/steering/__init__.py +0 -0
  133. wisent/core/optuna/steering/bigcode_evaluator_wrapper.py +188 -0
  134. wisent/core/optuna/steering/data_utils.py +342 -0
  135. wisent/core/optuna/steering/metrics.py +474 -0
  136. wisent/core/optuna/steering/optuna_pipeline.py +1738 -0
  137. wisent/core/optuna/steering/steering_optimization.py +1111 -0
  138. wisent/core/parser.py +1668 -0
  139. wisent/core/prompts/__init__.py +0 -0
  140. wisent/core/prompts/core/__init__.py +0 -0
  141. wisent/core/prompts/core/atom.py +57 -0
  142. wisent/core/prompts/core/prompt_formater.py +157 -0
  143. wisent/core/prompts/prompt_stratiegies/__init__.py +0 -0
  144. wisent/core/prompts/prompt_stratiegies/direct_completion.py +24 -0
  145. wisent/core/prompts/prompt_stratiegies/instruction_following.py +24 -0
  146. wisent/core/prompts/prompt_stratiegies/multiple_choice.py +29 -0
  147. wisent/core/prompts/prompt_stratiegies/role_playing.py +31 -0
  148. wisent/core/representation.py +5 -0
  149. wisent/core/sample_size_optimizer.py +648 -0
  150. wisent/core/sample_size_optimizer_v2.py +355 -0
  151. wisent/core/save_results.py +277 -0
  152. wisent/core/steering.py +652 -0
  153. wisent/core/steering_method.py +26 -0
  154. wisent/core/steering_methods/__init__.py +0 -0
  155. wisent/core/steering_methods/core/__init__.py +0 -0
  156. wisent/core/steering_methods/core/atoms.py +153 -0
  157. wisent/core/steering_methods/methods/__init__.py +0 -0
  158. wisent/core/steering_methods/methods/caa.py +44 -0
  159. wisent/core/steering_optimizer.py +1297 -0
  160. wisent/core/task_interface.py +132 -0
  161. wisent/core/task_selector.py +189 -0
  162. wisent/core/tasks/__init__.py +175 -0
  163. wisent/core/tasks/aime_task.py +141 -0
  164. wisent/core/tasks/file_task.py +211 -0
  165. wisent/core/tasks/hle_task.py +180 -0
  166. wisent/core/tasks/hmmt_task.py +119 -0
  167. wisent/core/tasks/livecodebench_task.py +201 -0
  168. wisent/core/tasks/livemathbench_task.py +158 -0
  169. wisent/core/tasks/lm_eval_task.py +455 -0
  170. wisent/core/tasks/math500_task.py +84 -0
  171. wisent/core/tasks/polymath_task.py +146 -0
  172. wisent/core/tasks/supergpqa_task.py +220 -0
  173. wisent/core/time_estimator.py +149 -0
  174. wisent/core/timing_calibration.py +174 -0
  175. wisent/core/tracking/__init__.py +54 -0
  176. wisent/core/tracking/latency.py +618 -0
  177. wisent/core/tracking/memory.py +359 -0
  178. wisent/core/trainers/__init__.py +0 -0
  179. wisent/core/trainers/core/__init__.py +11 -0
  180. wisent/core/trainers/core/atoms.py +45 -0
  181. wisent/core/trainers/steering_trainer.py +271 -0
  182. wisent/core/user_model_config.py +158 -0
  183. wisent/opti/__init__.py +0 -0
  184. wisent/opti/core/__init__.py +0 -0
  185. wisent/opti/core/atoms.py +175 -0
  186. wisent/opti/methods/__init__.py +0 -0
  187. wisent/opti/methods/opti_classificator.py +172 -0
  188. wisent/opti/methods/opti_steering.py +138 -0
  189. wisent/synthetic/__init__.py +0 -0
  190. wisent/synthetic/cleaners/__init__.py +0 -0
  191. wisent/synthetic/cleaners/core/__init__.py +0 -0
  192. wisent/synthetic/cleaners/core/atoms.py +58 -0
  193. wisent/synthetic/cleaners/deduper_cleaner.py +53 -0
  194. wisent/synthetic/cleaners/methods/__init__.py +0 -0
  195. wisent/synthetic/cleaners/methods/base_dedupers.py +320 -0
  196. wisent/synthetic/cleaners/methods/base_refusalers.py +286 -0
  197. wisent/synthetic/cleaners/methods/core/__init__.py +0 -0
  198. wisent/synthetic/cleaners/methods/core/atoms.py +47 -0
  199. wisent/synthetic/cleaners/pairs_cleaner.py +90 -0
  200. wisent/synthetic/cleaners/refusaler_cleaner.py +133 -0
  201. wisent/synthetic/db_instructions/__init__.py +0 -0
  202. wisent/synthetic/db_instructions/core/__init__.py +0 -0
  203. wisent/synthetic/db_instructions/core/atoms.py +25 -0
  204. wisent/synthetic/db_instructions/mini_dp.py +37 -0
  205. wisent/synthetic/generators/__init__.py +0 -0
  206. wisent/synthetic/generators/core/__init__.py +0 -0
  207. wisent/synthetic/generators/core/atoms.py +73 -0
  208. wisent/synthetic/generators/diversities/__init__.py +0 -0
  209. wisent/synthetic/generators/diversities/core/__init__.py +0 -0
  210. wisent/synthetic/generators/diversities/core/core.py +68 -0
  211. wisent/synthetic/generators/diversities/methods/__init__.py +0 -0
  212. wisent/synthetic/generators/diversities/methods/fast_diversity.py +249 -0
  213. wisent/synthetic/generators/pairs_generator.py +179 -0
  214. wisent-0.5.2.dist-info/METADATA +67 -0
  215. wisent-0.5.2.dist-info/RECORD +218 -0
  216. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/WHEEL +1 -1
  217. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info/licenses}/LICENSE +2 -2
  218. wisent/activations/__init__.py +0 -9
  219. wisent/activations/client.py +0 -97
  220. wisent/activations/extractor.py +0 -251
  221. wisent/activations/models.py +0 -95
  222. wisent/client.py +0 -45
  223. wisent/control_vector/__init__.py +0 -9
  224. wisent/control_vector/client.py +0 -85
  225. wisent/control_vector/manager.py +0 -168
  226. wisent/control_vector/models.py +0 -70
  227. wisent/inference/__init__.py +0 -9
  228. wisent/inference/client.py +0 -103
  229. wisent/inference/inferencer.py +0 -250
  230. wisent/inference/models.py +0 -66
  231. wisent/utils/__init__.py +0 -3
  232. wisent/utils/auth.py +0 -30
  233. wisent/utils/http.py +0 -228
  234. wisent/version.py +0 -3
  235. wisent-0.1.1.dist-info/METADATA +0 -142
  236. wisent-0.1.1.dist-info/RECORD +0 -23
  237. {wisent-0.1.1.dist-info → wisent-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,359 @@
1
+ """
2
+ Memory usage tracking for wisent-guard operations.
3
+
4
+ This module provides comprehensive memory monitoring capabilities including
5
+ GPU and CPU memory tracking, peak usage detection, and memory profiling.
6
+ """
7
+
8
+ import gc
9
+ import psutil
10
+ import time
11
+ import threading
12
+ from typing import Dict, List, Optional, Any, Callable
13
+ from dataclasses import dataclass, field
14
+ from contextlib import contextmanager
15
+ import torch
16
+
17
+ from wisent.core.utils.device import resolve_default_device
18
+
19
+ try:
20
+ import nvidia_ml_py3 as nvml
21
+ NVML_AVAILABLE = True
22
+ except ImportError:
23
+ NVML_AVAILABLE = False
24
+
25
+
26
+ @dataclass
27
+ class MemorySnapshot:
28
+ """Snapshot of memory usage at a specific point in time."""
29
+ timestamp: float
30
+ cpu_memory_mb: float
31
+ cpu_memory_percent: float
32
+ gpu_memory_mb: Optional[float] = None
33
+ gpu_memory_percent: Optional[float] = None
34
+ allocated_tensors: Optional[int] = None
35
+ cached_memory_mb: Optional[float] = None
36
+ operation: Optional[str] = None
37
+
38
+
39
+ @dataclass
40
+ class MemoryStats:
41
+ """Aggregated memory statistics over a period."""
42
+ peak_cpu_mb: float
43
+ peak_gpu_mb: Optional[float]
44
+ avg_cpu_mb: float
45
+ avg_gpu_mb: Optional[float]
46
+ min_cpu_mb: float
47
+ min_gpu_mb: Optional[float]
48
+ duration_seconds: float
49
+ snapshots: List[MemorySnapshot] = field(default_factory=list)
50
+ operations: List[str] = field(default_factory=list)
51
+
52
+
53
+ class MemoryTracker:
54
+ """
55
+ Comprehensive memory usage tracker for wisent-guard operations.
56
+
57
+ Tracks both CPU and GPU memory usage with optional continuous monitoring.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ track_gpu: bool = True,
63
+ sampling_interval: float = 0.1,
64
+ auto_cleanup: bool = True
65
+ ):
66
+ """
67
+ Initialize memory tracker.
68
+
69
+ Args:
70
+ track_gpu: Whether to track GPU memory (requires CUDA)
71
+ sampling_interval: How often to sample memory (seconds)
72
+ auto_cleanup: Whether to automatically run garbage collection
73
+ """
74
+ self.device_kind = resolve_default_device()
75
+ self.track_gpu = track_gpu and self.device_kind in {"cuda", "mps"}
76
+ self.sampling_interval = sampling_interval
77
+ self.auto_cleanup = auto_cleanup
78
+
79
+ self.snapshots: List[MemorySnapshot] = []
80
+ self.is_monitoring = False
81
+ self.monitor_thread: Optional[threading.Thread] = None
82
+ self.start_time: Optional[float] = None
83
+
84
+ # Initialize GPU monitoring if available
85
+ if self.track_gpu and self.device_kind == "cuda" and NVML_AVAILABLE:
86
+ try:
87
+ nvml.nvmlInit()
88
+ self.gpu_handle = nvml.nvmlDeviceGetHandleByIndex(0)
89
+ self.gpu_available = True
90
+ except Exception:
91
+ self.gpu_available = False
92
+ self.gpu_handle = None
93
+ else:
94
+ self.gpu_handle = None
95
+ self.gpu_available = False
96
+
97
+ def take_snapshot(self, operation: Optional[str] = None) -> MemorySnapshot:
98
+ """Take a single memory snapshot."""
99
+ timestamp = time.time()
100
+
101
+ # CPU memory
102
+ process = psutil.Process()
103
+ memory_info = process.memory_info()
104
+ cpu_memory_mb = memory_info.rss / 1024 / 1024
105
+ cpu_memory_percent = process.memory_percent()
106
+
107
+ # GPU memory
108
+ gpu_memory_mb = None
109
+ gpu_memory_percent = None
110
+ allocated_tensors = None
111
+ cached_memory_mb = None
112
+
113
+ if self.track_gpu:
114
+ if self.device_kind == "cuda" and torch.cuda.is_available():
115
+ gpu_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
116
+ cached_memory_mb = torch.cuda.memory_reserved() / 1024 / 1024
117
+ allocated_tensors = len(
118
+ [
119
+ obj
120
+ for obj in gc.get_objects()
121
+ if torch.is_tensor(obj) and getattr(obj, "is_cuda", False)
122
+ ]
123
+ )
124
+
125
+ if self.gpu_available and self.gpu_handle is not None:
126
+ try:
127
+ gpu_info = nvml.nvmlDeviceGetMemoryInfo(self.gpu_handle)
128
+ total_gpu_mb = gpu_info.total / 1024 / 1024
129
+ gpu_memory_percent = (gpu_memory_mb / total_gpu_mb) * 100
130
+ except Exception:
131
+ pass
132
+ elif self.device_kind == "mps" and hasattr(torch, "mps"):
133
+ try:
134
+ allocated_bytes = torch.mps.current_allocated_memory()
135
+ except AttributeError:
136
+ allocated_bytes = 0
137
+
138
+ try:
139
+ cached_bytes = torch.mps.driver_allocated_memory()
140
+ except AttributeError:
141
+ cached_bytes = allocated_bytes
142
+
143
+ gpu_memory_mb = allocated_bytes / 1024 / 1024
144
+ cached_memory_mb = cached_bytes / 1024 / 1024
145
+ allocated_tensors = len(
146
+ [
147
+ obj
148
+ for obj in gc.get_objects()
149
+ if torch.is_tensor(obj) and getattr(getattr(obj, "device", None), "type", None) == "mps"
150
+ ]
151
+ )
152
+
153
+ snapshot = MemorySnapshot(
154
+ timestamp=timestamp,
155
+ cpu_memory_mb=cpu_memory_mb,
156
+ cpu_memory_percent=cpu_memory_percent,
157
+ gpu_memory_mb=gpu_memory_mb,
158
+ gpu_memory_percent=gpu_memory_percent,
159
+ allocated_tensors=allocated_tensors,
160
+ cached_memory_mb=cached_memory_mb,
161
+ operation=operation
162
+ )
163
+
164
+ self.snapshots.append(snapshot)
165
+ return snapshot
166
+
167
+ def start_monitoring(self) -> None:
168
+ """Start continuous memory monitoring in a background thread."""
169
+ if self.is_monitoring:
170
+ return
171
+
172
+ self.is_monitoring = True
173
+ self.start_time = time.time()
174
+ self.snapshots.clear()
175
+
176
+ def monitor_loop():
177
+ while self.is_monitoring:
178
+ self.take_snapshot("continuous_monitoring")
179
+ time.sleep(self.sampling_interval)
180
+
181
+ self.monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
182
+ self.monitor_thread.start()
183
+
184
+ def stop_monitoring(self) -> MemoryStats:
185
+ """Stop continuous monitoring and return aggregated statistics."""
186
+ if not self.is_monitoring:
187
+ raise ValueError("Monitoring is not active")
188
+
189
+ self.is_monitoring = False
190
+ if self.monitor_thread:
191
+ self.monitor_thread.join()
192
+
193
+ return self.get_stats()
194
+
195
+ def get_stats(self) -> MemoryStats:
196
+ """Get aggregated memory statistics from collected snapshots."""
197
+ if not self.snapshots:
198
+ raise ValueError("No snapshots available")
199
+
200
+ cpu_values = [s.cpu_memory_mb for s in self.snapshots]
201
+ gpu_values = [s.gpu_memory_mb for s in self.snapshots if s.gpu_memory_mb is not None]
202
+
203
+ duration = self.snapshots[-1].timestamp - self.snapshots[0].timestamp
204
+ operations = list(set(s.operation for s in self.snapshots if s.operation))
205
+
206
+ return MemoryStats(
207
+ peak_cpu_mb=max(cpu_values),
208
+ peak_gpu_mb=max(gpu_values) if gpu_values else None,
209
+ avg_cpu_mb=sum(cpu_values) / len(cpu_values),
210
+ avg_gpu_mb=sum(gpu_values) / len(gpu_values) if gpu_values else None,
211
+ min_cpu_mb=min(cpu_values),
212
+ min_gpu_mb=min(gpu_values) if gpu_values else None,
213
+ duration_seconds=duration,
214
+ snapshots=self.snapshots.copy(),
215
+ operations=operations
216
+ )
217
+
218
+ def clear_cache(self) -> None:
219
+ """Clear GPU cache and run garbage collection."""
220
+ if self.auto_cleanup:
221
+ gc.collect()
222
+ if self.device_kind == "cuda" and torch.cuda.is_available():
223
+ torch.cuda.empty_cache()
224
+ elif self.device_kind == "mps" and hasattr(torch, "mps"):
225
+ try:
226
+ torch.mps.empty_cache()
227
+ except AttributeError:
228
+ pass
229
+
230
+ def reset(self) -> None:
231
+ """Reset the tracker, clearing all snapshots."""
232
+ if self.is_monitoring:
233
+ self.stop_monitoring()
234
+ self.snapshots.clear()
235
+ self.start_time = None
236
+
237
+ @contextmanager
238
+ def track_operation(self, operation_name: str):
239
+ """Context manager to track memory usage during a specific operation."""
240
+ self.take_snapshot(f"{operation_name}_start")
241
+ start_time = time.time()
242
+
243
+ try:
244
+ yield self
245
+ finally:
246
+ end_time = time.time()
247
+ self.take_snapshot(f"{operation_name}_end")
248
+
249
+ if self.auto_cleanup:
250
+ self.clear_cache()
251
+
252
+ def get_current_usage(self) -> Dict[str, Any]:
253
+ """Get current memory usage without storing a snapshot."""
254
+ snapshot = self.take_snapshot("current_check")
255
+ self.snapshots.pop() # Remove the snapshot we just added
256
+
257
+ usage = {
258
+ "cpu_memory_mb": snapshot.cpu_memory_mb,
259
+ "cpu_memory_percent": snapshot.cpu_memory_percent,
260
+ }
261
+
262
+ if snapshot.gpu_memory_mb is not None:
263
+ usage.update({
264
+ "gpu_memory_mb": snapshot.gpu_memory_mb,
265
+ "gpu_memory_percent": snapshot.gpu_memory_percent,
266
+ "allocated_tensors": snapshot.allocated_tensors,
267
+ "cached_memory_mb": snapshot.cached_memory_mb,
268
+ })
269
+
270
+ return usage
271
+
272
+ def format_stats(self, stats: MemoryStats, detailed: bool = False) -> str:
273
+ """Format memory statistics as a readable string."""
274
+ lines = [
275
+ "Memory Usage Statistics:",
276
+ f" Duration: {stats.duration_seconds:.2f} seconds",
277
+ f" CPU Memory:",
278
+ f" Peak: {stats.peak_cpu_mb:.1f} MB",
279
+ f" Average: {stats.avg_cpu_mb:.1f} MB",
280
+ f" Minimum: {stats.min_cpu_mb:.1f} MB",
281
+ ]
282
+
283
+ if stats.peak_gpu_mb is not None:
284
+ lines.extend([
285
+ f" GPU Memory:",
286
+ f" Peak: {stats.peak_gpu_mb:.1f} MB",
287
+ f" Average: {stats.avg_gpu_mb:.1f} MB",
288
+ f" Minimum: {stats.min_gpu_mb:.1f} MB",
289
+ ])
290
+
291
+ if stats.operations:
292
+ lines.append(f" Operations: {', '.join(stats.operations)}")
293
+
294
+ if detailed and stats.snapshots:
295
+ lines.append(f" Snapshots: {len(stats.snapshots)} collected")
296
+
297
+ # Show peak usage snapshot
298
+ peak_snapshot = max(stats.snapshots, key=lambda s: s.cpu_memory_mb)
299
+ lines.extend([
300
+ f" Peak Usage Snapshot:",
301
+ f" Time: {peak_snapshot.timestamp:.2f}",
302
+ f" CPU: {peak_snapshot.cpu_memory_mb:.1f} MB ({peak_snapshot.cpu_memory_percent:.1f}%)",
303
+ ])
304
+
305
+ if peak_snapshot.gpu_memory_mb is not None:
306
+ lines.append(f" GPU: {peak_snapshot.gpu_memory_mb:.1f} MB")
307
+ if peak_snapshot.allocated_tensors is not None:
308
+ lines.append(f" Tensors: {peak_snapshot.allocated_tensors}")
309
+
310
+ return "\n".join(lines)
311
+
312
+
313
+ # Global memory tracker instance
314
+ _global_tracker: Optional[MemoryTracker] = None
315
+
316
+
317
+ def get_global_tracker() -> MemoryTracker:
318
+ """Get or create the global memory tracker instance."""
319
+ global _global_tracker
320
+ if _global_tracker is None:
321
+ _global_tracker = MemoryTracker()
322
+ return _global_tracker
323
+
324
+
325
+ def track_memory(operation_name: str):
326
+ """Decorator to track memory usage of a function."""
327
+ def decorator(func: Callable) -> Callable:
328
+ def wrapper(*args, **kwargs):
329
+ tracker = get_global_tracker()
330
+ with tracker.track_operation(operation_name):
331
+ return func(*args, **kwargs)
332
+ return wrapper
333
+ return decorator
334
+
335
+
336
+ def get_memory_info() -> Dict[str, Any]:
337
+ """Get current memory information without tracking."""
338
+ tracker = MemoryTracker(auto_cleanup=False)
339
+ return tracker.get_current_usage()
340
+
341
+
342
+ def format_memory_usage(usage: Dict[str, Any]) -> str:
343
+ """Format memory usage dictionary as a readable string."""
344
+ lines = [
345
+ f"CPU Memory: {usage['cpu_memory_mb']:.1f} MB ({usage['cpu_memory_percent']:.1f}%)"
346
+ ]
347
+
348
+ if 'gpu_memory_mb' in usage and usage['gpu_memory_mb'] is not None:
349
+ lines.append(f"GPU Memory: {usage['gpu_memory_mb']:.1f} MB")
350
+ if 'gpu_memory_percent' in usage and usage['gpu_memory_percent'] is not None:
351
+ lines[-1] += f" ({usage['gpu_memory_percent']:.1f}%)"
352
+
353
+ if 'cached_memory_mb' in usage:
354
+ lines.append(f"GPU Cached: {usage['cached_memory_mb']:.1f} MB")
355
+
356
+ if 'allocated_tensors' in usage:
357
+ lines.append(f"GPU Tensors: {usage['allocated_tensors']}")
358
+
359
+ return " | ".join(lines)
File without changes
@@ -0,0 +1,11 @@
1
+ #1 class WisentSteeringTrainer:
2
+ #2 trainer should load activation collector (sse provded file).
3
+ #3 shoudl load contrastive pair set (see provided file)
4
+ #4 should decide what type of sterring trainig method choose: caa, bipo, etc.
5
+ #4 should be able to but from which layer we collect activation and then use for each activanis and layer steering method to obtain steering vector.
6
+ #5 some method uses many actviations from layers aome only one. we need to be able to specify that. like user can say use layer 10, 20, 30 or use all layers from 10 to 30.
7
+ #6 after training user need to obtain contrastive piars set with collected activatioons (see provded file) and steered vectors which need to be LayerActivations class.
8
+ #7 we should save all the trained sterred vectors, with contrastive pairs with activations, and meta data like date, model name, layers used, method used, hyperparams used etc.
9
+
10
+ # Imporatat info: we also need to sepcyfy activation collection stategy (see LayerActivations). All provded files has good descriptions/docstrings. Plse wrtire code with that in mind. create two files: atoms.py
11
+ # where we defied all base structure for the trainers, all abtarct calss etc. and steering_trainer.py where we implement WisentSteeringTrainer class.
@@ -0,0 +1,45 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict
4
+
5
+ from wisent.core.activations.core.atoms import LayerActivations
6
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
7
+
8
+ __all__ = [
9
+ "TrainingResult",
10
+ ]
11
+
12
+ @dataclass(slots=True)
13
+ class TrainingResult:
14
+ """
15
+ Container returned by a trainer after running the full pipeline.
16
+
17
+ attributes:
18
+ steered_vectors:
19
+ Per-layer steering vectors in a LayerActivations mapping. Each value
20
+ is typically a 1D tensor of shape [H].
21
+ pair_set_with_activations:
22
+ The original ContrastivePairSet, but with per-pair, per-layer activations
23
+ collected and stored in the Positive/NegativeResponse objects.
24
+ metadata:
25
+ A JSON-serializable dictionary with run metadata
26
+ (date, model_name, layers, method, hyperparams, aggregation, etc.).
27
+ """
28
+ steered_vectors: LayerActivations
29
+ pair_set_with_activations: ContrastivePairSet
30
+ metadata: Dict[str, Any]
31
+
32
+ class BaseSteeringTrainer(ABC):
33
+ """
34
+ Abstract interface for a trainer that orchestrates:
35
+ 1) Collecting activations for a set of contrastive pairs
36
+ 2) Training a steering vector(s) using a chosen method
37
+ 3) Returning a TrainingResult and (optionally) saving artifacts
38
+ """
39
+
40
+ @abstractmethod
41
+ def run(self, *args: Any, **kwargs: Any) -> TrainingResult:
42
+ """
43
+ Execute the full pipeline and return a TrainingResult.
44
+ """
45
+ ...
@@ -0,0 +1,271 @@
1
+ from __future__ import annotations
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any, Sequence
6
+
7
+ import json
8
+ import torch
9
+ import datetime as _dt
10
+
11
+ from wisent.core.activations.core.atoms import (
12
+ LayerActivations,
13
+ ActivationAggregationStrategy,
14
+ RawActivationMap,
15
+ )
16
+ from wisent.core.models.wisent_model import WisentModel
17
+
18
+ from wisent.core.trainers.core.atoms import (
19
+ TrainingResult,
20
+ BaseSteeringTrainer
21
+ )
22
+
23
+ from wisent.core.contrastive_pairs.core.set import ContrastivePairSet
24
+ from wisent.core.activations.activations_collector import ActivationCollector
25
+ from wisent.core.steering_methods.core.atoms import BaseSteeringMethod
26
+ from wisent.core.contrastive_pairs.diagnostics import run_control_vector_diagnostics
27
+
28
+ __all__ = [
29
+ "WisentSteeringTrainer",
30
+ ]
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ @dataclass(slots=True)
36
+ class WisentSteeringTrainer(BaseSteeringTrainer):
37
+ """
38
+ Orchestrates activation collection + steering vector training for a given model and pair set.
39
+
40
+ Minimal usage:
41
+ trainer = WisentSteeringTrainer(model, pair_set, steering_method)
42
+ result = trainer.run(layers_spec=..., method_kwargs=..., aggregation=..., ...)
43
+ # result is a TrainingResult with steered vectors, enriched pair set, and metadata
44
+ trainer.save_result(output_dir) # optional save
45
+
46
+ arguments:
47
+ model: WisentModel to use for activation collection.
48
+ pair_set: ContrastivePairSet with pairs to use for collection and training.
49
+ steering_method: BaseSteeringMethod instance to use for training.
50
+ store_device: Device to store collected activations on (default "cpu").
51
+ dtype: Optional torch.dtype to cast collected activations to (default None, meaning no cast).
52
+ """
53
+
54
+ model: WisentModel
55
+ pair_set: ContrastivePairSet
56
+ steering_method: BaseSteeringMethod
57
+ store_device: str | torch.device = "cpu"
58
+ dtype: torch.dtype | None = None
59
+
60
+ def __post_init__(self) -> None:
61
+ self.collector = ActivationCollector(model=self.model, store_device=self.store_device, dtype=self.dtype)
62
+ self._last_result: TrainingResult | None = None
63
+
64
+ def run(
65
+ self,
66
+ layers_spec: Sequence[str] | str | int | Sequence[int] | None,
67
+ method_kwargs: dict[str, Any] | None = None,
68
+ aggregation: ActivationAggregationStrategy = ActivationAggregationStrategy.CONTINUATION_TOKEN,
69
+ return_full_sequence: bool = False,
70
+ normalize_layers: bool = False,
71
+ save_dir: str | Path | None = None,
72
+ ) -> TrainingResult:
73
+ """
74
+ Full pipeline:
75
+ 1) Decide which layers to use (from spec or all layers if None).
76
+ 2) Collect activations for each pair at these layers.
77
+ 3) Train steering vectors using the selected method.
78
+ 4) Return a TrainingResult with vectors, enriched pair set, and metadata.
79
+ 5) Optionally save artifacts to disk.
80
+
81
+ arguments:
82
+ layers_spec:
83
+ - list like ["10","20","30"] or [10, 20, 30]
84
+ - range string "10-30" / "10..30"
85
+ - single int "12"
86
+ - None → use all available layers on the model
87
+ method:
88
+ Name of steering method ("caa", "bipo", ...).
89
+ method_kwargs:
90
+ Dict of hyperparameters for the method (e.g., {"normalize": True, "scale": 1.0}).
91
+ aggregation:
92
+ ActivationAggregationStrategy to use during collection when not returning
93
+ full sequences. Ignored if 'return_full_sequence=True'.
94
+ return_full_sequence:
95
+ If True, store full [T,H] sequences per layer (method then must know how
96
+ to collapse to vectors). Default False (collect [H] vectors directly).
97
+ normalize_layers:
98
+ If True, L2-normalize activations layer-wise during collection.
99
+ save_dir:
100
+ If provided, artifacts are written there. Directory is created if missing.
101
+
102
+ returns:
103
+ TrainingResult
104
+ """
105
+ method_kwargs = method_kwargs or {}
106
+
107
+ # 1) Resolve layer names
108
+ layers = self._resolve_layers(layers_spec)
109
+
110
+ # 2) Collect activations for each pair
111
+ for i, pair in enumerate(self.pair_set.pairs):
112
+ updated = self.collector.collect_for_pair(
113
+ pair,
114
+ layers=layers,
115
+ aggregation=aggregation,
116
+ return_full_sequence=return_full_sequence,
117
+ normalize_layers=normalize_layers,
118
+ )
119
+ self.pair_set.pairs[i] = updated
120
+
121
+ # 3) Train using selected method
122
+ raw_vectors: RawActivationMap = self.steering_method.train(self.pair_set, **(method_kwargs or {}))
123
+
124
+ steered = LayerActivations(raw_vectors)
125
+
126
+ control_vector_report = run_control_vector_diagnostics(steered)
127
+ for issue in control_vector_report.issues:
128
+ log_method = logger.error if issue.severity == "critical" else logger.warning
129
+ log_method(
130
+ "[control_vector diagnostics] %s (details=%s)",
131
+ issue.message,
132
+ issue.details,
133
+ )
134
+
135
+ control_vector_summary = control_vector_report.summary.get("control_vectors", {})
136
+ control_vector_issues = [
137
+ {
138
+ "metric": issue.metric,
139
+ "severity": issue.severity,
140
+ "message": issue.message,
141
+ "details": issue.details,
142
+ }
143
+ for issue in control_vector_report.issues
144
+ ]
145
+
146
+ if control_vector_report.has_critical_issues:
147
+ raise ValueError("Control vector diagnostics found critical issues; see logs for specifics.")
148
+
149
+ # 4) Metadata
150
+ now = _dt.datetime.now().astimezone()
151
+ metadata: dict[str, Any] = {
152
+ "timestamp": now.isoformat(),
153
+ "model_name": getattr(self.model, "model_name", getattr(self.model, "name", None)),
154
+ "layers_used": layers or "all",
155
+ "method": self.steering_method.name,
156
+ "method_kwargs": method_kwargs,
157
+ "activation_aggregation_strategy": (None if return_full_sequence else aggregation),
158
+ "return_full_sequence": bool(return_full_sequence),
159
+ "normalize_layers": bool(normalize_layers),
160
+ "num_pairs": len(self.pair_set.pairs),
161
+ "hidden_size": getattr(self.model, "hidden_size", None),
162
+ "control_vector_diagnostics": control_vector_summary,
163
+ }
164
+
165
+ if control_vector_issues:
166
+ metadata["control_vector_issues"] = control_vector_issues
167
+
168
+ result = TrainingResult(steered_vectors=steered, pair_set_with_activations=self.pair_set, metadata=metadata)
169
+ self._last_result = result
170
+
171
+ # 5) Optional save
172
+ if save_dir is not None:
173
+ self.save_result(save_dir, result)
174
+
175
+ return result
176
+
177
+ def save_result(self, output_dir: str | Path, result: TrainingResult | None = None) -> Path:
178
+ """
179
+ Persist vectors, metadata, and the pair set (with activations) to disk.
180
+
181
+ Files written:
182
+ - metadata.json (JSON)
183
+ - steering_vectors.pt (torch.save of dict[layer]->tensor on CPU)
184
+ - pairs_with_activations.pt (torch.save of the full ContrastivePairSet object)
185
+ - steering_vectors_summary.json (shapes/dtypes only, human-readable)
186
+
187
+ returns:
188
+ Path to the created directory.
189
+ """
190
+ result = result or self._last_result
191
+ if result is None:
192
+ raise RuntimeError("No result to save. Run the trainer first.")
193
+
194
+ out = Path(output_dir)
195
+ out.mkdir(parents=True, exist_ok=True)
196
+
197
+ # Vectors
198
+ raw_map: RawActivationMap = result.steered_vectors.to_dict() # still tensors
199
+ cpu_map = {k: (v.detach().to("cpu") if isinstance(v, torch.Tensor) else v) for k, v in raw_map.items() if k != "_activation_aggregation_strategy"}
200
+ torch.save(cpu_map, out / "steering_vectors.pt")
201
+
202
+ # Summary (json-serializable)
203
+ vec_summary = {
204
+ k: None if v is None else {
205
+ "shape": tuple(v.shape),
206
+ "dtype": str(v.dtype),
207
+ }
208
+ for k, v in cpu_map.items()
209
+ }
210
+ (out / "steering_vectors_summary.json").write_text(json.dumps(vec_summary, indent=2))
211
+
212
+ # Metadata
213
+ (out / "metadata.json").write_text(json.dumps(result.metadata, indent=2))
214
+
215
+ # Full pair set with activations (Python pickle via torch.save)
216
+ torch.save(result.pair_set_with_activations, out / "pairs_with_activations.pt")
217
+
218
+ return out
219
+
220
+ def _resolve_layers(self, spec: Sequence[str] | str | int | Sequence[int] | None) -> list[str] | None:
221
+ """
222
+ Convert a user-facing spec into canonical layer names ("1","2",...).
223
+ If None, return None (meaning: use all layers in the collector/model).
224
+
225
+ arguments:
226
+ spec: See 'layers_spec' argument in run().
227
+
228
+ returns:
229
+ Sorted list of layer names as strings, or None.
230
+
231
+ examples:
232
+ None -> None
233
+ "10-12" -> ["10","11","12"]
234
+ [5,10,15] -> ["5","10","15"]
235
+ "3,7,10..12" -> ["3","7","10","11","12"]
236
+ 8 -> ["8"]
237
+ """
238
+ if spec is None:
239
+ return None
240
+
241
+ if isinstance(spec, (list, tuple)):
242
+ names: list[str] = []
243
+ for item in spec:
244
+ if isinstance(item, int):
245
+ names.append(str(item))
246
+ else:
247
+ names.extend(self._parse_layer_token(item))
248
+ return sorted(set(names), key=lambda s: (len(s), s))
249
+
250
+ if isinstance(spec, int):
251
+ return [str(spec)]
252
+
253
+ names: list[str] = []
254
+ for token in str(spec).replace(" ", "").split(","):
255
+ names.extend(self._parse_layer_token(token))
256
+ return sorted(set(names), key=lambda s: (len(s), s))
257
+
258
+ @staticmethod
259
+ def _parse_layer_token(token: str) -> list[str]:
260
+ """
261
+ Parse a token like "5", "10-20", "10..20" into a list of names.
262
+ """
263
+ if not token:
264
+ return []
265
+ if "-" in token or ".." in token:
266
+ a, b = token.replace("..", "-").split("-")
267
+ a_i, b_i = int(a), int(b)
268
+ lo, hi = (a_i, b_i) if a_i <= b_i else (b_i, a_i)
269
+ return [str(i) for i in range(lo, hi + 1)]
270
+ else:
271
+ return [str(int(token))]