SURE-tools 2.1.46__py3-none-any.whl → 2.1.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +13 -9
- SURE/perturb/perturb.py +1 -1
- SURE/utils/__init__.py +1 -1
- SURE/utils/custom_mlp.py +35 -1
- {sure_tools-2.1.46.dist-info → sure_tools-2.1.48.dist-info}/METADATA +1 -1
- {sure_tools-2.1.46.dist-info → sure_tools-2.1.48.dist-info}/RECORD +10 -10
- {sure_tools-2.1.46.dist-info → sure_tools-2.1.48.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.46.dist-info → sure_tools-2.1.48.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.46.dist-info → sure_tools-2.1.48.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.46.dist-info → sure_tools-2.1.48.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -378,7 +378,8 @@ class PerturbFlow(nn.Module):
|
|
|
378
378
|
|
|
379
379
|
def guide1(self, xs):
|
|
380
380
|
with pyro.plate('data'):
|
|
381
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
381
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
382
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
382
383
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
383
384
|
|
|
384
385
|
alpha = self.encoder_n(zns)
|
|
@@ -466,7 +467,8 @@ class PerturbFlow(nn.Module):
|
|
|
466
467
|
|
|
467
468
|
def guide2(self, xs, us=None):
|
|
468
469
|
with pyro.plate('data'):
|
|
469
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
470
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
471
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
470
472
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
471
473
|
|
|
472
474
|
alpha = self.encoder_n(zns)
|
|
@@ -561,7 +563,8 @@ class PerturbFlow(nn.Module):
|
|
|
561
563
|
def guide3(self, xs, ys, embeds=None):
|
|
562
564
|
with pyro.plate('data'):
|
|
563
565
|
if embeds is None:
|
|
564
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
566
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
567
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
565
568
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
566
569
|
|
|
567
570
|
def model4(self, xs, us, ys, embeds=None):
|
|
@@ -663,7 +666,8 @@ class PerturbFlow(nn.Module):
|
|
|
663
666
|
def guide4(self, xs, us, ys, embeds=None):
|
|
664
667
|
with pyro.plate('data'):
|
|
665
668
|
if embeds is None:
|
|
666
|
-
zn_loc, zn_scale = self.encoder_zn(xs)
|
|
669
|
+
#zn_loc, zn_scale = self.encoder_zn(xs)
|
|
670
|
+
zn_loc, zn_scale = self._get_basal_embedding(xs)
|
|
667
671
|
zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
|
|
668
672
|
|
|
669
673
|
def _total_effects(self, zns, us):
|
|
@@ -692,8 +696,8 @@ class PerturbFlow(nn.Module):
|
|
|
692
696
|
return cb
|
|
693
697
|
|
|
694
698
|
def _get_basal_embedding(self, xs):
|
|
695
|
-
|
|
696
|
-
return
|
|
699
|
+
loc, scale = self.encoder_zn(xs)
|
|
700
|
+
return loc, scale
|
|
697
701
|
|
|
698
702
|
def get_basal_embedding(self,
|
|
699
703
|
xs,
|
|
@@ -720,7 +724,7 @@ class PerturbFlow(nn.Module):
|
|
|
720
724
|
Z = []
|
|
721
725
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
722
726
|
for X_batch, _ in dataloader:
|
|
723
|
-
zns = self._get_basal_embedding(X_batch)
|
|
727
|
+
zns,_ = self._get_basal_embedding(X_batch)
|
|
724
728
|
Z.append(tensor_to_numpy(zns))
|
|
725
729
|
pbar.update(1)
|
|
726
730
|
|
|
@@ -732,7 +736,7 @@ class PerturbFlow(nn.Module):
|
|
|
732
736
|
alpha = self.encoder_n(xs)
|
|
733
737
|
else:
|
|
734
738
|
#zns,_ = self.encoder_zn(xs)
|
|
735
|
-
zns = self._get_basal_embedding(xs)
|
|
739
|
+
zns,_ = self._get_basal_embedding(xs)
|
|
736
740
|
alpha = self.encoder_n(zns)
|
|
737
741
|
return alpha
|
|
738
742
|
|
|
@@ -803,7 +807,7 @@ class PerturbFlow(nn.Module):
|
|
|
803
807
|
|
|
804
808
|
def _cell_response(self, xs, factor_idx, perturb):
|
|
805
809
|
#zns,_ = self.encoder_zn(xs)
|
|
806
|
-
zns = self._get_basal_embedding(xs)
|
|
810
|
+
zns,_ = self._get_basal_embedding(xs)
|
|
807
811
|
if perturb.ndim==2:
|
|
808
812
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
809
813
|
else:
|
SURE/perturb/perturb.py
CHANGED
|
@@ -9,7 +9,7 @@ class LabelMatrix:
|
|
|
9
9
|
def __init__(self):
|
|
10
10
|
self.labels_ = None
|
|
11
11
|
|
|
12
|
-
def fit_transform(self, labels, control_label=None, sep_pattern=r'[
|
|
12
|
+
def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
|
|
13
13
|
if speedup=='none':
|
|
14
14
|
mat, self.labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
|
|
15
15
|
elif speedup=='vectorize':
|
SURE/utils/__init__.py
CHANGED
SURE/utils/custom_mlp.py
CHANGED
|
@@ -241,4 +241,38 @@ class ZeroBiasMLP(nn.Module):
|
|
|
241
241
|
mask = torch.zeros_like(y)
|
|
242
242
|
mask[x[1][:,0]>0,:] = 1
|
|
243
243
|
return y*mask
|
|
244
|
-
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class HDMLP(nn.Module):
|
|
247
|
+
def __init__(
|
|
248
|
+
self,
|
|
249
|
+
input_size,
|
|
250
|
+
hidden_sizes,
|
|
251
|
+
output_depth,
|
|
252
|
+
activation=nn.ReLU,
|
|
253
|
+
output_activation=None,
|
|
254
|
+
post_layer_fct=lambda layer_ix, total_layers, layer: None,
|
|
255
|
+
post_act_fct=lambda layer_ix, total_layers, layer: None,
|
|
256
|
+
allow_broadcast=False,
|
|
257
|
+
use_cuda=False,
|
|
258
|
+
):
|
|
259
|
+
# init the module object
|
|
260
|
+
super().__init__()
|
|
261
|
+
self.mlp = MLP(mlp_sizes=[1] + hidden_sizes + [output_depth],
|
|
262
|
+
activation=activation,
|
|
263
|
+
output_activation=output_activation,
|
|
264
|
+
post_layer_fct=post_layer_fct,
|
|
265
|
+
post_act_fct=post_act_fct,
|
|
266
|
+
allow_broadcast=allow_broadcast,
|
|
267
|
+
use_cuda=use_cuda,
|
|
268
|
+
bias=True)
|
|
269
|
+
self.input_size=input_size
|
|
270
|
+
self.output_depth=output_depth
|
|
271
|
+
|
|
272
|
+
# pass through our sequential for the output!
|
|
273
|
+
def forward(self, x):
|
|
274
|
+
batch_size, n = x.shape
|
|
275
|
+
x = x.view(batch_size * n, 1)
|
|
276
|
+
out = self.mlp(x)
|
|
277
|
+
out = out.view(batch_size, n, self.output_depth)
|
|
278
|
+
return out
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=0Zalme94qlUBnom7lY3OmUZcyLPd2F69QJ29LjuCpd8,52667
|
|
2
2
|
SURE/SURE.py,sha256=ko15a9BhvUqHviogZ0YCdTQjM-2zqkO9OvHZSpnGbg0,47458
|
|
3
3
|
SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
|
|
4
4
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -12,14 +12,14 @@ SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
|
|
|
12
12
|
SURE/flow/flow_stats.py,sha256=_pF7m4-87SKlCHVtVmx3LG2bAGVXOnAfEgMzLhLx4Io,10910
|
|
13
13
|
SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
|
|
14
14
|
SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
|
|
15
|
-
SURE/perturb/perturb.py,sha256=
|
|
16
|
-
SURE/utils/__init__.py,sha256=
|
|
17
|
-
SURE/utils/custom_mlp.py,sha256=
|
|
15
|
+
SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
|
|
16
|
+
SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
|
|
17
|
+
SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
|
|
18
18
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
19
19
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
25
|
-
sure_tools-2.1.
|
|
20
|
+
sure_tools-2.1.48.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
21
|
+
sure_tools-2.1.48.dist-info/METADATA,sha256=sjCCzYJQAUe3MpF9a4fsEoXBRhs83VC5ez5ar4z-cX4,2651
|
|
22
|
+
sure_tools-2.1.48.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
23
|
+
sure_tools-2.1.48.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
24
|
+
sure_tools-2.1.48.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
25
|
+
sure_tools-2.1.48.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|