tpcav 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tpcav/__init__.py +39 -0
- tpcav/cavs.py +334 -0
- tpcav/concepts.py +309 -0
- tpcav/helper.py +165 -0
- tpcav/logging_utils.py +10 -0
- tpcav/tpcav_model.py +427 -0
- tpcav/utils.py +601 -0
- tpcav-0.1.0.dist-info/METADATA +89 -0
- tpcav-0.1.0.dist-info/RECORD +12 -0
- tpcav-0.1.0.dist-info/WHEEL +5 -0
- tpcav-0.1.0.dist-info/licenses/LICENSE +21 -0
- tpcav-0.1.0.dist-info/top_level.txt +1 -0
tpcav/utils.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Copied utility helpers from scripts/utils.py so the package can run standalone.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
import Bio
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import pyfaidx
|
|
13
|
+
import seqchromloader as scl
|
|
14
|
+
import torch
|
|
15
|
+
from Bio import SeqIO
|
|
16
|
+
from pyfaidx import Fasta
|
|
17
|
+
from torch.utils.data import default_collate, get_worker_info
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def center_windows(df, window_len=1024):
|
|
23
|
+
"Get center window_len bp region of the given coordinate dataframe."
|
|
24
|
+
halfR = int(window_len / 2)
|
|
25
|
+
df = df.assign(mid=lambda x: ((x["start"] + x["end"]) / 2).astype(int)).assign(
|
|
26
|
+
start=lambda x: x["mid"] - halfR, end=lambda x: x["mid"] + halfR
|
|
27
|
+
)
|
|
28
|
+
if "strand" in df.columns:
|
|
29
|
+
return df[["chrom", "start", "end", "strand"]]
|
|
30
|
+
else:
|
|
31
|
+
return df[["chrom", "start", "end"]]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def collate_seq(batch):
|
|
35
|
+
seq, chrom, target, label = default_collate(batch)
|
|
36
|
+
return seq
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def collate_chrom(batch):
|
|
40
|
+
seq, chrom, target, label = default_collate(batch)
|
|
41
|
+
return chrom
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def collate_seqchrom(batch):
|
|
45
|
+
seq, chrom, target, label = default_collate(batch)
|
|
46
|
+
return seq, chrom
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def seq_dataloader_from_bed(
|
|
50
|
+
seq_bed, genome_fasta, window_len=1024, batch_size=8, num_workers=0
|
|
51
|
+
):
|
|
52
|
+
seq_df = pd.read_table(
|
|
53
|
+
seq_bed,
|
|
54
|
+
header=None,
|
|
55
|
+
usecols=[0, 1, 2],
|
|
56
|
+
names=["chrom", "start", "end"],
|
|
57
|
+
)
|
|
58
|
+
return seq_dataloader_from_dataframe(
|
|
59
|
+
seq_df, genome_fasta, window_len, batch_size, num_workers
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def seq_dataloader_from_dataframe(
|
|
64
|
+
seq_df, genome_fasta, window_len=1024, batch_size=8, num_workers=0
|
|
65
|
+
):
|
|
66
|
+
seq_df = center_windows(seq_df, window_len=window_len)
|
|
67
|
+
seq_df["label"] = -1
|
|
68
|
+
if "strand" not in seq_df.columns:
|
|
69
|
+
seq_df["strand"] = "+"
|
|
70
|
+
seq_df = scl.filter_chromosomes(seq_df, to_keep=Fasta(genome_fasta).keys())
|
|
71
|
+
dl = scl.SeqChromDatasetByDataFrame(
|
|
72
|
+
seq_df,
|
|
73
|
+
genome_fasta=genome_fasta,
|
|
74
|
+
dataloader_kws={
|
|
75
|
+
"batch_size": batch_size,
|
|
76
|
+
"num_workers": num_workers,
|
|
77
|
+
"collate_fn": collate_seq,
|
|
78
|
+
"drop_last": True,
|
|
79
|
+
},
|
|
80
|
+
)
|
|
81
|
+
return dl
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def seq_dataloader_from_fa(seq_fa, input_window_length=1024, batch_size=8):
|
|
85
|
+
with open(seq_fa) as handle:
|
|
86
|
+
dnaSeqs = []
|
|
87
|
+
for record in SeqIO.parse(handle, "fasta"):
|
|
88
|
+
if len(record.seq) != input_window_length:
|
|
89
|
+
raise Exception(
|
|
90
|
+
f"Sequence length {len(record.seq)} != input_window_length {input_window_length}"
|
|
91
|
+
)
|
|
92
|
+
dnaSeqs.append(torch.tensor(scl.dna2OneHot(record.seq)))
|
|
93
|
+
if len(dnaSeqs) >= batch_size:
|
|
94
|
+
yield torch.stack(dnaSeqs)
|
|
95
|
+
dnaSeqs = []
|
|
96
|
+
|
|
97
|
+
if len(dnaSeqs) > 0:
|
|
98
|
+
yield torch.stack(dnaSeqs)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def chrom_dataloader_from_bed(
|
|
102
|
+
chrom_bed, genome_fasta, input_window_length=1024, bigwigs=None, batch_size=8
|
|
103
|
+
):
|
|
104
|
+
chrom_df = pd.read_table(
|
|
105
|
+
chrom_bed,
|
|
106
|
+
header=None,
|
|
107
|
+
usecols=[0, 1, 2],
|
|
108
|
+
names=["chrom", "start", "end"],
|
|
109
|
+
)
|
|
110
|
+
return chrom_dataloader_from_dataframe(
|
|
111
|
+
chrom_df, genome_fasta, input_window_length, bigwigs or [], batch_size
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def chrom_dataloader_from_dataframe(
|
|
116
|
+
chrom_df,
|
|
117
|
+
genome_fasta,
|
|
118
|
+
input_window_length=1024,
|
|
119
|
+
bigwigs=None,
|
|
120
|
+
batch_size=8,
|
|
121
|
+
num_workers=0,
|
|
122
|
+
):
|
|
123
|
+
bigwigs = bigwigs or []
|
|
124
|
+
chrom_df = center_windows(chrom_df, window_len=input_window_length)
|
|
125
|
+
chrom_df["label"] = -1
|
|
126
|
+
if "strand" not in chrom_df.columns:
|
|
127
|
+
chrom_df["strand"] = "+"
|
|
128
|
+
chrom_df = scl.filter_chromosomes(chrom_df, to_keep=Fasta(genome_fasta).keys())
|
|
129
|
+
dl = scl.SeqChromDatasetByDataFrame(
|
|
130
|
+
chrom_df,
|
|
131
|
+
genome_fasta=genome_fasta,
|
|
132
|
+
bigwig_filelist=bigwigs,
|
|
133
|
+
dataloader_kws={
|
|
134
|
+
"batch_size": batch_size,
|
|
135
|
+
"num_workers": num_workers,
|
|
136
|
+
"collate_fn": collate_chrom,
|
|
137
|
+
"drop_last": True,
|
|
138
|
+
},
|
|
139
|
+
)
|
|
140
|
+
return dl
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class IterateSeqDataFrame(torch.utils.data.IterableDataset):
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
seq_df,
|
|
147
|
+
genome_fasta,
|
|
148
|
+
motif=None,
|
|
149
|
+
motif_mode="pwm",
|
|
150
|
+
num_motifs=128,
|
|
151
|
+
start_buffer=0,
|
|
152
|
+
end_buffer=0,
|
|
153
|
+
regions_insert=None,
|
|
154
|
+
return_region=False,
|
|
155
|
+
print_warning=True,
|
|
156
|
+
infinite=False,
|
|
157
|
+
):
|
|
158
|
+
self.seq_df = seq_df
|
|
159
|
+
self.genome_fasta = genome_fasta
|
|
160
|
+
self.motif = motif
|
|
161
|
+
self.motif_mode = motif_mode
|
|
162
|
+
self.num_motifs = num_motifs
|
|
163
|
+
self.start_buffer = start_buffer
|
|
164
|
+
self.end_buffer = end_buffer
|
|
165
|
+
self.regions_insert = regions_insert
|
|
166
|
+
self.return_region = return_region
|
|
167
|
+
self.print_warning = print_warning
|
|
168
|
+
self.infinite = infinite
|
|
169
|
+
|
|
170
|
+
def __iter__(self):
|
|
171
|
+
worker_info = get_worker_info()
|
|
172
|
+
self.genome = pyfaidx.Fasta(self.genome_fasta)
|
|
173
|
+
rng = np.random.default_rng(worker_info.id if worker_info is not None else 1)
|
|
174
|
+
|
|
175
|
+
if self.infinite:
|
|
176
|
+
while True:
|
|
177
|
+
chunk = self.seq_df.sample(frac=1.0).reset_index(drop=True)
|
|
178
|
+
try:
|
|
179
|
+
yield from iterate_seq_df_chunk(
|
|
180
|
+
chunk,
|
|
181
|
+
genome=self.genome,
|
|
182
|
+
motif=self.motif,
|
|
183
|
+
motif_mode=self.motif_mode,
|
|
184
|
+
num_motifs=self.num_motifs,
|
|
185
|
+
start_buffer=self.start_buffer,
|
|
186
|
+
end_buffer=self.end_buffer,
|
|
187
|
+
regions_insert=self.regions_insert,
|
|
188
|
+
batch_size=None,
|
|
189
|
+
return_region=self.return_region,
|
|
190
|
+
print_warning=self.print_warning,
|
|
191
|
+
rng=rng,
|
|
192
|
+
)
|
|
193
|
+
except StopIteration:
|
|
194
|
+
continue
|
|
195
|
+
else:
|
|
196
|
+
if worker_info is None:
|
|
197
|
+
chunk = self.seq_df
|
|
198
|
+
else:
|
|
199
|
+
chunk = np.array_split(self.seq_df, worker_info.num_workers)[
|
|
200
|
+
worker_info.id
|
|
201
|
+
]
|
|
202
|
+
yield from iterate_seq_df_chunk(
|
|
203
|
+
chunk,
|
|
204
|
+
genome=self.genome,
|
|
205
|
+
motif=self.motif,
|
|
206
|
+
motif_mode=self.motif_mode,
|
|
207
|
+
num_motifs=self.num_motifs,
|
|
208
|
+
start_buffer=self.start_buffer,
|
|
209
|
+
end_buffer=self.end_buffer,
|
|
210
|
+
regions_insert=self.regions_insert,
|
|
211
|
+
batch_size=None,
|
|
212
|
+
return_region=self.return_region,
|
|
213
|
+
print_warning=self.print_warning,
|
|
214
|
+
rng=rng,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def iterate_seq_df_chunk(
|
|
219
|
+
chunk,
|
|
220
|
+
genome,
|
|
221
|
+
motif=None,
|
|
222
|
+
motif_mode="pwm",
|
|
223
|
+
num_motifs=128,
|
|
224
|
+
start_buffer=0,
|
|
225
|
+
end_buffer=0,
|
|
226
|
+
regions_insert=None,
|
|
227
|
+
return_region=False,
|
|
228
|
+
batch_size=None,
|
|
229
|
+
print_warning=True,
|
|
230
|
+
rng=np.random.default_rng(1),
|
|
231
|
+
):
|
|
232
|
+
seqs = []
|
|
233
|
+
regions = []
|
|
234
|
+
for item in chunk.itertuples():
|
|
235
|
+
try:
|
|
236
|
+
seq = str(genome[item.chrom][item.start : item.end]).upper()
|
|
237
|
+
except KeyError:
|
|
238
|
+
if print_warning:
|
|
239
|
+
print(f"catch KeyError in region {item.chrom}:{item.start}-{item.end}")
|
|
240
|
+
continue
|
|
241
|
+
except pyfaidx.FetchError:
|
|
242
|
+
if print_warning:
|
|
243
|
+
print(
|
|
244
|
+
f"catch FetchError in region {item.chrom}:{item.start}-{item.end}, probably start coordinate negative"
|
|
245
|
+
)
|
|
246
|
+
continue
|
|
247
|
+
unique_chars = np.unique(list(seq))
|
|
248
|
+
if "N" in seq:
|
|
249
|
+
if print_warning:
|
|
250
|
+
print(f"Skip {item.chrom}:{item.start}-{item.end} due to containing N")
|
|
251
|
+
continue
|
|
252
|
+
elif len(unique_chars) == 0:
|
|
253
|
+
if print_warning:
|
|
254
|
+
print(
|
|
255
|
+
f"Skip region {item.chrom}:{item.start}-{item.end} due to no sequences available"
|
|
256
|
+
)
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
if motif is not None:
|
|
260
|
+
seq = insert_motif_into_seq(
|
|
261
|
+
seq,
|
|
262
|
+
motif,
|
|
263
|
+
num_motifs=num_motifs,
|
|
264
|
+
start_buffer=start_buffer,
|
|
265
|
+
end_buffer=end_buffer,
|
|
266
|
+
mode=motif_mode,
|
|
267
|
+
rng=rng,
|
|
268
|
+
)
|
|
269
|
+
elif regions_insert is not None:
|
|
270
|
+
seq = insert_region_into_seq(seq, regions_insert, genome, rng=rng)
|
|
271
|
+
|
|
272
|
+
if batch_size is None:
|
|
273
|
+
if return_region:
|
|
274
|
+
yield f"{item.chrom}:{item.start}-{item.end}", seq
|
|
275
|
+
else:
|
|
276
|
+
yield seq
|
|
277
|
+
else:
|
|
278
|
+
seqs.append(seq)
|
|
279
|
+
regions.append(f"{item.chrom}:{item.start}-{item.end}")
|
|
280
|
+
if len(seqs) >= batch_size:
|
|
281
|
+
if return_region:
|
|
282
|
+
yield regions, seqs
|
|
283
|
+
else:
|
|
284
|
+
yield seqs
|
|
285
|
+
seqs = []
|
|
286
|
+
regions = []
|
|
287
|
+
if (len(seqs) > 0) and batch_size is not None:
|
|
288
|
+
if return_region:
|
|
289
|
+
yield regions, seqs
|
|
290
|
+
else:
|
|
291
|
+
yield seqs
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def iterate_seq_df(
|
|
295
|
+
seq_df,
|
|
296
|
+
genome_fasta,
|
|
297
|
+
motif=None,
|
|
298
|
+
num_motifs=128,
|
|
299
|
+
motif_mode="pwm",
|
|
300
|
+
start_buffer=0,
|
|
301
|
+
end_buffer=0,
|
|
302
|
+
regions_insert=None,
|
|
303
|
+
batch_size=32,
|
|
304
|
+
return_region=False,
|
|
305
|
+
print_warning=True,
|
|
306
|
+
rng=np.random.default_rng(1),
|
|
307
|
+
):
|
|
308
|
+
genome = Fasta(genome_fasta)
|
|
309
|
+
yield from iterate_seq_df_chunk(
|
|
310
|
+
chunk=seq_df,
|
|
311
|
+
genome=genome,
|
|
312
|
+
motif=motif,
|
|
313
|
+
num_motifs=num_motifs,
|
|
314
|
+
motif_mode=motif_mode,
|
|
315
|
+
start_buffer=start_buffer,
|
|
316
|
+
end_buffer=end_buffer,
|
|
317
|
+
regions_insert=regions_insert,
|
|
318
|
+
return_region=return_region,
|
|
319
|
+
batch_size=batch_size,
|
|
320
|
+
print_warning=print_warning,
|
|
321
|
+
rng=rng,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def iterate_seq_bed(
|
|
326
|
+
seq_bed,
|
|
327
|
+
genome_fasta,
|
|
328
|
+
motif=None,
|
|
329
|
+
num_motifs=128,
|
|
330
|
+
motif_mode="pwm",
|
|
331
|
+
start_buffer=0,
|
|
332
|
+
end_buffer=0,
|
|
333
|
+
regions_insert=None,
|
|
334
|
+
batch_size=32,
|
|
335
|
+
print_warning=True,
|
|
336
|
+
rng=np.random.default_rng(1),
|
|
337
|
+
):
|
|
338
|
+
seq_df = pd.read_table(seq_bed, usecols=range(3), names=["chrom", "start", "end"])
|
|
339
|
+
yield from iterate_seq_df(
|
|
340
|
+
seq_df,
|
|
341
|
+
genome_fasta,
|
|
342
|
+
motif=motif,
|
|
343
|
+
motif_mode=motif_mode,
|
|
344
|
+
num_motifs=num_motifs,
|
|
345
|
+
start_buffer=start_buffer,
|
|
346
|
+
end_buffer=end_buffer,
|
|
347
|
+
regions_insert=regions_insert,
|
|
348
|
+
batch_size=batch_size,
|
|
349
|
+
print_warning=print_warning,
|
|
350
|
+
rng=rng,
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
class SeqChromConcept:
|
|
355
|
+
"Sequence + Chromatin concept given the bed files of sequences and chromatin regions."
|
|
356
|
+
|
|
357
|
+
def __init__(
|
|
358
|
+
self,
|
|
359
|
+
seq_bed,
|
|
360
|
+
seq_fa,
|
|
361
|
+
chrom_bed,
|
|
362
|
+
genome_fasta,
|
|
363
|
+
bws: list,
|
|
364
|
+
transforms=None,
|
|
365
|
+
window_len=1024,
|
|
366
|
+
batch_size=8,
|
|
367
|
+
):
|
|
368
|
+
self.seq_bed = seq_bed
|
|
369
|
+
self.seq_fa = seq_fa
|
|
370
|
+
self.chrom_bed = chrom_bed
|
|
371
|
+
self.seq_dl = None
|
|
372
|
+
self.chrom_dl = None
|
|
373
|
+
|
|
374
|
+
self.genome_fasta = genome_fasta
|
|
375
|
+
self.bws = bws
|
|
376
|
+
self.transforms = transforms
|
|
377
|
+
self.window_len = window_len
|
|
378
|
+
self.batch_size = batch_size
|
|
379
|
+
|
|
380
|
+
def seq_dataloader(self):
|
|
381
|
+
if self.seq_bed is not None:
|
|
382
|
+
seq_df = pd.read_table(
|
|
383
|
+
self.seq_bed,
|
|
384
|
+
header=None,
|
|
385
|
+
usecols=[0, 1, 2],
|
|
386
|
+
names=["chrom", "start", "end"],
|
|
387
|
+
)
|
|
388
|
+
seq_df = center_windows(seq_df, window_len=self.window_len)
|
|
389
|
+
seq_df["label"] = -1
|
|
390
|
+
if "strand" not in seq_df.columns:
|
|
391
|
+
seq_df["strand"] = "+"
|
|
392
|
+
seq_df = scl.filter_chromosomes(
|
|
393
|
+
seq_df, to_keep=Fasta(self.genome_fasta).keys()
|
|
394
|
+
)
|
|
395
|
+
dl = scl.SeqChromDatasetByDataFrame(
|
|
396
|
+
seq_df,
|
|
397
|
+
genome_fasta=self.genome_fasta,
|
|
398
|
+
dataloader_kws={
|
|
399
|
+
"batch_size": self.batch_size,
|
|
400
|
+
"num_workers": 0,
|
|
401
|
+
"collate_fn": collate_seq,
|
|
402
|
+
"drop_last": True,
|
|
403
|
+
},
|
|
404
|
+
)
|
|
405
|
+
for seq in dl:
|
|
406
|
+
if isinstance(seq, list):
|
|
407
|
+
assert len(seq) == 1
|
|
408
|
+
seq = seq[0]
|
|
409
|
+
yield seq
|
|
410
|
+
else:
|
|
411
|
+
with open(self.seq_fa) as handle:
|
|
412
|
+
dnaSeqs = []
|
|
413
|
+
for record in SeqIO.parse(handle, "fasta"):
|
|
414
|
+
dnaSeqs.append(torch.tensor(scl.dna2OneHot(record.seq)))
|
|
415
|
+
if len(dnaSeqs) >= self.batch_size:
|
|
416
|
+
yield torch.stack(dnaSeqs)
|
|
417
|
+
dnaSeqs = []
|
|
418
|
+
if len(dnaSeqs) > 0:
|
|
419
|
+
yield torch.stack(dnaSeqs)
|
|
420
|
+
|
|
421
|
+
def chrom_dataloader(self):
|
|
422
|
+
chrom_df = pd.read_table(
|
|
423
|
+
self.chrom_bed,
|
|
424
|
+
header=None,
|
|
425
|
+
usecols=[0, 1, 2],
|
|
426
|
+
names=["chrom", "start", "end"],
|
|
427
|
+
)
|
|
428
|
+
chrom_df = center_windows(chrom_df, window_len=self.window_len)
|
|
429
|
+
chrom_df["label"] = -1
|
|
430
|
+
if "strand" not in chrom_df.columns:
|
|
431
|
+
chrom_df["strand"] = "+"
|
|
432
|
+
chrom_df = scl.filter_chromosomes(
|
|
433
|
+
chrom_df, to_keep=Fasta(self.genome_fasta).keys()
|
|
434
|
+
)
|
|
435
|
+
dl = scl.SeqChromDatasetByDataFrame(
|
|
436
|
+
chrom_df,
|
|
437
|
+
genome_fasta=self.genome_fasta,
|
|
438
|
+
bigwig_filelist=self.bws,
|
|
439
|
+
transforms=self.transforms,
|
|
440
|
+
dataloader_kws={
|
|
441
|
+
"batch_size": self.batch_size,
|
|
442
|
+
"num_workers": 0,
|
|
443
|
+
"collate_fn": collate_chrom,
|
|
444
|
+
"drop_last": True,
|
|
445
|
+
},
|
|
446
|
+
)
|
|
447
|
+
yield from dl
|
|
448
|
+
|
|
449
|
+
def dataloader(self):
|
|
450
|
+
if (
|
|
451
|
+
self.seq_bed is not None or self.seq_fa is not None
|
|
452
|
+
) and self.seq_dl is None:
|
|
453
|
+
self.seq_dl = self.seq_dataloader()
|
|
454
|
+
if (self.chrom_bed is not None) and self.chrom_dl is None:
|
|
455
|
+
self.chrom_dl = self.chrom_dataloader()
|
|
456
|
+
yield from zip(self.seq_dl, self.chrom_dl)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def sample_from_pwm(motif, n_seqs=1, rng=None):
|
|
460
|
+
"""
|
|
461
|
+
Draw `n_seqs` independent sequences from a Bio.motifs.Motif PWM.
|
|
462
|
+
|
|
463
|
+
Parameters
|
|
464
|
+
----------
|
|
465
|
+
motif : Bio.motifs.Motif
|
|
466
|
+
Motif whose `.pwm` is used for sampling.
|
|
467
|
+
n_seqs : int, default 1
|
|
468
|
+
How many sequences to generate.
|
|
469
|
+
rng : numpy.random.Generator or None
|
|
470
|
+
Leave None for np.random.default_rng().
|
|
471
|
+
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
str | list[str]
|
|
475
|
+
A single string if n_seqs==1, otherwise a list of strings.
|
|
476
|
+
"""
|
|
477
|
+
rng = rng or np.random.default_rng()
|
|
478
|
+
|
|
479
|
+
# ---- 1. Build a (L, A) probability matrix --------------------------------
|
|
480
|
+
alphabet = list(motif.alphabet) # e.g. ['A', 'C', 'G', 'T']
|
|
481
|
+
L = motif.length
|
|
482
|
+
pwm_dict = motif.pwm # dict base → list(float)
|
|
483
|
+
|
|
484
|
+
# shape (L, A) with rows = positions, cols = alphabet order
|
|
485
|
+
prob_mat = np.column_stack([pwm_dict[b] for b in alphabet])
|
|
486
|
+
|
|
487
|
+
# ---- 2. Vectorised multinomial sampling -----------------------------------
|
|
488
|
+
# Draw U(0,1) numbers of shape (n_seqs, L)
|
|
489
|
+
u = rng.random((n_seqs, L))
|
|
490
|
+
|
|
491
|
+
# cumulative probabilities along alphabet axis
|
|
492
|
+
cum = np.cumsum(prob_mat, axis=1) # still (L, A)
|
|
493
|
+
|
|
494
|
+
# Broadcast cum to (n_seqs, L, A) and pick first index where cum > u
|
|
495
|
+
idx = (u[..., None] < cum).argmax(axis=2) # (n_seqs, L) int indices
|
|
496
|
+
|
|
497
|
+
# ---- 3. Convert indices back to letters -----------------------------------
|
|
498
|
+
letters = np.array(alphabet, dtype="U1")
|
|
499
|
+
seq_arr = letters[idx] # (n_seqs, L) array of chars
|
|
500
|
+
|
|
501
|
+
# Join per sequence
|
|
502
|
+
seqs = ["".join(row) for row in seq_arr]
|
|
503
|
+
return seqs[0] if n_seqs == 1 else seqs
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def insert_motif_into_seq(
|
|
507
|
+
seq,
|
|
508
|
+
motif,
|
|
509
|
+
num_motifs=3,
|
|
510
|
+
start_buffer=50,
|
|
511
|
+
end_buffer=50,
|
|
512
|
+
rng=np.random.default_rng(1),
|
|
513
|
+
mode="consensus",
|
|
514
|
+
):
|
|
515
|
+
assert mode in ["consensus", "pwm"]
|
|
516
|
+
|
|
517
|
+
seq_ins = list(deepcopy(seq))
|
|
518
|
+
pos_motif_overlap = np.ones(len(seq))
|
|
519
|
+
pos_motif_overlap[:start_buffer] = 0
|
|
520
|
+
pos_motif_overlap[(-end_buffer - len(motif)) :] = 0
|
|
521
|
+
num_insert_motifs = 0
|
|
522
|
+
for i in range(num_motifs):
|
|
523
|
+
try:
|
|
524
|
+
motif_start = rng.choice(
|
|
525
|
+
np.where(pos_motif_overlap > 0)[0]
|
|
526
|
+
).item() # randomly pick insert location
|
|
527
|
+
except ValueError:
|
|
528
|
+
# print(
|
|
529
|
+
# f"No samples can be taken for motif {motif.name}, skip inserting the rest of motifs"
|
|
530
|
+
# )
|
|
531
|
+
break
|
|
532
|
+
if isinstance(motif, PairedMotif):
|
|
533
|
+
seq_ins[motif_start : (motif_start + len(motif.motif1))] = (
|
|
534
|
+
list(motif.motif1.consensus)
|
|
535
|
+
if mode == "consensus"
|
|
536
|
+
else sample_from_pwm(motif.motif1)
|
|
537
|
+
)
|
|
538
|
+
seq_ins[
|
|
539
|
+
(motif_start + len(motif.motif1) + motif.spacing) : (
|
|
540
|
+
motif_start + len(motif)
|
|
541
|
+
)
|
|
542
|
+
] = (
|
|
543
|
+
list(motif.motif2.consensus)
|
|
544
|
+
if mode == "consensus"
|
|
545
|
+
else sample_from_pwm(motif.motif2)
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
seq_ins[motif_start : (motif_start + len(motif))] = (
|
|
549
|
+
list(motif.consensus) if mode == "consensus" else sample_from_pwm(motif)
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
pos_motif_overlap[(motif_start - len(motif)) : (motif_start + len(motif))] = 0
|
|
553
|
+
num_insert_motifs += 1
|
|
554
|
+
if num_insert_motifs < num_motifs:
|
|
555
|
+
logger.warning(
|
|
556
|
+
f"Only inserted {num_insert_motifs} out of {num_motifs} motifs for motif {motif.name}"
|
|
557
|
+
)
|
|
558
|
+
return "".join(seq_ins)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
class CustomMotif:
|
|
562
|
+
def __init__(self, name, consensus):
|
|
563
|
+
self.name = name
|
|
564
|
+
self.matrix_id = "custom"
|
|
565
|
+
self.consensus = consensus.upper()
|
|
566
|
+
self.rc = False
|
|
567
|
+
|
|
568
|
+
def __len__(self):
|
|
569
|
+
return len(self.consensus)
|
|
570
|
+
|
|
571
|
+
def reverse_complement(self):
|
|
572
|
+
self.consensus = Bio.Seq.reverse_complement(self.consensus)
|
|
573
|
+
self.name = self.name + "_rc"
|
|
574
|
+
return self
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class PairedMotif:
|
|
578
|
+
def __init__(self, motif1, motif2, spacing=0):
|
|
579
|
+
self.motif1 = motif1
|
|
580
|
+
self.motif2 = motif2
|
|
581
|
+
self.rc = False
|
|
582
|
+
self.spacing = spacing
|
|
583
|
+
self.pname = f"{self.motif1.name}_and_{self.motif2.name}"
|
|
584
|
+
self.pmatrix_id = f"{self.motif1.matrix_id}_and_{self.motif2.matrix_id}"
|
|
585
|
+
|
|
586
|
+
def reverse_complement(self):
|
|
587
|
+
self.rc = True
|
|
588
|
+
self.motif1 = self.motif1.reverse_complement()
|
|
589
|
+
self.motif2 = self.motif2.reverse_complement()
|
|
590
|
+
return self
|
|
591
|
+
|
|
592
|
+
@property
|
|
593
|
+
def name(self):
|
|
594
|
+
return self.pname
|
|
595
|
+
|
|
596
|
+
@property
|
|
597
|
+
def matrix_id(self):
|
|
598
|
+
return self.pmatrix_id
|
|
599
|
+
|
|
600
|
+
def __len__(self):
|
|
601
|
+
return len(self.motif1) + len(self.motif2) + self.spacing
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tpcav
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Testing with PCA projected Concept Activation Vectors
|
|
5
|
+
Author-email: Jianyu Yang <yztxwd@gmail.com>
|
|
6
|
+
License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
|
|
7
|
+
Project-URL: Homepage, https://github.com/seqcode/TPCAV
|
|
8
|
+
Keywords: interpretation,attribution,concept,genomics,deep learning
|
|
9
|
+
Requires-Python: >=3.8
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: torch
|
|
13
|
+
Requires-Dist: pandas
|
|
14
|
+
Requires-Dist: numpy
|
|
15
|
+
Requires-Dist: seqchromloader
|
|
16
|
+
Requires-Dist: deeplift
|
|
17
|
+
Requires-Dist: pyfaidx
|
|
18
|
+
Requires-Dist: pybedtools
|
|
19
|
+
Requires-Dist: captum
|
|
20
|
+
Requires-Dist: scikit-learn
|
|
21
|
+
Requires-Dist: biopython
|
|
22
|
+
Requires-Dist: seaborn
|
|
23
|
+
Requires-Dist: matplotlib
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# TPCAV (Testing with PCA projected Concept Activation Vectors)
|
|
27
|
+
|
|
28
|
+
Analysis pipeline for TPCAV
|
|
29
|
+
|
|
30
|
+
## Dependencies
|
|
31
|
+
|
|
32
|
+
You can use your own environment for the model, in addition, you need to install the following packages:
|
|
33
|
+
|
|
34
|
+
- captum 0.7
|
|
35
|
+
- seqchromloader 0.8.5
|
|
36
|
+
- scikit-learn 1.5.2
|
|
37
|
+
|
|
38
|
+
## Workflow
|
|
39
|
+
|
|
40
|
+
1. Since not every saved pytorch model stores the computation graph, you need to manually add functions to let the script know how to get the activations of the intermediate layer and how to proceed from there.
|
|
41
|
+
|
|
42
|
+
There are 3 places you need to insert your own code.
|
|
43
|
+
|
|
44
|
+
- Model class definition in models.py
|
|
45
|
+
- Please first copy your class definition into `Model_Class` in the script, it already has several pre-defined class functions, you need to fill in the following two functions:
|
|
46
|
+
- `forward_until_select_layer`: this is the function that takes your model input and forward until the layer you want to compute TPCAV score on
|
|
47
|
+
- `resume_forward_from_select_layer`: this is the function that starts from the activations of your select layer and forward all the way until the end
|
|
48
|
+
- There are also functions necessary for TPCAV computation, don't change them:
|
|
49
|
+
- `forward_from_start`: this function calls `forward_until_select_layer` and `resume_forward_from_select_layer` to do a full forward pass
|
|
50
|
+
- `forward_from_projected_and_residual`: this function takes the PCA projected activations and unexplained residual to do the forward pass
|
|
51
|
+
- `project_avs_to_pca`: this function takes care of the PCA projection
|
|
52
|
+
|
|
53
|
+
> NOTE: you can modify your final output tensor to specifically explain certain transformation of your output, for example, you can take weighted sum of base pair resolution signal prediction to emphasize high signal region.
|
|
54
|
+
|
|
55
|
+
- Function `load_model` in utils.py
|
|
56
|
+
- Take care of the model initialization and load saved parameters in `load_model`, return the model instance.
|
|
57
|
+
> NOTE: you need to use your own model class definition in models.py, as we need the functions defined in step 1.
|
|
58
|
+
|
|
59
|
+
- Function `seq_transform_fn` in utils.py
|
|
60
|
+
- By default the dataloader provides one hot coded DNA array of shape (batch_size, 4, len), coded in the order [A, C, G, T], if your model takes a different kind of input, modify `seq_transform_fn` to transform the input
|
|
61
|
+
|
|
62
|
+
- Function `chrom_transform_fn` in utils.py
|
|
63
|
+
- By default the dataloader provides signal array from bigwig files of shape (batch_size, # bigwigs, len), if your model takes a different kind of chromatin input, modify `chrom_transform_fn` to transform the input, if your model is sequence only, leave it to return None.
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
2. Compute CAVs on your model, example command:
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
srun -n1 -c8 --gres=gpu:1 --mem=128G python scripts/run_tcav_sgd_pca.py \
|
|
70
|
+
cavs_test 1024 data/hg19.fa data/hg19.fa.fai \
|
|
71
|
+
--meme-motifs data/motif-clustering-v2.1beta_consensus_pwms.test.meme \
|
|
72
|
+
--bed-chrom-concepts data/ENCODE_DNase_peaks.bed
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
3. Then compute the layer attributions, example command:
|
|
76
|
+
|
|
77
|
+
```bash
|
|
78
|
+
srun -n1 -c8 --gres=gpu:1 --mem=128G \
|
|
79
|
+
python scripts/compute_layer_attrs_only.py cavs_test/tpcav_model.pt \
|
|
80
|
+
data/ChIPseq.H1-hESC.MAX.conservative.all.shuf1k.narrowPeak \
|
|
81
|
+
1024 data/hg19.fa data/hg19.fa.fai cavs_test/test
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
4. run the jupyer notebook to generate summary of your results
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
papermill -f scripts/compute_tcav_v2_pwm.example.yaml scripts/compute_tcav_v2_pwm.py.ipynb cavs_test/tcav_report.py.ipynb
|
|
88
|
+
```
|
|
89
|
+
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
tpcav/__init__.py,sha256=GbO0qDy-VJjnBMZAl5TXh27znwnwEHLIadsPoWH-gY8,985
|
|
2
|
+
tpcav/cavs.py,sha256=DDe7vAUdewosU6wur5qDUp2OsR0Bg-k_8R4VXFjcheI,11587
|
|
3
|
+
tpcav/concepts.py,sha256=3HIybk5xrAru7OiOb3tBPKyWtfcfnA8DGa3DDCJXBxc,11775
|
|
4
|
+
tpcav/helper.py,sha256=qvEmvIwm-qMKa8_8z_uhWdlYotwzMFx-8EPUPSKoveg,5014
|
|
5
|
+
tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
|
|
6
|
+
tpcav/tpcav_model.py,sha256=gnM2RkBsv6mSFS2SYonziVBjHqdXoRX4cuYFmi9ITr0,16514
|
|
7
|
+
tpcav/utils.py,sha256=sftnhLqeY5ExZIvXnICY0pP27jjowSRCqtPyDi0t5Yg,18509
|
|
8
|
+
tpcav-0.1.0.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
|
|
9
|
+
tpcav-0.1.0.dist-info/METADATA,sha256=EW5LGdtqL6x6jge-oob_qgYDBuhDznxyKMkq-_YrMVA,4260
|
|
10
|
+
tpcav-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
tpcav-0.1.0.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
|
|
12
|
+
tpcav-0.1.0.dist-info/RECORD,,
|