SURE-tools 2.0.10__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/PerturbFlow.py +1315 -0
- SURE/SURE.py +11 -99
- SURE/__init__.py +3 -2
- {sure_tools-2.0.10.dist-info → sure_tools-2.1.1.dist-info}/METADATA +1 -1
- {sure_tools-2.0.10.dist-info → sure_tools-2.1.1.dist-info}/RECORD +9 -8
- sure_tools-2.1.1.dist-info/entry_points.txt +3 -0
- sure_tools-2.0.10.dist-info/entry_points.txt +0 -2
- {sure_tools-2.0.10.dist-info → sure_tools-2.1.1.dist-info}/WHEEL +0 -0
- {sure_tools-2.0.10.dist-info → sure_tools-2.1.1.dist-info}/licenses/LICENSE +0 -0
- {sure_tools-2.0.10.dist-info → sure_tools-2.1.1.dist-info}/top_level.txt +0 -0
SURE/SURE.py
CHANGED
|
@@ -97,7 +97,6 @@ class SURE(nn.Module):
|
|
|
97
97
|
input_size: int,
|
|
98
98
|
codebook_size: int = 200,
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
|
-
cell_factor_names: list = None,
|
|
101
100
|
supervised_mode: bool = False,
|
|
102
101
|
z_dim: int = 10,
|
|
103
102
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
@@ -135,7 +134,6 @@ class SURE(nn.Module):
|
|
|
135
134
|
self.post_layer_fct = post_layer_fct
|
|
136
135
|
self.post_act_fct = post_act_fct
|
|
137
136
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
-
self.cell_factor_names = cell_factor_names
|
|
139
137
|
|
|
140
138
|
self.codebook_weights = None
|
|
141
139
|
|
|
@@ -234,18 +232,15 @@ class SURE(nn.Module):
|
|
|
234
232
|
)
|
|
235
233
|
|
|
236
234
|
if self.cell_factor_size>0:
|
|
237
|
-
self.cell_factor_effect =
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
use_cuda=self.use_cuda,
|
|
247
|
-
)
|
|
248
|
-
)
|
|
235
|
+
self.cell_factor_effect = MLP(
|
|
236
|
+
[self.z_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.z_dim],
|
|
237
|
+
activation=activate_fct,
|
|
238
|
+
output_activation=None,
|
|
239
|
+
post_layer_fct=post_layer_fct,
|
|
240
|
+
post_act_fct=post_act_fct,
|
|
241
|
+
allow_broadcast=self.allow_broadcast,
|
|
242
|
+
use_cuda=self.use_cuda,
|
|
243
|
+
)
|
|
249
244
|
|
|
250
245
|
self.decoder_concentrate = MLP(
|
|
251
246
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -449,13 +444,7 @@ class SURE(nn.Module):
|
|
|
449
444
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
450
445
|
|
|
451
446
|
if self.cell_factor_size>0:
|
|
452
|
-
|
|
453
|
-
zus = None
|
|
454
|
-
for i in np.arange(self.cell_factor_size):
|
|
455
|
-
if i==0:
|
|
456
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
457
|
-
else:
|
|
458
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
447
|
+
zus = self.cell_factor_effect([zns,us])
|
|
459
448
|
zs = zns+zus
|
|
460
449
|
else:
|
|
461
450
|
zs = zns
|
|
@@ -645,13 +634,7 @@ class SURE(nn.Module):
|
|
|
645
634
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
646
635
|
|
|
647
636
|
if self.cell_factor_size>0:
|
|
648
|
-
|
|
649
|
-
zus = None
|
|
650
|
-
for i in np.arange(self.cell_factor_size):
|
|
651
|
-
if i==0:
|
|
652
|
-
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
653
|
-
else:
|
|
654
|
-
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
637
|
+
zus = self.decoder_undesired([zns,us])
|
|
655
638
|
zs = zns+zus
|
|
656
639
|
else:
|
|
657
640
|
zs = zns
|
|
@@ -813,77 +796,6 @@ class SURE(nn.Module):
|
|
|
813
796
|
A = np.concatenate(A)
|
|
814
797
|
return A
|
|
815
798
|
|
|
816
|
-
def _cell_move(self, xs, factor_idx, perturb):
|
|
817
|
-
zns,_ = self.encoder_zn(xs)
|
|
818
|
-
if type(factor_idx) == str:
|
|
819
|
-
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
820
|
-
|
|
821
|
-
if perturb.ndim==2:
|
|
822
|
-
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
823
|
-
else:
|
|
824
|
-
ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
|
|
825
|
-
|
|
826
|
-
return ms
|
|
827
|
-
|
|
828
|
-
def get_cell_move(self,
|
|
829
|
-
xs,
|
|
830
|
-
factor_idx,
|
|
831
|
-
perturb,
|
|
832
|
-
batch_size: int = 1024):
|
|
833
|
-
"""
|
|
834
|
-
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
835
|
-
|
|
836
|
-
"""
|
|
837
|
-
xs = self.preprocess(xs)
|
|
838
|
-
xs = convert_to_tensor(xs, device=self.get_device())
|
|
839
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
840
|
-
dataset = CustomDataset2(xs,ps)
|
|
841
|
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
842
|
-
|
|
843
|
-
Z = []
|
|
844
|
-
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
845
|
-
for X_batch, P_batch, _ in dataloader:
|
|
846
|
-
zns = self._cell_move(X_batch, factor_idx, P_batch)
|
|
847
|
-
Z.append(tensor_to_numpy(zns))
|
|
848
|
-
pbar.update(1)
|
|
849
|
-
|
|
850
|
-
Z = np.concatenate(Z)
|
|
851
|
-
return Z
|
|
852
|
-
|
|
853
|
-
def get_metacell_move(self, factor_idx, perturb):
|
|
854
|
-
zs = self._get_codebook()
|
|
855
|
-
ps = convert_to_tensor(perturb, device=self.get_device())
|
|
856
|
-
|
|
857
|
-
if type(factor_idx) == str:
|
|
858
|
-
factor_idx = int(np.where(self.cell_factor_names==factor_idx)[0])
|
|
859
|
-
|
|
860
|
-
ms = self.cell_factor_effect[factor_idx]([zs,ps])
|
|
861
|
-
return tensor_to_numpy(ms)
|
|
862
|
-
|
|
863
|
-
def _get_expression_responses(self, delta_zs):
|
|
864
|
-
return self.decoder_concentrate(delta_zs)
|
|
865
|
-
|
|
866
|
-
def get_expression_responses(self,
|
|
867
|
-
delta_zs,
|
|
868
|
-
batch_size: int = 1024):
|
|
869
|
-
"""
|
|
870
|
-
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
871
|
-
|
|
872
|
-
"""
|
|
873
|
-
delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
|
|
874
|
-
dataset = CustomDataset(delta_zs)
|
|
875
|
-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
876
|
-
|
|
877
|
-
R = []
|
|
878
|
-
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
879
|
-
for delta_Z_batch, _ in dataloader:
|
|
880
|
-
r = self._cell_move(delta_Z_batch)
|
|
881
|
-
R.append(tensor_to_numpy(r))
|
|
882
|
-
pbar.update(1)
|
|
883
|
-
|
|
884
|
-
R = np.concatenate(R)
|
|
885
|
-
return R
|
|
886
|
-
|
|
887
799
|
def preprocess(self, xs, threshold=0):
|
|
888
800
|
if self.loss_func == 'bernoulli':
|
|
889
801
|
ad = sc.AnnData(xs)
|
SURE/__init__.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from .SURE import SURE
|
|
2
|
-
from .
|
|
2
|
+
from .PerturbFlow import PerturbFlow
|
|
3
3
|
|
|
4
4
|
from . import utils
|
|
5
5
|
from . import codebook
|
|
6
6
|
from . import SURE
|
|
7
|
+
from . import PerturbFlow
|
|
7
8
|
from . import atac
|
|
8
9
|
from . import flow
|
|
9
10
|
|
|
10
|
-
__all__ = ['SURE', 'flow', 'atac', 'utils', 'codebook']
|
|
11
|
+
__all__ = ['SURE', 'PerturbFlow', 'flow', 'atac', 'utils', 'codebook']
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
SURE/
|
|
1
|
+
SURE/PerturbFlow.py,sha256=RQoIhYQJdQpHdY_sMeDuqurbwvm6IX1XH7SVWG6SmS0,51658
|
|
2
|
+
SURE/SURE.py,sha256=_ZOymj24DLQju0Lb90lKspHPmqIUDDzjIEr9t4qgqCI,48364
|
|
2
3
|
SURE/SURE2.py,sha256=8wlnMwb1xuf9QUksNkWdWx5ZWq-xIy9NLx8RdUnE82o,48501
|
|
3
|
-
SURE/__init__.py,sha256=
|
|
4
|
+
SURE/__init__.py,sha256=xV10iBbh69g4mjBMb1cQxjuHe8e3Aq7pDzkZmx5G754,260
|
|
4
5
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
5
6
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
6
7
|
SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
|
|
@@ -15,9 +16,9 @@ SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
|
|
|
15
16
|
SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
|
|
16
17
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
17
18
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
18
|
-
sure_tools-2.
|
|
19
|
-
sure_tools-2.
|
|
20
|
-
sure_tools-2.
|
|
21
|
-
sure_tools-2.
|
|
22
|
-
sure_tools-2.
|
|
23
|
-
sure_tools-2.
|
|
19
|
+
sure_tools-2.1.1.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
20
|
+
sure_tools-2.1.1.dist-info/METADATA,sha256=ET1LmoMzkRak6WiRpuGf2dcoY7cLgGoZtNmMkcqi6DU,2650
|
|
21
|
+
sure_tools-2.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
22
|
+
sure_tools-2.1.1.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
|
|
23
|
+
sure_tools-2.1.1.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
24
|
+
sure_tools-2.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|