SURE-tools 2.1.77__tar.gz → 2.1.78__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.77 → sure_tools-2.1.78}/PKG-INFO +1 -1
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/PerturbFlow.py +7 -26
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.77 → sure_tools-2.1.78}/setup.py +1 -1
- {sure_tools-2.1.77 → sure_tools-2.1.78}/LICENSE +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/README.md +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/SURE.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.77 → sure_tools-2.1.78}/setup.cfg +0 -0
|
@@ -881,25 +881,11 @@ class PerturbFlow(nn.Module):
|
|
|
881
881
|
if self.loss_func == 'bernoulli':
|
|
882
882
|
#counts = self.sigmoid(concentrate)
|
|
883
883
|
counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
|
|
884
|
-
|
|
885
|
-
rate = concentrate.exp()
|
|
886
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
887
|
-
|
|
888
|
-
total_count = self.total_count
|
|
889
|
-
#total_count = pyro.param("inverse_dispersion")
|
|
890
|
-
#store = pyro.get_param_store()
|
|
891
|
-
#total_count = store['inverse_dispersion']
|
|
892
|
-
counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
|
|
893
|
-
elif self.loss_func == 'poisson':
|
|
884
|
+
else:
|
|
894
885
|
rate = concentrate.exp()
|
|
895
886
|
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
896
887
|
counts = theta * library_size
|
|
897
888
|
#counts = dist.Poisson(rate=rate).to_event(1).mean
|
|
898
|
-
elif self.loss_func == 'multinomial':
|
|
899
|
-
rate = concentrate.exp()
|
|
900
|
-
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
901
|
-
#counts = dist.Multinomial(total_count=int(1e8), probs=theta).mean
|
|
902
|
-
counts = theta * library_size
|
|
903
889
|
return counts
|
|
904
890
|
|
|
905
891
|
def _count_sample(self,concentrate):
|
|
@@ -911,22 +897,17 @@ class PerturbFlow(nn.Module):
|
|
|
911
897
|
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
912
898
|
return counts
|
|
913
899
|
|
|
914
|
-
def get_counts(self, zs, library_sizes
|
|
900
|
+
def get_counts(self, zs, library_sizes,
|
|
915
901
|
batch_size: int = 1024,
|
|
916
902
|
use_sampler: bool = False):
|
|
917
903
|
|
|
918
904
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
919
905
|
|
|
920
|
-
if
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
elif len(library_sizes.shape)==1:
|
|
926
|
-
library_sizes = library_sizes.view(-1,1)
|
|
927
|
-
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
928
|
-
else:
|
|
929
|
-
ls = zs
|
|
906
|
+
if type(library_sizes) == list:
|
|
907
|
+
library_sizes = np.array(library_sizes).view(-1,1)
|
|
908
|
+
elif len(library_sizes.shape)==1:
|
|
909
|
+
library_sizes = library_sizes.view(-1,1)
|
|
910
|
+
ls = convert_to_tensor(library_sizes, device=self.get_device())
|
|
930
911
|
|
|
931
912
|
dataset = CustomDataset2(zs,ls)
|
|
932
913
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|