SURE-tools 2.1.42__py3-none-any.whl → 2.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +58 -13
- {sure_tools-2.1.42.dist-info → sure_tools-2.1.44.dist-info}/METADATA +1 -1
- {sure_tools-2.1.42.dist-info → sure_tools-2.1.44.dist-info}/RECORD +7 -7
- {sure_tools-2.1.42.dist-info → sure_tools-2.1.44.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.42.dist-info → sure_tools-2.1.44.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.42.dist-info → sure_tools-2.1.44.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.42.dist-info → sure_tools-2.1.44.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -73,7 +73,7 @@ class PerturbFlow(nn.Module):
|
|
|
73
73
|
config_enum: str = 'parallel',
|
|
74
74
|
use_cuda: bool = False,
|
|
75
75
|
seed: int = 42,
|
|
76
|
-
zero_bias: bool = True,
|
|
76
|
+
zero_bias: bool|list = True,
|
|
77
77
|
dtype = torch.float32, # type: ignore
|
|
78
78
|
):
|
|
79
79
|
super().__init__()
|
|
@@ -97,7 +97,11 @@ class PerturbFlow(nn.Module):
|
|
|
97
97
|
self.post_layer_fct = post_layer_fct
|
|
98
98
|
self.post_act_fct = post_act_fct
|
|
99
99
|
self.hidden_layer_activation = hidden_layer_activation
|
|
100
|
-
|
|
100
|
+
if type(zero_bias) == list:
|
|
101
|
+
self.use_bias = [not x for x in zero_bias]
|
|
102
|
+
else:
|
|
103
|
+
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
|
+
#self.use_bias = not zero_bias
|
|
101
105
|
|
|
102
106
|
self.codebook_weights = None
|
|
103
107
|
|
|
@@ -185,10 +189,29 @@ class PerturbFlow(nn.Module):
|
|
|
185
189
|
use_cuda=self.use_cuda,
|
|
186
190
|
)
|
|
187
191
|
|
|
188
|
-
self.encoder_zn = MLP(
|
|
189
|
-
|
|
192
|
+
#self.encoder_zn = MLP(
|
|
193
|
+
# [self.input_size] + hidden_sizes + [[latent_dim, latent_dim]],
|
|
194
|
+
# activation=activate_fct,
|
|
195
|
+
# output_activation=[None, Exp],
|
|
196
|
+
# post_layer_fct=post_layer_fct,
|
|
197
|
+
# post_act_fct=post_act_fct,
|
|
198
|
+
# allow_broadcast=self.allow_broadcast,
|
|
199
|
+
# use_cuda=self.use_cuda,
|
|
200
|
+
#)
|
|
201
|
+
|
|
202
|
+
self.encoder_zn_loc = MLP(
|
|
203
|
+
[self.input_size] + hidden_sizes + [self.input_size * latent_dim],
|
|
204
|
+
activation=activate_fct,
|
|
205
|
+
output_activation=None,
|
|
206
|
+
post_layer_fct=post_layer_fct,
|
|
207
|
+
post_act_fct=post_act_fct,
|
|
208
|
+
allow_broadcast=self.allow_broadcast,
|
|
209
|
+
use_cuda=self.use_cuda,
|
|
210
|
+
)
|
|
211
|
+
self.encoder_zn_scale = MLP(
|
|
212
|
+
[self.input_size] + hidden_sizes + [latent_dim],
|
|
190
213
|
activation=activate_fct,
|
|
191
|
-
output_activation=
|
|
214
|
+
output_activation=Exp,
|
|
192
215
|
post_layer_fct=post_layer_fct,
|
|
193
216
|
post_act_fct=post_act_fct,
|
|
194
217
|
allow_broadcast=self.allow_broadcast,
|
|
@@ -198,7 +221,7 @@ class PerturbFlow(nn.Module):
|
|
|
198
221
|
if self.cell_factor_size>0:
|
|
199
222
|
self.cell_factor_effect = nn.ModuleList()
|
|
200
223
|
for i in np.arange(self.cell_factor_size):
|
|
201
|
-
if self.use_bias:
|
|
224
|
+
if self.use_bias[i]:
|
|
202
225
|
self.cell_factor_effect.append(MLP(
|
|
203
226
|
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
204
227
|
activation=activate_fct,
|
|
@@ -374,7 +397,11 @@ class PerturbFlow(nn.Module):
|
|
|
374
397
|
|
|
375
398
|
def guide1(self, xs):
|
|
376
399
|
with pyro.plate('data'):
|
|
377
|
-
|
|
400
|
+
batch_size = xs.shape[0]
|
|
401
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
402
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
403
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
404
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
378
405
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
379
406
|
|
|
380
407
|
alpha = self.encoder_n(zns)
|
|
@@ -462,7 +489,11 @@ class PerturbFlow(nn.Module):
|
|
|
462
489
|
|
|
463
490
|
def guide2(self, xs, us=None):
|
|
464
491
|
with pyro.plate('data'):
|
|
465
|
-
|
|
492
|
+
batch_size = xs.shape[0]
|
|
493
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
494
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
495
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
496
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
466
497
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
467
498
|
|
|
468
499
|
alpha = self.encoder_n(zns)
|
|
@@ -557,7 +588,11 @@ class PerturbFlow(nn.Module):
|
|
|
557
588
|
def guide3(self, xs, ys, embeds=None):
|
|
558
589
|
with pyro.plate('data'):
|
|
559
590
|
if embeds is None:
|
|
560
|
-
|
|
591
|
+
batch_size = xs.shape[0]
|
|
592
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
593
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
594
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
595
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
561
596
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
562
597
|
|
|
563
598
|
def model4(self, xs, us, ys, embeds=None):
|
|
@@ -659,7 +694,11 @@ class PerturbFlow(nn.Module):
|
|
|
659
694
|
def guide4(self, xs, us, ys, embeds=None):
|
|
660
695
|
with pyro.plate('data'):
|
|
661
696
|
if embeds is None:
|
|
662
|
-
|
|
697
|
+
batch_size = xs.shape[0]
|
|
698
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
699
|
+
zn_loc, zn_scale = self.encoder_zn_loc(xs), self.encoder_zn_scale(xs)
|
|
700
|
+
zn_loc_3d = zn_loc.view(batch_size, self.input_size, self.latent_dim)
|
|
701
|
+
zn_loc = zn_loc_3d.sum(dim=1)
|
|
663
702
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
664
703
|
|
|
665
704
|
def _total_effects(self, zns, us):
|
|
@@ -688,7 +727,11 @@ class PerturbFlow(nn.Module):
|
|
|
688
727
|
return cb
|
|
689
728
|
|
|
690
729
|
def _get_basal_embedding(self, xs):
|
|
691
|
-
zns, _ = self.encoder_zn(xs)
|
|
730
|
+
#zns, _ = self.encoder_zn(xs)
|
|
731
|
+
batch_size = xs.shape[0]
|
|
732
|
+
zns = self.encoder_zn_loc(xs)
|
|
733
|
+
zns_3d = zns.view(batch_size, self.input_size, self.latent_dim)
|
|
734
|
+
zns = zns_3d.sum(dim=1)
|
|
692
735
|
return zns
|
|
693
736
|
|
|
694
737
|
def get_basal_embedding(self,
|
|
@@ -727,7 +770,8 @@ class PerturbFlow(nn.Module):
|
|
|
727
770
|
if self.supervised_mode:
|
|
728
771
|
alpha = self.encoder_n(xs)
|
|
729
772
|
else:
|
|
730
|
-
zns,_ = self.encoder_zn(xs)
|
|
773
|
+
#zns,_ = self.encoder_zn(xs)
|
|
774
|
+
zns = self._get_basal_embedding(xs)
|
|
731
775
|
alpha = self.encoder_n(zns)
|
|
732
776
|
return alpha
|
|
733
777
|
|
|
@@ -797,7 +841,8 @@ class PerturbFlow(nn.Module):
|
|
|
797
841
|
return A
|
|
798
842
|
|
|
799
843
|
def _cell_response(self, xs, factor_idx, perturb):
|
|
800
|
-
zns,_ = self.encoder_zn(xs)
|
|
844
|
+
#zns,_ = self.encoder_zn(xs)
|
|
845
|
+
zns = self._get_basal_embedding(xs)
|
|
801
846
|
if perturb.ndim==2:
|
|
802
847
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
803
848
|
else:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=EbbEQBf9HQ6XeUw3g4DKFbe2IVCJRGmCwjQcHwCnQxU,54309
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -17,9 +17,9 @@ SURE/utils/__init__.py,sha256=QJUOfrXzdWSmoM0P3LH8oKEHttzCWqpDy2UF0F0dtN4,673
|
|
|
17
17
|
SURE/utils/custom_mlp.py,sha256=rHnx9jEef02zfCUdbYVCmbuHcDdIBmRgt__wpdpZvYg,8104
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.44.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.44.dist-info/METADATA,sha256=GOqZOvGRfh--3bd3Bgpkiv7E05gfRoDNflkE3c4Mt6s,2651
|
|
22
|
+
sure_tools-2.1.44.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.44.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.44.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.44.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|