SURE-tools 2.1.83__py3-none-any.whl → 2.1.85__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +14 -8
- SURE/utils/custom_mlp.py +4 -1
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.85.dist-info}/METADATA +1 -1
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.85.dist-info}/RECORD +8 -8
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.85.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.85.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.85.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.85.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -74,6 +74,7 @@ class PerturbFlow(nn.Module):
|
|
|
74
74
|
use_cuda: bool = True,
|
|
75
75
|
seed: int = 42,
|
|
76
76
|
zero_bias: bool|list = True,
|
|
77
|
+
enumrate: bool = False,
|
|
77
78
|
dtype = torch.float32, # type: ignore
|
|
78
79
|
):
|
|
79
80
|
super().__init__()
|
|
@@ -102,6 +103,7 @@ class PerturbFlow(nn.Module):
|
|
|
102
103
|
else:
|
|
103
104
|
self.use_bias = [not zero_bias] * self.cell_factor_size
|
|
104
105
|
#self.use_bias = not zero_bias
|
|
106
|
+
self.enumerate = enumerate
|
|
105
107
|
|
|
106
108
|
self.codebook_weights = None
|
|
107
109
|
|
|
@@ -204,7 +206,7 @@ class PerturbFlow(nn.Module):
|
|
|
204
206
|
for i in np.arange(self.cell_factor_size):
|
|
205
207
|
if self.use_bias[i]:
|
|
206
208
|
self.cell_factor_effect.append(MLP(
|
|
207
|
-
[self.
|
|
209
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
208
210
|
activation=activate_fct,
|
|
209
211
|
output_activation=None,
|
|
210
212
|
post_layer_fct=post_layer_fct,
|
|
@@ -215,7 +217,7 @@ class PerturbFlow(nn.Module):
|
|
|
215
217
|
)
|
|
216
218
|
else:
|
|
217
219
|
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
218
|
-
[self.
|
|
220
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
219
221
|
activation=activate_fct,
|
|
220
222
|
output_activation=None,
|
|
221
223
|
post_layer_fct=post_layer_fct,
|
|
@@ -429,7 +431,10 @@ class PerturbFlow(nn.Module):
|
|
|
429
431
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
430
432
|
|
|
431
433
|
if self.cell_factor_size>0:
|
|
432
|
-
|
|
434
|
+
if self.enumerate:
|
|
435
|
+
zus = self._total_effects(zn_loc, us)
|
|
436
|
+
else:
|
|
437
|
+
zus = self._total_effects(zns, us)
|
|
433
438
|
zs = zns+zus
|
|
434
439
|
else:
|
|
435
440
|
zs = zns
|
|
@@ -631,7 +636,10 @@ class PerturbFlow(nn.Module):
|
|
|
631
636
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
632
637
|
# else:
|
|
633
638
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
634
|
-
|
|
639
|
+
if self.enumerate:
|
|
640
|
+
zus = self._total_effects(zn_loc, us)
|
|
641
|
+
else:
|
|
642
|
+
zus = self._total_effects(zns, us)
|
|
635
643
|
zs = zns+zus
|
|
636
644
|
else:
|
|
637
645
|
zs = zns
|
|
@@ -842,8 +850,7 @@ class PerturbFlow(nn.Module):
|
|
|
842
850
|
|
|
843
851
|
def _cell_response(self, xs, factor_idx, perturb):
|
|
844
852
|
#zns,_ = self.encoder_zn(xs)
|
|
845
|
-
|
|
846
|
-
zns = self._soft_assignments(xs)
|
|
853
|
+
zns,_ = self._get_basal_embedding(xs)
|
|
847
854
|
if perturb.ndim==2:
|
|
848
855
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
849
856
|
else:
|
|
@@ -877,8 +884,7 @@ class PerturbFlow(nn.Module):
|
|
|
877
884
|
return Z
|
|
878
885
|
|
|
879
886
|
def get_metacell_response(self, factor_idx, perturb):
|
|
880
|
-
|
|
881
|
-
zs = self._get_codebook_identity()
|
|
887
|
+
zs = self._get_codebook()
|
|
882
888
|
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
883
889
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
884
890
|
return tensor_to_numpy(ms)
|
SURE/utils/custom_mlp.py
CHANGED
|
@@ -239,7 +239,10 @@ class ZeroBiasMLP(nn.Module):
|
|
|
239
239
|
def forward(self, x):
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
|
-
|
|
242
|
+
if len(y.shape)==2:
|
|
243
|
+
mask[x[1][:,0]>0,:] = 1
|
|
244
|
+
elif len(y.shape)==3:
|
|
245
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
243
246
|
return y*mask
|
|
244
247
|
|
|
245
248
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=hu4kD6PphVqAG8GlITk44trwVfE6jz72RhQlTYRbi-o,54991
|
|
2
2
|
SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -14,12 +14,12 @@ SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
|
14
14
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
15
|
SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
|
|
16
16
|
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
|
-
SURE/utils/custom_mlp.py,sha256=
|
|
17
|
+
SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.85.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.85.dist-info/METADATA,sha256=I2sMctindGA7kNf_TAdn64Oq6GL_QcGDuKLV6B9zQKY,2678
|
|
22
|
+
sure_tools-2.1.85.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.85.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.85.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.85.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|