ai-metacognition-toolkit 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. ai_metacognition/__init__.py +123 -0
  2. ai_metacognition/analyzers/__init__.py +24 -0
  3. ai_metacognition/analyzers/base.py +39 -0
  4. ai_metacognition/analyzers/counterfactual_cot.py +579 -0
  5. ai_metacognition/analyzers/model_api.py +39 -0
  6. ai_metacognition/detectors/__init__.py +40 -0
  7. ai_metacognition/detectors/base.py +42 -0
  8. ai_metacognition/detectors/observer_effect.py +651 -0
  9. ai_metacognition/detectors/sandbagging_detector.py +1438 -0
  10. ai_metacognition/detectors/situational_awareness.py +526 -0
  11. ai_metacognition/integrations/__init__.py +16 -0
  12. ai_metacognition/integrations/anthropic_api.py +230 -0
  13. ai_metacognition/integrations/base.py +113 -0
  14. ai_metacognition/integrations/openai_api.py +300 -0
  15. ai_metacognition/probing/__init__.py +24 -0
  16. ai_metacognition/probing/extraction.py +176 -0
  17. ai_metacognition/probing/hooks.py +200 -0
  18. ai_metacognition/probing/probes.py +186 -0
  19. ai_metacognition/probing/vectors.py +133 -0
  20. ai_metacognition/utils/__init__.py +48 -0
  21. ai_metacognition/utils/feature_extraction.py +534 -0
  22. ai_metacognition/utils/statistical_tests.py +317 -0
  23. ai_metacognition/utils/text_processing.py +98 -0
  24. ai_metacognition/visualizations/__init__.py +22 -0
  25. ai_metacognition/visualizations/plotting.py +523 -0
  26. ai_metacognition_toolkit-0.3.0.dist-info/METADATA +621 -0
  27. ai_metacognition_toolkit-0.3.0.dist-info/RECORD +30 -0
  28. ai_metacognition_toolkit-0.3.0.dist-info/WHEEL +5 -0
  29. ai_metacognition_toolkit-0.3.0.dist-info/licenses/LICENSE +21 -0
  30. ai_metacognition_toolkit-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,176 @@
1
+ """Contrastive Activation Addition (CAA) vector extraction.
2
+
3
+ Implements the core algorithm for extracting behavioral directions
4
+ from contrast pairs of prompts.
5
+
6
+ Reference: https://arxiv.org/abs/2310.01405 (Steering Language Models)
7
+ """
8
+
9
+ from typing import Dict, List, Literal, Tuple, Union
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ from .hooks import ActivationHook
15
+ from .vectors import SteeringVector
16
+
17
+
18
+ def extract_caa_vector(
19
+ model,
20
+ tokenizer,
21
+ contrast_pairs: List[Dict[str, str]],
22
+ layer_idx: int,
23
+ token_position: Literal["last", "first", "mean"] = "last",
24
+ behavior: str = "sandbagging",
25
+ show_progress: bool = True,
26
+ ) -> SteeringVector:
27
+ """Extract steering vector using Contrastive Activation Addition.
28
+
29
+ The core idea: compute mean(positive_acts) - mean(negative_acts)
30
+ to find the direction in activation space that corresponds to
31
+ the target behavior.
32
+
33
+ Args:
34
+ model: HuggingFace model
35
+ tokenizer: Corresponding tokenizer
36
+ contrast_pairs: List of dicts with "positive" and "negative" keys
37
+ layer_idx: Which layer to extract from
38
+ token_position: Which token position to use
39
+ behavior: Name of the behavior being extracted
40
+ show_progress: Show progress bar
41
+
42
+ Returns:
43
+ SteeringVector for the extracted direction
44
+ """
45
+ device = next(model.parameters()).device
46
+ model.eval()
47
+
48
+ positive_activations = []
49
+ negative_activations = []
50
+
51
+ iterator = tqdm(contrast_pairs, desc=f"Layer {layer_idx}", disable=not show_progress)
52
+
53
+ for pair in iterator:
54
+ pos_text = pair["positive"]
55
+ neg_text = pair["negative"]
56
+
57
+ # Extract positive activation
58
+ pos_act = _get_activation(
59
+ model, tokenizer, pos_text, layer_idx, token_position, device
60
+ )
61
+ positive_activations.append(pos_act)
62
+
63
+ # Extract negative activation
64
+ neg_act = _get_activation(
65
+ model, tokenizer, neg_text, layer_idx, token_position, device
66
+ )
67
+ negative_activations.append(neg_act)
68
+
69
+ # Compute mean activations
70
+ pos_mean = torch.stack(positive_activations).mean(dim=0)
71
+ neg_mean = torch.stack(negative_activations).mean(dim=0)
72
+
73
+ # NOTE: this is the core of CAA - surprisingly simple but it works
74
+ # see the original paper for theoretical justification
75
+ steering_vector = pos_mean - neg_mean
76
+
77
+ model_name = getattr(model.config, "_name_or_path", "unknown")
78
+
79
+ return SteeringVector(
80
+ behavior=behavior,
81
+ layer_index=layer_idx,
82
+ vector=steering_vector.cpu(),
83
+ model_name=model_name,
84
+ extraction_method="caa",
85
+ metadata={
86
+ "num_pairs": len(contrast_pairs),
87
+ "token_position": token_position,
88
+ "pos_mean_norm": pos_mean.norm().item(),
89
+ "neg_mean_norm": neg_mean.norm().item(),
90
+ },
91
+ )
92
+
93
+
94
+ def _get_activation(
95
+ model,
96
+ tokenizer,
97
+ text: str,
98
+ layer_idx: int,
99
+ token_position: str,
100
+ device,
101
+ ) -> torch.Tensor:
102
+ """Get activation for a single text."""
103
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
104
+ inputs = {k: v.to(device) for k, v in inputs.items()}
105
+
106
+ hook = ActivationHook(model, [layer_idx], component="residual", token_position="all")
107
+
108
+ with hook:
109
+ with torch.no_grad():
110
+ model(**inputs)
111
+
112
+ activation = hook.cache.get(f"layer_{layer_idx}")
113
+
114
+ if activation is None:
115
+ raise RuntimeError(f"Failed to capture activation at layer {layer_idx}")
116
+
117
+ # Select token position
118
+ if token_position == "last":
119
+ result = activation[0, -1, :]
120
+ elif token_position == "first":
121
+ result = activation[0, 0, :]
122
+ elif token_position == "mean":
123
+ result = activation[0].mean(dim=0)
124
+ else:
125
+ raise ValueError(f"Unknown token_position: {token_position}")
126
+
127
+ return result
128
+
129
+
130
+ def extract_activations(
131
+ model,
132
+ tokenizer,
133
+ texts: List[str],
134
+ layer_indices: List[int],
135
+ token_position: Literal["last", "first", "mean"] = "last",
136
+ show_progress: bool = True,
137
+ ) -> Dict[int, torch.Tensor]:
138
+ """Extract activations for multiple texts at specified layers."""
139
+ # FIXME: this is slow for large datasets, could batch but hook handling is tricky
140
+ device = next(model.parameters()).device
141
+ model.eval()
142
+
143
+ # Initialize storage
144
+ layer_activations = {idx: [] for idx in layer_indices}
145
+
146
+ iterator = tqdm(texts, desc="Extracting", disable=not show_progress)
147
+
148
+ for text in iterator:
149
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
150
+ inputs = {k: v.to(device) for k, v in inputs.items()}
151
+
152
+ hook = ActivationHook(model, layer_indices, component="residual", token_position="all")
153
+
154
+ with hook:
155
+ with torch.no_grad():
156
+ model(**inputs)
157
+
158
+ for layer_idx in layer_indices:
159
+ activation = hook.cache.get(f"layer_{layer_idx}")
160
+ if activation is None:
161
+ raise RuntimeError(f"Failed to capture layer {layer_idx}")
162
+
163
+ if token_position == "last":
164
+ act = activation[0, -1, :]
165
+ elif token_position == "first":
166
+ act = activation[0, 0, :]
167
+ else:
168
+ act = activation[0].mean(dim=0)
169
+
170
+ layer_activations[layer_idx].append(act.cpu())
171
+
172
+ # Stack into tensors
173
+ return {
174
+ idx: torch.stack(acts)
175
+ for idx, acts in layer_activations.items()
176
+ }
@@ -0,0 +1,200 @@
1
+ """Activation hooks for capturing hidden states from transformer models.
2
+
3
+ Provides non-invasive access to model activations during forward pass
4
+ for analysis and probing.
5
+ """
6
+
7
+ from typing import Any, Dict, List, Optional, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class ActivationCache:
14
+ """Cache for storing captured activations."""
15
+
16
+ def __init__(self):
17
+ self._cache: Dict[str, torch.Tensor] = {}
18
+
19
+ def store(self, key: str, value: torch.Tensor) -> None:
20
+ self._cache[key] = value.detach().clone()
21
+
22
+ def get(self, key: str) -> Optional[torch.Tensor]:
23
+ return self._cache.get(key)
24
+
25
+ def clear(self) -> None:
26
+ self._cache.clear()
27
+
28
+ def keys(self) -> List[str]:
29
+ return list(self._cache.keys())
30
+
31
+
32
+ class ActivationHook:
33
+ """Hook for capturing activations from specific model layers.
34
+
35
+ Works with HuggingFace transformers models (GPT-2, Mistral, Llama, etc).
36
+
37
+ Example:
38
+ >>> hook = ActivationHook(model, layer_indices=[10, 15, 20])
39
+ >>> with hook:
40
+ ... outputs = model(**inputs)
41
+ >>> act = hook.cache.get("layer_15") # (batch, seq, hidden)
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ model: nn.Module,
47
+ layer_indices: List[int],
48
+ component: str = "residual",
49
+ token_position: str = "all",
50
+ ):
51
+ """Initialize activation hook.
52
+
53
+ Args:
54
+ model: HuggingFace model to hook
55
+ layer_indices: Which layers to capture
56
+ component: What to capture - "residual", "attn", or "mlp"
57
+ token_position: "all", "last", or "first"
58
+ """
59
+ self.model = model
60
+ self.layer_indices = layer_indices
61
+ self.component = component
62
+ self.token_position = token_position
63
+ self.cache = ActivationCache()
64
+ self._handles: List[Any] = []
65
+
66
+ def _get_layers(self) -> nn.ModuleList:
67
+ """Get the transformer layers from the model."""
68
+ # XXX: this is ugly but HF doesn't have a consistent API for this
69
+ if hasattr(self.model, "model"):
70
+ inner = self.model.model
71
+ if hasattr(inner, "layers"):
72
+ return inner.layers # Llama, Mistral
73
+ elif hasattr(inner, "decoder"):
74
+ return inner.decoder.layers
75
+ if hasattr(self.model, "transformer"):
76
+ if hasattr(self.model.transformer, "h"):
77
+ return self.model.transformer.h # GPT-2
78
+ if hasattr(self.model, "gpt_neox"):
79
+ return self.model.gpt_neox.layers
80
+
81
+ # TODO: add support for more architectures as needed
82
+ raise ValueError("Could not find transformer layers in model architecture")
83
+
84
+ def _make_hook(self, layer_idx: int):
85
+ """Create a hook function for a specific layer."""
86
+ def hook_fn(module, input, output):
87
+ # Handle different output formats
88
+ if isinstance(output, tuple):
89
+ hidden_states = output[0]
90
+ else:
91
+ hidden_states = output
92
+
93
+ # Store based on token position
94
+ if self.token_position == "last":
95
+ self.cache.store(f"layer_{layer_idx}", hidden_states[:, -1:, :])
96
+ elif self.token_position == "first":
97
+ self.cache.store(f"layer_{layer_idx}", hidden_states[:, :1, :])
98
+ else: # all
99
+ self.cache.store(f"layer_{layer_idx}", hidden_states)
100
+
101
+ return hook_fn
102
+
103
+ def __enter__(self):
104
+ """Register hooks on specified layers."""
105
+ self.cache.clear()
106
+ layers = self._get_layers()
107
+
108
+ for idx in self.layer_indices:
109
+ if idx >= len(layers):
110
+ raise ValueError(f"Layer {idx} out of range (model has {len(layers)} layers)")
111
+
112
+ layer = layers[idx]
113
+ handle = layer.register_forward_hook(self._make_hook(idx))
114
+ self._handles.append(handle)
115
+
116
+ return self
117
+
118
+ def __exit__(self, exc_type, exc_val, exc_tb):
119
+ """Remove hooks."""
120
+ for handle in self._handles:
121
+ handle.remove()
122
+ self._handles.clear()
123
+ return False
124
+
125
+
126
+ class ActivationInjector:
127
+ """Inject steering vectors into model activations during generation.
128
+
129
+ Used to test the effect of extracted sandbagging vectors.
130
+
131
+ Example:
132
+ >>> injector = ActivationInjector(model, [vector], strength=1.5)
133
+ >>> with injector:
134
+ ... outputs = model.generate(**inputs)
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ model: nn.Module,
140
+ vectors: List["SteeringVector"],
141
+ strength: float = 1.0,
142
+ ):
143
+ """Initialize activation injector.
144
+
145
+ Args:
146
+ model: Model to inject into
147
+ vectors: List of steering vectors to inject
148
+ strength: Injection strength multiplier
149
+ """
150
+ self.model = model
151
+ self.vectors = vectors
152
+ self.strength = strength
153
+ self._handles: List[Any] = []
154
+
155
+ def _get_layers(self) -> nn.ModuleList:
156
+ """Get transformer layers."""
157
+ if hasattr(self.model, "model"):
158
+ inner = self.model.model
159
+ if hasattr(inner, "layers"):
160
+ return inner.layers
161
+ elif hasattr(inner, "decoder"):
162
+ return inner.decoder.layers
163
+ if hasattr(self.model, "transformer"):
164
+ if hasattr(self.model.transformer, "h"):
165
+ return self.model.transformer.h
166
+ raise ValueError("Could not find transformer layers")
167
+
168
+ def _make_injection_hook(self, vector: torch.Tensor):
169
+ """Create hook that adds vector to activations."""
170
+ def hook_fn(module, input, output):
171
+ if isinstance(output, tuple):
172
+ hidden = output[0]
173
+ # HACK: add to all positions, might want per-position control later
174
+ modified = hidden + self.strength * vector.to(hidden.device)
175
+ return (modified,) + output[1:]
176
+ else:
177
+ return output + self.strength * vector.to(output.device)
178
+ return hook_fn
179
+
180
+ def __enter__(self):
181
+ """Register injection hooks."""
182
+ layers = self._get_layers()
183
+
184
+ for vec in self.vectors:
185
+ layer_idx = vec.layer_index
186
+ if layer_idx >= len(layers):
187
+ continue
188
+
189
+ hook = self._make_injection_hook(vec.vector)
190
+ handle = layers[layer_idx].register_forward_hook(hook)
191
+ self._handles.append(handle)
192
+
193
+ return self
194
+
195
+ def __exit__(self, exc_type, exc_val, exc_tb):
196
+ """Remove injection hooks."""
197
+ for handle in self._handles:
198
+ handle.remove()
199
+ self._handles.clear()
200
+ return False
@@ -0,0 +1,186 @@
1
+ """Linear probes for detecting behavioral patterns in activations.
2
+
3
+ Linear probes are simple classifiers trained on activation patterns
4
+ to detect specific behaviors like sandbagging.
5
+ """
6
+
7
+ import pickle
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+
13
+ # sklearn import with fallback - don't want to make it a hard dep
14
+ try:
15
+ from sklearn.linear_model import LogisticRegression
16
+ from sklearn.model_selection import cross_val_score
17
+ SKLEARN_AVAILABLE = True
18
+ except ImportError:
19
+ SKLEARN_AVAILABLE = False
20
+ # NOTE: user will get a clear error if they try to use LinearProbe
21
+
22
+
23
+ class LinearProbe:
24
+ """Linear probe for binary classification on activation patterns.
25
+
26
+ Uses logistic regression to learn a linear decision boundary
27
+ in activation space for detecting sandbagging behavior.
28
+
29
+ Example:
30
+ >>> probe = LinearProbe()
31
+ >>> probe.fit(X_train, y_train)
32
+ >>> predictions = probe.predict(X_test)
33
+ >>> accuracy = probe.score(X_test, y_test)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ C: float = 1.0,
39
+ max_iter: int = 1000,
40
+ random_state: int = 42,
41
+ ):
42
+ """Initialize linear probe.
43
+
44
+ Args:
45
+ C: Inverse regularization strength
46
+ max_iter: Maximum iterations for optimization
47
+ random_state: Random seed for reproducibility
48
+ """
49
+ if not SKLEARN_AVAILABLE:
50
+ raise ImportError("sklearn required for LinearProbe. Install: pip install scikit-learn")
51
+
52
+ self.C = C
53
+ self.max_iter = max_iter
54
+ self.random_state = random_state
55
+ self._model: Optional[LogisticRegression] = None
56
+ self._cv_scores: Optional[np.ndarray] = None
57
+
58
+ def fit(
59
+ self,
60
+ X: np.ndarray,
61
+ y: np.ndarray,
62
+ cv_folds: int = 5,
63
+ ) -> "LinearProbe":
64
+ """Fit the probe to training data.
65
+
66
+ Args:
67
+ X: Activation vectors (n_samples, hidden_dim)
68
+ y: Binary labels (0=genuine, 1=sandbagging)
69
+ cv_folds: Number of cross-validation folds
70
+
71
+ Returns:
72
+ self
73
+ """
74
+ self._model = LogisticRegression(
75
+ penalty="l2",
76
+ C=self.C,
77
+ solver="lbfgs",
78
+ max_iter=self.max_iter,
79
+ random_state=self.random_state,
80
+ )
81
+
82
+ # Cross-validation for accuracy estimate
83
+ self._cv_scores = cross_val_score(
84
+ self._model, X, y, cv=cv_folds, scoring="accuracy"
85
+ )
86
+
87
+ # Fit on full data
88
+ self._model.fit(X, y)
89
+
90
+ return self
91
+
92
+ def predict(self, X: np.ndarray) -> np.ndarray:
93
+ """Predict labels for new activations."""
94
+ if self._model is None:
95
+ raise RuntimeError("Probe not fitted. Call fit() first.")
96
+ return self._model.predict(X)
97
+
98
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
99
+ """Get probability estimates for each class."""
100
+ if self._model is None:
101
+ raise RuntimeError("Probe not fitted. Call fit() first.")
102
+ return self._model.predict_proba(X)
103
+
104
+ def score(self, X: np.ndarray, y: np.ndarray) -> float:
105
+ """Compute accuracy on test data."""
106
+ if self._model is None:
107
+ raise RuntimeError("Probe not fitted. Call fit() first.")
108
+ return self._model.score(X, y)
109
+
110
+ @property
111
+ def cv_accuracy(self) -> float:
112
+ """Mean cross-validation accuracy."""
113
+ if self._cv_scores is None:
114
+ raise RuntimeError("Probe not fitted. Call fit() first.")
115
+ return self._cv_scores.mean()
116
+
117
+ @property
118
+ def cv_std(self) -> float:
119
+ """Standard deviation of cross-validation accuracy."""
120
+ if self._cv_scores is None:
121
+ raise RuntimeError("Probe not fitted. Call fit() first.")
122
+ return self._cv_scores.std()
123
+
124
+ @property
125
+ def coef(self) -> np.ndarray:
126
+ """Coefficients of the linear classifier (the probe direction)."""
127
+ if self._model is None:
128
+ raise RuntimeError("Probe not fitted. Call fit() first.")
129
+ return self._model.coef_[0]
130
+
131
+ def save(self, path: Path) -> None:
132
+ """Save probe to disk."""
133
+ path = Path(path)
134
+ path.parent.mkdir(parents=True, exist_ok=True)
135
+ with open(path, "wb") as f:
136
+ pickle.dump({
137
+ "model": self._model,
138
+ "cv_scores": self._cv_scores,
139
+ "C": self.C,
140
+ "max_iter": self.max_iter,
141
+ "random_state": self.random_state,
142
+ }, f)
143
+
144
+ @classmethod
145
+ def load(cls, path: Path) -> "LinearProbe":
146
+ """Load probe from disk."""
147
+ with open(path, "rb") as f:
148
+ data = pickle.load(f)
149
+
150
+ probe = cls(
151
+ C=data["C"],
152
+ max_iter=data["max_iter"],
153
+ random_state=data["random_state"],
154
+ )
155
+ probe._model = data["model"]
156
+ probe._cv_scores = data["cv_scores"]
157
+ return probe
158
+
159
+
160
+ def train_probes_across_layers(
161
+ layer_activations: Dict[int, Tuple[np.ndarray, np.ndarray]],
162
+ cv_folds: int = 5,
163
+ ) -> Dict[int, LinearProbe]:
164
+ """Train probes for multiple layers and find the best one.
165
+
166
+ Args:
167
+ layer_activations: Dict mapping layer_idx -> (X, y) arrays
168
+ cv_folds: Cross-validation folds
169
+
170
+ Returns:
171
+ Dict mapping layer_idx -> trained LinearProbe
172
+ """
173
+ probes = {}
174
+
175
+ for layer_idx, (X, y) in layer_activations.items():
176
+ probe = LinearProbe()
177
+ probe.fit(X, y, cv_folds=cv_folds)
178
+ probes[layer_idx] = probe
179
+
180
+ return probes
181
+
182
+
183
+ def find_best_layer(probes: Dict[int, LinearProbe]) -> int:
184
+ """Find the layer with best probe accuracy."""
185
+ best_layer = max(probes.keys(), key=lambda l: probes[l].cv_accuracy)
186
+ return best_layer
@@ -0,0 +1,133 @@
1
+ """Steering vector representation for activation-level interventions.
2
+
3
+ Vectors represent directions in activation space that correspond to
4
+ specific behaviors (like sandbagging vs genuine response).
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+ import torch
12
+
13
+
14
+ @dataclass
15
+ class SteeringVector:
16
+ """A vector in activation space representing a behavioral direction.
17
+
18
+ Created by computing mean(positive_activations) - mean(negative_activations)
19
+ using Contrastive Activation Addition (CAA).
20
+
21
+ Attributes:
22
+ behavior: Name of the behavior (e.g., "sandbagging")
23
+ layer_index: Which layer this vector was extracted from
24
+ vector: The actual steering vector tensor
25
+ model_name: Model used for extraction
26
+ extraction_method: Method used (typically "caa")
27
+ metadata: Additional extraction details
28
+ """
29
+
30
+ behavior: str
31
+ layer_index: int
32
+ vector: torch.Tensor
33
+ model_name: str = "unknown"
34
+ extraction_method: str = "caa"
35
+ metadata: Dict[str, Any] = field(default_factory=dict)
36
+
37
+ @property
38
+ def norm(self) -> float:
39
+ """L2 norm of the steering vector."""
40
+ return self.vector.norm().item()
41
+
42
+ @property
43
+ def dim(self) -> int:
44
+ """Dimensionality of the vector."""
45
+ return self.vector.shape[-1]
46
+
47
+ def to(self, device: str) -> "SteeringVector":
48
+ """Move vector to specified device."""
49
+ return SteeringVector(
50
+ behavior=self.behavior,
51
+ layer_index=self.layer_index,
52
+ vector=self.vector.to(device),
53
+ model_name=self.model_name,
54
+ extraction_method=self.extraction_method,
55
+ metadata=self.metadata,
56
+ )
57
+
58
+ def normalize(self) -> "SteeringVector":
59
+ """Return unit-normalized version of this vector."""
60
+ return SteeringVector(
61
+ behavior=self.behavior,
62
+ layer_index=self.layer_index,
63
+ vector=self.vector / self.norm,
64
+ model_name=self.model_name,
65
+ extraction_method=self.extraction_method,
66
+ metadata={**self.metadata, "normalized": True},
67
+ )
68
+
69
+ def save(self, path: Path) -> None:
70
+ """Save vector to disk.
71
+
72
+ Creates:
73
+ - {path}.pt: The vector tensor
74
+ - {path}_meta.json: Metadata
75
+ """
76
+ import json
77
+
78
+ path = Path(path)
79
+ path.parent.mkdir(parents=True, exist_ok=True)
80
+
81
+ # Save tensor
82
+ torch.save(self.vector, f"{path}.pt")
83
+
84
+ # Save metadata
85
+ meta = {
86
+ "behavior": self.behavior,
87
+ "layer_index": self.layer_index,
88
+ "model_name": self.model_name,
89
+ "extraction_method": self.extraction_method,
90
+ "norm": self.norm,
91
+ "dim": self.dim,
92
+ **self.metadata,
93
+ }
94
+ with open(f"{path}_meta.json", "w") as f:
95
+ json.dump(meta, f, indent=2)
96
+
97
+ @classmethod
98
+ def load(cls, path: Path) -> "SteeringVector":
99
+ """Load vector from disk."""
100
+ import json
101
+
102
+ path = Path(path)
103
+
104
+ # Load tensor
105
+ vector = torch.load(f"{path}.pt", weights_only=True)
106
+
107
+ # Load metadata
108
+ with open(f"{path}_meta.json") as f:
109
+ meta = json.load(f)
110
+
111
+ return cls(
112
+ behavior=meta["behavior"],
113
+ layer_index=meta["layer_index"],
114
+ vector=vector,
115
+ model_name=meta.get("model_name", "unknown"),
116
+ extraction_method=meta.get("extraction_method", "caa"),
117
+ metadata={k: v for k, v in meta.items()
118
+ if k not in ["behavior", "layer_index", "model_name",
119
+ "extraction_method", "norm", "dim"]},
120
+ )
121
+
122
+ def cosine_similarity(self, other: "SteeringVector") -> float:
123
+ """Compute cosine similarity with another vector."""
124
+ return torch.nn.functional.cosine_similarity(
125
+ self.vector.unsqueeze(0),
126
+ other.vector.unsqueeze(0),
127
+ ).item()
128
+
129
+ def __repr__(self) -> str:
130
+ return (
131
+ f"SteeringVector(behavior='{self.behavior}', "
132
+ f"layer={self.layer_index}, dim={self.dim}, norm={self.norm:.4f})"
133
+ )