tpcav 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tpcav/cavs.py CHANGED
@@ -205,8 +205,8 @@ class CavTrainer:
205
205
  c, num_samples=num_samples
206
206
  )
207
207
  fscore, weight = _train(
208
- concept_embeddings,
209
- self.control_embeddings,
208
+ concept_embeddings.cpu(),
209
+ self.control_embeddings.cpu(),
210
210
  Path(output_dir) / c.name,
211
211
  self.penalty,
212
212
  )
@@ -223,8 +223,8 @@ class CavTrainer:
223
223
  res = pool.apply_async(
224
224
  _train,
225
225
  args=(
226
- concept_embeddings,
227
- self.control_embeddings,
226
+ concept_embeddings.cpu(),
227
+ self.control_embeddings.cpu(),
228
228
  Path(output_dir) / c.name,
229
229
  self.penalty,
230
230
  ),
@@ -409,8 +409,12 @@ def run_tpcav(
409
409
  output_dir: str = "tpcav/",
410
410
  num_samples_for_pca=10,
411
411
  num_samples_for_cav=1000,
412
+ input_window_length=1024,
413
+ batch_size=8,
412
414
  bws=None,
413
415
  input_transform_func=helper.fasta_chrom_to_one_hot_seq,
416
+ fit_pca=True,
417
+ p=4
414
418
  ):
415
419
  """
416
420
  One-stop function to compute CAVs on motif concepts and bed concepts, compute AUC of motif concept f-scores after correction
@@ -427,12 +431,12 @@ def run_tpcav(
427
431
  for nm in num_motif_insertions:
428
432
  builder = ConceptBuilder(
429
433
  genome_fasta=genome_fasta,
430
- input_window_length=1024,
434
+ input_window_length=input_window_length,
431
435
  bws=bws,
432
436
  num_motifs=nm,
433
437
  include_reverse_complement=True,
434
438
  min_samples=num_samples_for_cav,
435
- batch_size=8,
439
+ batch_size=batch_size,
436
440
  )
437
441
  # use random regions as control
438
442
  builder.build_control()
@@ -449,12 +453,12 @@ def run_tpcav(
449
453
  if bed_seq_file is not None or bed_chrom_file is not None:
450
454
  bed_builder = ConceptBuilder(
451
455
  genome_fasta=genome_fasta,
452
- input_window_length=1024,
456
+ input_window_length=input_window_length,
453
457
  bws=bws,
454
458
  num_motifs=0,
455
459
  include_reverse_complement=True,
456
460
  min_samples=num_samples_for_cav,
457
- batch_size=8,
461
+ batch_size=batch_size,
458
462
  )
459
463
  # use random regions as control
460
464
  bed_builder.build_control()
@@ -472,11 +476,12 @@ def run_tpcav(
472
476
  # create TPCAV model on top of the given model
473
477
  tpcav_model = TPCAV(model, layer_name=layer_name)
474
478
  # fit PCA on sampled all concept activations of the last builder (should have the most motifs)
475
- tpcav_model.fit_pca(
476
- concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
477
- num_samples_per_concept=num_samples_for_pca,
478
- num_pc="full",
479
- )
479
+ if fit_pca:
480
+ tpcav_model.fit_pca(
481
+ concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
482
+ num_samples_per_concept=num_samples_for_pca,
483
+ num_pc="full",
484
+ )
480
485
  #torch.save(tpcav_model, output_path / "tpcav_model.pt")
481
486
 
482
487
  # create trainer for computing CAVs
@@ -493,7 +498,7 @@ def run_tpcav(
493
498
  [motif_concept,],
494
499
  num_samples_for_cav,
495
500
  output_dir=str(output_path / f"cavs_{nm}_motifs/"),
496
- num_processes=4,
501
+ num_processes=p,
497
502
  )
498
503
  motif_cav_trainers[nm] = cav_trainer
499
504
  if bed_builder is not None:
@@ -505,7 +510,7 @@ def run_tpcav(
505
510
  bed_builder.concepts,
506
511
  num_samples_for_cav,
507
512
  output_dir=str(output_path / f"cavs_bed_concepts/"),
508
- num_processes=4,
513
+ num_processes=p,
509
514
  )
510
515
  else:
511
516
  bed_cav_trainer = None
tpcav/tpcav_model.py CHANGED
@@ -140,7 +140,7 @@ class TPCAV(torch.nn.Module):
140
140
  residual, projected = self.project_activations(avs)
141
141
  if projected is not None:
142
142
  return torch.cat((projected, residual), dim=1)
143
- return residual
143
+ return residual.detach()
144
144
 
145
145
  def forward_from_embeddings_at_layer(
146
146
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpcav
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Testing with PCA projected Concept Activation Vectors
5
5
  Author-email: Jianyu Yang <yztxwd@gmail.com>
6
6
  License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
@@ -0,0 +1,12 @@
1
+ tpcav/__init__.py,sha256=CpHijSyE1HMy8dlvdSaYrwN9gYMGDEJGDdsneNWnqdA,996
2
+ tpcav/cavs.py,sha256=m4Uhur5LwxdDfSpcznjoKCeqg4-UFjDEicpNglQ_7ss,19377
3
+ tpcav/concepts.py,sha256=_ht4UTu2EVJh52JGnKT3PEgDHk4Q-JCpNuHfFOVmzCw,12884
4
+ tpcav/helper.py,sha256=CcNFJEFG00pujUrthBoMInpIBz1mWIG3y5fztaiHO-c,9917
5
+ tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
6
+ tpcav/tpcav_model.py,sha256=N6YmwqeBR8-QVxkV2uoqhYb3WGGMXo0ND6N8V3dIYug,16573
7
+ tpcav/utils.py,sha256=s2TfC-YoH_xa73WuMqvtpuqzx6g3ne12hE90Yg9hToY,21502
8
+ tpcav-0.2.2.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
9
+ tpcav-0.2.2.dist-info/METADATA,sha256=lnI4YDN09v7vOcrc5a1qKfb_t-XNuKjRb31a1yILioY,3502
10
+ tpcav-0.2.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
11
+ tpcav-0.2.2.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
12
+ tpcav-0.2.2.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- tpcav/__init__.py,sha256=CpHijSyE1HMy8dlvdSaYrwN9gYMGDEJGDdsneNWnqdA,996
2
- tpcav/cavs.py,sha256=lqw-V45FCiCNUC8w7payuSPsbMBFy_qcOZaqPGA68js,19195
3
- tpcav/concepts.py,sha256=_ht4UTu2EVJh52JGnKT3PEgDHk4Q-JCpNuHfFOVmzCw,12884
4
- tpcav/helper.py,sha256=CcNFJEFG00pujUrthBoMInpIBz1mWIG3y5fztaiHO-c,9917
5
- tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
6
- tpcav/tpcav_model.py,sha256=XgNLPXr6_B-Dyb7RdgsUsFnrSK6oNjqqFPOjpz1wXmM,16564
7
- tpcav/utils.py,sha256=s2TfC-YoH_xa73WuMqvtpuqzx6g3ne12hE90Yg9hToY,21502
8
- tpcav-0.2.0.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
9
- tpcav-0.2.0.dist-info/METADATA,sha256=tz5MWTr_-veczwEOPDGUzLEwy3XU9WLlSw4IPOzddc0,3502
10
- tpcav-0.2.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
11
- tpcav-0.2.0.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
12
- tpcav-0.2.0.dist-info/RECORD,,
File without changes