tpcav 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tpcav/__init__.py CHANGED
@@ -10,7 +10,7 @@ import logging
10
10
  # Set the logging level to INFO
11
11
  logging.basicConfig(level=logging.INFO)
12
12
 
13
- from .cavs import CavTrainer
13
+ from .cavs import CavTrainer, run_tpcav
14
14
  from .concepts import ConceptBuilder
15
15
  from .helper import (
16
16
  bed_to_chrom_tracks_iter,
tpcav/cavs.py CHANGED
@@ -5,11 +5,15 @@ CAV training and attribution utilities built on TPCAV.
5
5
 
6
6
  import logging
7
7
  import multiprocessing
8
+ from collections import defaultdict
9
+ import os
8
10
  from pathlib import Path
9
- from typing import Iterable, List, Optional, Tuple
11
+ from typing import Iterable, List, Optional, Tuple, Dict
10
12
 
13
+ from Bio import motifs
11
14
  import matplotlib.pyplot as plt
12
15
  import numpy as np
16
+ import pandas as pd
13
17
  import seaborn as sns
14
18
  import torch
15
19
  from sklearn.linear_model import SGDClassifier
@@ -17,8 +21,11 @@ from sklearn.metrics import precision_recall_fscore_support
17
21
  from sklearn.metrics.pairwise import cosine_similarity
18
22
  from sklearn.model_selection import GridSearchCV
19
23
  from torch.utils.data import DataLoader, TensorDataset, random_split
24
+ from sklearn.linear_model import LinearRegression
20
25
 
21
- from tpcav.tpcav_model import TPCAV
26
+ from . import helper, utils
27
+ from .concepts import ConceptBuilder
28
+ from .tpcav_model import TPCAV
22
29
 
23
30
  logger = logging.getLogger(__name__)
24
31
 
@@ -246,6 +253,16 @@ class CavTrainer:
246
253
 
247
254
  return scores
248
255
 
256
+ def tpcav_score_all_concepts(self, attributions: torch.Tensor) -> dict:
257
+ """
258
+ Compute TCAV scores for all trained concepts.
259
+ """
260
+ scores_dict = {}
261
+ for concept_name in self.cav_weights.keys():
262
+ scores = self.tpcav_score(concept_name, attributions)
263
+ scores_dict[concept_name] = scores
264
+ return scores_dict
265
+
249
266
  def tpcav_score_binary_log_ratio(
250
267
  self, concept_name: str, attributions: torch.Tensor, pseudocount: float = 1.0
251
268
  ) -> float:
@@ -259,6 +276,20 @@ class CavTrainer:
259
276
 
260
277
  return np.log((pos_count + pseudocount) / (neg_count + pseudocount))
261
278
 
279
+ def tpcav_score_all_concepts_log_ratio(
280
+ self, attributions: torch.Tensor, pseudocount: float = 1.0
281
+ ) -> dict:
282
+ """
283
+ Compute TCAV log ratio scores for all trained concepts.
284
+ """
285
+ log_ratio_dict = {}
286
+ for concept_name in self.cav_weights.keys():
287
+ log_ratio = self.tpcav_score_binary_log_ratio(
288
+ concept_name, attributions, pseudocount
289
+ )
290
+ log_ratio_dict[concept_name] = log_ratio
291
+ return log_ratio_dict
292
+
262
293
  def plot_cavs_similaritiy_heatmap(
263
294
  self,
264
295
  attributions: torch.Tensor,
@@ -274,7 +305,7 @@ class CavTrainer:
274
305
  cavs_names_pass = []
275
306
  for cname in cavs_names:
276
307
  if self.cavs_fscores[cname] >= fscore_thresh:
277
- cavs_pass.append(self.cav_weights[cname])
308
+ cavs_pass.append(self.cav_weights[cname].cpu().numpy())
278
309
  cavs_names_pass.append(cname)
279
310
  else:
280
311
  logger.info(
@@ -332,3 +363,156 @@ class CavTrainer:
332
363
  ax_log.set_title("TCAV log ratio")
333
364
 
334
365
  plt.savefig(output_path, dpi=300, bbox_inches="tight")
366
+
367
+ def load_motifs_from_meme(motif_meme_file):
368
+ return {utils.clean_motif_name(m.name): m for m in motifs.parse(open(motif_meme_file), fmt="MINIMAL")}
369
+
370
+ def compute_motif_auc_fscore(num_motif_insertions: List[int], cav_trainers: List[CavTrainer], meme_motif_file: str | None = None):
371
+ cavs_fscores_df = pd.DataFrame({nm: cav_trainer.cavs_fscores for nm, cav_trainer in zip(num_motif_insertions, cav_trainers)})
372
+ cavs_fscores_df['concept'] = list(cav_trainers[0].cavs_fscores.keys())
373
+
374
+ def compute_auc_fscore(row):
375
+ y = [row[nm] for nm in num_motif_insertions]
376
+ return np.trapz(y, num_motif_insertions) / (
377
+ num_motif_insertions[-1] - num_motif_insertions[0]
378
+ )
379
+
380
+ cavs_fscores_df["AUC_fscores"] = cavs_fscores_df.apply(compute_auc_fscore, axis=1)
381
+
382
+ # if motif instances are provided, fit linear regression curve to remove the dependency of f-scores on information content and motif lengthj
383
+ if meme_motif_file is not None:
384
+ motifs_dict = load_motifs_from_meme(meme_motif_file)
385
+ cavs_fscores_df['information_content'] = cavs_fscores_df.apply(lambda x: motifs_dict[x['concept']].relative_entropy.sum(), axis=1)
386
+ cavs_fscores_df['motif_len'] = cavs_fscores_df.apply(lambda x: len(motifs_dict[x['concept']].consensus), axis=1)
387
+
388
+ model = LinearRegression()
389
+ model.fit(cavs_fscores_df[['information_content', 'motif_len']].to_numpy(), cavs_fscores_df['AUC_fscores'].to_numpy()[:, np.newaxis])
390
+
391
+ y_pred = model.predict(cavs_fscores_df[['information_content', 'motif_len']].to_numpy())
392
+ residuals = cavs_fscores_df['AUC_fscores'].to_numpy() - y_pred.flatten()
393
+ cavs_fscores_df['AUC_fscores_residual'] = residuals
394
+
395
+ cavs_fscores_df.sort_values("AUC_fscores_residual", ascending=False, inplace=True)
396
+ else:
397
+ cavs_fscores_df.sort_values("AUC_fscores", ascending=False, inplace=True)
398
+
399
+ return cavs_fscores_df
400
+
401
+ def run_tpcav(
402
+ model,
403
+ layer_name: str,
404
+ meme_motif_file: str,
405
+ genome_fasta: str,
406
+ num_motif_insertions: List[int] = [4, 8, 16],
407
+ bed_seq_file: Optional[str] = None,
408
+ bed_chrom_file: Optional[str] = None,
409
+ output_dir: str = "tpcav/",
410
+ num_samples_for_pca=10,
411
+ num_samples_for_cav=1000,
412
+ bws=None,
413
+ input_transform_func=helper.fasta_chrom_to_one_hot_seq,
414
+ ):
415
+ """
416
+ One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
417
+ """
418
+ if not os.path.exists(output_dir):
419
+ os.makedirs(output_dir)
420
+
421
+ output_path = Path(output_dir)
422
+ # create concept builder to generate concepts
423
+ ## motif concepts
424
+ motif_concepts_pairs = {}
425
+ motif_concept_builders = []
426
+ num_motif_insertions.sort()
427
+ for nm in num_motif_insertions:
428
+ builder = ConceptBuilder(
429
+ genome_fasta=genome_fasta,
430
+ input_window_length=1024,
431
+ bws=bws,
432
+ num_motifs=nm,
433
+ include_reverse_complement=True,
434
+ min_samples=num_samples_for_cav,
435
+ batch_size=8,
436
+ )
437
+ # use random regions as control
438
+ builder.build_control()
439
+ # use meme motif PWMs to build motif concepts, one concept per motif
440
+ concepts_pairs = builder.add_meme_motif_concepts(str(meme_motif_file))
441
+
442
+ # apply transform to convert fasta sequences to one-hot encoded sequences
443
+ builder.apply_transform(input_transform_func)
444
+
445
+ motif_concepts_pairs[nm] = concepts_pairs
446
+ motif_concept_builders.append(builder)
447
+
448
+ ## bed concepts (optional)
449
+ if bed_seq_file is not None or bed_chrom_file is not None:
450
+ bed_builder = ConceptBuilder(
451
+ genome_fasta=genome_fasta,
452
+ input_window_length=1024,
453
+ bws=bws,
454
+ num_motifs=0,
455
+ include_reverse_complement=True,
456
+ min_samples=num_samples_for_cav,
457
+ batch_size=8,
458
+ )
459
+ # use random regions as control
460
+ bed_builder.build_control()
461
+ if bed_seq_file is not None:
462
+ # build concepts from fasta sequences in bed file
463
+ bed_builder.add_bed_sequence_concepts(bed_seq_file)
464
+ if bed_chrom_file is not None:
465
+ # build concepts from chromatin tracks in bed file
466
+ bed_builder.add_bed_chrom_concepts(bed_chrom_file)
467
+ # apply transform to convert fasta sequences to one-hot encoded sequences
468
+ bed_builder.apply_transform(input_transform_func)
469
+ else:
470
+ bed_builder = None
471
+
472
+ # create TPCAV model on top of the given model
473
+ tpcav_model = TPCAV(model, layer_name=layer_name)
474
+ # fit PCA on sampled all concept activations of the last builder (should have the most motifs)
475
+ tpcav_model.fit_pca(
476
+ concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
477
+ num_samples_per_concept=num_samples_for_pca,
478
+ num_pc="full",
479
+ )
480
+ #torch.save(tpcav_model, output_path / "tpcav_model.pt")
481
+
482
+ # create trainer for computing CAVs
483
+ motif_cav_trainers = {}
484
+ for nm in num_motif_insertions:
485
+ cav_trainer = CavTrainer(tpcav_model, penalty="l2")
486
+ for motif_concept, permuted_concept in motif_concepts_pairs[nm]:
487
+ # set control concept for CAV training
488
+ cav_trainer.set_control(
489
+ permuted_concept, num_samples=num_samples_for_cav
490
+ )
491
+ # train CAVs for all concepts
492
+ cav_trainer.train_concepts(
493
+ [motif_concept,],
494
+ num_samples_for_cav,
495
+ output_dir=str(output_path / f"cavs_{nm}_motifs/"),
496
+ num_processes=4,
497
+ )
498
+ motif_cav_trainers[nm] = cav_trainer
499
+ if bed_builder is not None:
500
+ bed_cav_trainer = CavTrainer(tpcav_model, penalty="l2")
501
+ bed_cav_trainer.set_control(
502
+ bed_builder.control_concepts[0], num_samples=num_samples_for_cav
503
+ )
504
+ bed_cav_trainer.train_concepts(
505
+ bed_builder.concepts,
506
+ num_samples_for_cav,
507
+ output_dir=str(output_path / f"cavs_bed_concepts/"),
508
+ num_processes=4,
509
+ )
510
+ else:
511
+ bed_cav_trainer = None
512
+
513
+ if len(num_motif_insertions) > 1:
514
+ cavs_fscores_df = compute_motif_auc_fscore(num_motif_insertions, list(motif_cav_trainers.values()), meme_motif_file=meme_motif_file)
515
+ else:
516
+ cavs_fscores_df = None
517
+
518
+ return cavs_fscores_df, motif_cav_trainers, bed_cav_trainer
tpcav/concepts.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import logging
2
- from copy import deepcopy
3
2
  from typing import Dict, Iterable, List, Optional, Sequence, Tuple
4
3
 
5
4
  import numpy as np
6
5
  import pandas as pd
6
+ import pyfaidx
7
7
  import seqchromloader as scl
8
8
  import webdataset as wds
9
9
  from Bio import motifs as Bio_motifs
@@ -21,15 +21,15 @@ class _PairedLoader:
21
21
  def __init__(self, seq_dl: Iterable, chrom_dl: Iterable) -> None:
22
22
  self.seq_dl = seq_dl
23
23
  self.chrom_dl = chrom_dl
24
- self.apply_func = None
24
+ self.apply_func_list = []
25
25
 
26
26
  def apply(self, apply_func):
27
- self.apply_func = apply_func
27
+ self.apply_func_list.append(apply_func)
28
28
 
29
29
  def __iter__(self):
30
30
  for inputs in zip(self.seq_dl, self.chrom_dl):
31
- if self.apply_func:
32
- inputs = self.apply_func(*inputs)
31
+ for apply_func in self.apply_func_list:
32
+ inputs = apply_func(*inputs)
33
33
  yield inputs
34
34
 
35
35
 
@@ -74,7 +74,6 @@ class ConceptBuilder:
74
74
  def __init__(
75
75
  self,
76
76
  genome_fasta: str,
77
- genome_size_file: str,
78
77
  input_window_length: int = 1024,
79
78
  bws: Optional[List[str]] = None,
80
79
  batch_size: int = 8,
@@ -83,9 +82,13 @@ class ConceptBuilder:
83
82
  include_reverse_complement: bool = False,
84
83
  min_samples: int = 5000,
85
84
  rng_seed: int = 1001,
85
+ concept_name_suffix: str = "",
86
86
  ) -> None:
87
87
  self.genome_fasta = genome_fasta
88
- self.genome_size_file = genome_size_file
88
+ pyfaidx.Fasta(
89
+ genome_fasta, build_index=True
90
+ ) # validate genome fasta file and build index if needed
91
+ self.genome_size_file = self.genome_fasta + ".fai"
89
92
  self.input_window_length = input_window_length
90
93
  self.bws = bws or []
91
94
  self.batch_size = batch_size
@@ -94,6 +97,7 @@ class ConceptBuilder:
94
97
  self.include_reverse_complement = include_reverse_complement
95
98
  self.min_samples = min_samples
96
99
  self.rng_seed = rng_seed
100
+ self.concept_name_suffix = concept_name_suffix
97
101
 
98
102
  self.control_regions: pd.DataFrame | None = None
99
103
  self.control_concepts: List[Concept] = []
@@ -116,7 +120,7 @@ class ConceptBuilder:
116
120
 
117
121
  concept = Concept(
118
122
  id=self._reserve_id(is_control=True),
119
- name=name,
123
+ name=name + self.concept_name_suffix,
120
124
  data_iter=_PairedLoader(self._control_seq_dl(), self._control_chrom_dl()),
121
125
  )
122
126
  self.control_concepts = [concept]
@@ -142,144 +146,153 @@ class ConceptBuilder:
142
146
  return chrom_iter
143
147
 
144
148
  def add_custom_motif_concepts(
145
- self, motif_table: str, control_regions: Optional[pd.DataFrame] = None
146
- ) -> List[Concept]:
149
+ self, motif_table: str, control_regions: Optional[pd.DataFrame] = None, build_permute_control=True
150
+ ) -> List[Concept] | List[Tuple[Concept]]:
147
151
  """Add concepts from a tab-delimited motif table: motif_name<TAB>consensus."""
148
- if control_regions is None:
149
- if not self.control_concepts:
150
- raise ValueError("Call build_control or pass control_regions first.")
151
- control_regions = self.metadata.get("control_regions")
152
- assert control_regions is not None
153
152
  df = pd.read_table(motif_table, names=["motif_name", "consensus_seq"])
154
- added: List[Concept] = []
153
+ added = []
155
154
  for motif_name in np.unique(df.motif_name):
155
+ motif_name = utils.clean_motif_name(motif_name)
156
156
  consensus = df.loc[df.motif_name == motif_name, "consensus_seq"].tolist()
157
157
  motifs = []
158
158
  for idx, cons in enumerate(consensus):
159
159
  motif = utils.CustomMotif(f"{motif_name}_{idx}", cons)
160
160
  motifs.append(motif)
161
- if self.include_reverse_complement:
162
- motifs.append(motif.reverse_complement())
163
- seq_dl = _construct_motif_concept_dataloader_from_control(
164
- control_regions,
165
- self.genome_fasta,
166
- motifs=motifs,
167
- num_motifs=self.num_motifs,
168
- motif_mode="consensus",
169
- batch_size=self.batch_size,
170
- num_workers=self.num_workers,
171
- )
172
- concept = Concept(
173
- id=self._reserve_id(),
174
- name=motif_name,
175
- data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
176
- )
161
+ concept = self.build_motif_concept(motifs, motif_name, control_regions=control_regions, motif_mode="consensus")
177
162
  self.concepts.append(concept)
178
- added.append(concept)
163
+ # build permute control if specified
164
+ if build_permute_control:
165
+ motifs_permuted = [m.permute() for m in motifs]
166
+ concept_permuted = self.build_motif_concept(motifs_permuted, motif_name + "_perm", control_regions=control_regions, motif_mode="consensus")
167
+ self.control_concepts.append(concept_permuted)
168
+ added.append((concept, concept_permuted))
169
+ else:
170
+ added.append(concept)
179
171
  return added
180
172
 
181
173
  def add_meme_motif_concepts(
182
- self, meme_file: str, control_regions: Optional[pd.DataFrame] = None
183
- ) -> List[Concept]:
174
+ self, meme_file: str, control_regions: Optional[pd.DataFrame] = None, build_permute_control=True) -> List[Concept] | List[Tuple[Concept]]:
184
175
  """Add concepts from a MEME minimal-format motif file."""
176
+
177
+ added = []
178
+ with open(meme_file) as handle:
179
+ for motif in Bio_motifs.parse(handle, fmt="MINIMAL"):
180
+ motif_name = utils.clean_motif_name(motif.name)
181
+ concept = self.build_motif_concept([motif,], motif_name, control_regions=control_regions, motif_mode="pwm")
182
+ self.concepts.append(concept)
183
+ # build permute control if specified
184
+ if build_permute_control:
185
+ motif_permuted = utils.PermutedPWMMotif(motif)
186
+ concept_permuted = self.build_motif_concept([motif_permuted,], motif_name + "_perm", control_regions=control_regions, motif_mode="pwm")
187
+ self.control_concepts.append(concept_permuted)
188
+ added.append((concept, concept_permuted))
189
+ else:
190
+ added.append(concept)
191
+ return added
192
+
193
+ def build_motif_concept(self, motifs, concept_name, control_regions=None, motif_mode="pwm"):
185
194
  if control_regions is None:
186
195
  if not self.control_concepts:
187
196
  raise ValueError("Call build_control or pass control_regions first.")
188
197
  control_regions = self.metadata.get("control_regions")
189
198
  assert control_regions is not None
190
199
 
200
+ if self.include_reverse_complement:
201
+ motifs.extend([m.reverse_complement() for m in motifs])
202
+ seq_dl = _construct_motif_concept_dataloader_from_control(
203
+ control_regions,
204
+ self.genome_fasta,
205
+ motifs=motifs,
206
+ num_motifs=self.num_motifs,
207
+ motif_mode=motif_mode,
208
+ batch_size=self.batch_size,
209
+ num_workers=self.num_workers,
210
+ )
211
+ concept = Concept(
212
+ id=self._reserve_id(),
213
+ name=concept_name + self.concept_name_suffix,
214
+ data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
215
+ )
216
+ return concept
217
+
218
+ def add_bed_sequence_concepts(self, bed_path: str) -> List[Concept]:
219
+ """Add concepts backed by BED sequences with concept_name in column 5."""
191
220
  added: List[Concept] = []
192
- with open(meme_file) as handle:
193
- for motif in Bio_motifs.parse(handle, fmt="MINIMAL"):
194
- motifs = [motif]
195
- if self.include_reverse_complement:
196
- motifs.append(motif.reverse_complement())
197
- motif_name = motif.name.replace("/", "-")
198
- seq_dl = _construct_motif_concept_dataloader_from_control(
199
- control_regions,
200
- self.genome_fasta,
201
- motifs=motifs,
202
- num_motifs=self.num_motifs,
203
- motif_mode="pwm",
204
- batch_size=self.batch_size,
205
- num_workers=self.num_workers,
206
- )
207
- concept = Concept(
208
- id=self._reserve_id(),
209
- name=motif_name,
210
- data_iter=_PairedLoader(seq_dl, self._control_chrom_dl()),
211
- )
212
- self.concepts.append(concept)
213
- added.append(concept)
221
+ bed_df = pd.read_table(
222
+ bed_path,
223
+ header=None,
224
+ usecols=[0, 1, 2, 3, 4],
225
+ names=["chrom", "start", "end", "strand", "concept_name"],
226
+ )
227
+ added.extend(self.add_dataframe_sequence_concepts(bed_df))
214
228
  return added
215
229
 
216
- def add_bed_sequence_concepts(self, bed_paths: Iterable[str]) -> List[Concept]:
230
+ def add_dataframe_sequence_concepts(self, dataframe: pd.DataFrame) -> List[Concept]:
217
231
  """Add concepts backed by BED sequences with concept_name in column 5."""
232
+ dataframe = helper.center_dataframe_regions(dataframe, self.input_window_length)
218
233
  added: List[Concept] = []
219
- for bed in bed_paths:
220
- bed_df = pd.read_table(
221
- bed,
222
- header=None,
223
- usecols=[0, 1, 2, 3, 4],
224
- names=["chrom", "start", "end", "strand", "concept_name"],
225
- )
226
- for concept_name in bed_df.concept_name.unique():
227
- concept_df = bed_df.loc[bed_df.concept_name == concept_name]
228
- if len(concept_df) < self.min_samples:
229
- logger.warning(
230
- "Concept %s has %s samples, fewer than min_samples=%s; skipping",
231
- concept_name,
232
- len(concept_df),
233
- self.min_samples,
234
- )
235
- continue
236
- seq_fasta_iter = helper.dataframe_to_fasta_iter(
237
- concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
238
- self.genome_fasta,
239
- batch_size=self.batch_size,
234
+ for concept_name in dataframe.concept_name.unique():
235
+ concept_df = dataframe.loc[dataframe.concept_name == concept_name]
236
+ if len(concept_df) < self.min_samples:
237
+ logger.warning(
238
+ "Concept %s has %s samples, fewer than min_samples=%s; skipping",
239
+ concept_name,
240
+ len(concept_df),
241
+ self.min_samples,
240
242
  )
241
- concept = Concept(
242
- id=self._reserve_id(),
243
- name=concept_name,
244
- data_iter=_PairedLoader(seq_fasta_iter, self._control_chrom_dl()),
245
- )
246
- self.concepts.append(concept)
247
- added.append(concept)
243
+ continue
244
+ seq_fasta_iter = helper.dataframe_to_fasta_iter(
245
+ concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
246
+ self.genome_fasta,
247
+ batch_size=self.batch_size,
248
+ )
249
+ concept = Concept(
250
+ id=self._reserve_id(),
251
+ name=concept_name + self.concept_name_suffix,
252
+ data_iter=_PairedLoader(seq_fasta_iter, self._control_chrom_dl()),
253
+ )
254
+ self.concepts.append(concept)
255
+ added.append(concept)
248
256
  return added
249
257
 
250
- def add_bed_chrom_concepts(self, bed_paths: Iterable[str]) -> List[Concept]:
258
+ def add_bed_chrom_concepts(self, bed_path: str) -> List[Concept]:
251
259
  """Add concepts backed by chromatin signal bigwigs and BED coordinates."""
252
260
  added: List[Concept] = []
253
- for bed in bed_paths:
254
- bed_df = pd.read_table(
255
- bed,
256
- header=None,
257
- usecols=[0, 1, 2, 3, 4],
258
- names=["chrom", "start", "end", "strand", "concept_name"],
259
- )
260
- for concept_name in bed_df.concept_name.unique():
261
- concept_df = bed_df.loc[bed_df.concept_name == concept_name]
262
- if len(concept_df) < self.min_samples:
263
- logger.warning(
264
- "Concept %s has %s samples, fewer than min_samples=%s; skipping",
265
- concept_name,
266
- len(concept_df),
267
- self.min_samples,
268
- )
269
- continue
270
- chrom_dl = helper.dataframe_to_chrom_tracks_iter(
271
- concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
272
- self.genome_fasta,
273
- self.bws,
274
- batch_size=self.batch_size,
275
- )
276
- concept = Concept(
277
- id=self._reserve_id(),
278
- name=concept_name,
279
- data_iter=_PairedLoader(self._control_seq_dl(), chrom_dl),
261
+ bed_df = pd.read_table(
262
+ bed_path,
263
+ header=None,
264
+ usecols=[0, 1, 2, 3, 4],
265
+ names=["chrom", "start", "end", "strand", "concept_name"],
266
+ )
267
+ added.extend(self.add_dataframe_chrom_concepts(bed_df))
268
+ return added
269
+
270
+ def add_dataframe_chrom_concepts(self, dataframe) -> List[Concept]:
271
+ """Add concepts backed by chromatin signal bigwigs and BED coordinates."""
272
+ dataframe = helper.center_dataframe_regions(dataframe, self.input_window_length)
273
+ added: List[Concept] = []
274
+ for concept_name in dataframe.concept_name.unique():
275
+ concept_df = dataframe.loc[dataframe.concept_name == concept_name]
276
+ if len(concept_df) < self.min_samples:
277
+ logger.warning(
278
+ "Concept %s has %s samples, fewer than min_samples=%s; skipping",
279
+ concept_name,
280
+ len(concept_df),
281
+ self.min_samples,
280
282
  )
281
- self.concepts.append(concept)
282
- added.append(concept)
283
+ continue
284
+ chrom_dl = helper.dataframe_to_chrom_tracks_iter(
285
+ concept_df.sample(n=self.min_samples, random_state=self.rng_seed),
286
+ self.bws,
287
+ batch_size=self.batch_size,
288
+ )
289
+ concept = Concept(
290
+ id=self._reserve_id(),
291
+ name=concept_name + self.concept_name_suffix,
292
+ data_iter=_PairedLoader(self._control_seq_dl(), chrom_dl),
293
+ )
294
+ self.concepts.append(concept)
295
+ added.append(concept)
283
296
  return added
284
297
 
285
298
  def all_concepts(self) -> List[Concept]:
tpcav/helper.py CHANGED
@@ -5,6 +5,11 @@ Lightweight data loading helpers for sequences and chromatin tracks.
5
5
 
6
6
  from typing import Iterable, List, Optional
7
7
 
8
+ import itertools
9
+ import logging
10
+ import pyBigWig
11
+ import re
12
+ import sys
8
13
  import numpy as np
9
14
  import pandas as pd
10
15
  import seqchromloader as scl
@@ -13,17 +18,31 @@ from deeplift.dinuc_shuffle import dinuc_shuffle
13
18
  from pyfaidx import Fasta
14
19
  from seqchromloader.utils import dna2OneHot, extract_bw
15
20
 
21
+ logger = logging.getLogger(__name__)
22
+
16
23
 
17
24
  def load_bed_and_center(bed_file: str, window: int) -> pd.DataFrame:
18
25
  """
19
26
  Load a BED file and center the regions to a fixed window size.
20
27
  """
21
28
  bed_df = pd.read_table(bed_file, usecols=[0, 1, 2], names=["chrom", "start", "end"])
22
- bed_df["center"] = ((bed_df["start"] + bed_df["end"]) // 2).astype(int)
23
- bed_df["start"] = bed_df["center"] - (window // 2)
24
- bed_df["end"] = bed_df["start"] + window
25
- bed_df = bed_df[["chrom", "start", "end"]]
26
- return bed_df
29
+ return center_dataframe_regions(bed_df, window)
30
+
31
+
32
+ def center_dataframe_regions(df: pd.DataFrame, window: int) -> pd.DataFrame:
33
+ """
34
+ Center the regions in a DataFrame to a fixed window size, keep other columns. Put chrom, start, end as the first 3 columns.
35
+ """
36
+ df = df.copy()
37
+ df["center"] = ((df["start"] + df["end"]) // 2).astype(int)
38
+ df["start"] = df["center"] - (window // 2)
39
+ df["end"] = df["start"] + window
40
+ df = df.drop(columns=["center"])
41
+ cols = ["chrom", "start", "end"] + [
42
+ col for col in df.columns if col not in ["chrom", "start", "end"]
43
+ ]
44
+ df = df[cols]
45
+ return df
27
46
 
28
47
 
29
48
  def bed_to_fasta_iter(
@@ -46,6 +65,10 @@ def dataframe_to_fasta_iter(
46
65
  fasta_seqs = []
47
66
  for row in df.itertuples(index=False):
48
67
  seq = str(genome[row.chrom][row.start : row.end]).upper()
68
+ if len(seq) != (row.end - row.start):
69
+ raise ValueError(
70
+ f"Extract Fasta sequence length mismatch with region coordinate length {row.chrom}:{row.start}-{row.end}"
71
+ )
49
72
  fasta_seqs.append(seq)
50
73
  if len(fasta_seqs) == batch_size:
51
74
  yield fasta_seqs
@@ -163,3 +186,117 @@ def dinuc_shuffle_sequences(
163
186
  )
164
187
  results.append(shuffles)
165
188
  return results
189
+
190
+
191
+ def fasta_chrom_to_one_hot_seq(seq, chrom):
192
+ return (fasta_to_one_hot_sequences(seq),)
193
+
194
+ def write_attrs_to_bw(arrs, regions, genome_info, bigwig_fn, smooth=False):
195
+ """
196
+ write the attributions into bigwig files
197
+ shape of arrs should be (# samples, length)
198
+ Note: If regions overlap with each other, only base pairs not covered by previous regions would be assigned current region's attribution score
199
+ """
200
+ # write header into bigwig
201
+ bw = pyBigWig.open(bigwig_fn, "w")
202
+ heads = []
203
+ with open(genome_info, "r") as f:
204
+ for line in f:
205
+ chrom, length = line.strip().split("\t")[:2]
206
+ heads.append((chrom, int(length)))
207
+ heads = sorted(heads, key=lambda x: x[0])
208
+ bw.addHeader(heads)
209
+
210
+ # sort regions and arrs
211
+ assert len(regions) == len(arrs)
212
+
213
+ def get_key(x):
214
+ chrom, start, end = re.split("[:-]", regions[x])
215
+ start = int(start)
216
+ return chrom, start
217
+
218
+ idx_sort = sorted(range(len(regions)), key=get_key)
219
+ regions = [regions[i] for i in idx_sort]
220
+ arrs = arrs[idx_sort]
221
+ # construct iterables
222
+ it = zip(arrs, regions)
223
+ it = itertools.chain(
224
+ it, zip([np.array([-1000])], ["chrNone:10-100"])
225
+ ) # add pseudo region to make sure the last entry will be added to bw file
226
+ arr, lastRegion = next(it)
227
+ lastChrom, start, end = re.split(r"[:-]", lastRegion)
228
+
229
+ start = int(start)
230
+ end = int(end)
231
+ # extend coordinates if attribution arr is larger than interval length
232
+ if end - start < len(arr):
233
+ logger.warning(
234
+ "Interval length is smaller than attribution array length, expand it!"
235
+ )
236
+ diff = len(arr) - (end - start)
237
+ if diff % 2 != 0:
238
+ raise Exception(
239
+ "The difference between attribution array length and interval length is not even! Can't do symmetric extension in this case, exiting..."
240
+ )
241
+ start -= int(diff / 2)
242
+ end += int(diff / 2)
243
+ elif end - start == len(arr):
244
+ diff = 0
245
+ else:
246
+ raise Exception(
247
+ "Interval length is larger than attribution array length, this is not expected situation, exiting..."
248
+ )
249
+ arr_store_tmp = arr
250
+ for arr, region in it:
251
+ rchrom, rstart, rend = re.split(r"[:-]", region)
252
+ rstart = int(rstart)
253
+ rend = int(rend)
254
+ # extend coordinates if attribution arr is larger than interval length
255
+ rstart -= int(diff / 2)
256
+ rend += int(diff / 2)
257
+ if rstart < 0:
258
+ break
259
+ if end <= rstart or rchrom != lastChrom:
260
+ arr_store_tmp = (
261
+ np.convolve(arr_store_tmp, np.ones(10) / 10, mode="same")
262
+ if smooth
263
+ else arr_store_tmp
264
+ )
265
+ try:
266
+ bw.addEntries(
267
+ lastChrom,
268
+ np.arange(start, end, dtype=np.int64),
269
+ values=arr_store_tmp.astype(np.float64),
270
+ span=1,
271
+ )
272
+ except:
273
+ print(lastChrom)
274
+ print(start)
275
+ print(end)
276
+ print(arr_store_tmp.shape, arr_store_tmp.dtype)
277
+ print(rchrom)
278
+ print(rstart)
279
+ print(rend)
280
+ raise Exception(
281
+ "Runtime error when adding entries to bigwig file, see above messages for region info"
282
+ )
283
+ lastChrom = rchrom
284
+ start = rstart
285
+ end = rend
286
+ arr_store_tmp = arr
287
+ # get uncovered interval (defined by start coordinate `start` and relative start coordinate `start_idx`)
288
+ else:
289
+ assert (
290
+ end > rstart and rchrom == lastChrom
291
+ ) # just double check make sure two intervals are overlapped
292
+ start_idx = end - rstart
293
+ end = rend
294
+ try:
295
+ arr_store_tmp = np.concatenate([arr_store_tmp, arr[start_idx:]])
296
+ except TypeError:
297
+ print(start_idx)
298
+ print(rstart, rend, rchrom, start, end, lastChrom)
299
+ print(arr_store_tmp.shape, print(arr.shape))
300
+ sys.exit(1)
301
+ bw.close()
302
+
tpcav/tpcav_model.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
2
  from functools import partial
3
3
  from typing import Dict, Iterable, List, Optional, Tuple
4
4
 
5
+ import numpy as np
5
6
  import torch
6
7
  from captum.attr import DeepLift
7
8
  from scipy.linalg import svd
@@ -11,10 +12,6 @@ logger = logging.getLogger(__name__)
11
12
 
12
13
  def _abs_attribution_func(multipliers, inputs, baselines):
13
14
  "Multiplier x abs(inputs - baselines) to avoid double-sign effects."
14
- # print(f"inputs: {inputs[1][:5]}")
15
- # print(f"baselines: {baselines[1][:5]}")
16
- # print(f"multipliers: {multipliers[0][:5]}")
17
- # print(f"multipliers: {multipliers[1][:5]}")
18
15
  return tuple(
19
16
  (input_ - baseline).abs() * multiplier
20
17
  for input_, baseline, multiplier in zip(inputs, baselines, multipliers)
@@ -52,7 +49,7 @@ class TPCAV(torch.nn.Module):
52
49
  "layer_name": self.layer_name,
53
50
  "zscore_mean": getattr(self, "zscore_mean", None),
54
51
  "zscore_std": getattr(self, "zscore_std", None),
55
- "pca_inv": getattr(self, "pca_inv", None),
52
+ "Vh": getattr(self, "Vh", None),
56
53
  "orig_shape": getattr(self, "orig_shape", None),
57
54
  }
58
55
 
@@ -61,7 +58,7 @@ class TPCAV(torch.nn.Module):
61
58
  self.layer_name = tpcav_state_dict["layer_name"]
62
59
  self._set_buffer("zscore_mean", tpcav_state_dict["zscore_mean"])
63
60
  self._set_buffer("zscore_std", tpcav_state_dict["zscore_std"])
64
- self._set_buffer("pca_inv", tpcav_state_dict["pca_inv"])
61
+ self._set_buffer("Vh", tpcav_state_dict["Vh"])
65
62
  self._set_buffer("orig_shape", tpcav_state_dict["orig_shape"])
66
63
  self.fitted = True
67
64
  logger.warning(
@@ -91,20 +88,22 @@ class TPCAV(torch.nn.Module):
91
88
  std[std == 0] = -1
92
89
  standardized = (flat - mean) / std
93
90
 
94
- v_inverse = None
95
91
  if num_pc is None or num_pc == "full":
96
- _, _, v = svd(standardized, lapack_driver="gesvd", full_matrices=False)
97
- v_inverse = torch.tensor(v)
92
+ _, S, Vh = svd(standardized, lapack_driver="gesvd", full_matrices=False)
93
+ Vh = torch.tensor(Vh)
98
94
  elif int(num_pc) == 0:
99
- v_inverse = None
95
+ S = None
96
+ Vh = None
100
97
  else:
101
- _, _, v = svd(standardized, lapack_driver="gesvd", full_matrices=False)
102
- v_inverse = torch.tensor(v[: int(num_pc)])
98
+ _, S, Vh = svd(standardized, lapack_driver="gesvd", full_matrices=False)
99
+ Vh = torch.tensor(Vh[: int(num_pc)])
100
+
101
+ self.eigen_values = np.square(S) if S is not None else None
103
102
 
104
103
  self._set_buffer("zscore_mean", mean.to(self.device))
105
104
  self._set_buffer("zscore_std", std.to(self.device))
106
105
  self._set_buffer(
107
- "pca_inv", v_inverse.to(self.device) if v_inverse is not None else None
106
+ "Vh", Vh.to(self.device) if Vh is not None else None
108
107
  )
109
108
  self._set_buffer("orig_shape", torch.tensor(orig_shape).to(self.device))
110
109
  self.fitted = True
@@ -112,7 +111,7 @@ class TPCAV(torch.nn.Module):
112
111
  return {
113
112
  "zscore_mean": mean,
114
113
  "zscore_std": std,
115
- "pca_inv": v_inverse,
114
+ "Vh": Vh,
116
115
  "orig_shape": torch.tensor(orig_shape),
117
116
  }
118
117
 
@@ -124,13 +123,13 @@ class TPCAV(torch.nn.Module):
124
123
  raise RuntimeError("Call fit_pca before projecting activations.")
125
124
 
126
125
  y = activations.flatten(start_dim=1).to(self.device)
127
- if self.pca_inv is not None:
128
- V = self.pca_inv.T
126
+ if self.Vh is not None:
127
+ V = self.Vh.T
129
128
  zscore_mean = getattr(self, "zscore_mean", 0.0)
130
129
  zscore_std = getattr(self, "zscore_std", 1.0)
131
130
  y_standardized = (y - zscore_mean) / zscore_std
132
131
  y_projected = torch.matmul(y_standardized, V)
133
- y_residual = y_standardized - torch.matmul(y_projected, self.pca_inv)
132
+ y_residual = y_standardized - torch.matmul(y_projected, self.Vh)
134
133
  return y_residual, y_projected
135
134
  else:
136
135
  return y, None
@@ -174,8 +173,12 @@ class TPCAV(torch.nn.Module):
174
173
  target_batches: Iterable,
175
174
  baseline_batches: Iterable,
176
175
  multiply_by_inputs: bool = True,
176
+ abs_inputs_diff: bool = True,
177
177
  ) -> Dict[str, torch.Tensor]:
178
- """Compute DeepLift attributions on PCA embedding space.
178
+ """
179
+ Compute DeepLift attributions on PCA embedding space.
180
+
181
+ By default, it computes (input - baseline).abs() * multiplier to avoid double-sign effects (abs_inputs_diff=True).
179
182
 
180
183
  target_batches and baseline_batches should yield (seq, chrom) pairs of matching length.
181
184
  """
@@ -184,6 +187,8 @@ class TPCAV(torch.nn.Module):
184
187
  self.forward = self.forward_from_embeddings_at_layer
185
188
  deeplift = DeepLift(self, multiply_by_inputs=multiply_by_inputs)
186
189
 
190
+ custom_attr_func = _abs_attribution_func if abs_inputs_diff else None
191
+
187
192
  attributions = []
188
193
  for inputs, binputs in zip(target_batches, baseline_batches):
189
194
  avs = self._layer_output(*[i.to(self.device) for i in inputs])
@@ -205,7 +210,7 @@ class TPCAV(torch.nn.Module):
205
210
  ),
206
211
  additional_forward_args=(inputs,),
207
212
  custom_attribution_func=(
208
- None if not multiply_by_inputs else _abs_attribution_func
213
+ None if not multiply_by_inputs else custom_attr_func
209
214
  ),
210
215
  )
211
216
  attr_residual, attr_projected = attribution
@@ -219,7 +224,7 @@ class TPCAV(torch.nn.Module):
219
224
  inputs,
220
225
  ),
221
226
  custom_attribution_func=(
222
- None if not multiply_by_inputs else _abs_attribution_func
227
+ None if not multiply_by_inputs else custom_attr_func
223
228
  ),
224
229
  )[0]
225
230
 
@@ -324,7 +329,7 @@ class TPCAV(torch.nn.Module):
324
329
  Combine projected/residual embeddings into the layer activation space,
325
330
  mirroring scripts/models.py merge logic.
326
331
  """
327
- y_hat = torch.matmul(avs_projected, self.pca_inv) + avs_residual
332
+ y_hat = torch.matmul(avs_projected, self.Vh) + avs_residual
328
333
  y_hat = y_hat * self.zscore_std + self.zscore_mean
329
334
 
330
335
  return y_hat.reshape((-1, *self.orig_shape[1:]))
tpcav/utils.py CHANGED
@@ -18,6 +18,8 @@ from torch.utils.data import default_collate, get_worker_info
18
18
 
19
19
  logger = logging.getLogger(__name__)
20
20
 
21
+ def clean_motif_name(motif_name):
22
+ return motif_name.replace("/", "-")
21
23
 
22
24
  def center_windows(df, window_len=1024):
23
25
  "Get center window_len bp region of the given coordinate dataframe."
@@ -568,11 +570,102 @@ class CustomMotif:
568
570
  def __len__(self):
569
571
  return len(self.consensus)
570
572
 
573
+ def permute(self, seed=None, min_shift=0.3, name_suffix="_perm", max_attempts=100):
574
+ """
575
+ Permute the consensus sequence, return the object
576
+ """
577
+ permuted = deepcopy(self)
578
+
579
+ rng = np.random.default_rng(seed)
580
+ L = len(self.consensus)
581
+
582
+ count = 0
583
+ while True:
584
+ perm = rng.permutation(L)
585
+ frac_moved = np.mean(perm != np.arange(L))
586
+ if frac_moved >= min_shift:
587
+ break
588
+ else:
589
+ count += 1
590
+ if count > max_attempts:
591
+ raise ValueError(
592
+ f"Could not generate a permutation with min_shift={min_shift} for motif {self.name}"
593
+ )
594
+ permuted_consensus = ''.join([self.consensus[i] for i in perm])
595
+ permuted.consensus = permuted_consensus
596
+ permuted.name = self.name + name_suffix
597
+ permuted.matrix_id = self.name + name_suffix
598
+
599
+ return permuted
600
+
571
601
  def reverse_complement(self):
572
602
  self.consensus = Bio.Seq.reverse_complement(self.consensus)
573
603
  self.name = self.name + "_rc"
574
604
  return self
575
605
 
606
+ class PermutedPWMMotif:
607
+ BASES = ("A", "C", "G", "T")
608
+ RC_MAP = {"A": "T", "C": "G", "G": "C", "T": "A"}
609
+
610
+ def __init__(self, motif, seed=None, min_shift=0.3, name_suffix="_perm"):
611
+ """
612
+ motif: Bio.motifs.Motif
613
+ seed: RNG seed
614
+ min_shift: fraction of positions that must move
615
+ """
616
+ self.original_motif = motif
617
+ self.name = motif.name + name_suffix if motif.name else "permuted_motif"
618
+ self.length = motif.length
619
+ self.alphabet = motif.alphabet
620
+
621
+ # extract PWM as dict of lists
622
+ pwm = {b: list(motif.pwm[b]) for b in self.BASES}
623
+
624
+ self.pwm, self.permutation = self._permute_pwm_positions(
625
+ pwm, seed=seed, min_shift=min_shift
626
+ )
627
+
628
+ def __len__(self):
629
+ return self.length
630
+
631
+ def _permute_pwm_positions(self, pwm, seed=None, min_shift=0.3):
632
+ rng = np.random.default_rng(seed)
633
+ L = len(pwm["A"])
634
+
635
+ count = 0
636
+ while True:
637
+ perm = rng.permutation(L)
638
+ frac_moved = np.mean(perm != np.arange(L))
639
+ if frac_moved >= min_shift:
640
+ break
641
+ else:
642
+ count += 1
643
+ if count > 100:
644
+ raise ValueError(
645
+ f"Could not generate a permutation with min_shift={min_shift} for motif {self.original_motif.name}"
646
+ )
647
+
648
+ permuted_pwm = {b: [pwm[b][i] for i in perm] for b in self.BASES}
649
+
650
+ return permuted_pwm, perm
651
+
652
+ def reverse_complement(self):
653
+ """
654
+ Return a NEW PermutedMotif with reverse-complemented PWM
655
+ """
656
+ rc_pwm = {b: [] for b in self.BASES}
657
+ L = len(self.pwm["A"])
658
+
659
+ for i in reversed(range(L)):
660
+ for b in self.BASES:
661
+ rc_base = self.RC_MAP[b]
662
+ rc_pwm[rc_base].append(self.pwm[b][i])
663
+
664
+ rc = deepcopy(self)
665
+ rc.pwm = rc_pwm
666
+ rc.name = self.name + "_rc"
667
+ return rc
668
+
576
669
 
577
670
  class PairedMotif:
578
671
  def __init__(self, motif1, motif2, spacing=0):
@@ -0,0 +1,91 @@
1
+ Metadata-Version: 2.4
2
+ Name: tpcav
3
+ Version: 0.2.0
4
+ Summary: Testing with PCA projected Concept Activation Vectors
5
+ Author-email: Jianyu Yang <yztxwd@gmail.com>
6
+ License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
7
+ Project-URL: Homepage, https://github.com/seqcode/TPCAV
8
+ Keywords: interpretation,attribution,concept,genomics,deep learning
9
+ Requires-Python: >=3.8
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: torch
13
+ Requires-Dist: pandas
14
+ Requires-Dist: numpy
15
+ Requires-Dist: seqchromloader
16
+ Requires-Dist: deeplift
17
+ Requires-Dist: pyfaidx
18
+ Requires-Dist: pybedtools
19
+ Requires-Dist: captum
20
+ Requires-Dist: scikit-learn
21
+ Requires-Dist: biopython
22
+ Requires-Dist: seaborn
23
+ Requires-Dist: matplotlib
24
+ Dynamic: license-file
25
+
26
+ # TPCAV (Testing with PCA projected Concept Activation Vectors)
27
+
28
+ This repository contains code to compute TPCAV (Testing with PCA projected Concept Activation Vectors) on deep learning models. TPCAV is an extension of the original TCAV method, which uses PCA to reduce the dimensionality of the activations at a selected intermediate layer before computing Concept Activation Vectors (CAVs) to improve the consistency of the results.
29
+
30
+ ## Installation
31
+
32
+ `pip install tpcav`
33
+
34
+ ## Quick start
35
+
36
+ > `tpcav` only works with Pytorch model, if your model is built using other libraries, you should port the model into Pytorch first. For Tensorflow models, you can use [tf2onnx](https://github.com/onnx/tensorflow-onnx) and [onnx2pytorch](https://github.com/Talmaj/onnx2pytorch) for the conversion.
37
+
38
+ ```python
39
+ import torch
40
+ from tpcav import run_tpcav
41
+
42
+ class DummyModelSeq(torch.nn.Module):
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.layer1 = torch.nn.Linear(1024, 1)
46
+ self.layer2 = torch.nn.Linear(4, 1)
47
+
48
+ def forward(self, seq):
49
+ y_hat = self.layer1(seq)
50
+ y_hat = y_hat.squeeze(-1)
51
+ y_hat = self.layer2(y_hat)
52
+ return y_hat
53
+
54
+ # transformation function to obtain one-hot encoded sequences
55
+ def transform_fasta_to_one_hot_seq(seq, chrom):
56
+ # `seq` is a list of fasta sequences
57
+ # `chrom` is a numpy array of bigwig signals of shape [-1, # bigwigs, len]
58
+ return (helper.fasta_to_one_hot_sequences(seq),) # it has to return a tuple of inputs, even if there is only one input
59
+
60
+ motif_path = "data/motif-clustering-v2.1beta_consensus_pwms.test.meme"
61
+ bed_seq_concept = "data/hg38_rmsk.head500k.bed"
62
+ genome_fasta = "data/hg38.analysisSet.fa"
63
+ model = DummyModelSeq() # load the model
64
+ layer_name = "layer1" # name of the layer to be interpreted
65
+
66
+ # concept_fscores_dataframe: fscores of each concept
67
+ # motif_cav_trainers: each trainer contains the cav weights of motifs inserted different number of times
68
+ # bed_cav_trainer: trainer contains the cav weights of the sequence concepts provided in bed file
69
+ concept_fscores_dataframe, motif_cav_trainers, bed_cav_trainer = run_tpcav(
70
+ model=model,
71
+ layer_name=layer_name,
72
+ meme_motif_file=motif_path,
73
+ genome_fasta=genome_fasta,
74
+ num_motif_insertions=[4, 8],
75
+ bed_seq_file=bed_seq_concept,
76
+ output_dir="test_run_tpcav_output/",
77
+ input_transform_func=transform_fasta_to_one_hot_seq
78
+ )
79
+
80
+ # check each trainer for detailed weights
81
+ print(bed_cav_trainer.cav_weights)
82
+
83
+ ```
84
+
85
+
86
+ ## Detailed Usage
87
+
88
+ For detailed usage, please refer to this [jupyter notebook](https://github.com/seqcode/TPCAV/tree/main/examples/tpcav_detailed_usage.ipynb)
89
+
90
+ If you find any issue, feel free to open an issue (strongly suggested) or contact [Jianyu Yang](mailto:jmy5455@psu.edu).
91
+
@@ -0,0 +1,12 @@
1
+ tpcav/__init__.py,sha256=CpHijSyE1HMy8dlvdSaYrwN9gYMGDEJGDdsneNWnqdA,996
2
+ tpcav/cavs.py,sha256=lqw-V45FCiCNUC8w7payuSPsbMBFy_qcOZaqPGA68js,19195
3
+ tpcav/concepts.py,sha256=_ht4UTu2EVJh52JGnKT3PEgDHk4Q-JCpNuHfFOVmzCw,12884
4
+ tpcav/helper.py,sha256=CcNFJEFG00pujUrthBoMInpIBz1mWIG3y5fztaiHO-c,9917
5
+ tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
6
+ tpcav/tpcav_model.py,sha256=XgNLPXr6_B-Dyb7RdgsUsFnrSK6oNjqqFPOjpz1wXmM,16564
7
+ tpcav/utils.py,sha256=s2TfC-YoH_xa73WuMqvtpuqzx6g3ne12hE90Yg9hToY,21502
8
+ tpcav-0.2.0.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
9
+ tpcav-0.2.0.dist-info/METADATA,sha256=tz5MWTr_-veczwEOPDGUzLEwy3XU9WLlSw4IPOzddc0,3502
10
+ tpcav-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
11
+ tpcav-0.2.0.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
12
+ tpcav-0.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,89 +0,0 @@
1
- Metadata-Version: 2.4
2
- Name: tpcav
3
- Version: 0.1.0
4
- Summary: Testing with PCA projected Concept Activation Vectors
5
- Author-email: Jianyu Yang <yztxwd@gmail.com>
6
- License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
7
- Project-URL: Homepage, https://github.com/seqcode/TPCAV
8
- Keywords: interpretation,attribution,concept,genomics,deep learning
9
- Requires-Python: >=3.8
10
- Description-Content-Type: text/markdown
11
- License-File: LICENSE
12
- Requires-Dist: torch
13
- Requires-Dist: pandas
14
- Requires-Dist: numpy
15
- Requires-Dist: seqchromloader
16
- Requires-Dist: deeplift
17
- Requires-Dist: pyfaidx
18
- Requires-Dist: pybedtools
19
- Requires-Dist: captum
20
- Requires-Dist: scikit-learn
21
- Requires-Dist: biopython
22
- Requires-Dist: seaborn
23
- Requires-Dist: matplotlib
24
- Dynamic: license-file
25
-
26
- # TPCAV (Testing with PCA projected Concept Activation Vectors)
27
-
28
- Analysis pipeline for TPCAV
29
-
30
- ## Dependencies
31
-
32
- You can use your own environment for the model, in addition, you need to install the following packages:
33
-
34
- - captum 0.7
35
- - seqchromloader 0.8.5
36
- - scikit-learn 1.5.2
37
-
38
- ## Workflow
39
-
40
- 1. Since not every saved pytorch model stores the computation graph, you need to manually add functions to let the script know how to get the activations of the intermediate layer and how to proceed from there.
41
-
42
- There are 3 places you need to insert your own code.
43
-
44
- - Model class definition in models.py
45
- - Please first copy your class definition into `Model_Class` in the script, it already has several pre-defined class functions, you need to fill in the following two functions:
46
- - `forward_until_select_layer`: this is the function that takes your model input and forward until the layer you want to compute TPCAV score on
47
- - `resume_forward_from_select_layer`: this is the function that starts from the activations of your select layer and forward all the way until the end
48
- - There are also functions necessary for TPCAV computation, don't change them:
49
- - `forward_from_start`: this function calls `forward_until_select_layer` and `resume_forward_from_select_layer` to do a full forward pass
50
- - `forward_from_projected_and_residual`: this function takes the PCA projected activations and unexplained residual to do the forward pass
51
- - `project_avs_to_pca`: this function takes care of the PCA projection
52
-
53
- > NOTE: you can modify your final output tensor to specifically explain certain transformation of your output, for example, you can take weighted sum of base pair resolution signal prediction to emphasize high signal region.
54
-
55
- - Function `load_model` in utils.py
56
- - Take care of the model initialization and load saved parameters in `load_model`, return the model instance.
57
- > NOTE: you need to use your own model class definition in models.py, as we need the functions defined in step 1.
58
-
59
- - Function `seq_transform_fn` in utils.py
60
- - By default the dataloader provides one hot coded DNA array of shape (batch_size, 4, len), coded in the order [A, C, G, T], if your model takes a different kind of input, modify `seq_transform_fn` to transform the input
61
-
62
- - Function `chrom_transform_fn` in utils.py
63
- - By default the dataloader provides signal array from bigwig files of shape (batch_size, # bigwigs, len), if your model takes a different kind of chromatin input, modify `chrom_transform_fn` to transform the input, if your model is sequence only, leave it to return None.
64
-
65
-
66
- 2. Compute CAVs on your model, example command:
67
-
68
- ```bash
69
- srun -n1 -c8 --gres=gpu:1 --mem=128G python scripts/run_tcav_sgd_pca.py \
70
- cavs_test 1024 data/hg19.fa data/hg19.fa.fai \
71
- --meme-motifs data/motif-clustering-v2.1beta_consensus_pwms.test.meme \
72
- --bed-chrom-concepts data/ENCODE_DNase_peaks.bed
73
- ```
74
-
75
- 3. Then compute the layer attributions, example command:
76
-
77
- ```bash
78
- srun -n1 -c8 --gres=gpu:1 --mem=128G \
79
- python scripts/compute_layer_attrs_only.py cavs_test/tpcav_model.pt \
80
- data/ChIPseq.H1-hESC.MAX.conservative.all.shuf1k.narrowPeak \
81
- 1024 data/hg19.fa data/hg19.fa.fai cavs_test/test
82
- ```
83
-
84
- 4. run the jupyer notebook to generate summary of your results
85
-
86
- ```bash
87
- papermill -f scripts/compute_tcav_v2_pwm.example.yaml scripts/compute_tcav_v2_pwm.py.ipynb cavs_test/tcav_report.py.ipynb
88
- ```
89
-
@@ -1,12 +0,0 @@
1
- tpcav/__init__.py,sha256=GbO0qDy-VJjnBMZAl5TXh27znwnwEHLIadsPoWH-gY8,985
2
- tpcav/cavs.py,sha256=DDe7vAUdewosU6wur5qDUp2OsR0Bg-k_8R4VXFjcheI,11587
3
- tpcav/concepts.py,sha256=3HIybk5xrAru7OiOb3tBPKyWtfcfnA8DGa3DDCJXBxc,11775
4
- tpcav/helper.py,sha256=qvEmvIwm-qMKa8_8z_uhWdlYotwzMFx-8EPUPSKoveg,5014
5
- tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
6
- tpcav/tpcav_model.py,sha256=gnM2RkBsv6mSFS2SYonziVBjHqdXoRX4cuYFmi9ITr0,16514
7
- tpcav/utils.py,sha256=sftnhLqeY5ExZIvXnICY0pP27jjowSRCqtPyDi0t5Yg,18509
8
- tpcav-0.1.0.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
9
- tpcav-0.1.0.dist-info/METADATA,sha256=EW5LGdtqL6x6jge-oob_qgYDBuhDznxyKMkq-_YrMVA,4260
10
- tpcav-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
- tpcav-0.1.0.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
12
- tpcav-0.1.0.dist-info/RECORD,,