tpcav 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tpcav/cavs.py +11 -9
- tpcav/tpcav_model.py +1 -1
- {tpcav-0.2.1.dist-info → tpcav-0.2.2.dist-info}/METADATA +1 -1
- tpcav-0.2.2.dist-info/RECORD +12 -0
- tpcav-0.2.1.dist-info/RECORD +0 -12
- {tpcav-0.2.1.dist-info → tpcav-0.2.2.dist-info}/WHEEL +0 -0
- {tpcav-0.2.1.dist-info → tpcav-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {tpcav-0.2.1.dist-info → tpcav-0.2.2.dist-info}/top_level.txt +0 -0
tpcav/cavs.py
CHANGED
|
@@ -205,8 +205,8 @@ class CavTrainer:
|
|
|
205
205
|
c, num_samples=num_samples
|
|
206
206
|
)
|
|
207
207
|
fscore, weight = _train(
|
|
208
|
-
concept_embeddings,
|
|
209
|
-
self.control_embeddings,
|
|
208
|
+
concept_embeddings.cpu(),
|
|
209
|
+
self.control_embeddings.cpu(),
|
|
210
210
|
Path(output_dir) / c.name,
|
|
211
211
|
self.penalty,
|
|
212
212
|
)
|
|
@@ -223,8 +223,8 @@ class CavTrainer:
|
|
|
223
223
|
res = pool.apply_async(
|
|
224
224
|
_train,
|
|
225
225
|
args=(
|
|
226
|
-
concept_embeddings,
|
|
227
|
-
self.control_embeddings,
|
|
226
|
+
concept_embeddings.cpu(),
|
|
227
|
+
self.control_embeddings.cpu(),
|
|
228
228
|
Path(output_dir) / c.name,
|
|
229
229
|
self.penalty,
|
|
230
230
|
),
|
|
@@ -413,6 +413,7 @@ def run_tpcav(
|
|
|
413
413
|
batch_size=8,
|
|
414
414
|
bws=None,
|
|
415
415
|
input_transform_func=helper.fasta_chrom_to_one_hot_seq,
|
|
416
|
+
fit_pca=True,
|
|
416
417
|
p=4
|
|
417
418
|
):
|
|
418
419
|
"""
|
|
@@ -475,11 +476,12 @@ def run_tpcav(
|
|
|
475
476
|
# create TPCAV model on top of the given model
|
|
476
477
|
tpcav_model = TPCAV(model, layer_name=layer_name)
|
|
477
478
|
# fit PCA on sampled all concept activations of the last builder (should have the most motifs)
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
479
|
+
if fit_pca:
|
|
480
|
+
tpcav_model.fit_pca(
|
|
481
|
+
concepts=motif_concept_builders[-1].all_concepts() + bed_builder.concepts if bed_builder is not None else motif_concept_builders[-1].all_concepts(),
|
|
482
|
+
num_samples_per_concept=num_samples_for_pca,
|
|
483
|
+
num_pc="full",
|
|
484
|
+
)
|
|
483
485
|
#torch.save(tpcav_model, output_path / "tpcav_model.pt")
|
|
484
486
|
|
|
485
487
|
# create trainer for computing CAVs
|
tpcav/tpcav_model.py
CHANGED
|
@@ -140,7 +140,7 @@ class TPCAV(torch.nn.Module):
|
|
|
140
140
|
residual, projected = self.project_activations(avs)
|
|
141
141
|
if projected is not None:
|
|
142
142
|
return torch.cat((projected, residual), dim=1)
|
|
143
|
-
return residual
|
|
143
|
+
return residual.detach()
|
|
144
144
|
|
|
145
145
|
def forward_from_embeddings_at_layer(
|
|
146
146
|
self,
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
tpcav/__init__.py,sha256=CpHijSyE1HMy8dlvdSaYrwN9gYMGDEJGDdsneNWnqdA,996
|
|
2
|
+
tpcav/cavs.py,sha256=m4Uhur5LwxdDfSpcznjoKCeqg4-UFjDEicpNglQ_7ss,19377
|
|
3
|
+
tpcav/concepts.py,sha256=_ht4UTu2EVJh52JGnKT3PEgDHk4Q-JCpNuHfFOVmzCw,12884
|
|
4
|
+
tpcav/helper.py,sha256=CcNFJEFG00pujUrthBoMInpIBz1mWIG3y5fztaiHO-c,9917
|
|
5
|
+
tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
|
|
6
|
+
tpcav/tpcav_model.py,sha256=N6YmwqeBR8-QVxkV2uoqhYb3WGGMXo0ND6N8V3dIYug,16573
|
|
7
|
+
tpcav/utils.py,sha256=s2TfC-YoH_xa73WuMqvtpuqzx6g3ne12hE90Yg9hToY,21502
|
|
8
|
+
tpcav-0.2.2.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
|
|
9
|
+
tpcav-0.2.2.dist-info/METADATA,sha256=lnI4YDN09v7vOcrc5a1qKfb_t-XNuKjRb31a1yILioY,3502
|
|
10
|
+
tpcav-0.2.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
11
|
+
tpcav-0.2.2.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
|
|
12
|
+
tpcav-0.2.2.dist-info/RECORD,,
|
tpcav-0.2.1.dist-info/RECORD
DELETED
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
tpcav/__init__.py,sha256=CpHijSyE1HMy8dlvdSaYrwN9gYMGDEJGDdsneNWnqdA,996
|
|
2
|
-
tpcav/cavs.py,sha256=qXeNiTqlrCPb824ivVvZNrhHSZ6YRx2xmjdZ9JTlAgM,19299
|
|
3
|
-
tpcav/concepts.py,sha256=_ht4UTu2EVJh52JGnKT3PEgDHk4Q-JCpNuHfFOVmzCw,12884
|
|
4
|
-
tpcav/helper.py,sha256=CcNFJEFG00pujUrthBoMInpIBz1mWIG3y5fztaiHO-c,9917
|
|
5
|
-
tpcav/logging_utils.py,sha256=wug7O_5IjxjhOpQr-aq90qKMEUp1EgcPkrv26d8li6Q,281
|
|
6
|
-
tpcav/tpcav_model.py,sha256=XgNLPXr6_B-Dyb7RdgsUsFnrSK6oNjqqFPOjpz1wXmM,16564
|
|
7
|
-
tpcav/utils.py,sha256=s2TfC-YoH_xa73WuMqvtpuqzx6g3ne12hE90Yg9hToY,21502
|
|
8
|
-
tpcav-0.2.1.dist-info/licenses/LICENSE,sha256=uC-2s0ObLnQzWFKH5aokHXo6CzxlJgeI0P3bIUCZgfU,1064
|
|
9
|
-
tpcav-0.2.1.dist-info/METADATA,sha256=XaYcUWr6humOfiUhwgKrccufSqDl_XiAutlO_wCf4lo,3502
|
|
10
|
-
tpcav-0.2.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
11
|
-
tpcav-0.2.1.dist-info/top_level.txt,sha256=I9veSE_WsuFYrXlcfRevqtatDyWWZNsWA3dV0CeBXVg,6
|
|
12
|
-
tpcav-0.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|