SURE-tools 2.1.0__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +5 -14
- SURE/SURE.py +11 -28
- {sure_tools-2.1.0.dist-info → sure_tools-2.1.1.dist-info}/METADATA +1 -1
- {sure_tools-2.1.0.dist-info → sure_tools-2.1.1.dist-info}/RECORD +8 -8
- {sure_tools-2.1.0.dist-info → sure_tools-2.1.1.dist-info}/WHEEL +0 -0
- {sure_tools-2.1.0.dist-info → sure_tools-2.1.1.dist-info}/entry_points.txt +0 -0
- {sure_tools-2.1.0.dist-info → sure_tools-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.1.0.dist-info → sure_tools-2.1.1.dist-info}/top_level.txt +0 -0
SURE/PerturbFlow.py
CHANGED
|
@@ -97,7 +97,6 @@ class PerturbFlow(nn.Module):
|
|
|
97
97
|
input_size: int,
|
|
98
98
|
codebook_size: int = 200,
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
|
-
cell_factor_names: list = None,
|
|
101
100
|
supervised_mode: bool = False,
|
|
102
101
|
z_dim: int = 10,
|
|
103
102
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
@@ -135,7 +134,6 @@ class PerturbFlow(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.cell_factor_names = cell_factor_names
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -813,11 +811,8 @@ class PerturbFlow(nn.Module):
|
|
|
813
811
|
A = np.concatenate(A)
|
|
814
812
|
return A
|
|
815
813
|
|
|
816
|
-
def
|
|
817
|
-
zns,_ = self.encoder_zn(xs)
|
|
818
|
-
if type(factor_idx) == str:
|
|
819
|
-
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
820
|
-
|
|
814
|
+
def _cell_response(self, xs, factor_idx, perturb):
|
|
815
|
+
zns,_ = self.encoder_zn(xs)
|
|
821
816
|
if perturb.ndim==2:
|
|
822
817
|
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
823
818
|
else:
|
|
@@ -825,7 +820,7 @@ class PerturbFlow(nn.Module):
|
|
|
825
820
|
|
|
826
821
|
return ms
|
|
827
822
|
|
|
828
|
-
def
|
|
823
|
+
def get_cell_response(self,
|
|
829
824
|
xs,
|
|
830
825
|
factor_idx,
|
|
831
826
|
perturb,
|
|
@@ -843,7 +838,7 @@ class PerturbFlow(nn.Module):
|
|
|
843
838
|
Z = []
|
|
844
839
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
845
840
|
for X_batch, P_batch, _ in dataloader:
|
|
846
|
-
zns = self.
|
|
841
|
+
zns = self._cell_response(X_batch, factor_idx, P_batch)
|
|
847
842
|
Z.append(tensor_to_numpy(zns))
|
|
848
843
|
pbar.update(1)
|
|
849
844
|
|
|
@@ -852,11 +847,7 @@ class PerturbFlow(nn.Module):
|
|
|
852
847
|
|
|
853
848
|
def get_metacell_response(self, factor_idx, perturb):
|
|
854
849
|
zs = self._get_codebook()
|
|
855
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
856
|
-
|
|
857
|
-
if type(factor_idx) == str:
|
|
858
|
-
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
859
|
-
|
|
850
|
+
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
860
851
|
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
861
852
|
return tensor_to_numpy(ms)
|
|
862
853
|
|
SURE/SURE.py
CHANGED
|
@@ -97,7 +97,6 @@ class SURE(nn.Module):
|
|
|
97
97
|
input_size: int,
|
|
98
98
|
codebook_size: int = 200,
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
|
-
cell_factor_names: list = None,
|
|
101
100
|
supervised_mode: bool = False,
|
|
102
101
|
z_dim: int = 10,
|
|
103
102
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
@@ -135,7 +134,6 @@ class SURE(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.cell_factor_names = cell_factor_names
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -234,18 +232,15 @@ class SURE(nn.Module):
|
|
|
234
232
|
)
|
|
235
233
|
|
|
236
234
|
if self.cell_factor_size>0:
|
|
237
|
-
self.cell_factor_effect =
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
use_cuda=self.use_cuda,
|
|
247
|
-
)
|
|
248
|
-
)
|
|
235
|
+
self.cell_factor_effect = MLP(
|
|
236
|
+
[self.z_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.z_dim],
|
|
237
|
+
activation=activate_fct,
|
|
238
|
+
output_activation=None,
|
|
239
|
+
post_layer_fct=post_layer_fct,
|
|
240
|
+
post_act_fct=post_act_fct,
|
|
241
|
+
allow_broadcast=self.allow_broadcast,
|
|
242
|
+
use_cuda=self.use_cuda,
|
|
243
|
+
)
|
|
249
244
|
|
|
250
245
|
self.decoder_concentrate = MLP(
|
|
251
246
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -449,13 +444,7 @@ class SURE(nn.Module):
|
|
|
449
444
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
450
445
|
|
|
451
446
|
if self.cell_factor_size>0:
|
|
452
|
-
|
|
453
|
-
zus = None
|
|
454
|
-
for i in np.arange(self.cell_factor_size):
|
|
455
|
-
if i==0:
|
|
456
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
457
|
-
else:
|
|
458
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
447
|
+
zus = self.cell_factor_effect([zns,us])
|
|
459
448
|
zs = zns+zus
|
|
460
449
|
else:
|
|
461
450
|
zs = zns
|
|
@@ -645,13 +634,7 @@ class SURE(nn.Module):
|
|
|
645
634
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
646
635
|
|
|
647
636
|
if self.cell_factor_size>0:
|
|
648
|
-
|
|
649
|
-
zus = None
|
|
650
|
-
for i in np.arange(self.cell_factor_size):
|
|
651
|
-
if i==0:
|
|
652
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
653
|
-
else:
|
|
654
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
637
|
+
zus = self.decoder_undesired([zns,us])
|
|
655
638
|
zs = zns+zus
|
|
656
639
|
else:
|
|
657
640
|
zs = zns
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
SURE/PerturbFlow.py,sha256=
|
|
2
|
-
SURE/SURE.py,sha256=
|
|
1
|
+
SURE/PerturbFlow.py,sha256=RQoIhYQJdQpHdY_sMeDuqurbwvm6IX1XH7SVWG6SmS0,51658
|
|
2
|
+
SURE/SURE.py,sha256=_ZOymj24DLQju0Lb90lKspHPmqIUDDzjIEr9t4qgqCI,48364
|
|
3
3
|
SURE/SURE2.py,sha256=8wlnMwb1xuf9QUksNkWdWx5ZWq-xIy9NLx8RdUnE82o,48501
|
|
4
4
|
SURE/__init__.py,sha256=xV10iBbh69g4mjBMb1cQxjuHe8e3Aq7pDzkZmx5G754,260
|
|
5
5
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
@@ -16,9 +16,9 @@ SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
|
|
|
16
16
|
SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
|
|
17
17
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
18
18
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
19
|
-
sure_tools-2.1.
|
|
20
|
-
sure_tools-2.1.
|
|
21
|
-
sure_tools-2.1.
|
|
22
|
-
sure_tools-2.1.
|
|
23
|
-
sure_tools-2.1.
|
|
24
|
-
sure_tools-2.1.
|
|
19
|
+
sure_tools-2.1.1.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
20
|
+
sure_tools-2.1.1.dist-info/METADATA,sha256=ET1LmoMzkRak6WiRpuGf2dcoY7cLgGoZtNmMkcqi6DU,2650
|
|
21
|
+
sure_tools-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
sure_tools-2.1.1.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
23
|
+
sure_tools-2.1.1.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
24
|
+
sure_tools-2.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|