tpcav 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tpcav/utils.py ADDED
@@ -0,0 +1,601 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Copied utility helpers from scripts/utils.py so the package can run standalone.
4
+ """
5
+
6
+ import logging
7
+ from copy import deepcopy
8
+
9
+ import Bio
10
+ import numpy as np
11
+ import pandas as pd
12
+ import pyfaidx
13
+ import seqchromloader as scl
14
+ import torch
15
+ from Bio import SeqIO
16
+ from pyfaidx import Fasta
17
+ from torch.utils.data import default_collate, get_worker_info
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def center_windows(df, window_len=1024):
23
+ "Get center window_len bp region of the given coordinate dataframe."
24
+ halfR = int(window_len / 2)
25
+ df = df.assign(mid=lambda x: ((x["start"] + x["end"]) / 2).astype(int)).assign(
26
+ start=lambda x: x["mid"] - halfR, end=lambda x: x["mid"] + halfR
27
+ )
28
+ if "strand" in df.columns:
29
+ return df[["chrom", "start", "end", "strand"]]
30
+ else:
31
+ return df[["chrom", "start", "end"]]
32
+
33
+
34
+ def collate_seq(batch):
35
+ seq, chrom, target, label = default_collate(batch)
36
+ return seq
37
+
38
+
39
+ def collate_chrom(batch):
40
+ seq, chrom, target, label = default_collate(batch)
41
+ return chrom
42
+
43
+
44
+ def collate_seqchrom(batch):
45
+ seq, chrom, target, label = default_collate(batch)
46
+ return seq, chrom
47
+
48
+
49
+ def seq_dataloader_from_bed(
50
+ seq_bed, genome_fasta, window_len=1024, batch_size=8, num_workers=0
51
+ ):
52
+ seq_df = pd.read_table(
53
+ seq_bed,
54
+ header=None,
55
+ usecols=[0, 1, 2],
56
+ names=["chrom", "start", "end"],
57
+ )
58
+ return seq_dataloader_from_dataframe(
59
+ seq_df, genome_fasta, window_len, batch_size, num_workers
60
+ )
61
+
62
+
63
+ def seq_dataloader_from_dataframe(
64
+ seq_df, genome_fasta, window_len=1024, batch_size=8, num_workers=0
65
+ ):
66
+ seq_df = center_windows(seq_df, window_len=window_len)
67
+ seq_df["label"] = -1
68
+ if "strand" not in seq_df.columns:
69
+ seq_df["strand"] = "+"
70
+ seq_df = scl.filter_chromosomes(seq_df, to_keep=Fasta(genome_fasta).keys())
71
+ dl = scl.SeqChromDatasetByDataFrame(
72
+ seq_df,
73
+ genome_fasta=genome_fasta,
74
+ dataloader_kws={
75
+ "batch_size": batch_size,
76
+ "num_workers": num_workers,
77
+ "collate_fn": collate_seq,
78
+ "drop_last": True,
79
+ },
80
+ )
81
+ return dl
82
+
83
+
84
+ def seq_dataloader_from_fa(seq_fa, input_window_length=1024, batch_size=8):
85
+ with open(seq_fa) as handle:
86
+ dnaSeqs = []
87
+ for record in SeqIO.parse(handle, "fasta"):
88
+ if len(record.seq) != input_window_length:
89
+ raise Exception(
90
+ f"Sequence length {len(record.seq)} != input_window_length {input_window_length}"
91
+ )
92
+ dnaSeqs.append(torch.tensor(scl.dna2OneHot(record.seq)))
93
+ if len(dnaSeqs) >= batch_size:
94
+ yield torch.stack(dnaSeqs)
95
+ dnaSeqs = []
96
+
97
+ if len(dnaSeqs) > 0:
98
+ yield torch.stack(dnaSeqs)
99
+
100
+
101
+ def chrom_dataloader_from_bed(
102
+ chrom_bed, genome_fasta, input_window_length=1024, bigwigs=None, batch_size=8
103
+ ):
104
+ chrom_df = pd.read_table(
105
+ chrom_bed,
106
+ header=None,
107
+ usecols=[0, 1, 2],
108
+ names=["chrom", "start", "end"],
109
+ )
110
+ return chrom_dataloader_from_dataframe(
111
+ chrom_df, genome_fasta, input_window_length, bigwigs or [], batch_size
112
+ )
113
+
114
+
115
+ def chrom_dataloader_from_dataframe(
116
+ chrom_df,
117
+ genome_fasta,
118
+ input_window_length=1024,
119
+ bigwigs=None,
120
+ batch_size=8,
121
+ num_workers=0,
122
+ ):
123
+ bigwigs = bigwigs or []
124
+ chrom_df = center_windows(chrom_df, window_len=input_window_length)
125
+ chrom_df["label"] = -1
126
+ if "strand" not in chrom_df.columns:
127
+ chrom_df["strand"] = "+"
128
+ chrom_df = scl.filter_chromosomes(chrom_df, to_keep=Fasta(genome_fasta).keys())
129
+ dl = scl.SeqChromDatasetByDataFrame(
130
+ chrom_df,
131
+ genome_fasta=genome_fasta,
132
+ bigwig_filelist=bigwigs,
133
+ dataloader_kws={
134
+ "batch_size": batch_size,
135
+ "num_workers": num_workers,
136
+ "collate_fn": collate_chrom,
137
+ "drop_last": True,
138
+ },
139
+ )
140
+ return dl
141
+
142
+
143
+ class IterateSeqDataFrame(torch.utils.data.IterableDataset):
144
+ def __init__(
145
+ self,
146
+ seq_df,
147
+ genome_fasta,
148
+ motif=None,
149
+ motif_mode="pwm",
150
+ num_motifs=128,
151
+ start_buffer=0,
152
+ end_buffer=0,
153
+ regions_insert=None,
154
+ return_region=False,
155
+ print_warning=True,
156
+ infinite=False,
157
+ ):
158
+ self.seq_df = seq_df
159
+ self.genome_fasta = genome_fasta
160
+ self.motif = motif
161
+ self.motif_mode = motif_mode
162
+ self.num_motifs = num_motifs
163
+ self.start_buffer = start_buffer
164
+ self.end_buffer = end_buffer
165
+ self.regions_insert = regions_insert
166
+ self.return_region = return_region
167
+ self.print_warning = print_warning
168
+ self.infinite = infinite
169
+
170
+ def __iter__(self):
171
+ worker_info = get_worker_info()
172
+ self.genome = pyfaidx.Fasta(self.genome_fasta)
173
+ rng = np.random.default_rng(worker_info.id if worker_info is not None else 1)
174
+
175
+ if self.infinite:
176
+ while True:
177
+ chunk = self.seq_df.sample(frac=1.0).reset_index(drop=True)
178
+ try:
179
+ yield from iterate_seq_df_chunk(
180
+ chunk,
181
+ genome=self.genome,
182
+ motif=self.motif,
183
+ motif_mode=self.motif_mode,
184
+ num_motifs=self.num_motifs,
185
+ start_buffer=self.start_buffer,
186
+ end_buffer=self.end_buffer,
187
+ regions_insert=self.regions_insert,
188
+ batch_size=None,
189
+ return_region=self.return_region,
190
+ print_warning=self.print_warning,
191
+ rng=rng,
192
+ )
193
+ except StopIteration:
194
+ continue
195
+ else:
196
+ if worker_info is None:
197
+ chunk = self.seq_df
198
+ else:
199
+ chunk = np.array_split(self.seq_df, worker_info.num_workers)[
200
+ worker_info.id
201
+ ]
202
+ yield from iterate_seq_df_chunk(
203
+ chunk,
204
+ genome=self.genome,
205
+ motif=self.motif,
206
+ motif_mode=self.motif_mode,
207
+ num_motifs=self.num_motifs,
208
+ start_buffer=self.start_buffer,
209
+ end_buffer=self.end_buffer,
210
+ regions_insert=self.regions_insert,
211
+ batch_size=None,
212
+ return_region=self.return_region,
213
+ print_warning=self.print_warning,
214
+ rng=rng,
215
+ )
216
+
217
+
218
+ def iterate_seq_df_chunk(
219
+ chunk,
220
+ genome,
221
+ motif=None,
222
+ motif_mode="pwm",
223
+ num_motifs=128,
224
+ start_buffer=0,
225
+ end_buffer=0,
226
+ regions_insert=None,
227
+ return_region=False,
228
+ batch_size=None,
229
+ print_warning=True,
230
+ rng=np.random.default_rng(1),
231
+ ):
232
+ seqs = []
233
+ regions = []
234
+ for item in chunk.itertuples():
235
+ try:
236
+ seq = str(genome[item.chrom][item.start : item.end]).upper()
237
+ except KeyError:
238
+ if print_warning:
239
+ print(f"catch KeyError in region {item.chrom}:{item.start}-{item.end}")
240
+ continue
241
+ except pyfaidx.FetchError:
242
+ if print_warning:
243
+ print(
244
+ f"catch FetchError in region {item.chrom}:{item.start}-{item.end}, probably start coordinate negative"
245
+ )
246
+ continue
247
+ unique_chars = np.unique(list(seq))
248
+ if "N" in seq:
249
+ if print_warning:
250
+ print(f"Skip {item.chrom}:{item.start}-{item.end} due to containing N")
251
+ continue
252
+ elif len(unique_chars) == 0:
253
+ if print_warning:
254
+ print(
255
+ f"Skip region {item.chrom}:{item.start}-{item.end} due to no sequences available"
256
+ )
257
+ continue
258
+
259
+ if motif is not None:
260
+ seq = insert_motif_into_seq(
261
+ seq,
262
+ motif,
263
+ num_motifs=num_motifs,
264
+ start_buffer=start_buffer,
265
+ end_buffer=end_buffer,
266
+ mode=motif_mode,
267
+ rng=rng,
268
+ )
269
+ elif regions_insert is not None:
270
+ seq = insert_region_into_seq(seq, regions_insert, genome, rng=rng)
271
+
272
+ if batch_size is None:
273
+ if return_region:
274
+ yield f"{item.chrom}:{item.start}-{item.end}", seq
275
+ else:
276
+ yield seq
277
+ else:
278
+ seqs.append(seq)
279
+ regions.append(f"{item.chrom}:{item.start}-{item.end}")
280
+ if len(seqs) >= batch_size:
281
+ if return_region:
282
+ yield regions, seqs
283
+ else:
284
+ yield seqs
285
+ seqs = []
286
+ regions = []
287
+ if (len(seqs) > 0) and batch_size is not None:
288
+ if return_region:
289
+ yield regions, seqs
290
+ else:
291
+ yield seqs
292
+
293
+
294
+ def iterate_seq_df(
295
+ seq_df,
296
+ genome_fasta,
297
+ motif=None,
298
+ num_motifs=128,
299
+ motif_mode="pwm",
300
+ start_buffer=0,
301
+ end_buffer=0,
302
+ regions_insert=None,
303
+ batch_size=32,
304
+ return_region=False,
305
+ print_warning=True,
306
+ rng=np.random.default_rng(1),
307
+ ):
308
+ genome = Fasta(genome_fasta)
309
+ yield from iterate_seq_df_chunk(
310
+ chunk=seq_df,
311
+ genome=genome,
312
+ motif=motif,
313
+ num_motifs=num_motifs,
314
+ motif_mode=motif_mode,
315
+ start_buffer=start_buffer,
316
+ end_buffer=end_buffer,
317
+ regions_insert=regions_insert,
318
+ return_region=return_region,
319
+ batch_size=batch_size,
320
+ print_warning=print_warning,
321
+ rng=rng,
322
+ )
323
+
324
+
325
+ def iterate_seq_bed(
326
+ seq_bed,
327
+ genome_fasta,
328
+ motif=None,
329
+ num_motifs=128,
330
+ motif_mode="pwm",
331
+ start_buffer=0,
332
+ end_buffer=0,
333
+ regions_insert=None,
334
+ batch_size=32,
335
+ print_warning=True,
336
+ rng=np.random.default_rng(1),
337
+ ):
338
+ seq_df = pd.read_table(seq_bed, usecols=range(3), names=["chrom", "start", "end"])
339
+ yield from iterate_seq_df(
340
+ seq_df,
341
+ genome_fasta,
342
+ motif=motif,
343
+ motif_mode=motif_mode,
344
+ num_motifs=num_motifs,
345
+ start_buffer=start_buffer,
346
+ end_buffer=end_buffer,
347
+ regions_insert=regions_insert,
348
+ batch_size=batch_size,
349
+ print_warning=print_warning,
350
+ rng=rng,
351
+ )
352
+
353
+
354
+ class SeqChromConcept:
355
+ "Sequence + Chromatin concept given the bed files of sequences and chromatin regions."
356
+
357
+ def __init__(
358
+ self,
359
+ seq_bed,
360
+ seq_fa,
361
+ chrom_bed,
362
+ genome_fasta,
363
+ bws: list,
364
+ transforms=None,
365
+ window_len=1024,
366
+ batch_size=8,
367
+ ):
368
+ self.seq_bed = seq_bed
369
+ self.seq_fa = seq_fa
370
+ self.chrom_bed = chrom_bed
371
+ self.seq_dl = None
372
+ self.chrom_dl = None
373
+
374
+ self.genome_fasta = genome_fasta
375
+ self.bws = bws
376
+ self.transforms = transforms
377
+ self.window_len = window_len
378
+ self.batch_size = batch_size
379
+
380
+ def seq_dataloader(self):
381
+ if self.seq_bed is not None:
382
+ seq_df = pd.read_table(
383
+ self.seq_bed,
384
+ header=None,
385
+ usecols=[0, 1, 2],
386
+ names=["chrom", "start", "end"],
387
+ )
388
+ seq_df = center_windows(seq_df, window_len=self.window_len)
389
+ seq_df["label"] = -1
390
+ if "strand" not in seq_df.columns:
391
+ seq_df["strand"] = "+"
392
+ seq_df = scl.filter_chromosomes(
393
+ seq_df, to_keep=Fasta(self.genome_fasta).keys()
394
+ )
395
+ dl = scl.SeqChromDatasetByDataFrame(
396
+ seq_df,
397
+ genome_fasta=self.genome_fasta,
398
+ dataloader_kws={
399
+ "batch_size": self.batch_size,
400
+ "num_workers": 0,
401
+ "collate_fn": collate_seq,
402
+ "drop_last": True,
403
+ },
404
+ )
405
+ for seq in dl:
406
+ if isinstance(seq, list):
407
+ assert len(seq) == 1
408
+ seq = seq[0]
409
+ yield seq
410
+ else:
411
+ with open(self.seq_fa) as handle:
412
+ dnaSeqs = []
413
+ for record in SeqIO.parse(handle, "fasta"):
414
+ dnaSeqs.append(torch.tensor(scl.dna2OneHot(record.seq)))
415
+ if len(dnaSeqs) >= self.batch_size:
416
+ yield torch.stack(dnaSeqs)
417
+ dnaSeqs = []
418
+ if len(dnaSeqs) > 0:
419
+ yield torch.stack(dnaSeqs)
420
+
421
+ def chrom_dataloader(self):
422
+ chrom_df = pd.read_table(
423
+ self.chrom_bed,
424
+ header=None,
425
+ usecols=[0, 1, 2],
426
+ names=["chrom", "start", "end"],
427
+ )
428
+ chrom_df = center_windows(chrom_df, window_len=self.window_len)
429
+ chrom_df["label"] = -1
430
+ if "strand" not in chrom_df.columns:
431
+ chrom_df["strand"] = "+"
432
+ chrom_df = scl.filter_chromosomes(
433
+ chrom_df, to_keep=Fasta(self.genome_fasta).keys()
434
+ )
435
+ dl = scl.SeqChromDatasetByDataFrame(
436
+ chrom_df,
437
+ genome_fasta=self.genome_fasta,
438
+ bigwig_filelist=self.bws,
439
+ transforms=self.transforms,
440
+ dataloader_kws={
441
+ "batch_size": self.batch_size,
442
+ "num_workers": 0,
443
+ "collate_fn": collate_chrom,
444
+ "drop_last": True,
445
+ },
446
+ )
447
+ yield from dl
448
+
449
+ def dataloader(self):
450
+ if (
451
+ self.seq_bed is not None or self.seq_fa is not None
452
+ ) and self.seq_dl is None:
453
+ self.seq_dl = self.seq_dataloader()
454
+ if (self.chrom_bed is not None) and self.chrom_dl is None:
455
+ self.chrom_dl = self.chrom_dataloader()
456
+ yield from zip(self.seq_dl, self.chrom_dl)
457
+
458
+
459
+ def sample_from_pwm(motif, n_seqs=1, rng=None):
460
+ """
461
+ Draw `n_seqs` independent sequences from a Bio.motifs.Motif PWM.
462
+
463
+ Parameters
464
+ ----------
465
+ motif : Bio.motifs.Motif
466
+ Motif whose `.pwm` is used for sampling.
467
+ n_seqs : int, default 1
468
+ How many sequences to generate.
469
+ rng : numpy.random.Generator or None
470
+ Leave None for np.random.default_rng().
471
+
472
+ Returns
473
+ -------
474
+ str | list[str]
475
+ A single string if n_seqs==1, otherwise a list of strings.
476
+ """
477
+ rng = rng or np.random.default_rng()
478
+
479
+ # ---- 1. Build a (L, A) probability matrix --------------------------------
480
+ alphabet = list(motif.alphabet) # e.g. ['A', 'C', 'G', 'T']
481
+ L = motif.length
482
+ pwm_dict = motif.pwm # dict base → list(float)
483
+
484
+ # shape (L, A) with rows = positions, cols = alphabet order
485
+ prob_mat = np.column_stack([pwm_dict[b] for b in alphabet])
486
+
487
+ # ---- 2. Vectorised multinomial sampling -----------------------------------
488
+ # Draw U(0,1) numbers of shape (n_seqs, L)
489
+ u = rng.random((n_seqs, L))
490
+
491
+ # cumulative probabilities along alphabet axis
492
+ cum = np.cumsum(prob_mat, axis=1) # still (L, A)
493
+
494
+ # Broadcast cum to (n_seqs, L, A) and pick first index where cum > u
495
+ idx = (u[..., None] < cum).argmax(axis=2) # (n_seqs, L) int indices
496
+
497
+ # ---- 3. Convert indices back to letters -----------------------------------
498
+ letters = np.array(alphabet, dtype="U1")
499
+ seq_arr = letters[idx] # (n_seqs, L) array of chars
500
+
501
+ # Join per sequence
502
+ seqs = ["".join(row) for row in seq_arr]
503
+ return seqs[0] if n_seqs == 1 else seqs
504
+
505
+
506
+ def insert_motif_into_seq(
507
+ seq,
508
+ motif,
509
+ num_motifs=3,
510
+ start_buffer=50,
511
+ end_buffer=50,
512
+ rng=np.random.default_rng(1),
513
+ mode="consensus",
514
+ ):
515
+ assert mode in ["consensus", "pwm"]
516
+
517
+ seq_ins = list(deepcopy(seq))
518
+ pos_motif_overlap = np.ones(len(seq))
519
+ pos_motif_overlap[:start_buffer] = 0
520
+ pos_motif_overlap[(-end_buffer - len(motif)) :] = 0
521
+ num_insert_motifs = 0
522
+ for i in range(num_motifs):
523
+ try:
524
+ motif_start = rng.choice(
525
+ np.where(pos_motif_overlap > 0)[0]
526
+ ).item() # randomly pick insert location
527
+ except ValueError:
528
+ # print(
529
+ # f"No samples can be taken for motif {motif.name}, skip inserting the rest of motifs"
530
+ # )
531
+ break
532
+ if isinstance(motif, PairedMotif):
533
+ seq_ins[motif_start : (motif_start + len(motif.motif1))] = (
534
+ list(motif.motif1.consensus)
535
+ if mode == "consensus"
536
+ else sample_from_pwm(motif.motif1)
537
+ )
538
+ seq_ins[
539
+ (motif_start + len(motif.motif1) + motif.spacing) : (
540
+ motif_start + len(motif)
541
+ )
542
+ ] = (
543
+ list(motif.motif2.consensus)
544
+ if mode == "consensus"
545
+ else sample_from_pwm(motif.motif2)
546
+ )
547
+ else:
548
+ seq_ins[motif_start : (motif_start + len(motif))] = (
549
+ list(motif.consensus) if mode == "consensus" else sample_from_pwm(motif)
550
+ )
551
+
552
+ pos_motif_overlap[(motif_start - len(motif)) : (motif_start + len(motif))] = 0
553
+ num_insert_motifs += 1
554
+ if num_insert_motifs < num_motifs:
555
+ logger.warning(
556
+ f"Only inserted {num_insert_motifs} out of {num_motifs} motifs for motif {motif.name}"
557
+ )
558
+ return "".join(seq_ins)
559
+
560
+
561
+ class CustomMotif:
562
+ def __init__(self, name, consensus):
563
+ self.name = name
564
+ self.matrix_id = "custom"
565
+ self.consensus = consensus.upper()
566
+ self.rc = False
567
+
568
+ def __len__(self):
569
+ return len(self.consensus)
570
+
571
+ def reverse_complement(self):
572
+ self.consensus = Bio.Seq.reverse_complement(self.consensus)
573
+ self.name = self.name + "_rc"
574
+ return self
575
+
576
+
577
+ class PairedMotif:
578
+ def __init__(self, motif1, motif2, spacing=0):
579
+ self.motif1 = motif1
580
+ self.motif2 = motif2
581
+ self.rc = False
582
+ self.spacing = spacing
583
+ self.pname = f"{self.motif1.name}_and_{self.motif2.name}"
584
+ self.pmatrix_id = f"{self.motif1.matrix_id}_and_{self.motif2.matrix_id}"
585
+
586
+ def reverse_complement(self):
587
+ self.rc = True
588
+ self.motif1 = self.motif1.reverse_complement()
589
+ self.motif2 = self.motif2.reverse_complement()
590
+ return self
591
+
592
+ @property
593
+ def name(self):
594
+ return self.pname
595
+
596
+ @property
597
+ def matrix_id(self):
598
+ return self.pmatrix_id
599
+
600
+ def __len__(self):
601
+ return len(self.motif1) + len(self.motif2) + self.spacing
@@ -0,0 +1,89 @@
1
+ Metadata-Version: 2.4
2
+ Name: tpcav
3
+ Version: 0.1.0
4
+ Summary: Testing with PCA projected Concept Activation Vectors
5
+ Author-email: Jianyu Yang <yztxwd@gmail.com>
6
+ License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
7
+ Project-URL: Homepage, https://github.com/seqcode/TPCAV
8
+ Keywords: interpretation,attribution,concept,genomics,deep learning
9
+ Requires-Python: >=3.8
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: torch
13
+ Requires-Dist: pandas
14
+ Requires-Dist: numpy
15
+ Requires-Dist: seqchromloader
16
+ Requires-Dist: deeplift
17
+ Requires-Dist: pyfaidx
18
+ Requires-Dist: pybedtools
19
+ Requires-Dist: captum
20
+ Requires-Dist: scikit-learn
21
+ Requires-Dist: biopython
22
+ Requires-Dist: seaborn
23
+ Requires-Dist: matplotlib
24
+ Dynamic: license-file
25
+
26
+ # TPCAV (Testing with PCA projected Concept Activation Vectors)
27
+
28
+ Analysis pipeline for TPCAV
29
+
30
+ ## Dependencies
31
+
32
+ You can use your own environment for the model, in addition, you need to install the following packages:
33
+
34
+ - captum 0.7
35
+ - seqchromloader 0.8.5
36
+ - scikit-learn 1.5.2
37
+
38
+ ## Workflow
39
+
40
+ 1. Since not every saved pytorch model stores the computation graph, you need to manually add functions to let the script know how to get the activations of the intermediate layer and how to proceed from there.
41
+
42
+ There are 3 places you need to insert your own code.
43
+
44
+ - Model class definition in models.py
45
+ - Please first copy your class definition into `Model_Class` in the script, it already has several pre-defined class functions, you need to fill in the following two functions:
46
+ - `forward_until_select_layer`: this is the function that takes your model input and forward until the layer you want to compute TPCAV score on
47
+ - `resume_forward_from_select_layer`: this is the function that starts from the activations of your select layer and forward all the way until the end
48
+ - There are also functions necessary for TPCAV computation, don't change them:
49
+ - `forward_from_start`: this function calls `forward_until_select_layer` and `resume_forward_from_select_layer` to do a full forward pass
50
+ - `forward_from_projected_and_residual`: this function takes the PCA projected activations and unexplained residual to do the forward pass
51
+ - `project_avs_to_pca`: this function takes care of the PCA projection
52
+
53
+ > NOTE: you can modify your final output tensor to specifically explain certain transformation of your output, for example, you can take weighted sum of base pair resolution signal prediction to emphasize high signal region.
54
+
55
+ - Function `load_model` in utils.py
56
+ - Take care of the model initialization and load saved parameters in `load_model`, return the model instance.
57
+ > NOTE: you need to use your own model class definition in models.py, as we need the functions defined in step 1.
58
+
59
+ - Function `seq_transform_fn` in utils.py
60
+ - By default the dataloader provides one hot coded DNA array of shape (batch_size, 4, len), coded in the order [A, C, G, T], if your model takes a different kind of input, modify `seq_transform_fn` to transform the input
61
+
62
+ - Function `chrom_transform_fn` in utils.py
63
+ - By default the dataloader provides signal array from bigwig files of shape (batch_size, # bigwigs, len), if your model takes a different kind of chromatin input, modify `chrom_transform_fn` to transform the input, if your model is sequence only, leave it to return None.
64
+
65
+
66
+ 2. Compute CAVs on your model, example command:
67
+
68
+ ```bash
69
+ srun -n1 -c8 --gres=gpu:1 --mem=128G python scripts/run_tcav_sgd_pca.py \
70
+ cavs_test 1024 data/hg19.fa data/hg19.fa.fai \
71
+ --meme-motifs data/motif-clustering-v2.1beta_consensus_pwms.test.meme \
72
+ --bed-chrom-concepts data/ENCODE_DNase_peaks.bed
73
+ ```
74
+
75
+ 3. Then compute the layer attributions, example command:
76
+
77
+ ```bash
78
+ srun -n1 -c8 --gres=gpu:1 --mem=128G \
79
+ python scripts/compute_layer_attrs_only.py cavs_test/tpcav_model.pt \
80
+ data/ChIPseq.H1-hESC.MAX.conservative.all.shuf1k.narrowPeak \
81
+ 1024 data/hg19.fa data/hg19.fa.fai cavs_test/test
82
+ ```
83
+
84
+ 4. run the jupyer notebook to generate summary of your results
85
+
86
+ ```bash
87
+ papermill -f scripts/compute_tcav_v2_pwm.example.yaml scripts/compute_tcav_v2_pwm.py.ipynb cavs_test/tcav_report.py.ipynb
88
+ ```
89
+
@@ -0,0 +1,12 @@
1
+ tpcav/__init__.py,sha256=GbO0qDy-VJjnBMZAl5TXh27znwnwEHLIadsPoWH-gY8,985
2
+ tpcav/cavs.py,sha256=DDe7vAUdewosU6wur5qDUp2OsR0Bg-k_8R4VXFjcheI,11587
3
+ tpcav/concepts.py,sha256=3HIybk5xrAru7OiOb3tBPKyWtfcfnA8DGa3DDCJXBxc,11775
4
+ tpcav/helper.py,sha256=qvEmvIwm-qMKa8_8z_uhWdlYotwzMFx-8EPUPSKoveg,5014
5
+ tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
6
+ tpcav/tpcav_model.py,sha256=gnM2RkBsv6mSFS2SYonziVBjHqdXoRX4cuYFmi9ITr0,16514
7
+ tpcav/utils.py,sha256=sftnhLqeY5ExZIvXnICY0pP27jjowSRCqtPyDi0t5Yg,18509
8
+ tpcav-0.1.0.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
9
+ tpcav-0.1.0.dist-info/METADATA,sha256=EW5LGdtqL6x6jge-oob_qgYDBuhDznxyKMkq-_YrMVA,4260
10
+ tpcav-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ tpcav-0.1.0.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
12
+ tpcav-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+