tpcav 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tpcav/__init__.py +39 -0
- tpcav/cavs.py +334 -0
- tpcav/concepts.py +309 -0
- tpcav/helper.py +165 -0
- tpcav/logging_utils.py +10 -0
- tpcav/tpcav_model.py +427 -0
- tpcav/utils.py +601 -0
- tpcav-0.1.0.dist-info/METADATA +89 -0
- tpcav-0.1.0.dist-info/RECORD +12 -0
- tpcav-0.1.0.dist-info/WHEEL +5 -0
- tpcav-0.1.0.dist-info/licenses/LICENSE +21 -0
- tpcav-0.1.0.dist-info/top_level.txt +1 -0
tpcav/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Lightweight, reusable TCAV utilities built from the repository scripts.
|
|
3
|
+
|
|
4
|
+
The package keeps existing scripts untouched while offering programmatic
|
|
5
|
+
access to concept construction and PCA/attribution workflows.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
# Set the logging level to INFO
|
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
|
12
|
+
|
|
13
|
+
from .cavs import CavTrainer
|
|
14
|
+
from .concepts import ConceptBuilder
|
|
15
|
+
from .helper import (
|
|
16
|
+
bed_to_chrom_tracks_iter,
|
|
17
|
+
bed_to_fasta_iter,
|
|
18
|
+
dataframe_to_chrom_tracks_iter,
|
|
19
|
+
dataframe_to_fasta_iter,
|
|
20
|
+
dinuc_shuffle_sequences,
|
|
21
|
+
fasta_to_one_hot_sequences,
|
|
22
|
+
random_regions_dataframe,
|
|
23
|
+
)
|
|
24
|
+
from .logging_utils import set_verbose
|
|
25
|
+
from .tpcav_model import TPCAV
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"ConceptBuilder",
|
|
29
|
+
"CavTrainer",
|
|
30
|
+
"TPCAV",
|
|
31
|
+
"bed_to_fasta_iter",
|
|
32
|
+
"dataframe_to_fasta_iter",
|
|
33
|
+
"bed_to_chrom_tracks_iter",
|
|
34
|
+
"dataframe_to_chrom_tracks_iter",
|
|
35
|
+
"fasta_to_one_hot_sequences",
|
|
36
|
+
"random_regions_dataframe",
|
|
37
|
+
"dinuc_shuffle_sequences",
|
|
38
|
+
"set_verbose",
|
|
39
|
+
]
|
tpcav/cavs.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
CAV training and attribution utilities built on TPCAV.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import multiprocessing
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Iterable, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
import seaborn as sns
|
|
14
|
+
import torch
|
|
15
|
+
from sklearn.linear_model import SGDClassifier
|
|
16
|
+
from sklearn.metrics import precision_recall_fscore_support
|
|
17
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
18
|
+
from sklearn.model_selection import GridSearchCV
|
|
19
|
+
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
20
|
+
|
|
21
|
+
from tpcav.tpcav_model import TPCAV
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _load_all_tensors_to_numpy(dataloaders: Iterable[DataLoader]):
|
|
27
|
+
if not isinstance(dataloaders, list):
|
|
28
|
+
dataloaders = [dataloaders]
|
|
29
|
+
avs, ls = [], []
|
|
30
|
+
for dataloader in dataloaders:
|
|
31
|
+
for av, l in dataloader:
|
|
32
|
+
avs.append(av.cpu().numpy())
|
|
33
|
+
ls.append(l.cpu().numpy())
|
|
34
|
+
return np.concatenate(avs), np.concatenate(ls)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _SGDWrapper:
|
|
38
|
+
"""Lightweight SGD concept classifier."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, penalty: str = "l2", n_jobs: int = -1):
|
|
41
|
+
self.lm = SGDClassifier(
|
|
42
|
+
max_iter=1000,
|
|
43
|
+
early_stopping=True,
|
|
44
|
+
validation_fraction=0.1,
|
|
45
|
+
learning_rate="optimal",
|
|
46
|
+
n_iter_no_change=10,
|
|
47
|
+
n_jobs=n_jobs,
|
|
48
|
+
penalty=penalty,
|
|
49
|
+
)
|
|
50
|
+
if penalty == "l2":
|
|
51
|
+
params = {"alpha": [1e-2, 1e-4, 1e-6]}
|
|
52
|
+
elif penalty == "l1":
|
|
53
|
+
params = {"alpha": [1e-1, 1]}
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unexpected penalty type {penalty}")
|
|
56
|
+
self.search = GridSearchCV(self.lm, params)
|
|
57
|
+
|
|
58
|
+
def fit(self, train_dl: DataLoader, val_dl: DataLoader):
|
|
59
|
+
train_avs, train_ls = _load_all_tensors_to_numpy([train_dl, val_dl])
|
|
60
|
+
self.search.fit(train_avs, train_ls)
|
|
61
|
+
self.lm = self.search.best_estimator_
|
|
62
|
+
logger.info(
|
|
63
|
+
"Best Params: %s | Iterations: %s",
|
|
64
|
+
self.search.best_params_,
|
|
65
|
+
self.lm.n_iter_,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def weights(self) -> torch.Tensor:
|
|
70
|
+
if len(self.lm.coef_) == 1:
|
|
71
|
+
return torch.tensor(np.array([-1 * self.lm.coef_[0], self.lm.coef_[0]]))
|
|
72
|
+
return torch.tensor(self.lm.coef_)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def classes_(self):
|
|
76
|
+
return self.lm.classes_
|
|
77
|
+
|
|
78
|
+
def predict(self, x: np.ndarray) -> np.ndarray:
|
|
79
|
+
return self.lm.predict(x)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _train(
|
|
83
|
+
concept_embeddings: torch.Tensor,
|
|
84
|
+
control_embeddings: torch.Tensor,
|
|
85
|
+
output_dir: str,
|
|
86
|
+
penalty: str = "l2",
|
|
87
|
+
) -> Tuple[float, torch.Tensor]:
|
|
88
|
+
"""
|
|
89
|
+
Train a binary CAV classifier for a concept vs cached control embeddings.
|
|
90
|
+
|
|
91
|
+
Requires set_control to have been called beforehand.
|
|
92
|
+
"""
|
|
93
|
+
output_dir = Path(output_dir)
|
|
94
|
+
|
|
95
|
+
avd = TensorDataset(
|
|
96
|
+
concept_embeddings, torch.full((concept_embeddings.shape[0],), 0)
|
|
97
|
+
)
|
|
98
|
+
cvd = TensorDataset(
|
|
99
|
+
control_embeddings, torch.full((control_embeddings.shape[0],), 1)
|
|
100
|
+
)
|
|
101
|
+
train_ds, val_ds, test_ds = random_split(avd, [0.8, 0.1, 0.1])
|
|
102
|
+
c_train, c_val, c_test = random_split(cvd, [0.8, 0.1, 0.1])
|
|
103
|
+
|
|
104
|
+
train_dl = DataLoader(train_ds + c_train, batch_size=32, shuffle=True)
|
|
105
|
+
val_dl = DataLoader(val_ds + c_val, batch_size=32)
|
|
106
|
+
test_dl = DataLoader(test_ds + c_test, batch_size=32)
|
|
107
|
+
|
|
108
|
+
clf = _SGDWrapper(penalty=penalty)
|
|
109
|
+
clf.fit(train_dl, val_dl)
|
|
110
|
+
|
|
111
|
+
def _eval(split_dl: DataLoader, name: str):
|
|
112
|
+
y_preds, y_trues = [], []
|
|
113
|
+
for x, y in split_dl:
|
|
114
|
+
y_pred = clf.predict(x.cpu().numpy())
|
|
115
|
+
y_preds.append(y_pred)
|
|
116
|
+
y_trues.append(y.cpu().numpy())
|
|
117
|
+
y_preds = np.concatenate(y_preds)
|
|
118
|
+
y_trues = np.concatenate(y_trues)
|
|
119
|
+
acc = (y_preds == y_trues).sum() / len(y_trues)
|
|
120
|
+
precision, recall, fscore, support = precision_recall_fscore_support(
|
|
121
|
+
y_trues, y_preds, average="binary", pos_label=1
|
|
122
|
+
)
|
|
123
|
+
logger.info("[%s] Accuracy: %.4f", name, acc)
|
|
124
|
+
(output_dir / f"classifier_perform_on_{name}.txt").write_text(
|
|
125
|
+
f"Accuracy: {acc}\n"
|
|
126
|
+
)
|
|
127
|
+
return fscore
|
|
128
|
+
|
|
129
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
130
|
+
_eval(train_dl, "train")
|
|
131
|
+
_eval(val_dl, "val")
|
|
132
|
+
test_fscore = _eval(test_dl, "test")
|
|
133
|
+
|
|
134
|
+
weights = clf.weights
|
|
135
|
+
assert len(weights.shape) == 2 and weights.shape[0] == 2
|
|
136
|
+
torch.save(weights, output_dir / "classifier_weights.pt")
|
|
137
|
+
|
|
138
|
+
return test_fscore, weights[0]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class CavTrainer:
|
|
142
|
+
"""Train CAVs and compute attribution-driven TCAV scores."""
|
|
143
|
+
|
|
144
|
+
def __init__(self, tpcav: TPCAV, penalty: str = "l2") -> None:
|
|
145
|
+
self.tpcav = tpcav
|
|
146
|
+
self.penalty = penalty
|
|
147
|
+
self.cavs_fscores = {}
|
|
148
|
+
self.cav_weights = {}
|
|
149
|
+
self.control_embeddings: Optional[torch.Tensor] = None
|
|
150
|
+
self.cavs_list: List[torch.Tensor] = []
|
|
151
|
+
|
|
152
|
+
def save_state(self, output_path: str = "cav_trainer_state.pt"):
|
|
153
|
+
"""
|
|
154
|
+
Save CavTrainer state to a file.
|
|
155
|
+
"""
|
|
156
|
+
state = {
|
|
157
|
+
"cavs_fscores": self.cavs_fscores,
|
|
158
|
+
"cav_weights": self.cav_weights,
|
|
159
|
+
"control_embeddings": self.control_embeddings,
|
|
160
|
+
"cavs_list": self.cavs_list,
|
|
161
|
+
}
|
|
162
|
+
torch.save(state, output_path)
|
|
163
|
+
|
|
164
|
+
def restore_state(self, input_path: str = "cav_trainer_state.pt"):
|
|
165
|
+
"""
|
|
166
|
+
Restore CavTrainer state from a file.
|
|
167
|
+
"""
|
|
168
|
+
state = torch.load(input_path, map_location="cpu")
|
|
169
|
+
self.cavs_fscores = state["cavs_fscores"]
|
|
170
|
+
self.cav_weights = state["cav_weights"]
|
|
171
|
+
self.control_embeddings = state["control_embeddings"]
|
|
172
|
+
self.cavs_list = state["cavs_list"]
|
|
173
|
+
|
|
174
|
+
def set_control(self, control_concept, num_samples: int) -> torch.Tensor:
|
|
175
|
+
"""
|
|
176
|
+
Set and cache control embeddings to avoid recomputation across CAV trainings.
|
|
177
|
+
"""
|
|
178
|
+
self.control_embeddings = self.tpcav.concept_embeddings(
|
|
179
|
+
control_concept, num_samples=num_samples
|
|
180
|
+
)
|
|
181
|
+
return self.control_embeddings
|
|
182
|
+
|
|
183
|
+
def train_concepts(
|
|
184
|
+
self,
|
|
185
|
+
concept_list,
|
|
186
|
+
num_samples: int,
|
|
187
|
+
output_dir: str,
|
|
188
|
+
num_processes: int = 1,
|
|
189
|
+
):
|
|
190
|
+
if self.control_embeddings is None:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"Call set_control(control_concept, num_samples=...) before training CAVs."
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if num_processes == 1:
|
|
196
|
+
for c in concept_list:
|
|
197
|
+
concept_embeddings = self.tpcav.concept_embeddings(
|
|
198
|
+
c, num_samples=num_samples
|
|
199
|
+
)
|
|
200
|
+
fscore, weight = _train(
|
|
201
|
+
concept_embeddings,
|
|
202
|
+
self.control_embeddings,
|
|
203
|
+
Path(output_dir) / c.name,
|
|
204
|
+
self.penalty,
|
|
205
|
+
)
|
|
206
|
+
self.cavs_fscores[c.name] = fscore
|
|
207
|
+
self.cav_weights[c.name] = weight
|
|
208
|
+
self.cavs_list.append(weight)
|
|
209
|
+
else:
|
|
210
|
+
pool = multiprocessing.Pool(processes=num_processes)
|
|
211
|
+
results = []
|
|
212
|
+
for c in concept_list:
|
|
213
|
+
concept_embeddings = self.tpcav.concept_embeddings(
|
|
214
|
+
c, num_samples=num_samples
|
|
215
|
+
)
|
|
216
|
+
res = pool.apply_async(
|
|
217
|
+
_train,
|
|
218
|
+
args=(
|
|
219
|
+
concept_embeddings,
|
|
220
|
+
self.control_embeddings,
|
|
221
|
+
Path(output_dir) / c.name,
|
|
222
|
+
self.penalty,
|
|
223
|
+
),
|
|
224
|
+
)
|
|
225
|
+
logger.info("Submitted CAV training for concept %s", c.name)
|
|
226
|
+
results.append((c.name, res))
|
|
227
|
+
pool.close()
|
|
228
|
+
pool.join()
|
|
229
|
+
results = [(name, res.get()) for name, res in results]
|
|
230
|
+
for name, (fscore, weight) in results:
|
|
231
|
+
self.cavs_fscores[name] = fscore
|
|
232
|
+
self.cav_weights[name] = weight
|
|
233
|
+
self.cavs_list.append(weight)
|
|
234
|
+
|
|
235
|
+
def tpcav_score(
|
|
236
|
+
self, concept_name: str, attributions: torch.Tensor
|
|
237
|
+
) -> torch.Tensor:
|
|
238
|
+
"""
|
|
239
|
+
Compute a simple TCAV score: mean directional attribution along the concept CAV.
|
|
240
|
+
"""
|
|
241
|
+
if concept_name not in self.cav_weights:
|
|
242
|
+
raise ValueError(f"No CAV weights stored for concept {concept_name}")
|
|
243
|
+
weights = self.cav_weights[concept_name]
|
|
244
|
+
flat_attr = attributions.flatten(start_dim=1)
|
|
245
|
+
scores = torch.matmul(flat_attr, weights.to(flat_attr.device).unsqueeze(-1))
|
|
246
|
+
|
|
247
|
+
return scores
|
|
248
|
+
|
|
249
|
+
def tpcav_score_binary_log_ratio(
|
|
250
|
+
self, concept_name: str, attributions: torch.Tensor, pseudocount: float = 1.0
|
|
251
|
+
) -> float:
|
|
252
|
+
"""
|
|
253
|
+
Compute TCAV log ratio score: log2 of ratio of positive to negative directional attributions.
|
|
254
|
+
"""
|
|
255
|
+
scores = self.tpcav_score(concept_name, attributions)
|
|
256
|
+
|
|
257
|
+
pos_count = (scores > 0).sum().item()
|
|
258
|
+
neg_count = (scores < 0).sum().item()
|
|
259
|
+
|
|
260
|
+
return np.log((pos_count + pseudocount) / (neg_count + pseudocount))
|
|
261
|
+
|
|
262
|
+
def plot_cavs_similaritiy_heatmap(
|
|
263
|
+
self,
|
|
264
|
+
attributions: torch.Tensor,
|
|
265
|
+
concept_list: List[str] | None = None,
|
|
266
|
+
fscore_thresh=0.8,
|
|
267
|
+
output_path: str = "cavs_similarity_heatmap.png",
|
|
268
|
+
):
|
|
269
|
+
if concept_list is None:
|
|
270
|
+
cavs_names = list(self.cav_weights.keys())
|
|
271
|
+
else:
|
|
272
|
+
cavs_names = concept_list
|
|
273
|
+
cavs_pass = []
|
|
274
|
+
cavs_names_pass = []
|
|
275
|
+
for cname in cavs_names:
|
|
276
|
+
if self.cavs_fscores[cname] >= fscore_thresh:
|
|
277
|
+
cavs_pass.append(self.cav_weights[cname])
|
|
278
|
+
cavs_names_pass.append(cname)
|
|
279
|
+
else:
|
|
280
|
+
logger.info(
|
|
281
|
+
"Skipping CAV %s with F-score %.3f below threshold %.3f",
|
|
282
|
+
cname,
|
|
283
|
+
self.cavs_fscores[cname],
|
|
284
|
+
fscore_thresh,
|
|
285
|
+
)
|
|
286
|
+
if len(cavs_pass) == 0:
|
|
287
|
+
logger.warning(f"No CAVs passed the F-score threshold {fscore_thresh:.3f}.")
|
|
288
|
+
return
|
|
289
|
+
|
|
290
|
+
# compute similarity matrix
|
|
291
|
+
matrix_similarity = cosine_similarity(cavs_pass)
|
|
292
|
+
|
|
293
|
+
# plot
|
|
294
|
+
cm = sns.clustermap(
|
|
295
|
+
matrix_similarity,
|
|
296
|
+
xticklabels=False,
|
|
297
|
+
yticklabels=False,
|
|
298
|
+
cmap="bwr",
|
|
299
|
+
vmin=-1,
|
|
300
|
+
vmax=1,
|
|
301
|
+
)
|
|
302
|
+
cm.gs.update(left=0.05, right=0.5)
|
|
303
|
+
cm.ax_cbar.set_position([0.01, 0.9, 0.05, 0.05])
|
|
304
|
+
|
|
305
|
+
cavs_names_sorted = [
|
|
306
|
+
cavs_names_pass[i] for i in cm.dendrogram_col.reordered_ind
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
## plot log ratio plot
|
|
310
|
+
ax_log = cm.figure.add_subplot()
|
|
311
|
+
heatmap_bbox = cm.ax_heatmap.get_position()
|
|
312
|
+
ax_log.set_position([0.5, heatmap_bbox.y0, 0.2, heatmap_bbox.height])
|
|
313
|
+
# used to leave space for motif logos
|
|
314
|
+
# ax_log.tick_params(
|
|
315
|
+
# axis="y", which="major", pad=cm.figure.get_size_inches()[0] * 0.2 * 72
|
|
316
|
+
# )
|
|
317
|
+
|
|
318
|
+
log_ratios_reordered = [
|
|
319
|
+
self.tpcav_score_binary_log_ratio(cname, attributions)
|
|
320
|
+
for cname in cavs_names_sorted
|
|
321
|
+
]
|
|
322
|
+
sns.barplot(y=cavs_names_sorted, x=log_ratios_reordered, orient="y", ax=ax_log)
|
|
323
|
+
# set color of bar by value
|
|
324
|
+
for idx in range(len((ax_log.containers[0]))):
|
|
325
|
+
if ax_log.containers[0].datavalues[idx] > 0:
|
|
326
|
+
ax_log.containers[0][idx].set_color("red")
|
|
327
|
+
else:
|
|
328
|
+
ax_log.containers[0][idx].set_color("blue")
|
|
329
|
+
|
|
330
|
+
ax_log.set_xlim(left=-5, right=5)
|
|
331
|
+
ax_log.yaxis.tick_right()
|
|
332
|
+
ax_log.set_title("TCAV log ratio")
|
|
333
|
+
|
|
334
|
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
tpcav/concepts.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import seqchromloader as scl
|
|
8
|
+
import webdataset as wds
|
|
9
|
+
from Bio import motifs as Bio_motifs
|
|
10
|
+
from captum.concept import Concept
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
|
|
13
|
+
from . import helper, utils
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _PairedLoader:
|
|
19
|
+
"""Allow repeated iteration over paired dataloaders."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, seq_dl: Iterable, chrom_dl: Iterable) -> None:
|
|
22
|
+
self.seq_dl = seq_dl
|
|
23
|
+
self.chrom_dl = chrom_dl
|
|
24
|
+
self.apply_func = None
|
|
25
|
+
|
|
26
|
+
def apply(self, apply_func):
|
|
27
|
+
self.apply_func = apply_func
|
|
28
|
+
|
|
29
|
+
def __iter__(self):
|
|
30
|
+
for inputs in zip(self.seq_dl, self.chrom_dl):
|
|
31
|
+
if self.apply_func:
|
|
32
|
+
inputs = self.apply_func(*inputs)
|
|
33
|
+
yield inputs
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _construct_motif_concept_dataloader_from_control(
|
|
37
|
+
control_seq_df: pd.DataFrame,
|
|
38
|
+
genome_fasta: str,
|
|
39
|
+
motifs: Sequence,
|
|
40
|
+
num_motifs: int,
|
|
41
|
+
motif_mode: str,
|
|
42
|
+
batch_size: int,
|
|
43
|
+
num_workers: int,
|
|
44
|
+
) -> DataLoader:
|
|
45
|
+
"""Mirror the motif-based dataloader logic used in the TCAV script."""
|
|
46
|
+
datasets = []
|
|
47
|
+
for motif in motifs:
|
|
48
|
+
ds = utils.IterateSeqDataFrame(
|
|
49
|
+
control_seq_df,
|
|
50
|
+
genome_fasta,
|
|
51
|
+
motif=motif,
|
|
52
|
+
motif_mode=motif_mode,
|
|
53
|
+
num_motifs=num_motifs,
|
|
54
|
+
start_buffer=0,
|
|
55
|
+
end_buffer=0,
|
|
56
|
+
print_warning=False,
|
|
57
|
+
infinite=False,
|
|
58
|
+
)
|
|
59
|
+
datasets.append(ds)
|
|
60
|
+
|
|
61
|
+
mixed_dl = DataLoader(
|
|
62
|
+
wds.RandomMix(datasets),
|
|
63
|
+
batch_size=batch_size,
|
|
64
|
+
num_workers=num_workers,
|
|
65
|
+
pin_memory=True,
|
|
66
|
+
drop_last=False,
|
|
67
|
+
)
|
|
68
|
+
return mixed_dl
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ConceptBuilder:
|
|
72
|
+
"""Build and store concepts/control concepts in a reusable, programmatic way."""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
genome_fasta: str,
|
|
77
|
+
genome_size_file: str,
|
|
78
|
+
input_window_length: int = 1024,
|
|
79
|
+
bws: Optional[List[str]] = None,
|
|
80
|
+
batch_size: int = 8,
|
|
81
|
+
num_workers: int = 0,
|
|
82
|
+
num_motifs: int = 12,
|
|
83
|
+
include_reverse_complement: bool = False,
|
|
84
|
+
min_samples: int = 5000,
|
|
85
|
+
rng_seed: int = 1001,
|
|
86
|
+
) -> None:
|
|
87
|
+
self.genome_fasta = genome_fasta
|
|
88
|
+
self.genome_size_file = genome_size_file
|
|
89
|
+
self.input_window_length = input_window_length
|
|
90
|
+
self.bws = bws or []
|
|
91
|
+
self.batch_size = batch_size
|
|
92
|
+
self.num_workers = num_workers
|
|
93
|
+
self.num_motifs = num_motifs
|
|
94
|
+
self.include_reverse_complement = include_reverse_complement
|
|
95
|
+
self.min_samples = min_samples
|
|
96
|
+
self.rng_seed = rng_seed
|
|
97
|
+
|
|
98
|
+
self.control_regions: pd.DataFrame | None = None
|
|
99
|
+
self.control_concepts: List[Concept] = []
|
|
100
|
+
self.concepts: List[Concept] = []
|
|
101
|
+
self.metadata: Dict[str, object] = {}
|
|
102
|
+
self._next_concept_id = 0
|
|
103
|
+
self._control_seq_loader: Optional[Iterable] = None
|
|
104
|
+
self._control_chrom_loader: Optional[Iterable] = None
|
|
105
|
+
|
|
106
|
+
def build_control(self, name: str = "random_regions") -> Concept:
|
|
107
|
+
"""Create the background/control concept."""
|
|
108
|
+
control_regions = scl.random_coords(
|
|
109
|
+
gs=self.genome_size_file,
|
|
110
|
+
l=self.input_window_length,
|
|
111
|
+
n=self.min_samples,
|
|
112
|
+
)
|
|
113
|
+
control_regions["label"] = -1
|
|
114
|
+
control_regions["strand"] = "+"
|
|
115
|
+
self.control_regions = control_regions
|
|
116
|
+
|
|
117
|
+
concept = Concept(
|
|
118
|
+
id=self._reserve_id(is_control=True),
|
|
119
|
+
name=name,
|
|
120
|
+
data_iter=_PairedLoader(self._control_seq_dl(), self._control_chrom_dl()),
|
|
121
|
+
)
|
|
122
|
+
self.control_concepts = [concept]
|
|
123
|
+
self.metadata["control_regions"] = control_regions
|
|
124
|
+
return concept
|
|
125
|
+
|
|
126
|
+
def _control_seq_dl(self):
|
|
127
|
+
if self.control_regions is None:
|
|
128
|
+
raise ValueError("Call build_control before creating control regions.")
|
|
129
|
+
seq_fasta_iter = helper.DataFrame2FastaIterator(
|
|
130
|
+
self.control_regions, self.genome_fasta, batch_size=self.batch_size
|
|
131
|
+
)
|
|
132
|
+
return seq_fasta_iter
|
|
133
|
+
|
|
134
|
+
def _control_chrom_dl(self):
|
|
135
|
+
if self.control_regions is None:
|
|
136
|
+
raise ValueError("Call build_control before creating control regions.")
|
|
137
|
+
chrom_iter = helper.DataFrame2ChromTracksIterator(
|
|
138
|
+
self.control_regions,
|
|
139
|
+
self.bws,
|
|
140
|
+
batch_size=self.batch_size,
|
|
141
|
+
)
|
|
142
|
+
return chrom_iter
|
|
143
|
+
|
|
144
|
+
def add_custom_motif_concepts(
|
|
145
|
+
self, motif_table: str, control_regions: Optional[pd.DataFrame] = None
|
|
146
|
+
) -> List[Concept]:
|
|
147
|
+
"""Add concepts from a tab-delimited motif table: motif_name<TAB>consensus."""
|
|
148
|
+
if control_regions is None:
|
|
149
|
+
if not self.control_concepts:
|
|
150
|
+
raise ValueError("Call build_control or pass control_regions first.")
|
|
151
|
+
control_regions = self.metadata.get("control_regions")
|
|
152
|
+
assert control_regions is not None
|
|
153
|
+
df = pd.read_table(motif_table, names=["motif_name", "consensus_seq"])
|
|
154
|
+
added: List[Concept] = []
|
|
155
|
+
for motif_name in np.unique(df.motif_name):
|
|
156
|
+
consensus = df.loc[df.motif_name == motif_name, "consensus_seq"].tolist()
|
|
157
|
+
motifs = []
|
|
158
|
+
for idx, cons in enumerate(consensus):
|
|
159
|
+
motif = utils.CustomMotif(f"{motif_name}_{idx}", cons)
|
|
160
|
+
motifs.append(motif)
|
|
161
|
+
if self.include_reverse_complement:
|
|
162
|
+
motifs.append(motif.reverse_complement())
|
|
163
|
+
seq_dl = _construct_motif_concept_dataloader_from_control(
|
|
164
|
+
control_regions,
|
|
165
|
+
self.genome_fasta,
|
|
166
|
+
motifs=motifs,
|
|
167
|
+
num_motifs=self.num_motifs,
|
|
168
|
+
motif_mode="consensus",
|
|
169
|
+
batch_size=self.batch_size,
|
|
170
|
+
num_workers=self.num_workers,
|
|
171
|
+
)
|
|
172
|
+
concept = Concept(
|
|
173
|
+
id=self._reserve_id(),
|
|
174
|
+
name=motif_name,
|
|
175
|
+
data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
|
|
176
|
+
)
|
|
177
|
+
self.concepts.append(concept)
|
|
178
|
+
added.append(concept)
|
|
179
|
+
return added
|
|
180
|
+
|
|
181
|
+
def add_meme_motif_concepts(
|
|
182
|
+
self, meme_file: str, control_regions: Optional[pd.DataFrame] = None
|
|
183
|
+
) -> List[Concept]:
|
|
184
|
+
"""Add concepts from a MEME minimal-format motif file."""
|
|
185
|
+
if control_regions is None:
|
|
186
|
+
if not self.control_concepts:
|
|
187
|
+
raise ValueError("Call build_control or pass control_regions first.")
|
|
188
|
+
control_regions = self.metadata.get("control_regions")
|
|
189
|
+
assert control_regions is not None
|
|
190
|
+
|
|
191
|
+
added: List[Concept] = []
|
|
192
|
+
with open(meme_file) as handle:
|
|
193
|
+
for motif in Bio_motifs.parse(handle, fmt="MINIMAL"):
|
|
194
|
+
motifs = [motif]
|
|
195
|
+
if self.include_reverse_complement:
|
|
196
|
+
motifs.append(motif.reverse_complement())
|
|
197
|
+
motif_name = motif.name.replace("/", "-")
|
|
198
|
+
seq_dl = _construct_motif_concept_dataloader_from_control(
|
|
199
|
+
control_regions,
|
|
200
|
+
self.genome_fasta,
|
|
201
|
+
motifs=motifs,
|
|
202
|
+
num_motifs=self.num_motifs,
|
|
203
|
+
motif_mode="pwm",
|
|
204
|
+
batch_size=self.batch_size,
|
|
205
|
+
num_workers=self.num_workers,
|
|
206
|
+
)
|
|
207
|
+
concept = Concept(
|
|
208
|
+
id=self._reserve_id(),
|
|
209
|
+
name=motif_name,
|
|
210
|
+
data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
|
|
211
|
+
)
|
|
212
|
+
self.concepts.append(concept)
|
|
213
|
+
added.append(concept)
|
|
214
|
+
return added
|
|
215
|
+
|
|
216
|
+
def add_bed_sequence_concepts(self, bed_paths: Iterable[str]) -> List[Concept]:
|
|
217
|
+
"""Add concepts backed by BED sequences with concept_name in column 5."""
|
|
218
|
+
added: List[Concept] = []
|
|
219
|
+
for bed in bed_paths:
|
|
220
|
+
bed_df = pd.read_table(
|
|
221
|
+
bed,
|
|
222
|
+
header=None,
|
|
223
|
+
usecols=[0, 1, 2, 3, 4],
|
|
224
|
+
names=["chrom", "start", "end", "strand", "concept_name"],
|
|
225
|
+
)
|
|
226
|
+
for concept_name in bed_df.concept_name.unique():
|
|
227
|
+
concept_df = bed_df.loc[bed_df.concept_name == concept_name]
|
|
228
|
+
if len(concept_df) < self.min_samples:
|
|
229
|
+
logger.warning(
|
|
230
|
+
"Concept %s has %s samples, fewer than min_samples=%s; skipping",
|
|
231
|
+
concept_name,
|
|
232
|
+
len(concept_df),
|
|
233
|
+
self.min_samples,
|
|
234
|
+
)
|
|
235
|
+
continue
|
|
236
|
+
seq_fasta_iter = helper.dataframe_to_fasta_iter(
|
|
237
|
+
concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
|
|
238
|
+
self.genome_fasta,
|
|
239
|
+
batch_size=self.batch_size,
|
|
240
|
+
)
|
|
241
|
+
concept = Concept(
|
|
242
|
+
id=self._reserve_id(),
|
|
243
|
+
name=concept_name,
|
|
244
|
+
data_iter=_PairedLoader(seq_fasta_iter, self._control_chrom_dl()),
|
|
245
|
+
)
|
|
246
|
+
self.concepts.append(concept)
|
|
247
|
+
added.append(concept)
|
|
248
|
+
return added
|
|
249
|
+
|
|
250
|
+
def add_bed_chrom_concepts(self, bed_paths: Iterable[str]) -> List[Concept]:
|
|
251
|
+
"""Add concepts backed by chromatin signal bigwigs and BED coordinates."""
|
|
252
|
+
added: List[Concept] = []
|
|
253
|
+
for bed in bed_paths:
|
|
254
|
+
bed_df = pd.read_table(
|
|
255
|
+
bed,
|
|
256
|
+
header=None,
|
|
257
|
+
usecols=[0, 1, 2, 3, 4],
|
|
258
|
+
names=["chrom", "start", "end", "strand", "concept_name"],
|
|
259
|
+
)
|
|
260
|
+
for concept_name in bed_df.concept_name.unique():
|
|
261
|
+
concept_df = bed_df.loc[bed_df.concept_name == concept_name]
|
|
262
|
+
if len(concept_df) < self.min_samples:
|
|
263
|
+
logger.warning(
|
|
264
|
+
"Concept %s has %s samples, fewer than min_samples=%s; skipping",
|
|
265
|
+
concept_name,
|
|
266
|
+
len(concept_df),
|
|
267
|
+
self.min_samples,
|
|
268
|
+
)
|
|
269
|
+
continue
|
|
270
|
+
chrom_dl = helper.dataframe_to_chrom_tracks_iter(
|
|
271
|
+
concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
|
|
272
|
+
self.genome_fasta,
|
|
273
|
+
self.bws,
|
|
274
|
+
batch_size=self.batch_size,
|
|
275
|
+
)
|
|
276
|
+
concept = Concept(
|
|
277
|
+
id=self._reserve_id(),
|
|
278
|
+
name=concept_name,
|
|
279
|
+
data_iter=_PairedLoader(self._control_seq_dl(), chrom_dl),
|
|
280
|
+
)
|
|
281
|
+
self.concepts.append(concept)
|
|
282
|
+
added.append(concept)
|
|
283
|
+
return added
|
|
284
|
+
|
|
285
|
+
def all_concepts(self) -> List[Concept]:
|
|
286
|
+
"""Return test + control concepts."""
|
|
287
|
+
return [*self.concepts, *self.control_concepts]
|
|
288
|
+
|
|
289
|
+
def summary(self) -> Dict[str, object]:
|
|
290
|
+
"""Lightweight run metadata."""
|
|
291
|
+
return {
|
|
292
|
+
"num_concepts": len(self.concepts),
|
|
293
|
+
"num_control": len(self.control_concepts),
|
|
294
|
+
"concept_names": [c.name for c in self.concepts],
|
|
295
|
+
"control_names": [c.name for c in self.control_concepts],
|
|
296
|
+
"input_window_length": self.input_window_length,
|
|
297
|
+
"num_motifs": self.num_motifs,
|
|
298
|
+
"bigwigs": self.bws,
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
def _reserve_id(self, is_control: bool = False) -> int:
|
|
302
|
+
cid = -self._next_concept_id if is_control else self._next_concept_id
|
|
303
|
+
self._next_concept_id += 1
|
|
304
|
+
return cid
|
|
305
|
+
|
|
306
|
+
def apply_transform(self, transform):
|
|
307
|
+
"""Apply a transform function to all concepts"""
|
|
308
|
+
for c in self.all_concepts():
|
|
309
|
+
c.data_iter.apply(transform)
|