tpcav 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tpcav-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,91 @@
1
+ Metadata-Version: 2.4
2
+ Name: tpcav
3
+ Version: 0.2.0
4
+ Summary: Testing with PCA projected Concept Activation Vectors
5
+ Author-email: Jianyu Yang <yztxwd@gmail.com>
6
+ License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
7
+ Project-URL: Homepage, https://github.com/seqcode/TPCAV
8
+ Keywords: interpretation,attribution,concept,genomics,deep learning
9
+ Requires-Python: >=3.8
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: torch
13
+ Requires-Dist: pandas
14
+ Requires-Dist: numpy
15
+ Requires-Dist: seqchromloader
16
+ Requires-Dist: deeplift
17
+ Requires-Dist: pyfaidx
18
+ Requires-Dist: pybedtools
19
+ Requires-Dist: captum
20
+ Requires-Dist: scikit-learn
21
+ Requires-Dist: biopython
22
+ Requires-Dist: seaborn
23
+ Requires-Dist: matplotlib
24
+ Dynamic: license-file
25
+
26
+ # TPCAV (Testing with PCA projected Concept Activation Vectors)
27
+
28
+ This repository contains code to compute TPCAV (Testing with PCA projected Concept Activation Vectors) on deep learning models. TPCAV is an extension of the original TCAV method, which uses PCA to reduce the dimensionality of the activations at a selected intermediate layer before computing Concept Activation Vectors (CAVs) to improve the consistency of the results.
29
+
30
+ ## Installation
31
+
32
+ `pip install tpcav`
33
+
34
+ ## Quick start
35
+
36
+ > `tpcav` only works with Pytorch model, if your model is built using other libraries, you should port the model into Pytorch first. For Tensorflow models, you can use [tf2onnx](https://github.com/onnx/tensorflow-onnx) and [onnx2pytorch](https://github.com/Talmaj/onnx2pytorch) for the conversion.
37
+
38
+ ```python
39
+ import torch
40
+ from tpcav import run_tpcav
41
+
42
+ class DummyModelSeq(torch.nn.Module):
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.layer1 = torch.nn.Linear(1024, 1)
46
+ self.layer2 = torch.nn.Linear(4, 1)
47
+
48
+ def forward(self, seq):
49
+ y_hat = self.layer1(seq)
50
+ y_hat = y_hat.squeeze(-1)
51
+ y_hat = self.layer2(y_hat)
52
+ return y_hat
53
+
54
+ # transformation function to obtain one-hot encoded sequences
55
+ def transform_fasta_to_one_hot_seq(seq, chrom):
56
+ # `seq` is a list of fasta sequences
57
+ # `chrom` is a numpy array of bigwig signals of shape [-1, # bigwigs, len]
58
+ return (helper.fasta_to_one_hot_sequences(seq),) # it has to return a tuple of inputs, even if there is only one input
59
+
60
+ motif_path = "data/motif-clustering-v2.1beta_consensus_pwms.test.meme"
61
+ bed_seq_concept = "data/hg38_rmsk.head500k.bed"
62
+ genome_fasta = "data/hg38.analysisSet.fa"
63
+ model = DummyModelSeq() # load the model
64
+ layer_name = "layer1" # name of the layer to be interpreted
65
+
66
+ # concept_fscores_dataframe: fscores of each concept
67
+ # motif_cav_trainers: each trainer contains the cav weights of motifs inserted different number of times
68
+ # bed_cav_trainer: trainer contains the cav weights of the sequence concepts provided in bed file
69
+ concept_fscores_dataframe, motif_cav_trainers, bed_cav_trainer = run_tpcav(
70
+ model=model,
71
+ layer_name=layer_name,
72
+ meme_motif_file=motif_path,
73
+ genome_fasta=genome_fasta,
74
+ num_motif_insertions=[4, 8],
75
+ bed_seq_file=bed_seq_concept,
76
+ output_dir="test_run_tpcav_output/",
77
+ input_transform_func=transform_fasta_to_one_hot_seq
78
+ )
79
+
80
+ # check each trainer for detailed weights
81
+ print(bed_cav_trainer.cav_weights)
82
+
83
+ ```
84
+
85
+
86
+ ## Detailed Usage
87
+
88
+ For detailed usage, please refer to this [jupyter notebook](https://github.com/seqcode/TPCAV/tree/main/examples/tpcav_detailed_usage.ipynb)
89
+
90
+ If you find any issue, feel free to open an issue (strongly suggested) or contact [Jianyu Yang](mailto:jmy5455@psu.edu).
91
+
tpcav-0.2.0/README.md ADDED
@@ -0,0 +1,66 @@
1
+ # TPCAV (Testing with PCA projected Concept Activation Vectors)
2
+
3
+ This repository contains code to compute TPCAV (Testing with PCA projected Concept Activation Vectors) on deep learning models. TPCAV is an extension of the original TCAV method, which uses PCA to reduce the dimensionality of the activations at a selected intermediate layer before computing Concept Activation Vectors (CAVs) to improve the consistency of the results.
4
+
5
+ ## Installation
6
+
7
+ `pip install tpcav`
8
+
9
+ ## Quick start
10
+
11
+ > `tpcav` only works with Pytorch model, if your model is built using other libraries, you should port the model into Pytorch first. For Tensorflow models, you can use [tf2onnx](https://github.com/onnx/tensorflow-onnx) and [onnx2pytorch](https://github.com/Talmaj/onnx2pytorch) for the conversion.
12
+
13
+ ```python
14
+ import torch
15
+ from tpcav import run_tpcav
16
+
17
+ class DummyModelSeq(torch.nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.layer1 = torch.nn.Linear(1024, 1)
21
+ self.layer2 = torch.nn.Linear(4, 1)
22
+
23
+ def forward(self, seq):
24
+ y_hat = self.layer1(seq)
25
+ y_hat = y_hat.squeeze(-1)
26
+ y_hat = self.layer2(y_hat)
27
+ return y_hat
28
+
29
+ # transformation function to obtain one-hot encoded sequences
30
+ def transform_fasta_to_one_hot_seq(seq, chrom):
31
+ # `seq` is a list of fasta sequences
32
+ # `chrom` is a numpy array of bigwig signals of shape [-1, # bigwigs, len]
33
+ return (helper.fasta_to_one_hot_sequences(seq),) # it has to return a tuple of inputs, even if there is only one input
34
+
35
+ motif_path = "data/motif-clustering-v2.1beta_consensus_pwms.test.meme"
36
+ bed_seq_concept = "data/hg38_rmsk.head500k.bed"
37
+ genome_fasta = "data/hg38.analysisSet.fa"
38
+ model = DummyModelSeq() # load the model
39
+ layer_name = "layer1" # name of the layer to be interpreted
40
+
41
+ # concept_fscores_dataframe: fscores of each concept
42
+ # motif_cav_trainers: each trainer contains the cav weights of motifs inserted different number of times
43
+ # bed_cav_trainer: trainer contains the cav weights of the sequence concepts provided in bed file
44
+ concept_fscores_dataframe, motif_cav_trainers, bed_cav_trainer = run_tpcav(
45
+ model=model,
46
+ layer_name=layer_name,
47
+ meme_motif_file=motif_path,
48
+ genome_fasta=genome_fasta,
49
+ num_motif_insertions=[4, 8],
50
+ bed_seq_file=bed_seq_concept,
51
+ output_dir="test_run_tpcav_output/",
52
+ input_transform_func=transform_fasta_to_one_hot_seq
53
+ )
54
+
55
+ # check each trainer for detailed weights
56
+ print(bed_cav_trainer.cav_weights)
57
+
58
+ ```
59
+
60
+
61
+ ## Detailed Usage
62
+
63
+ For detailed usage, please refer to this [jupyter notebook](https://github.com/seqcode/TPCAV/tree/main/examples/tpcav_detailed_usage.ipynb)
64
+
65
+ If you find any issue, feel free to open an issue (strongly suggested) or contact [Jianyu Yang](mailto:jmy5455@psu.edu).
66
+
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "tpcav"
7
- version = "0.1.0"
7
+ version = "0.2.0"
8
8
  description = "Testing with PCA projected Concept Activation Vectors"
9
9
  authors = [{name = "Jianyu Yang", email = "yztxwd@gmail.com"},]
10
10
  readme = "README.md"
@@ -6,7 +6,7 @@ import torch
6
6
  from Bio import motifs as Bio_motifs
7
7
  from captum.attr import DeepLift
8
8
 
9
- from tpcav import helper
9
+ from tpcav import helper, run_tpcav
10
10
  from tpcav.cavs import CavTrainer
11
11
  from tpcav.concepts import ConceptBuilder
12
12
  from tpcav.tpcav_model import TPCAV, _abs_attribution_func
@@ -47,15 +47,14 @@ def transform_fasta_to_one_hot_seq(seq, chrom):
47
47
  return (helper.fasta_to_one_hot_sequences(seq),)
48
48
 
49
49
 
50
- class CavTrainerIntegrationTest(unittest.TestCase):
50
+ class TPCAVTest(unittest.TestCase):
51
51
 
52
- def test_motif_concepts(self):
52
+ def test_motif_concepts_insertion(self):
53
53
  motif_path = Path("data") / "motif-clustering-v2.1beta_consensus_pwms.test.meme"
54
54
  self.assertTrue(motif_path.exists(), "Motif file is missing")
55
55
 
56
56
  builder = ConceptBuilder(
57
57
  genome_fasta="data/hg38.analysisSet.fa",
58
- genome_size_file="data/hg38.analysisSet.fa.fai",
59
58
  input_window_length=1024,
60
59
  bws=None,
61
60
  num_motifs=16,
@@ -102,6 +101,62 @@ class CavTrainerIntegrationTest(unittest.TestCase):
102
101
  f"Control concept has more motif matches than Motif concept, motif concept: {len(matches)}, control concept: {len(control_matches)}",
103
102
  )
104
103
 
104
+ def test_run_tpcav(self):
105
+ motif_path = Path("data") / "motif-clustering-v2.1beta_consensus_pwms.test.meme"
106
+ genome_fasta = "data/hg38.analysisSet.fa"
107
+ model = DummyModelSeq()
108
+ layer_name = "layer1"
109
+
110
+ cavs_fscores_df, motif_cav_trainers, bed_cav_trainer = run_tpcav(
111
+ model=model,
112
+ layer_name=layer_name,
113
+ meme_motif_file=str(motif_path),
114
+ genome_fasta=genome_fasta,
115
+ num_motif_insertions=[4, 8],
116
+ bed_seq_file="data/hg38_rmsk.head50k.bed",
117
+ output_dir="data/test_run_tpcav_output/",
118
+ )
119
+
120
+ def test_write_bw(self):
121
+ random_regions_1 = helper.random_regions_dataframe(
122
+ "data/hg38.analysisSet.fa.fai", 1024, 100, seed=1
123
+ )
124
+ helper.write_attrs_to_bw(torch.rand((100, 1024)).numpy(),
125
+ random_regions_1.apply(lambda x: f"{x.chrom}:{x.start}-{x.end}", axis=1).tolist(),
126
+ "data/hg38.analysisSet.fa.fai", "data/test_run_tpcav_output/input_attrs.bw")
127
+
128
+ def test_motif_concepts_against_permute_control(self):
129
+ motif_path = Path("data") / "motif-clustering-v2.1beta_consensus_pwms.test.meme"
130
+ self.assertTrue(motif_path.exists(), "Motif file is missing")
131
+
132
+ builder = ConceptBuilder(
133
+ genome_fasta="data/hg38.analysisSet.fa",
134
+ input_window_length=1024,
135
+ bws=None,
136
+ num_motifs=16,
137
+ include_reverse_complement=True,
138
+ min_samples=1000,
139
+ batch_size=8,
140
+ )
141
+
142
+ builder.build_control()
143
+
144
+ concepts_pairs = builder.add_meme_motif_concepts(str(motif_path), build_permute_control=True)
145
+ builder.apply_transform(transform_fasta_to_one_hot_seq)
146
+
147
+ tpcav_model = TPCAV(DummyModelSeq(), layer_name="layer1")
148
+ tpcav_model.fit_pca(
149
+ concepts=builder.all_concepts(),
150
+ num_samples_per_concept=10,
151
+ num_pc="full",
152
+ )
153
+ cav_trainer = CavTrainer(tpcav_model)
154
+
155
+ for motif_concept, permute_concept in concepts_pairs:
156
+ cav_trainer.set_control(permute_concept, 200)
157
+ cav_trainer.train_concepts([motif_concept,], 200, output_dir="data/cavs_permute/", num_processes=2)
158
+
159
+
105
160
  def test_all(self):
106
161
 
107
162
  motif_path = Path("data") / "motif-clustering-v2.1beta_consensus_pwms.test.meme"
@@ -109,7 +164,6 @@ class CavTrainerIntegrationTest(unittest.TestCase):
109
164
 
110
165
  builder = ConceptBuilder(
111
166
  genome_fasta="data/hg38.analysisSet.fa",
112
- genome_size_file="data/hg38.analysisSet.fa.fai",
113
167
  input_window_length=1024,
114
168
  bws=None,
115
169
  num_motifs=12,
@@ -166,7 +220,7 @@ class CavTrainerIntegrationTest(unittest.TestCase):
166
220
 
167
221
  attributions = tpcav_model.layer_attributions(
168
222
  pack_data_iters(random_regions_1), pack_data_iters(random_regions_2)
169
- )["attributions"]
223
+ )["attributions"].cpu()
170
224
 
171
225
  cav_trainer.tpcav_score("AC0001:GATA-PROP:GATA", attributions)
172
226
 
@@ -220,9 +274,12 @@ class CavTrainerIntegrationTest(unittest.TestCase):
220
274
  custom_attribution_func=_abs_attribution_func,
221
275
  )
222
276
  attr_residual, attr_projected = attributions_old
223
- attributions_old = torch.cat((attr_projected, attr_residual), dim=1)
277
+ attributions_old = torch.cat((attr_projected, attr_residual), dim=1).cpu()
224
278
 
225
- self.assertTrue(torch.allclose(attributions.cpu(), attributions_old.cpu()))
279
+ self.assertTrue(
280
+ torch.allclose(attributions.cpu(), attributions_old.cpu(), atol=1e-6),
281
+ f"Attributions do not match, max difference is {torch.abs(attributions - attributions_old).max()}",
282
+ )
226
283
 
227
284
 
228
285
  if __name__ == "__main__":
@@ -10,7 +10,7 @@ import logging
10
10
  # Set the logging level to INFO
11
11
  logging.basicConfig(level=logging.INFO)
12
12
 
13
- from .cavs import CavTrainer
13
+ from .cavs import CavTrainer, run_tpcav
14
14
  from .concepts import ConceptBuilder
15
15
  from .helper import (
16
16
  bed_to_chrom_tracks_iter,
@@ -5,11 +5,15 @@ CAV training and attribution utilities built on TPCAV.
5
5
 
6
6
  import logging
7
7
  import multiprocessing
8
+ from collections import defaultdict
9
+ import os
8
10
  from pathlib import Path
9
- from typing import Iterable, List, Optional, Tuple
11
+ from typing import Iterable, List, Optional, Tuple, Dict
10
12
 
13
+ from Bio import motifs
11
14
  import matplotlib.pyplot as plt
12
15
  import numpy as np
16
+ import pandas as pd
13
17
  import seaborn as sns
14
18
  import torch
15
19
  from sklearn.linear_model import SGDClassifier
@@ -17,8 +21,11 @@ from sklearn.metrics import precision_recall_fscore_support
17
21
  from sklearn.metrics.pairwise import cosine_similarity
18
22
  from sklearn.model_selection import GridSearchCV
19
23
  from torch.utils.data import DataLoader, TensorDataset, random_split
24
+ from sklearn.linear_model import LinearRegression
20
25
 
21
- from tpcav.tpcav_model import TPCAV
26
+ from . import helper, utils
27
+ from .concepts import ConceptBuilder
28
+ from .tpcav_model import TPCAV
22
29
 
23
30
  logger = logging.getLogger(__name__)
24
31
 
@@ -246,6 +253,16 @@ class CavTrainer:
246
253
 
247
254
  return scores
248
255
 
256
+ def tpcav_score_all_concepts(self, attributions: torch.Tensor) -> dict:
257
+ """
258
+ Compute TCAV scores for all trained concepts.
259
+ """
260
+ scores_dict = {}
261
+ for concept_name in self.cav_weights.keys():
262
+ scores = self.tpcav_score(concept_name, attributions)
263
+ scores_dict[concept_name] = scores
264
+ return scores_dict
265
+
249
266
  def tpcav_score_binary_log_ratio(
250
267
  self, concept_name: str, attributions: torch.Tensor, pseudocount: float = 1.0
251
268
  ) -> float:
@@ -259,6 +276,20 @@ class CavTrainer:
259
276
 
260
277
  return np.log((pos_count + pseudocount) / (neg_count + pseudocount))
261
278
 
279
+ def tpcav_score_all_concepts_log_ratio(
280
+ self, attributions: torch.Tensor, pseudocount: float = 1.0
281
+ ) -> dict:
282
+ """
283
+ Compute TCAV log ratio scores for all trained concepts.
284
+ """
285
+ log_ratio_dict = {}
286
+ for concept_name in self.cav_weights.keys():
287
+ log_ratio = self.tpcav_score_binary_log_ratio(
288
+ concept_name, attributions, pseudocount
289
+ )
290
+ log_ratio_dict[concept_name] = log_ratio
291
+ return log_ratio_dict
292
+
262
293
  def plot_cavs_similaritiy_heatmap(
263
294
  self,
264
295
  attributions: torch.Tensor,
@@ -274,7 +305,7 @@ class CavTrainer:
274
305
  cavs_names_pass = []
275
306
  for cname in cavs_names:
276
307
  if self.cavs_fscores[cname] >= fscore_thresh:
277
- cavs_pass.append(self.cav_weights[cname])
308
+ cavs_pass.append(self.cav_weights[cname].cpu().numpy())
278
309
  cavs_names_pass.append(cname)
279
310
  else:
280
311
  logger.info(
@@ -332,3 +363,156 @@ class CavTrainer:
332
363
  ax_log.set_title("TCAV log ratio")
333
364
 
334
365
  plt.savefig(output_path, dpi=300, bbox_inches="tight")
366
+
367
+ def load_motifs_from_meme(motif_meme_file):
368
+ return {utils.clean_motif_name(m.name): m for m in motifs.parse(open(motif_meme_file), fmt="MINIMAL")}
369
+
370
+ def compute_motif_auc_fscore(num_motif_insertions: List[int], cav_trainers: List[CavTrainer], meme_motif_file: str | None = None):
371
+ cavs_fscores_df = pd.DataFrame({nm: cav_trainer.cavs_fscores for nm, cav_trainer in zip(num_motif_insertions, cav_trainers)})
372
+ cavs_fscores_df['concept'] = list(cav_trainers[0].cavs_fscores.keys())
373
+
374
+ def compute_auc_fscore(row):
375
+ y = [row[nm] for nm in num_motif_insertions]
376
+ return np.trapz(y, num_motif_insertions) / (
377
+ num_motif_insertions[-1] - num_motif_insertions[0]
378
+ )
379
+
380
+ cavs_fscores_df["AUC_fscores"] = cavs_fscores_df.apply(compute_auc_fscore, axis=1)
381
+
382
+ # if motif instances are provided, fit linear regression curve to remove the dependency of f-scores on information content and motif lengthj
383
+ if meme_motif_file is not None:
384
+ motifs_dict = load_motifs_from_meme(meme_motif_file)
385
+ cavs_fscores_df['information_content'] = cavs_fscores_df.apply(lambda x: motifs_dict[x['concept']].relative_entropy.sum(), axis=1)
386
+ cavs_fscores_df['motif_len'] = cavs_fscores_df.apply(lambda x: len(motifs_dict[x['concept']].consensus), axis=1)
387
+
388
+ model = LinearRegression()
389
+ model.fit(cavs_fscores_df[['information_content', 'motif_len']].to_numpy(), cavs_fscores_df['AUC_fscores'].to_numpy()[:, np.newaxis])
390
+
391
+ y_pred = model.predict(cavs_fscores_df[['information_content', 'motif_len']].to_numpy())
392
+ residuals = cavs_fscores_df['AUC_fscores'].to_numpy() - y_pred.flatten()
393
+ cavs_fscores_df['AUC_fscores_residual'] = residuals
394
+
395
+ cavs_fscores_df.sort_values("AUC_fscores_residual", ascending=False, inplace=True)
396
+ else:
397
+ cavs_fscores_df.sort_values("AUC_fscores", ascending=False, inplace=True)
398
+
399
+ return cavs_fscores_df
400
+
401
+ def run_tpcav(
402
+ model,
403
+ layer_name: str,
404
+ meme_motif_file: str,
405
+ genome_fasta: str,
406
+ num_motif_insertions: List[int] = [4, 8, 16],
407
+ bed_seq_file: Optional[str] = None,
408
+ bed_chrom_file: Optional[str] = None,
409
+ output_dir: str = "tpcav/",
410
+ num_samples_for_pca=10,
411
+ num_samples_for_cav=1000,
412
+ bws=None,
413
+ input_transform_func=helper.fasta_chrom_to_one_hot_seq,
414
+ ):
415
+ """
416
+ One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
417
+ """
418
+ if not os.path.exists(output_dir):
419
+ os.makedirs(output_dir)
420
+
421
+ output_path = Path(output_dir)
422
+ # create concept builder to generate concepts
423
+ ## motif concepts
424
+ motif_concepts_pairs = {}
425
+ motif_concept_builders = []
426
+ num_motif_insertions.sort()
427
+ for nm in num_motif_insertions:
428
+ builder = ConceptBuilder(
429
+ genome_fasta=genome_fasta,
430
+ input_window_length=1024,
431
+ bws=bws,
432
+ num_motifs=nm,
433
+ include_reverse_complement=True,
434
+ min_samples=num_samples_for_cav,
435
+ batch_size=8,
436
+ )
437
+ # use random regions as control
438
+ builder.build_control()
439
+ # use meme motif PWMs to build motif concepts, one concept per motif
440
+ concepts_pairs = builder.add_meme_motif_concepts(str(meme_motif_file))
441
+
442
+ # apply transform to convert fasta sequences to one-hot encoded sequences
443
+ builder.apply_transform(input_transform_func)
444
+
445
+ motif_concepts_pairs[nm] = concepts_pairs
446
+ motif_concept_builders.append(builder)
447
+
448
+ ## bed concepts (optional)
449
+ if bed_seq_file is not None or bed_chrom_file is not None:
450
+ bed_builder = ConceptBuilder(
451
+ genome_fasta=genome_fasta,
452
+ input_window_length=1024,
453
+ bws=bws,
454
+ num_motifs=0,
455
+ include_reverse_complement=True,
456
+ min_samples=num_samples_for_cav,
457
+ batch_size=8,
458
+ )
459
+ # use random regions as control
460
+ bed_builder.build_control()
461
+ if bed_seq_file is not None:
462
+ # build concepts from fasta sequences in bed file
463
+ bed_builder.add_bed_sequence_concepts(bed_seq_file)
464
+ if bed_chrom_file is not None:
465
+ # build concepts from chromatin tracks in bed file
466
+ bed_builder.add_bed_chrom_concepts(bed_chrom_file)
467
+ # apply transform to convert fasta sequences to one-hot encoded sequences
468
+ bed_builder.apply_transform(input_transform_func)
469
+ else:
470
+ bed_builder = None
471
+
472
+ # create TPCAV model on top of the given model
473
+ tpcav_model = TPCAV(model, layer_name=layer_name)
474
+ # fit PCA on sampled all concept activations of the last builder (should have the most motifs)
475
+ tpcav_model.fit_pca(
476
+ concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
477
+ num_samples_per_concept=num_samples_for_pca,
478
+ num_pc="full",
479
+ )
480
+ #torch.save(tpcav_model, output_path / "tpcav_model.pt")
481
+
482
+ # create trainer for computing CAVs
483
+ motif_cav_trainers = {}
484
+ for nm in num_motif_insertions:
485
+ cav_trainer = CavTrainer(tpcav_model, penalty="l2")
486
+ for motif_concept, permuted_concept in motif_concepts_pairs[nm]:
487
+ # set control concept for CAV training
488
+ cav_trainer.set_control(
489
+ permuted_concept, num_samples=num_samples_for_cav
490
+ )
491
+ # train CAVs for all concepts
492
+ cav_trainer.train_concepts(
493
+ [motif_concept,],
494
+ num_samples_for_cav,
495
+ output_dir=str(output_path / f"cavs_{nm}_motifs/"),
496
+ num_processes=4,
497
+ )
498
+ motif_cav_trainers[nm] = cav_trainer
499
+ if bed_builder is not None:
500
+ bed_cav_trainer = CavTrainer(tpcav_model, penalty="l2")
501
+ bed_cav_trainer.set_control(
502
+ bed_builder.control_concepts[0], num_samples=num_samples_for_cav
503
+ )
504
+ bed_cav_trainer.train_concepts(
505
+ bed_builder.concepts,
506
+ num_samples_for_cav,
507
+ output_dir=str(output_path / f"cavs_bed_concepts/"),
508
+ num_processes=4,
509
+ )
510
+ else:
511
+ bed_cav_trainer = None
512
+
513
+ if len(num_motif_insertions) > 1:
514
+ cavs_fscores_df = compute_motif_auc_fscore(num_motif_insertions, list(motif_cav_trainers.values()), meme_motif_file=meme_motif_file)
515
+ else:
516
+ cavs_fscores_df = None
517
+
518
+ return cavs_fscores_df, motif_cav_trainers, bed_cav_trainer