SURE-tools 2.1.13__tar.gz → 2.1.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.13 → sure_tools-2.1.15}/PKG-INFO +1 -1
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/PerturbFlow.py +38 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/SURE.py +1 -1
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.13 → sure_tools-2.1.15}/setup.py +1 -1
- {sure_tools-2.1.13 → sure_tools-2.1.15}/LICENSE +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/README.md +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.13 → sure_tools-2.1.15}/setup.cfg +0 -0
|
@@ -835,6 +835,44 @@ class PerturbFlow(nn.Module):
|
|
|
835
835
|
R = np.concatenate(R)
|
|
836
836
|
return R
|
|
837
837
|
|
|
838
|
+
def _count(self,concentrate):
|
|
839
|
+
if self.loss_func == 'bernoulli':
|
|
840
|
+
counts = self.sigmoid(concentrate)
|
|
841
|
+
else:
|
|
842
|
+
counts = concentrate.exp()
|
|
843
|
+
return counts
|
|
844
|
+
|
|
845
|
+
def _count_sample(self,concentrate):
|
|
846
|
+
if self.loss_func == 'bernoulli':
|
|
847
|
+
logits = concentrate
|
|
848
|
+
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
849
|
+
else:
|
|
850
|
+
counts = self._count(concentrate=concentrate)
|
|
851
|
+
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
852
|
+
return counts
|
|
853
|
+
|
|
854
|
+
def get_counts(self, zs,
|
|
855
|
+
batch_size: int = 1024,
|
|
856
|
+
use_sampler: bool = False):
|
|
857
|
+
|
|
858
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
859
|
+
dataset = CustomDataset(zs)
|
|
860
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
861
|
+
|
|
862
|
+
E = []
|
|
863
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
864
|
+
for Z_batch, _ in dataloader:
|
|
865
|
+
concentrate = self._expression(Z_batch)
|
|
866
|
+
if use_sampler:
|
|
867
|
+
counts = self._count_sample(concentrate)
|
|
868
|
+
else:
|
|
869
|
+
counts = self._count(concentrate)
|
|
870
|
+
E.append(tensor_to_numpy(counts))
|
|
871
|
+
pbar.update(1)
|
|
872
|
+
|
|
873
|
+
E = np.concatenate(E)
|
|
874
|
+
return E
|
|
875
|
+
|
|
838
876
|
def preprocess(self, xs, threshold=0):
|
|
839
877
|
if self.loss_func == 'bernoulli':
|
|
840
878
|
ad = sc.AnnData(xs)
|
|
@@ -233,7 +233,7 @@ class SURE(nn.Module):
|
|
|
233
233
|
|
|
234
234
|
if self.cell_factor_size>0:
|
|
235
235
|
self.cell_factor_effect = MLP(
|
|
236
|
-
[self.
|
|
236
|
+
[self.latent_dim + self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
|
|
237
237
|
activation=activate_fct,
|
|
238
238
|
output_activation=None,
|
|
239
239
|
post_layer_fct=post_layer_fct,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|