SURE-tools 2.1.43__py3-none-any.whl → 2.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +51 -10
- {sure_tools-2.1.43.dist-info → sure_tools-2.1.44.dist-info}/METADATA +1 -1
- {sure_tools-2.1.43.dist-info → sure_tools-2.1.44.dist-info}/RECORD +7 -7
- {sure_tools-2.1.43.dist-info → sure_tools-2.1.44.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.43.dist-info → sure_tools-2.1.44.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.43.dist-info → sure_tools-2.1.44.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.43.dist-info → sure_tools-2.1.44.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -189,10 +189,29 @@ class PerturbFlow(nn.Module):
|
|
|
189
189
|
use_cuda=self.use_cuda,
|
|
190
190
|
)
|
|
191
191
|
|
|
192
|
-
self.encoder_zn = MLP(
|
|
193
|
-
|
|
192
|
+
#self.encoder_zn = MLP(
|
|
193
|
+
# [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
|
|
194
|
+
# activation=activate_fct,
|
|
195
|
+
# output_activation=[None, Exp],
|
|
196
|
+
# post_layer_fct=post_layer_fct,
|
|
197
|
+
# post_act_fct=post_act_fct,
|
|
198
|
+
# allow_broadcast=self.allow_broadcast,
|
|
199
|
+
# use_cuda=self.use_cuda,
|
|
200
|
+
#)
|
|
201
|
+
|
|
202
|
+
self.encoder_zn_loc = MLP(
|
|
203
|
+
[self.input_size] + hidden_sizes + [self.input_size * latent_dim],
|
|
204
|
+
activation=activate_fct,
|
|
205
|
+
output_activation=None,
|
|
206
|
+
post_layer_fct=post_layer_fct,
|
|
207
|
+
post_act_fct=post_act_fct,
|
|
208
|
+
allow_broadcast=self.allow_broadcast,
|
|
209
|
+
use_cuda=self.use_cuda,
|
|
210
|
+
)
|
|
211
|
+
self.encoder_zn_scale = MLP(
|
|
212
|
+
[self.input_size] + hidden_sizes + [latent_dim],
|
|
194
213
|
activation=activate_fct,
|
|
195
|
-
output_activation=
|
|
214
|
+
output_activation=Exp,
|
|
196
215
|
post_layer_fct=post_layer_fct,
|
|
197
216
|
post_act_fct=post_act_fct,
|
|
198
217
|
allow_broadcast=self.allow_broadcast,
|
|
@@ -378,7 +397,11 @@ class PerturbFlow(nn.Module):
|
|
|
378
397
|
|
|
379
398
|
def guide1(self, xs):
|
|
380
399
|
with pyro.plate('data'):
|
|
381
|
-
|
|
400
|
+
batch_size = xs.shape[0]
|
|
401
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
402
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
403
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
404
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
382
405
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
383
406
|
|
|
384
407
|
alpha = self.encoder_n(zns)
|
|
@@ -466,7 +489,11 @@ class PerturbFlow(nn.Module):
|
|
|
466
489
|
|
|
467
490
|
def guide2(self, xs, us=None):
|
|
468
491
|
with pyro.plate('data'):
|
|
469
|
-
|
|
492
|
+
batch_size = xs.shape[0]
|
|
493
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
494
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
495
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
496
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
470
497
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
471
498
|
|
|
472
499
|
alpha = self.encoder_n(zns)
|
|
@@ -561,7 +588,11 @@ class PerturbFlow(nn.Module):
|
|
|
561
588
|
def guide3(self, xs, ys, embeds=None):
|
|
562
589
|
with pyro.plate('data'):
|
|
563
590
|
if embeds is None:
|
|
564
|
-
|
|
591
|
+
batch_size = xs.shape[0]
|
|
592
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
593
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
594
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
595
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
565
596
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
566
597
|
|
|
567
598
|
def model4(self, xs, us, ys, embeds=None):
|
|
@@ -663,7 +694,11 @@ class PerturbFlow(nn.Module):
|
|
|
663
694
|
def guide4(self, xs, us, ys, embeds=None):
|
|
664
695
|
with pyro.plate('data'):
|
|
665
696
|
if embeds is None:
|
|
666
|
-
|
|
697
|
+
batch_size = xs.shape[0]
|
|
698
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
699
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
700
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
701
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
667
702
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
668
703
|
|
|
669
704
|
def _total_effects(self, zns, us):
|
|
@@ -692,7 +727,11 @@ class PerturbFlow(nn.Module):
|
|
|
692
727
|
return cb
|
|
693
728
|
|
|
694
729
|
def _get_basal_embedding(self, xs):
|
|
695
|
-
zns, _ = self.encoder_zn(xs)
|
|
730
|
+
#zns, _ = self.encoder_zn(xs)
|
|
731
|
+
batch_size = xs.shape[0]
|
|
732
|
+
zns = self.encoder_zn_loc(xs)
|
|
733
|
+
zns_3d = zns.view(batch_size, self.input_size, self.latent_dim)
|
|
734
|
+
zns = zns_3d.sum(dim=1)
|
|
696
735
|
return zns
|
|
697
736
|
|
|
698
737
|
def get_basal_embedding(self,
|
|
@@ -731,7 +770,8 @@ class PerturbFlow(nn.Module):
|
|
|
731
770
|
if self.supervised_mode:
|
|
732
771
|
alpha = self.encoder_n(xs)
|
|
733
772
|
else:
|
|
734
|
-
zns,_ = self.encoder_zn(xs)
|
|
773
|
+
#zns,_ = self.encoder_zn(xs)
|
|
774
|
+
zns = self._get_basal_embedding(xs)
|
|
735
775
|
alpha = self.encoder_n(zns)
|
|
736
776
|
return alpha
|
|
737
777
|
|
|
@@ -801,7 +841,8 @@ class PerturbFlow(nn.Module):
|
|
|
801
841
|
return A
|
|
802
842
|
|
|
803
843
|
def _cell_response(self, xs, factor_idx, perturb):
|
|
804
|
-
zns,_ = self.encoder_zn(xs)
|
|
844
|
+
#zns,_ = self.encoder_zn(xs)
|
|
845
|
+
zns = self._get_basal_embedding(xs)
|
|
805
846
|
if perturb.ndim==2:
|
|
806
847
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
807
848
|
else:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=EbbEQBf9HQ6XeUw3g4DKFbe2IVCJRGmCwjQcHwCnQxU,54309
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.44.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.44.dist-info/METADATA,sha256=GOqZOvGRfh--3bd3Bgpkiv7E05gfRoDNflkE3c4Mt6s,2651
|
|
22
|
+
sure_tools-2.1.44.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.44.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.44.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.44.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|