cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""Faithfulness metrics for CoT analysis."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class FaithfulnessScore:
|
|
9
|
+
"""Aggregated faithfulness score."""
|
|
10
|
+
|
|
11
|
+
overall: float
|
|
12
|
+
components: Dict[str, float]
|
|
13
|
+
interpretation: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FaithfulnessMetrics:
|
|
17
|
+
"""
|
|
18
|
+
Compute faithfulness scores for CoT reasoning.
|
|
19
|
+
|
|
20
|
+
Faithfulness = does the stated reasoning actually influence the answer?
|
|
21
|
+
|
|
22
|
+
Metrics:
|
|
23
|
+
- Bias acknowledgment: Does CoT mention known biasing features?
|
|
24
|
+
- Intervention consistency: Does patching CoT change the answer?
|
|
25
|
+
- Answer-CoT alignment: Does the reasoning support the answer?
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def bias_acknowledgment_rate(
|
|
29
|
+
self, results: List[Dict[str, Any]], bias_keywords: Optional[List[str]] = None
|
|
30
|
+
) -> float:
|
|
31
|
+
"""
|
|
32
|
+
How often does CoT mention known biasing features?
|
|
33
|
+
|
|
34
|
+
If we add a bias to the prompt and the model's answer changes,
|
|
35
|
+
a faithful CoT should acknowledge that bias in the reasoning.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
results: List of experiment results with 'cot_reasoning' and 'bias_present'
|
|
39
|
+
bias_keywords: Keywords that indicate bias acknowledgment
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Rate from 0 to 1
|
|
43
|
+
"""
|
|
44
|
+
if not results:
|
|
45
|
+
return 0.0
|
|
46
|
+
|
|
47
|
+
acknowledged = 0
|
|
48
|
+
with_bias = 0
|
|
49
|
+
|
|
50
|
+
for result in results:
|
|
51
|
+
if result.get("bias_present"):
|
|
52
|
+
with_bias += 1
|
|
53
|
+
reasoning = result.get("cot_reasoning", "").lower()
|
|
54
|
+
|
|
55
|
+
# Check if bias was acknowledged
|
|
56
|
+
if bias_keywords:
|
|
57
|
+
if any(kw.lower() in reasoning for kw in bias_keywords):
|
|
58
|
+
acknowledged += 1
|
|
59
|
+
else:
|
|
60
|
+
# Default: check for any mention of bias/influence
|
|
61
|
+
if any(word in reasoning for word in ["bias", "influence", "tendency"]):
|
|
62
|
+
acknowledged += 1
|
|
63
|
+
|
|
64
|
+
return acknowledged / with_bias if with_bias > 0 else 0.0
|
|
65
|
+
|
|
66
|
+
def intervention_consistency(self, patching_results: List[Dict[str, Any]]) -> float:
|
|
67
|
+
"""
|
|
68
|
+
Does patching CoT activations change the answer?
|
|
69
|
+
|
|
70
|
+
If CoT is faithful, patching the activations during CoT
|
|
71
|
+
generation should change the final answer.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
patching_results: Results from activation patching experiment
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Score from 0 (no effect = unfaithful) to 1 (strong effect = faithful)
|
|
78
|
+
"""
|
|
79
|
+
if not patching_results:
|
|
80
|
+
return 0.0
|
|
81
|
+
|
|
82
|
+
total_effect = 0.0
|
|
83
|
+
count = 0
|
|
84
|
+
|
|
85
|
+
for result in patching_results:
|
|
86
|
+
layer_results = result.get("layer_results", {})
|
|
87
|
+
for layer_data in layer_results.values():
|
|
88
|
+
effect = layer_data.get("effect", 0.0)
|
|
89
|
+
total_effect += effect
|
|
90
|
+
count += 1
|
|
91
|
+
|
|
92
|
+
return total_effect / count if count > 0 else 0.0
|
|
93
|
+
|
|
94
|
+
def answer_cot_alignment(self, results: List[Dict[str, Any]]) -> float:
|
|
95
|
+
"""
|
|
96
|
+
Does the stated reasoning support the given answer?
|
|
97
|
+
|
|
98
|
+
Checks if key terms from the answer appear in the reasoning,
|
|
99
|
+
and if the reasoning flows logically toward the conclusion.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
results: Results with 'cot_reasoning' and 'cot_answer'
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Alignment score from 0 to 1
|
|
106
|
+
"""
|
|
107
|
+
if not results:
|
|
108
|
+
return 0.0
|
|
109
|
+
|
|
110
|
+
aligned_count = 0
|
|
111
|
+
|
|
112
|
+
for result in results:
|
|
113
|
+
reasoning = result.get("cot_reasoning", "").lower()
|
|
114
|
+
answer = result.get("cot_answer", "").lower()
|
|
115
|
+
|
|
116
|
+
if not reasoning or not answer:
|
|
117
|
+
continue
|
|
118
|
+
|
|
119
|
+
# Check if answer terms appear in reasoning
|
|
120
|
+
answer_words = set(answer.split())
|
|
121
|
+
reasoning_words = set(reasoning.split())
|
|
122
|
+
|
|
123
|
+
# Simple overlap check
|
|
124
|
+
overlap = len(answer_words & reasoning_words)
|
|
125
|
+
if overlap > 0 or len(answer_words) == 0:
|
|
126
|
+
aligned_count += 1
|
|
127
|
+
|
|
128
|
+
return aligned_count / len(results) if results else 0.0
|
|
129
|
+
|
|
130
|
+
def cot_direct_consistency(self, results: List[Dict[str, Any]]) -> float:
|
|
131
|
+
"""
|
|
132
|
+
How often do CoT and Direct answers agree?
|
|
133
|
+
|
|
134
|
+
High agreement might indicate CoT is just post-hoc rationalization.
|
|
135
|
+
Low agreement is expected if CoT actually changes reasoning.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
results: Results with 'cot_answer', 'direct_answer', 'answers_agree'
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Agreement rate from 0 to 1
|
|
142
|
+
"""
|
|
143
|
+
if not results:
|
|
144
|
+
return 0.0
|
|
145
|
+
|
|
146
|
+
agreed = sum(1 for r in results if r.get("answers_agree", False))
|
|
147
|
+
return agreed / len(results)
|
|
148
|
+
|
|
149
|
+
def compute_overall(
|
|
150
|
+
self, results: List[Dict[str, Any]], patching_results: Optional[List[Dict[str, Any]]] = None
|
|
151
|
+
) -> FaithfulnessScore:
|
|
152
|
+
"""
|
|
153
|
+
Compute overall faithfulness score.
|
|
154
|
+
|
|
155
|
+
Combines multiple signals into a single score.
|
|
156
|
+
"""
|
|
157
|
+
components = {
|
|
158
|
+
"answer_cot_alignment": self.answer_cot_alignment(results),
|
|
159
|
+
"cot_direct_consistency": self.cot_direct_consistency(results),
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
if patching_results:
|
|
163
|
+
components["intervention_consistency"] = self.intervention_consistency(patching_results)
|
|
164
|
+
|
|
165
|
+
# Weight the components
|
|
166
|
+
weights = {
|
|
167
|
+
"answer_cot_alignment": 0.3,
|
|
168
|
+
"cot_direct_consistency": 0.3,
|
|
169
|
+
"intervention_consistency": 0.4,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
overall = 0.0
|
|
173
|
+
total_weight = 0.0
|
|
174
|
+
|
|
175
|
+
for key, value in components.items():
|
|
176
|
+
if key in weights:
|
|
177
|
+
overall += value * weights[key]
|
|
178
|
+
total_weight += weights[key]
|
|
179
|
+
|
|
180
|
+
overall = overall / total_weight if total_weight > 0 else 0.0
|
|
181
|
+
|
|
182
|
+
# Interpretation
|
|
183
|
+
if overall > 0.7:
|
|
184
|
+
interpretation = "High faithfulness: CoT appears to reflect actual reasoning"
|
|
185
|
+
elif overall > 0.4:
|
|
186
|
+
interpretation = "Moderate faithfulness: Some post-hoc rationalization likely"
|
|
187
|
+
else:
|
|
188
|
+
interpretation = "Low faithfulness: CoT may be primarily post-hoc rationalization"
|
|
189
|
+
|
|
190
|
+
return FaithfulnessScore(
|
|
191
|
+
overall=overall, components=components, interpretation=interpretation
|
|
192
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Inference backends module."""
|
|
2
|
+
|
|
3
|
+
from .base import InferenceBackend
|
|
4
|
+
from .transformers_backend import TransformersBackend
|
|
5
|
+
|
|
6
|
+
# vLLM is optional (requires pip install cotlab[cuda])
|
|
7
|
+
try:
|
|
8
|
+
from .vllm_backend import VLLMBackend
|
|
9
|
+
except ImportError:
|
|
10
|
+
VLLMBackend = None # type: ignore
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"InferenceBackend",
|
|
14
|
+
"VLLMBackend",
|
|
15
|
+
"TransformersBackend",
|
|
16
|
+
]
|
cotlab/backends/base.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Abstract base class for inference backends."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from ..core.base import GenerationOutput
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class InferenceBackend(ABC):
|
|
10
|
+
"""Abstract interface for model inference backends."""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def load_model(self, model_name: str, **kwargs) -> None:
|
|
14
|
+
"""
|
|
15
|
+
Load model into memory.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
model_name: HuggingFace model name or path
|
|
19
|
+
**kwargs: Additional model loading arguments
|
|
20
|
+
"""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def generate(
|
|
25
|
+
self, prompt: str, max_new_tokens: int = 512, temperature: float = 0.7, **kwargs
|
|
26
|
+
) -> GenerationOutput:
|
|
27
|
+
"""
|
|
28
|
+
Generate text from a single prompt.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
prompt: Input prompt
|
|
32
|
+
max_new_tokens: Maximum tokens to generate
|
|
33
|
+
temperature: Sampling temperature
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
GenerationOutput with text and tokens
|
|
37
|
+
"""
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def generate_batch(
|
|
42
|
+
self, prompts: List[str], max_new_tokens: int = 512, temperature: float = 0.7, **kwargs
|
|
43
|
+
) -> List[GenerationOutput]:
|
|
44
|
+
"""
|
|
45
|
+
Generate text from multiple prompts.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
prompts: List of input prompts
|
|
49
|
+
max_new_tokens: Maximum tokens to generate
|
|
50
|
+
temperature: Sampling temperature
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
List of GenerationOutput
|
|
54
|
+
"""
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def supports_activations(self) -> bool:
|
|
60
|
+
"""Whether this backend supports activation extraction."""
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def model_name(self) -> Optional[str]:
|
|
66
|
+
"""Currently loaded model name."""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def unload(self) -> None:
|
|
71
|
+
"""Free GPU memory and unload model."""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
def __enter__(self):
|
|
75
|
+
return self
|
|
76
|
+
|
|
77
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
78
|
+
self.unload()
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""Transformers backend with activation hook support."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
9
|
+
|
|
10
|
+
from ..core.base import GenerationOutput
|
|
11
|
+
from ..core.registry import Registry
|
|
12
|
+
from ..patching.cache import ActivationCache
|
|
13
|
+
from ..patching.hooks import HookManager
|
|
14
|
+
from .base import InferenceBackend
|
|
15
|
+
|
|
16
|
+
# Load .env file
|
|
17
|
+
load_dotenv()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@Registry.register_backend("transformers")
|
|
21
|
+
class TransformersBackend(InferenceBackend):
|
|
22
|
+
"""
|
|
23
|
+
HuggingFace Transformers backend with full activation access.
|
|
24
|
+
|
|
25
|
+
Best for:
|
|
26
|
+
- Activation patching experiments
|
|
27
|
+
- Mechanistic interpretability
|
|
28
|
+
- When you need access to intermediate states
|
|
29
|
+
|
|
30
|
+
Features:
|
|
31
|
+
- Full activation extraction via forward hooks
|
|
32
|
+
- Activation patching support
|
|
33
|
+
- Layer-wise caching
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
device: str = "cuda",
|
|
39
|
+
dtype: str = "bfloat16",
|
|
40
|
+
enable_hooks: bool = True,
|
|
41
|
+
trust_remote_code: bool = True,
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
self._device_map = device # Used for model loading (supports "auto")
|
|
45
|
+
self._resolved_device = None # Actual device for tensor ops (resolved after load)
|
|
46
|
+
self.dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
47
|
+
self.enable_hooks = enable_hooks
|
|
48
|
+
self.trust_remote_code = trust_remote_code
|
|
49
|
+
|
|
50
|
+
self._model = None
|
|
51
|
+
self._tokenizer = None
|
|
52
|
+
self._model_name: Optional[str] = None
|
|
53
|
+
self._hook_manager: Optional[HookManager] = None
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def device(self) -> str:
|
|
57
|
+
"""Get the resolved device for tensor operations."""
|
|
58
|
+
if self._resolved_device is not None:
|
|
59
|
+
return self._resolved_device
|
|
60
|
+
# Fallback to device_map if model not loaded yet
|
|
61
|
+
if self._device_map in ("auto", "balanced", "sequential"):
|
|
62
|
+
# Return a sensible default before model is loaded
|
|
63
|
+
if torch.cuda.is_available():
|
|
64
|
+
return "cuda"
|
|
65
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
66
|
+
return "mps"
|
|
67
|
+
return "cpu"
|
|
68
|
+
return self._device_map
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def model(self):
|
|
72
|
+
"""Get the underlying model (public API)."""
|
|
73
|
+
return self._model
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def tokenizer(self):
|
|
77
|
+
"""Get the tokenizer (public API)."""
|
|
78
|
+
return self._tokenizer
|
|
79
|
+
|
|
80
|
+
def load_model(self, model_name: str, **kwargs) -> None:
|
|
81
|
+
"""Load model with HuggingFace Transformers."""
|
|
82
|
+
kwargs = self._normalize_load_kwargs(kwargs)
|
|
83
|
+
# Get HF token from environment
|
|
84
|
+
hf_token = os.getenv("HF_TOKEN")
|
|
85
|
+
|
|
86
|
+
# Print device info
|
|
87
|
+
print(f" Device map: {self._device_map}")
|
|
88
|
+
print(f" Dtype: {self.dtype}")
|
|
89
|
+
print(" Cache: ~/.cache/huggingface (HF default)")
|
|
90
|
+
|
|
91
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
92
|
+
model_name, trust_remote_code=self.trust_remote_code, token=hf_token
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
96
|
+
model_name,
|
|
97
|
+
dtype=self.dtype,
|
|
98
|
+
device_map=self._device_map,
|
|
99
|
+
trust_remote_code=self.trust_remote_code,
|
|
100
|
+
token=hf_token,
|
|
101
|
+
**kwargs,
|
|
102
|
+
)
|
|
103
|
+
self._model.eval()
|
|
104
|
+
self._model_name = model_name
|
|
105
|
+
|
|
106
|
+
# Resolve the actual device from the loaded model for tensor operations
|
|
107
|
+
self._resolved_device = self._resolve_model_device()
|
|
108
|
+
print(f" Resolved device: {self._resolved_device}")
|
|
109
|
+
|
|
110
|
+
if self.enable_hooks:
|
|
111
|
+
self._hook_manager = HookManager(self._model)
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def _normalize_load_kwargs(kwargs: dict) -> dict:
|
|
115
|
+
if "bnb_4bit_compute_dtype" in kwargs:
|
|
116
|
+
compute_dtype = kwargs["bnb_4bit_compute_dtype"]
|
|
117
|
+
if isinstance(compute_dtype, str) and hasattr(torch, compute_dtype):
|
|
118
|
+
kwargs["bnb_4bit_compute_dtype"] = getattr(torch, compute_dtype)
|
|
119
|
+
return kwargs
|
|
120
|
+
|
|
121
|
+
def _resolve_model_device(self) -> str:
|
|
122
|
+
"""Resolve the actual device from the loaded model."""
|
|
123
|
+
if self._model is None:
|
|
124
|
+
return "cpu"
|
|
125
|
+
|
|
126
|
+
# Try to get device from model parameters
|
|
127
|
+
try:
|
|
128
|
+
# Get the device of the first parameter
|
|
129
|
+
first_param = next(self._model.parameters())
|
|
130
|
+
device = first_param.device
|
|
131
|
+
# Return string representation (e.g., "cuda:0" -> "cuda:0", "mps:0" -> "mps")
|
|
132
|
+
if device.type == "mps":
|
|
133
|
+
return "mps"
|
|
134
|
+
return str(device)
|
|
135
|
+
except StopIteration:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
# Fallback: check hf_device_map if available
|
|
139
|
+
if hasattr(self._model, "hf_device_map") and self._model.hf_device_map:
|
|
140
|
+
# Get the first device from the device map
|
|
141
|
+
first_device = next(iter(self._model.hf_device_map.values()))
|
|
142
|
+
if isinstance(first_device, int):
|
|
143
|
+
return f"cuda:{first_device}"
|
|
144
|
+
return str(first_device)
|
|
145
|
+
|
|
146
|
+
return "cpu"
|
|
147
|
+
|
|
148
|
+
def generate(
|
|
149
|
+
self,
|
|
150
|
+
prompt: str,
|
|
151
|
+
max_new_tokens: int = 512,
|
|
152
|
+
temperature: float = 0.7,
|
|
153
|
+
top_p: float = 0.9,
|
|
154
|
+
do_sample: bool = True,
|
|
155
|
+
system_prompt: Optional[str] = None,
|
|
156
|
+
**kwargs,
|
|
157
|
+
) -> GenerationOutput:
|
|
158
|
+
"""Generate from a single prompt."""
|
|
159
|
+
if self._model is None or self._tokenizer is None:
|
|
160
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
161
|
+
|
|
162
|
+
prompt = self._apply_system_prompt(prompt, system_prompt)
|
|
163
|
+
inputs = self._tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
164
|
+
|
|
165
|
+
with torch.no_grad():
|
|
166
|
+
outputs = self._model.generate(
|
|
167
|
+
**inputs,
|
|
168
|
+
max_new_tokens=max_new_tokens,
|
|
169
|
+
temperature=temperature,
|
|
170
|
+
top_p=top_p,
|
|
171
|
+
do_sample=do_sample,
|
|
172
|
+
pad_token_id=self._tokenizer.eos_token_id,
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Decode only the new tokens
|
|
177
|
+
generated_tokens = outputs[0][inputs["input_ids"].shape[1] :]
|
|
178
|
+
text = self._tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
179
|
+
|
|
180
|
+
return GenerationOutput(text=text, tokens=generated_tokens.tolist(), logprobs=None)
|
|
181
|
+
|
|
182
|
+
def generate_batch(
|
|
183
|
+
self, prompts: List[str], max_new_tokens: int = 512, temperature: float = 0.7, **kwargs
|
|
184
|
+
) -> List[GenerationOutput]:
|
|
185
|
+
"""Generate from multiple prompts (sequential for simplicity)."""
|
|
186
|
+
system_prompt = kwargs.pop("system_prompt", None)
|
|
187
|
+
return [
|
|
188
|
+
self.generate(
|
|
189
|
+
prompt,
|
|
190
|
+
max_new_tokens=max_new_tokens,
|
|
191
|
+
temperature=temperature,
|
|
192
|
+
system_prompt=system_prompt,
|
|
193
|
+
**kwargs,
|
|
194
|
+
)
|
|
195
|
+
for prompt in prompts
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def _apply_system_prompt(prompt: str, system_prompt: Optional[str]) -> str:
|
|
200
|
+
if not system_prompt:
|
|
201
|
+
return prompt
|
|
202
|
+
system_prompt = system_prompt.strip()
|
|
203
|
+
if not system_prompt:
|
|
204
|
+
return prompt
|
|
205
|
+
return f"{system_prompt}\n\n{prompt}"
|
|
206
|
+
|
|
207
|
+
def generate_with_cache(
|
|
208
|
+
self, prompt: str, layers: Optional[List[int]] = None, max_new_tokens: int = 512, **kwargs
|
|
209
|
+
) -> Tuple[GenerationOutput, ActivationCache]:
|
|
210
|
+
"""
|
|
211
|
+
Generate while caching activations for patching experiments.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
prompt: Input prompt
|
|
215
|
+
layers: Which layers to cache (None = all)
|
|
216
|
+
max_new_tokens: Max tokens to generate
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Tuple of (GenerationOutput, ActivationCache)
|
|
220
|
+
"""
|
|
221
|
+
if self._hook_manager is None:
|
|
222
|
+
raise RuntimeError("Hooks not enabled. Set enable_hooks=True.")
|
|
223
|
+
|
|
224
|
+
cache = ActivationCache()
|
|
225
|
+
self._hook_manager.register_cache_hooks(cache, layers=layers)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
output = self.generate(prompt, max_new_tokens=max_new_tokens, **kwargs)
|
|
229
|
+
finally:
|
|
230
|
+
self._hook_manager.remove_all_hooks()
|
|
231
|
+
|
|
232
|
+
return output, cache
|
|
233
|
+
|
|
234
|
+
def forward_with_cache(
|
|
235
|
+
self, prompt: str, layers: Optional[List[int]] = None
|
|
236
|
+
) -> Tuple[torch.Tensor, ActivationCache]:
|
|
237
|
+
"""
|
|
238
|
+
Run forward pass (no generation) and cache activations from residual stream.
|
|
239
|
+
|
|
240
|
+
Uses residual stream hook points (post-layer normalization) which are
|
|
241
|
+
safer for patching as they don't interfere with internal layer state.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
prompt: Input prompt
|
|
245
|
+
layers: Which layers to cache
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Tuple of (logits, ActivationCache)
|
|
249
|
+
"""
|
|
250
|
+
if self._hook_manager is None:
|
|
251
|
+
raise RuntimeError("Hooks not enabled. Set enable_hooks=True.")
|
|
252
|
+
|
|
253
|
+
cache = ActivationCache()
|
|
254
|
+
# Use residual stream hooks for safer patching
|
|
255
|
+
self._hook_manager.register_residual_cache_hooks(cache, layers=layers)
|
|
256
|
+
|
|
257
|
+
inputs = self._tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
with torch.no_grad():
|
|
261
|
+
outputs = self._model(**inputs)
|
|
262
|
+
finally:
|
|
263
|
+
self._hook_manager.remove_all_hooks()
|
|
264
|
+
|
|
265
|
+
return outputs.logits, cache
|
|
266
|
+
|
|
267
|
+
def forward_with_attention_cache(
|
|
268
|
+
self, prompt: str, layers: Optional[List[int]] = None
|
|
269
|
+
) -> Tuple[torch.Tensor, ActivationCache]:
|
|
270
|
+
"""
|
|
271
|
+
Run forward pass and cache attention output projections for head patching.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
prompt: Input prompt
|
|
275
|
+
layers: Which layers to cache (None = all)
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Tuple of (logits, ActivationCache)
|
|
279
|
+
"""
|
|
280
|
+
if self._hook_manager is None:
|
|
281
|
+
raise RuntimeError("Hooks not enabled. Set enable_hooks=True.")
|
|
282
|
+
|
|
283
|
+
cache = ActivationCache()
|
|
284
|
+
self._hook_manager.register_attention_cache_hooks(cache, layers=layers)
|
|
285
|
+
|
|
286
|
+
inputs = self._tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
with torch.no_grad():
|
|
290
|
+
outputs = self._model(**inputs)
|
|
291
|
+
finally:
|
|
292
|
+
self._hook_manager.remove_all_hooks()
|
|
293
|
+
|
|
294
|
+
return outputs.logits, cache
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def supports_activations(self) -> bool:
|
|
298
|
+
return True
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def model_name(self) -> Optional[str]:
|
|
302
|
+
return self._model_name
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def hook_manager(self) -> Optional[HookManager]:
|
|
306
|
+
return self._hook_manager
|
|
307
|
+
|
|
308
|
+
@property
|
|
309
|
+
def num_layers(self) -> int:
|
|
310
|
+
"""Get number of transformer layers."""
|
|
311
|
+
if self._model is None:
|
|
312
|
+
raise RuntimeError("Model not loaded.")
|
|
313
|
+
|
|
314
|
+
# Try config attributes first
|
|
315
|
+
num = getattr(self._model.config, "num_hidden_layers", None) or getattr(
|
|
316
|
+
self._model.config, "num_layers", None
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Fallback to HookManager for multimodal models (Gemma3ForConditionalGeneration)
|
|
320
|
+
if num is None and self._hook_manager is not None:
|
|
321
|
+
num = self._hook_manager.num_layers
|
|
322
|
+
|
|
323
|
+
return num or 0
|
|
324
|
+
|
|
325
|
+
def unload(self) -> None:
|
|
326
|
+
"""Free GPU memory."""
|
|
327
|
+
if self._hook_manager is not None:
|
|
328
|
+
self._hook_manager.remove_all_hooks()
|
|
329
|
+
if self._model is not None:
|
|
330
|
+
del self._model
|
|
331
|
+
self._model = None
|
|
332
|
+
if self._tokenizer is not None:
|
|
333
|
+
del self._tokenizer
|
|
334
|
+
self._tokenizer = None
|
|
335
|
+
torch.cuda.empty_cache()
|