tpcav 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpcav
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: Testing with PCA projected Concept Activation Vectors
5
5
  Author-email: Jianyu Yang <yztxwd@gmail.com>
6
6
  License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "tpcav"
7
- version = "0.2.1"
7
+ version = "0.2.3"
8
8
  description = "Testing with PCA projected Concept Activation Vectors"
9
9
  authors = [{name = "Jianyu Yang", email = "yztxwd@gmail.com"},]
10
10
  readme = "README.md"
@@ -182,7 +182,9 @@ class TPCAVTest(unittest.TestCase):
182
182
 
183
183
  self.assertTupleEqual(batch[0].shape, (builder.batch_size, 4, 1024))
184
184
 
185
- tpcav_model = TPCAV(DummyModelSeq(), layer_name="layer1")
185
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
186
+
187
+ tpcav_model = TPCAV(DummyModelSeq(), layer_name="layer1").to(device)
186
188
  tpcav_model.fit_pca(
187
189
  concepts=builder.all_concepts(),
188
190
  num_samples_per_concept=10,
@@ -205,8 +205,8 @@ class CavTrainer:
205
205
  c, num_samples=num_samples
206
206
  )
207
207
  fscore, weight = _train(
208
- concept_embeddings,
209
- self.control_embeddings,
208
+ concept_embeddings.cpu(),
209
+ self.control_embeddings.cpu(),
210
210
  Path(output_dir) / c.name,
211
211
  self.penalty,
212
212
  )
@@ -223,8 +223,8 @@ class CavTrainer:
223
223
  res = pool.apply_async(
224
224
  _train,
225
225
  args=(
226
- concept_embeddings,
227
- self.control_embeddings,
226
+ concept_embeddings.cpu(),
227
+ self.control_embeddings.cpu(),
228
228
  Path(output_dir) / c.name,
229
229
  self.penalty,
230
230
  ),
@@ -413,6 +413,7 @@ def run_tpcav(
413
413
  batch_size=8,
414
414
  bws=None,
415
415
  input_transform_func=helper.fasta_chrom_to_one_hot_seq,
416
+ fit_pca=True,
416
417
  p=4
417
418
  ):
418
419
  """
@@ -475,11 +476,12 @@ def run_tpcav(
475
476
  # create TPCAV model on top of the given model
476
477
  tpcav_model = TPCAV(model, layer_name=layer_name)
477
478
  # fit PCA on sampled all concept activations of the last builder (should have the most motifs)
478
- tpcav_model.fit_pca(
479
- concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
480
- num_samples_per_concept=num_samples_for_pca,
481
- num_pc="full",
482
- )
479
+ if fit_pca:
480
+ tpcav_model.fit_pca(
481
+ concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
482
+ num_samples_per_concept=num_samples_for_pca,
483
+ num_pc="full",
484
+ )
483
485
  #torch.save(tpcav_model, output_path / "tpcav_model.pt")
484
486
 
485
487
  # create trainer for computing CAVs
@@ -41,6 +41,8 @@ def _construct_motif_concept_dataloader_from_control(
41
41
  motif_mode: str,
42
42
  batch_size: int,
43
43
  num_workers: int,
44
+ start_buffer=0,
45
+ end_buffer=0
44
46
  ) -> DataLoader:
45
47
  """Mirror the motif-based dataloader logic used in the TCAV script."""
46
48
  datasets = []
@@ -51,8 +53,8 @@ def _construct_motif_concept_dataloader_from_control(
51
53
  motif=motif,
52
54
  motif_mode=motif_mode,
53
55
  num_motifs=num_motifs,
54
- start_buffer=0,
55
- end_buffer=0,
56
+ start_buffer=start_buffer,
57
+ end_buffer=end_buffer,
56
58
  print_warning=False,
57
59
  infinite=False,
58
60
  )
@@ -190,7 +192,7 @@ class ConceptBuilder:
190
192
  added.append(concept)
191
193
  return added
192
194
 
193
- def build_motif_concept(self, motifs, concept_name, control_regions=None, motif_mode="pwm"):
195
+ def build_motif_concept(self, motifs, concept_name, control_regions=None, motif_mode="pwm", start_buffer=0, end_buffer=0):
194
196
  if control_regions is None:
195
197
  if not self.control_concepts:
196
198
  raise ValueError("Call build_control or pass control_regions first.")
@@ -207,6 +209,8 @@ class ConceptBuilder:
207
209
  motif_mode=motif_mode,
208
210
  batch_size=self.batch_size,
209
211
  num_workers=self.num_workers,
212
+ start_buffer=start_buffer,
213
+ end_buffer=end_buffer
210
214
  )
211
215
  concept = Concept(
212
216
  id=self._reserve_id(),
@@ -140,7 +140,7 @@ class TPCAV(torch.nn.Module):
140
140
  residual, projected = self.project_activations(avs)
141
141
  if projected is not None:
142
142
  return torch.cat((projected, residual), dim=1)
143
- return residual
143
+ return residual.detach()
144
144
 
145
145
  def forward_from_embeddings_at_layer(
146
146
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpcav
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: Testing with PCA projected Concept Activation Vectors
5
5
  Author-email: Jianyu Yang <yztxwd@gmail.com>
6
6
  License-Expression: MIT AND (Apache-2.0 OR BSD-2-Clause)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes