SURE-tools 1.0.10__tar.gz → 2.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-1.0.10 → sure_tools-2.0.1}/PKG-INFO +1 -1
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/SURE.py +59 -6
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE_tools.egg-info/SOURCES.txt +0 -1
- {sure_tools-1.0.10 → sure_tools-2.0.1}/setup.py +1 -1
- sure_tools-1.0.10/SURE/SURE2.py +0 -1236
- {sure_tools-1.0.10 → sure_tools-2.0.1}/LICENSE +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/README.md +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/__init__.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/assembly/__init__.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/assembly/assembly.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/assembly/atlas.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/atac/__init__.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/atac/utils.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/codebook/__init__.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/codebook/codebook.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/utils/__init__.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/utils/queue.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE/utils/utils.py +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-1.0.10 → sure_tools-2.0.1}/setup.cfg +0 -0
|
@@ -97,6 +97,7 @@ class SURE(nn.Module):
|
|
|
97
97
|
input_size: int,
|
|
98
98
|
codebook_size: int = 200,
|
|
99
99
|
cell_factor_size: int = 0,
|
|
100
|
+
cell_factor_names: list = None,
|
|
100
101
|
supervised_mode: bool = False,
|
|
101
102
|
z_dim: int = 10,
|
|
102
103
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
@@ -134,6 +135,7 @@ class SURE(nn.Module):
|
|
|
134
135
|
self.post_layer_fct = post_layer_fct
|
|
135
136
|
self.post_act_fct = post_act_fct
|
|
136
137
|
self.hidden_layer_activation = hidden_layer_activation
|
|
138
|
+
self.cell_factor_names = cell_factor_names
|
|
137
139
|
|
|
138
140
|
self.codebook_weights = None
|
|
139
141
|
|
|
@@ -232,8 +234,10 @@ class SURE(nn.Module):
|
|
|
232
234
|
)
|
|
233
235
|
|
|
234
236
|
if self.cell_factor_size>0:
|
|
235
|
-
self.cell_factor_effect =
|
|
236
|
-
|
|
237
|
+
self.cell_factor_effect = nn.ModuleList()
|
|
238
|
+
for i in np.arange(self.cell_factor_size):
|
|
239
|
+
self.cell_factor_effect.append(MLP(
|
|
240
|
+
[self.z_dim+1] + hidden_sizes + [self.z_dim],
|
|
237
241
|
activation=activate_fct,
|
|
238
242
|
output_activation=None,
|
|
239
243
|
post_layer_fct=post_layer_fct,
|
|
@@ -241,6 +245,7 @@ class SURE(nn.Module):
|
|
|
241
245
|
allow_broadcast=self.allow_broadcast,
|
|
242
246
|
use_cuda=self.use_cuda,
|
|
243
247
|
)
|
|
248
|
+
)
|
|
244
249
|
|
|
245
250
|
self.decoder_concentrate = MLP(
|
|
246
251
|
[self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
|
|
@@ -444,8 +449,14 @@ class SURE(nn.Module):
|
|
|
444
449
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
|
|
445
450
|
|
|
446
451
|
if self.cell_factor_size>0:
|
|
447
|
-
zus = self.
|
|
448
|
-
|
|
452
|
+
#zus = self.decoder_undesired([zns,us])
|
|
453
|
+
zus = None
|
|
454
|
+
for i in np.arange(self.cell_factor_size):
|
|
455
|
+
if i==0:
|
|
456
|
+
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
457
|
+
else:
|
|
458
|
+
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
459
|
+
zs = zns+zus
|
|
449
460
|
else:
|
|
450
461
|
zs = zns
|
|
451
462
|
|
|
@@ -634,8 +645,14 @@ class SURE(nn.Module):
|
|
|
634
645
|
zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1), obs=embeds)
|
|
635
646
|
|
|
636
647
|
if self.cell_factor_size>0:
|
|
637
|
-
zus = self.
|
|
638
|
-
|
|
648
|
+
#zus = self.decoder_undesired([zns,us])
|
|
649
|
+
zus = None
|
|
650
|
+
for i in np.arange(self.cell_factor_size):
|
|
651
|
+
if i==0:
|
|
652
|
+
zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
653
|
+
else:
|
|
654
|
+
zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
|
|
655
|
+
zs = zns+zus
|
|
639
656
|
else:
|
|
640
657
|
zs = zns
|
|
641
658
|
|
|
@@ -796,6 +813,42 @@ class SURE(nn.Module):
|
|
|
796
813
|
A = np.concatenate(A)
|
|
797
814
|
return A
|
|
798
815
|
|
|
816
|
+
def _cell_move(self, xs, factor_idx, perturb):
|
|
817
|
+
zns = self.encoder_zn(xs)
|
|
818
|
+
if type(factor_idx) == str:
|
|
819
|
+
factor_idx = np.where(self.cell_factor_names==factor_idx)
|
|
820
|
+
|
|
821
|
+
if perturb.ndim==2:
|
|
822
|
+
ms = self.cell_factor_effect[factor_idx]([zns, perturb])
|
|
823
|
+
else:
|
|
824
|
+
ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
|
|
825
|
+
|
|
826
|
+
return ms
|
|
827
|
+
|
|
828
|
+
def get_cell_move(self,
|
|
829
|
+
xs,
|
|
830
|
+
factor_idx,
|
|
831
|
+
perturb,
|
|
832
|
+
batch_size: int = 1024):
|
|
833
|
+
"""
|
|
834
|
+
Return cells' changes in the latent space induced by specific perturbation of a factor
|
|
835
|
+
|
|
836
|
+
"""
|
|
837
|
+
xs = self.preprocess(xs)
|
|
838
|
+
xs = convert_to_tensor(xs, device=self.get_device())
|
|
839
|
+
dataset = CustomDataset(xs)
|
|
840
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
841
|
+
|
|
842
|
+
Z = []
|
|
843
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
844
|
+
for X_batch, _ in dataloader:
|
|
845
|
+
zns = self._cell_move(X_batch, factor_idx, perturb)
|
|
846
|
+
Z.append(tensor_to_numpy(zns))
|
|
847
|
+
pbar.update(1)
|
|
848
|
+
|
|
849
|
+
Z = np.concatenate(Z)
|
|
850
|
+
return Z
|
|
851
|
+
|
|
799
852
|
def preprocess(self, xs, threshold=0):
|
|
800
853
|
if self.loss_func == 'bernoulli':
|
|
801
854
|
ad = sc.AnnData(xs)
|