SURE-tools 2.1.83__py3-none-any.whl → 2.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +6 -8
- SURE/utils/custom_mlp.py +4 -1
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.84.dist-info}/METADATA +1 -1
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.84.dist-info}/RECORD +8 -8
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.84.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.84.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.84.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.83.dist-info → sure_tools-2.1.84.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -204,7 +204,7 @@ class PerturbFlow(nn.Module):
|
|
|
204
204
|
for i in np.arange(self.cell_factor_size):
|
|
205
205
|
if self.use_bias[i]:
|
|
206
206
|
self.cell_factor_effect.append(MLP(
|
|
207
|
-
[self.
|
|
207
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
208
208
|
activation=activate_fct,
|
|
209
209
|
output_activation=None,
|
|
210
210
|
post_layer_fct=post_layer_fct,
|
|
@@ -215,7 +215,7 @@ class PerturbFlow(nn.Module):
|
|
|
215
215
|
)
|
|
216
216
|
else:
|
|
217
217
|
self.cell_factor_effect.append(ZeroBiasMLP(
|
|
218
|
-
[self.
|
|
218
|
+
[self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
|
|
219
219
|
activation=activate_fct,
|
|
220
220
|
output_activation=None,
|
|
221
221
|
post_layer_fct=post_layer_fct,
|
|
@@ -429,7 +429,7 @@ class PerturbFlow(nn.Module):
|
|
|
429
429
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
430
430
|
|
|
431
431
|
if self.cell_factor_size>0:
|
|
432
|
-
zus = self._total_effects(
|
|
432
|
+
zus = self._total_effects(zn_loc, us)
|
|
433
433
|
zs = zns+zus
|
|
434
434
|
else:
|
|
435
435
|
zs = zns
|
|
@@ -631,7 +631,7 @@ class PerturbFlow(nn.Module):
|
|
|
631
631
|
# zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
632
632
|
# else:
|
|
633
633
|
# zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
634
|
-
zus = self._total_effects(
|
|
634
|
+
zus = self._total_effects(zn_loc, us)
|
|
635
635
|
zs = zns+zus
|
|
636
636
|
else:
|
|
637
637
|
zs = zns
|
|
@@ -842,8 +842,7 @@ class PerturbFlow(nn.Module):
|
|
|
842
842
|
|
|
843
843
|
def _cell_response(self, xs, factor_idx, perturb):
|
|
844
844
|
#zns,_ = self.encoder_zn(xs)
|
|
845
|
-
|
|
846
|
-
zns = self._soft_assignments(xs)
|
|
845
|
+
zns,_ = self._get_basal_embedding(xs)
|
|
847
846
|
if perturb.ndim==2:
|
|
848
847
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
849
848
|
else:
|
|
@@ -877,8 +876,7 @@ class PerturbFlow(nn.Module):
|
|
|
877
876
|
return Z
|
|
878
877
|
|
|
879
878
|
def get_metacell_response(self, factor_idx, perturb):
|
|
880
|
-
|
|
881
|
-
zs = self._get_codebook_identity()
|
|
879
|
+
zs = self._get_codebook()
|
|
882
880
|
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
883
881
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
884
882
|
return tensor_to_numpy(ms)
|
SURE/utils/custom_mlp.py
CHANGED
|
@@ -239,7 +239,10 @@ class ZeroBiasMLP(nn.Module):
|
|
|
239
239
|
def forward(self, x):
|
|
240
240
|
y = self.mlp(x)
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
|
-
|
|
242
|
+
if len(y.shape)==2:
|
|
243
|
+
mask[x[1][:,0]>0,:] = 1
|
|
244
|
+
elif len(y.shape)==3:
|
|
245
|
+
mask[:,x[1][:,0]>0,:] = 1
|
|
243
246
|
return y*mask
|
|
244
247
|
|
|
245
248
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=11MmXAMHovmvAHHUvBYGMKo-GS2mOiWxBgkzhq-yZE4,54683
|
|
2
2
|
SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -14,12 +14,12 @@ SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
|
14
14
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
15
|
SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
|
|
16
16
|
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
|
-
SURE/utils/custom_mlp.py,sha256=
|
|
17
|
+
SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.84.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.84.dist-info/METADATA,sha256=l_eZxgCGKk6isE5ZgbyyRyMI1ZuGYgSK_Dpl_zFqOxs,2678
|
|
22
|
+
sure_tools-2.1.84.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.84.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.84.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.84.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|