tpcav 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tpcav/__init__.py +39 -0
- tpcav/cavs.py +334 -0
- tpcav/concepts.py +309 -0
- tpcav/helper.py +165 -0
- tpcav/logging_utils.py +10 -0
- tpcav/tpcav_model.py +427 -0
- tpcav/utils.py +601 -0
- tpcav-0.1.0.dist-info/METADATA +89 -0
- tpcav-0.1.0.dist-info/RECORD +12 -0
- tpcav-0.1.0.dist-info/WHEEL +5 -0
- tpcav-0.1.0.dist-info/licenses/LICENSE +21 -0
- tpcav-0.1.0.dist-info/top_level.txt +1 -0
tpcav/helper.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Lightweight data loading helpers for sequences and chromatin tracks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import seqchromloader as scl
|
|
11
|
+
import torch
|
|
12
|
+
from deeplift.dinuc_shuffle import dinuc_shuffle
|
|
13
|
+
from pyfaidx import Fasta
|
|
14
|
+
from seqchromloader.utils import dna2OneHot, extract_bw
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_bed_and_center(bed_file: str, window: int) -> pd.DataFrame:
|
|
18
|
+
"""
|
|
19
|
+
Load a BED file and center the regions to a fixed window size.
|
|
20
|
+
"""
|
|
21
|
+
bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
|
|
22
|
+
bed_df["center"] = ((bed_df["start"] + bed_df["end"]) // 2).astype(int)
|
|
23
|
+
bed_df["start"] = bed_df["center"] - (window // 2)
|
|
24
|
+
bed_df["end"] = bed_df["start"] + window
|
|
25
|
+
bed_df = bed_df[["chrom", "start", "end"]]
|
|
26
|
+
return bed_df
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def bed_to_fasta_iter(
|
|
30
|
+
bed_file: str, genome_fasta: str, batch_size: int
|
|
31
|
+
) -> Iterable[List[str]]:
|
|
32
|
+
"""
|
|
33
|
+
Yield sequences from a BED file as fasta strings.
|
|
34
|
+
"""
|
|
35
|
+
bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
|
|
36
|
+
yield from dataframe_to_fasta_iter(bed_df, genome_fasta, batch_size)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def dataframe_to_fasta_iter(
|
|
40
|
+
df: pd.DataFrame, genome_fasta: str, batch_size: int
|
|
41
|
+
) -> Iterable[List[str]]:
|
|
42
|
+
"""
|
|
43
|
+
Yield sequences from a DataFrame with columns [chrom, start, end].
|
|
44
|
+
"""
|
|
45
|
+
genome = Fasta(genome_fasta)
|
|
46
|
+
fasta_seqs = []
|
|
47
|
+
for row in df.itertuples(index=False):
|
|
48
|
+
seq = str(genome[row.chrom][row.start : row.end]).upper()
|
|
49
|
+
fasta_seqs.append(seq)
|
|
50
|
+
if len(fasta_seqs) == batch_size:
|
|
51
|
+
yield fasta_seqs
|
|
52
|
+
fasta_seqs = []
|
|
53
|
+
if fasta_seqs:
|
|
54
|
+
yield fasta_seqs
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class DataFrame2FastaIterator:
|
|
58
|
+
"""
|
|
59
|
+
Iterator class to yield sequences from a DataFrame with columns [chrom, start, end].
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, df: pd.DataFrame, genome_fasta: str, batch_size: int):
|
|
63
|
+
self.genome_fasta = genome_fasta
|
|
64
|
+
self.df = df
|
|
65
|
+
self.batch_size = batch_size
|
|
66
|
+
|
|
67
|
+
def __iter__(self):
|
|
68
|
+
return dataframe_to_fasta_iter(
|
|
69
|
+
self.df, genome_fasta=self.genome_fasta, batch_size=self.batch_size
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def bed_to_chrom_tracks_iter(
|
|
74
|
+
bed_file: str, genome_fasta: str, bigwigs: List[str]
|
|
75
|
+
) -> Iterable[torch.Tensor]:
|
|
76
|
+
"""
|
|
77
|
+
Yield chromatin tracks for BED regions using bigwig files.
|
|
78
|
+
Each batch is shaped [batch, num_tracks, window].
|
|
79
|
+
"""
|
|
80
|
+
bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
|
|
81
|
+
yield from dataframe_to_chrom_tracks_iter(bed_df, genome_fasta, bigwigs)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def dataframe_to_chrom_tracks_iter(
|
|
85
|
+
df: pd.DataFrame,
|
|
86
|
+
bigwigs: List[str] | None,
|
|
87
|
+
batch_size: int = 1,
|
|
88
|
+
) -> Iterable[torch.Tensor | None]:
|
|
89
|
+
"""
|
|
90
|
+
Yield chromatin tracks for regions from a DataFrame using bigwig files.
|
|
91
|
+
"""
|
|
92
|
+
if bigwigs is not None and len(bigwigs) > 0:
|
|
93
|
+
chrom_arrs = []
|
|
94
|
+
for row in df.itertuples(index=False):
|
|
95
|
+
chrom = extract_bw(
|
|
96
|
+
row.chrom, row.start, row.end, getattr(row, "strand", "+"), bigwigs
|
|
97
|
+
)
|
|
98
|
+
chrom_arrs.append(chrom)
|
|
99
|
+
if batch_size is not None and len(chrom_arrs) == batch_size:
|
|
100
|
+
yield torch.tensor(np.stack(chrom_arrs))
|
|
101
|
+
chrom_arrs = []
|
|
102
|
+
if chrom_arrs:
|
|
103
|
+
yield torch.tensor(np.stack(chrom_arrs))
|
|
104
|
+
else:
|
|
105
|
+
while True:
|
|
106
|
+
yield torch.full((batch_size, 1), torch.nan)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class DataFrame2ChromTracksIterator:
|
|
110
|
+
"""
|
|
111
|
+
Iterator class to yield chromatin tracks from a DataFrame with columns [chrom, start, end].
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
df: pd.DataFrame,
|
|
117
|
+
bigwigs: List[str] | None,
|
|
118
|
+
batch_size: int = 1,
|
|
119
|
+
):
|
|
120
|
+
self.bigwigs = bigwigs
|
|
121
|
+
self.df = df
|
|
122
|
+
self.batch_size = batch_size
|
|
123
|
+
|
|
124
|
+
def __iter__(self):
|
|
125
|
+
return dataframe_to_chrom_tracks_iter(
|
|
126
|
+
self.df,
|
|
127
|
+
bigwigs=self.bigwigs,
|
|
128
|
+
batch_size=self.batch_size,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def fasta_to_one_hot_sequences(seqs: List[str]) -> torch.Tensor:
|
|
133
|
+
"""
|
|
134
|
+
Return one-hot encoded numpy arrays [4, L] for list of fasta sequences.
|
|
135
|
+
"""
|
|
136
|
+
return torch.tensor(np.stack([dna2OneHot(seq) for seq in seqs]))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def random_regions_dataframe(
|
|
140
|
+
genome_size_file: str, window: int, n: int, seed: int = 1
|
|
141
|
+
) -> pd.DataFrame:
|
|
142
|
+
"""
|
|
143
|
+
Generate random regions as a DataFrame with columns [chrom, start, end].
|
|
144
|
+
"""
|
|
145
|
+
return scl.random_coords(gs=genome_size_file, l=window, n=n, seed=seed)[
|
|
146
|
+
["chrom", "start", "end"]
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def dinuc_shuffle_sequences(
|
|
151
|
+
seqs: Iterable[str], num_shuffles: int = 10, seed: int = 1
|
|
152
|
+
) -> List[List[str]]:
|
|
153
|
+
"""
|
|
154
|
+
For each fasta sequence, yield a list of dinucleotide-shuffled sequences.
|
|
155
|
+
"""
|
|
156
|
+
rng = np.random.RandomState(seed)
|
|
157
|
+
results = []
|
|
158
|
+
for seq in seqs:
|
|
159
|
+
shuffles = dinuc_shuffle(
|
|
160
|
+
seq,
|
|
161
|
+
num_shufs=num_shuffles,
|
|
162
|
+
rng=rng,
|
|
163
|
+
)
|
|
164
|
+
results.append(shuffles)
|
|
165
|
+
return results
|
tpcav/logging_utils.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def set_verbose(level: str = "INFO") -> None:
|
|
5
|
+
"""
|
|
6
|
+
Set logging level for the tpcav_package (and root logger).
|
|
7
|
+
"""
|
|
8
|
+
lvl = getattr(logging, level.upper(), logging.INFO)
|
|
9
|
+
logging.getLogger().setLevel(lvl)
|
|
10
|
+
logging.getLogger("tpcav").setLevel(lvl)
|
tpcav/tpcav_model.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from captum.attr import DeepLift
|
|
7
|
+
from scipy.linalg import svd
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _abs_attribution_func(multipliers, inputs, baselines):
|
|
13
|
+
"Multiplier x abs(inputs - baselines) to avoid double-sign effects."
|
|
14
|
+
# print(f"inputs: {inputs[1][:5]}")
|
|
15
|
+
# print(f"baselines: {baselines[1][:5]}")
|
|
16
|
+
# print(f"multipliers: {multipliers[0][:5]}")
|
|
17
|
+
# print(f"multipliers: {multipliers[1][:5]}")
|
|
18
|
+
return tuple(
|
|
19
|
+
(input_ - baseline).abs() * multiplier
|
|
20
|
+
for input_, baseline, multiplier in zip(inputs, baselines, multipliers)
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TPCAV(torch.nn.Module):
|
|
25
|
+
"""End-to-end PCA fitting, projection, and attribution utilities."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
model,
|
|
30
|
+
device: Optional[str] = None,
|
|
31
|
+
layer_name: Optional[str] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""
|
|
34
|
+
layer_name: optional module name to intercept activations via forward hook
|
|
35
|
+
(useful if forward_until_select_layer is not implemented).
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.model = model
|
|
39
|
+
self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
40
|
+
self.model.to(self.device)
|
|
41
|
+
self.model.eval()
|
|
42
|
+
self.fitted = False
|
|
43
|
+
self.layer_name = layer_name
|
|
44
|
+
|
|
45
|
+
def list_module_names(self) -> List[str]:
|
|
46
|
+
"""List all module names in the model for layer selection."""
|
|
47
|
+
return [name for name, _ in self.model.named_modules()]
|
|
48
|
+
|
|
49
|
+
def tpcav_state_dict(self) -> Dict:
|
|
50
|
+
"""Export TPCAV buffers."""
|
|
51
|
+
return {
|
|
52
|
+
"layer_name": self.layer_name,
|
|
53
|
+
"zscore_mean": getattr(self, "zscore_mean", None),
|
|
54
|
+
"zscore_std": getattr(self, "zscore_std", None),
|
|
55
|
+
"pca_inv": getattr(self, "pca_inv", None),
|
|
56
|
+
"orig_shape": getattr(self, "orig_shape", None),
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
def restore_tpcav_state(self, tpcav_state_dict: Dict) -> None:
|
|
60
|
+
"""Load PCA buffers from disk."""
|
|
61
|
+
self.layer_name = tpcav_state_dict["layer_name"]
|
|
62
|
+
self._set_buffer("zscore_mean", tpcav_state_dict["zscore_mean"])
|
|
63
|
+
self._set_buffer("zscore_std", tpcav_state_dict["zscore_std"])
|
|
64
|
+
self._set_buffer("pca_inv", tpcav_state_dict["pca_inv"])
|
|
65
|
+
self._set_buffer("orig_shape", tpcav_state_dict["orig_shape"])
|
|
66
|
+
self.fitted = True
|
|
67
|
+
logger.warning(
|
|
68
|
+
"Restored TPCAV state, please set model attribute!\n\n Example: self.model = Model_class()",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def fit_pca(
|
|
72
|
+
self,
|
|
73
|
+
concepts: Iterable,
|
|
74
|
+
num_samples_per_concept: int = 10,
|
|
75
|
+
num_pc: Optional[int] | str = None,
|
|
76
|
+
) -> Dict[str, torch.Tensor]:
|
|
77
|
+
"""Sample activations, compute PCA, and attach buffers to the model."""
|
|
78
|
+
sampled_avs = []
|
|
79
|
+
for concept in concepts:
|
|
80
|
+
avs = self._sample_concept(concept, num_samples=num_samples_per_concept)
|
|
81
|
+
logger.info(
|
|
82
|
+
"Sampled %s activations from concept %s", avs.shape[0], concept.name
|
|
83
|
+
)
|
|
84
|
+
sampled_avs.append(avs)
|
|
85
|
+
all_avs = torch.cat(sampled_avs)
|
|
86
|
+
orig_shape = all_avs.shape
|
|
87
|
+
flat = all_avs.flatten(start_dim=1)
|
|
88
|
+
|
|
89
|
+
mean = flat.mean(dim=0)
|
|
90
|
+
std = flat.std(dim=0)
|
|
91
|
+
std[std == 0] = -1
|
|
92
|
+
standardized = (flat - mean) / std
|
|
93
|
+
|
|
94
|
+
v_inverse = None
|
|
95
|
+
if num_pc is None or num_pc == "full":
|
|
96
|
+
_, _, v = svd(standardized, lapack_driver="gesvd", full_matrices=False)
|
|
97
|
+
v_inverse = torch.tensor(v)
|
|
98
|
+
elif int(num_pc) == 0:
|
|
99
|
+
v_inverse = None
|
|
100
|
+
else:
|
|
101
|
+
_, _, v = svd(standardized, lapack_driver="gesvd", full_matrices=False)
|
|
102
|
+
v_inverse = torch.tensor(v[: int(num_pc)])
|
|
103
|
+
|
|
104
|
+
self._set_buffer("zscore_mean", mean.to(self.device))
|
|
105
|
+
self._set_buffer("zscore_std", std.to(self.device))
|
|
106
|
+
self._set_buffer(
|
|
107
|
+
"pca_inv", v_inverse.to(self.device) if v_inverse is not None else None
|
|
108
|
+
)
|
|
109
|
+
self._set_buffer("orig_shape", torch.tensor(orig_shape).to(self.device))
|
|
110
|
+
self.fitted = True
|
|
111
|
+
|
|
112
|
+
return {
|
|
113
|
+
"zscore_mean": mean,
|
|
114
|
+
"zscore_std": std,
|
|
115
|
+
"pca_inv": v_inverse,
|
|
116
|
+
"orig_shape": torch.tensor(orig_shape),
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
def project_activations(
|
|
120
|
+
self, activations: torch.Tensor
|
|
121
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
122
|
+
"""Project flattened activations into PCA space and residual."""
|
|
123
|
+
if not self.fitted:
|
|
124
|
+
raise RuntimeError("Call fit_pca before projecting activations.")
|
|
125
|
+
|
|
126
|
+
y = activations.flatten(start_dim=1).to(self.device)
|
|
127
|
+
if self.pca_inv is not None:
|
|
128
|
+
V = self.pca_inv.T
|
|
129
|
+
zscore_mean = getattr(self, "zscore_mean", 0.0)
|
|
130
|
+
zscore_std = getattr(self, "zscore_std", 1.0)
|
|
131
|
+
y_standardized = (y - zscore_mean) / zscore_std
|
|
132
|
+
y_projected = torch.matmul(y_standardized, V)
|
|
133
|
+
y_residual = y_standardized - torch.matmul(y_projected, self.pca_inv)
|
|
134
|
+
return y_residual, y_projected
|
|
135
|
+
else:
|
|
136
|
+
return y, None
|
|
137
|
+
|
|
138
|
+
def concept_embeddings(self, concept, num_samples: int) -> torch.Tensor:
|
|
139
|
+
"""Return concatenated projected + residual activations for a concept."""
|
|
140
|
+
avs = self._sample_concept(concept, num_samples=num_samples)
|
|
141
|
+
residual, projected = self.project_activations(avs)
|
|
142
|
+
if projected is not None:
|
|
143
|
+
return torch.cat((projected, residual), dim=1)
|
|
144
|
+
return residual
|
|
145
|
+
|
|
146
|
+
def forward_from_embeddings_at_layer(
|
|
147
|
+
self,
|
|
148
|
+
avs_residual: torch.Tensor,
|
|
149
|
+
avs_projected: Optional[torch.Tensor] = None,
|
|
150
|
+
model_inputs: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
|
|
151
|
+
) -> torch.Tensor:
|
|
152
|
+
"""
|
|
153
|
+
Resume model forward by injecting activations at a named layer.
|
|
154
|
+
|
|
155
|
+
If layer_name is not set, falls back to forward_with_embeddings which
|
|
156
|
+
expects model.forward_from_projected_and_residual to exist.
|
|
157
|
+
"""
|
|
158
|
+
name = self.layer_name
|
|
159
|
+
if name is None:
|
|
160
|
+
raise ValueError("layer name must be defined")
|
|
161
|
+
if model_inputs is None:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
"model_inputs (seq, chrom) must be provided to run forward."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
y_hat = self.embedding_to_layer_activation(avs_residual, avs_projected)
|
|
167
|
+
return self.forward_patched(
|
|
168
|
+
model_inputs,
|
|
169
|
+
layer_activation=y_hat,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def layer_attributions(
|
|
173
|
+
self,
|
|
174
|
+
target_batches: Iterable,
|
|
175
|
+
baseline_batches: Iterable,
|
|
176
|
+
multiply_by_inputs: bool = True,
|
|
177
|
+
) -> Dict[str, torch.Tensor]:
|
|
178
|
+
"""Compute DeepLift attributions on PCA embedding space.
|
|
179
|
+
|
|
180
|
+
target_batches and baseline_batches should yield (seq, chrom) pairs of matching length.
|
|
181
|
+
"""
|
|
182
|
+
if not self.fitted:
|
|
183
|
+
raise RuntimeError("Call fit_pca before attributing.")
|
|
184
|
+
self.forward = self.forward_from_embeddings_at_layer
|
|
185
|
+
deeplift = DeepLift(self, multiply_by_inputs=multiply_by_inputs)
|
|
186
|
+
|
|
187
|
+
attributions = []
|
|
188
|
+
for inputs, binputs in zip(target_batches, baseline_batches):
|
|
189
|
+
avs = self._layer_output(*[i.to(self.device) for i in inputs])
|
|
190
|
+
avs_residual, avs_projected = self.project_activations(avs)
|
|
191
|
+
|
|
192
|
+
bavs = self._layer_output(*[bi.to(self.device) for bi in binputs])
|
|
193
|
+
bavs_residual, bavs_projected = self.project_activations(bavs)
|
|
194
|
+
|
|
195
|
+
# detach the projected tensor as it's connnected to the original input graph,
|
|
196
|
+
# detaching it would keep the gradients on it
|
|
197
|
+
if avs_projected is not None:
|
|
198
|
+
avs_projected = avs_projected.detach()
|
|
199
|
+
bavs_projected = bavs_projected.detach()
|
|
200
|
+
attribution = deeplift.attribute(
|
|
201
|
+
(avs_residual.to(self.device), avs_projected.to(self.device)),
|
|
202
|
+
baselines=(
|
|
203
|
+
bavs_residual.to(self.device),
|
|
204
|
+
bavs_projected.to(self.device),
|
|
205
|
+
),
|
|
206
|
+
additional_forward_args=(inputs,),
|
|
207
|
+
custom_attribution_func=(
|
|
208
|
+
None if not multiply_by_inputs else _abs_attribution_func
|
|
209
|
+
),
|
|
210
|
+
)
|
|
211
|
+
attr_residual, attr_projected = attribution
|
|
212
|
+
attribution = torch.cat((attr_projected, attr_residual), dim=1)
|
|
213
|
+
else:
|
|
214
|
+
attribution = deeplift.attribute(
|
|
215
|
+
(avs_residual.to(self.device),),
|
|
216
|
+
baselines=(bavs_residual.to(self.device),),
|
|
217
|
+
additional_forward_args=(
|
|
218
|
+
None,
|
|
219
|
+
inputs,
|
|
220
|
+
),
|
|
221
|
+
custom_attribution_func=(
|
|
222
|
+
None if not multiply_by_inputs else _abs_attribution_func
|
|
223
|
+
),
|
|
224
|
+
)[0]
|
|
225
|
+
|
|
226
|
+
attributions.append(attribution.detach().cpu())
|
|
227
|
+
|
|
228
|
+
with torch.no_grad():
|
|
229
|
+
del (
|
|
230
|
+
avs,
|
|
231
|
+
avs_projected,
|
|
232
|
+
avs_residual,
|
|
233
|
+
bavs,
|
|
234
|
+
bavs_projected,
|
|
235
|
+
bavs_residual,
|
|
236
|
+
)
|
|
237
|
+
torch.cuda.empty_cache()
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"attributions": torch.cat(attributions),
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
def input_attributions(
|
|
244
|
+
self,
|
|
245
|
+
target_batches: Iterable,
|
|
246
|
+
baseline_batches: Iterable,
|
|
247
|
+
multiply_by_inputs: bool = True,
|
|
248
|
+
cavs_list: List[torch.Tensor] | None = None,
|
|
249
|
+
) -> List[torch.Tensor]:
|
|
250
|
+
"""Compute DeepLift attributions on PCA embedding space.
|
|
251
|
+
|
|
252
|
+
target_batches and baseline_batches should yield (seq, chrom) pairs of matching length.
|
|
253
|
+
"""
|
|
254
|
+
if not self.fitted:
|
|
255
|
+
raise RuntimeError("Call fit_pca before attributing.")
|
|
256
|
+
self.forward = partial(
|
|
257
|
+
self.forward_patched_tensor,
|
|
258
|
+
layer_activation=None,
|
|
259
|
+
cavs_list=cavs_list,
|
|
260
|
+
mute_x_avs=False,
|
|
261
|
+
mute_remainder=True,
|
|
262
|
+
)
|
|
263
|
+
deeplift = DeepLift(self, multiply_by_inputs=multiply_by_inputs)
|
|
264
|
+
|
|
265
|
+
attributions = []
|
|
266
|
+
for inputs, binputs in zip(target_batches, baseline_batches):
|
|
267
|
+
attribution = deeplift.attribute(
|
|
268
|
+
tuple([i.to(self.device) for i in inputs]),
|
|
269
|
+
baselines=tuple([bi.to(self.device) for bi in binputs]),
|
|
270
|
+
)
|
|
271
|
+
attributions.append(
|
|
272
|
+
[a.detach().cpu() for a in attribution]
|
|
273
|
+
if isinstance(attribution, tuple)
|
|
274
|
+
else attribution.detach().cpu()
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
return [torch.cat(z) for z in zip(*attributions)]
|
|
278
|
+
|
|
279
|
+
def _sample_concept(self, concept, num_samples: int) -> torch.Tensor:
|
|
280
|
+
avs: List[torch.Tensor] = []
|
|
281
|
+
num = 0
|
|
282
|
+
for inputs in concept.data_iter:
|
|
283
|
+
av = self._layer_output(*[i.to(self.device) for i in inputs])
|
|
284
|
+
avs.append(av.detach().cpu())
|
|
285
|
+
num += av.shape[0]
|
|
286
|
+
if num >= num_samples:
|
|
287
|
+
break
|
|
288
|
+
if not avs:
|
|
289
|
+
raise ValueError(f"No activations gathered for concept {concept.name}")
|
|
290
|
+
return torch.cat(avs)[:num_samples]
|
|
291
|
+
|
|
292
|
+
def _layer_output(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
293
|
+
"""Return activations from the configured layer or model hook."""
|
|
294
|
+
if self.layer_name is None:
|
|
295
|
+
# No layer configured; return model output directly.
|
|
296
|
+
self.model(*inputs)
|
|
297
|
+
layer = self._resolve_layer(self.layer_name)
|
|
298
|
+
cache: List[torch.Tensor] = []
|
|
299
|
+
|
|
300
|
+
def hook_fn(_module, _inputs, output):
|
|
301
|
+
cache.append(output)
|
|
302
|
+
|
|
303
|
+
handle = layer.register_forward_hook(hook_fn)
|
|
304
|
+
try:
|
|
305
|
+
inputs = [inp.to(self.device) if inp is not None else inp for inp in inputs]
|
|
306
|
+
_ = self.model(*inputs)
|
|
307
|
+
finally:
|
|
308
|
+
handle.remove()
|
|
309
|
+
|
|
310
|
+
if not cache:
|
|
311
|
+
raise RuntimeError(f"No activation captured for layer {self.layer_name}")
|
|
312
|
+
return cache[0]
|
|
313
|
+
|
|
314
|
+
def _resolve_layer(self, name: str):
|
|
315
|
+
for module_name, module in self.model.named_modules():
|
|
316
|
+
if module_name == name:
|
|
317
|
+
return module
|
|
318
|
+
raise ValueError(f"Layer {name} not found in model.")
|
|
319
|
+
|
|
320
|
+
def embedding_to_layer_activation(
|
|
321
|
+
self, avs_residual: torch.Tensor, avs_projected: Optional[torch.Tensor]
|
|
322
|
+
) -> torch.Tensor:
|
|
323
|
+
"""
|
|
324
|
+
Combine projected/residual embeddings into the layer activation space,
|
|
325
|
+
mirroring scripts/models.py merge logic.
|
|
326
|
+
"""
|
|
327
|
+
y_hat = torch.matmul(avs_projected, self.pca_inv) + avs_residual
|
|
328
|
+
y_hat = y_hat * self.zscore_std + self.zscore_mean
|
|
329
|
+
|
|
330
|
+
return y_hat.reshape((-1, *self.orig_shape[1:]))
|
|
331
|
+
|
|
332
|
+
def forward_patched_tensor(
|
|
333
|
+
self,
|
|
334
|
+
*model_inputs: torch.Tensor,
|
|
335
|
+
layer_activation: Optional[torch.Tensor] = None,
|
|
336
|
+
cavs_list: Optional[List[torch.Tensor]] = None,
|
|
337
|
+
mute_x_avs: bool = False,
|
|
338
|
+
mute_remainder: bool = True,
|
|
339
|
+
) -> torch.Tensor:
|
|
340
|
+
return self.forward_patched(
|
|
341
|
+
model_inputs, layer_activation, cavs_list, mute_x_avs, mute_remainder
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def forward_patched(
|
|
345
|
+
self,
|
|
346
|
+
model_inputs: Tuple[torch.Tensor, Optional[torch.Tensor]],
|
|
347
|
+
layer_activation: Optional[torch.Tensor] = None,
|
|
348
|
+
cavs_list: Optional[List[torch.Tensor]] = None,
|
|
349
|
+
mute_x_avs: bool = False,
|
|
350
|
+
mute_remainder: bool = True,
|
|
351
|
+
) -> torch.Tensor:
|
|
352
|
+
"""
|
|
353
|
+
Full forward pass with optional activation replacement and/or CAV-based gradient muting.
|
|
354
|
+
"""
|
|
355
|
+
name = self.layer_name
|
|
356
|
+
if name is None:
|
|
357
|
+
raise ValueError("layer_name must be set on TPCAV to use forward_patched.")
|
|
358
|
+
layer = self._resolve_layer(name)
|
|
359
|
+
|
|
360
|
+
def hook_fn(_module, _inputs, output):
|
|
361
|
+
y = layer_activation if layer_activation is not None else output
|
|
362
|
+
if cavs_list is None or len(cavs_list) == 0:
|
|
363
|
+
return y
|
|
364
|
+
return self._disentangle_with_cavs(
|
|
365
|
+
y, cavs_list, mute_x_avs=mute_x_avs, mute_remainder=mute_remainder
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
handle = layer.register_forward_hook(hook_fn)
|
|
369
|
+
try:
|
|
370
|
+
output = self.model(*[i.to(self.device) for i in model_inputs])
|
|
371
|
+
finally:
|
|
372
|
+
handle.remove()
|
|
373
|
+
return output
|
|
374
|
+
|
|
375
|
+
def _disentangle_with_cavs(
|
|
376
|
+
self,
|
|
377
|
+
layer_output: torch.Tensor,
|
|
378
|
+
cavs_list: List[torch.Tensor],
|
|
379
|
+
mute_x_avs: bool = False,
|
|
380
|
+
mute_remainder: bool = True,
|
|
381
|
+
) -> torch.Tensor:
|
|
382
|
+
"""
|
|
383
|
+
Project activations into CAV subspace and optionally zero gradients for
|
|
384
|
+
subspaces (default mutes orthogonal/remainder gradients).
|
|
385
|
+
"""
|
|
386
|
+
y = layer_output.flatten(start_dim=1)
|
|
387
|
+
y_residual, y_projected = self.project_activations(y)
|
|
388
|
+
y_pca_all = torch.cat([y_projected, y_residual], dim=1)
|
|
389
|
+
|
|
390
|
+
cavs_matrix = torch.stack(cavs_list, dim=1).to(y.device) # [dims, #cavs]
|
|
391
|
+
|
|
392
|
+
if cavs_matrix.shape[1] > cavs_matrix.shape[0]:
|
|
393
|
+
logger.warning(
|
|
394
|
+
"CAV matrix has more columns than rows; remainder should be near zero."
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
mrank = torch.linalg.matrix_rank(cavs_matrix)
|
|
398
|
+
cavs_ortho = torch.linalg.qr(cavs_matrix, mode="reduced").Q[:, :mrank].detach()
|
|
399
|
+
if not torch.allclose(
|
|
400
|
+
cavs_ortho.T @ cavs_ortho,
|
|
401
|
+
torch.eye(mrank, device=cavs_ortho.device),
|
|
402
|
+
atol=1e-3,
|
|
403
|
+
rtol=1e-3,
|
|
404
|
+
):
|
|
405
|
+
logger.warning("Q^TQ not identity; check CAV matrix conditioning.")
|
|
406
|
+
|
|
407
|
+
y_pca_x_avs = y_pca_all @ cavs_ortho @ cavs_ortho.T
|
|
408
|
+
y_pca_remainder = y_pca_all - y_pca_x_avs
|
|
409
|
+
|
|
410
|
+
if mute_x_avs:
|
|
411
|
+
y_pca_x_avs.register_hook(lambda grad: torch.zeros_like(grad))
|
|
412
|
+
if mute_remainder:
|
|
413
|
+
y_pca_remainder.register_hook(lambda grad: torch.zeros_like(grad))
|
|
414
|
+
|
|
415
|
+
y_pca_hat = y_pca_x_avs + y_pca_remainder
|
|
416
|
+
|
|
417
|
+
dim_projected = y_projected.shape[1] if y_projected is not None else 0
|
|
418
|
+
|
|
419
|
+
return self.embedding_to_layer_activation(
|
|
420
|
+
y_pca_hat[:, dim_projected:], y_pca_hat[:, :dim_projected]
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
def _set_buffer(self, name: str, value: Optional[torch.Tensor]) -> None:
|
|
424
|
+
if hasattr(self.model, "_buffers") and name in self.model._buffers:
|
|
425
|
+
self._buffers[name] = value # type: ignore[index]
|
|
426
|
+
else:
|
|
427
|
+
self.register_buffer(name, value)
|