SURE-tools 2.1.54__tar.gz → 2.1.55__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.54 → sure_tools-2.1.55}/PKG-INFO +1 -1
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/PerturbFlow.py +23 -5
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.54 → sure_tools-2.1.55}/setup.py +1 -1
- {sure_tools-2.1.54 → sure_tools-2.1.55}/LICENSE +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/README.md +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/SURE.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.54 → sure_tools-2.1.55}/setup.cfg +0 -0
|
@@ -960,7 +960,7 @@ class PerturbFlow(nn.Module):
|
|
|
960
960
|
R = np.concatenate(R)
|
|
961
961
|
return R
|
|
962
962
|
|
|
963
|
-
def _count(self,concentrate):
|
|
963
|
+
def _count(self,concentrate, library_size=None):
|
|
964
964
|
if self.loss_func == 'bernoulli':
|
|
965
965
|
#counts = self.sigmoid(concentrate)
|
|
966
966
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
@@ -976,6 +976,11 @@ class PerturbFlow(nn.Module):
|
|
|
976
976
|
counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
977
977
|
elif self.loss_func == 'gamma-poisson':
|
|
978
978
|
counts = dist.Poisson(rate=concentrate).to_event(1).mean
|
|
979
|
+
elif self.loss_func == 'multinomial':
|
|
980
|
+
rate = concentrate.exp()
|
|
981
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
982
|
+
counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
|
|
983
|
+
counts = counts * library_size
|
|
979
984
|
return counts
|
|
980
985
|
|
|
981
986
|
def _count_sample(self,concentrate):
|
|
@@ -987,22 +992,35 @@ class PerturbFlow(nn.Module):
|
|
|
987
992
|
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
988
993
|
return counts
|
|
989
994
|
|
|
990
|
-
def get_counts(self, zs,
|
|
995
|
+
def get_counts(self, zs, library_sizes = None,
|
|
991
996
|
batch_size: int = 1024,
|
|
992
997
|
use_sampler: bool = False):
|
|
993
998
|
|
|
994
999
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
995
|
-
|
|
1000
|
+
ls = zs
|
|
1001
|
+
|
|
1002
|
+
if self.loss_func == 'multinomial':
|
|
1003
|
+
assert library_sizes!=None, 'Library sizes are required for multinomial!'
|
|
1004
|
+
|
|
1005
|
+
if type(library_sizes) == list:
|
|
1006
|
+
library_sizes = np.array(library_sizes).view(-1,1)
|
|
1007
|
+
elif len(library_sizes.shape)==1:
|
|
1008
|
+
library_sizes = library_sizes.view(-1,1)
|
|
1009
|
+
ls = convert_to_tensor(library_sizes, device=self.get_device)
|
|
1010
|
+
|
|
1011
|
+
dataset = CustomDataset2(zs,ls)
|
|
996
1012
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
997
1013
|
|
|
998
1014
|
E = []
|
|
999
1015
|
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
1000
|
-
for Z_batch, _ in dataloader:
|
|
1016
|
+
for Z_batch, L_batch, _ in dataloader:
|
|
1017
|
+
if self.loss_func != 'multinomial':
|
|
1018
|
+
L_batch = None
|
|
1001
1019
|
concentrate = self._get_expression_response(Z_batch)
|
|
1002
1020
|
if use_sampler:
|
|
1003
1021
|
counts = self._count_sample(concentrate)
|
|
1004
1022
|
else:
|
|
1005
|
-
counts = self._count(concentrate)
|
|
1023
|
+
counts = self._count(concentrate, L_batch)
|
|
1006
1024
|
E.append(tensor_to_numpy(counts))
|
|
1007
1025
|
pbar.update(1)
|
|
1008
1026
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|