SURE-tools 2.1.14__tar.gz → 2.1.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.14 → sure_tools-2.1.15}/PKG-INFO +1 -1
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/PerturbFlow.py +38 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.14 → sure_tools-2.1.15}/setup.py +1 -1
- {sure_tools-2.1.14 → sure_tools-2.1.15}/LICENSE +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/README.md +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/SURE.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.14 → sure_tools-2.1.15}/setup.cfg +0 -0
|
@@ -835,6 +835,44 @@ class PerturbFlow(nn.Module):
|
|
|
835
835
|
R = np.concatenate(R)
|
|
836
836
|
return R
|
|
837
837
|
|
|
838
|
+
def _count(self,concentrate):
|
|
839
|
+
if self.loss_func == 'bernoulli':
|
|
840
|
+
counts = self.sigmoid(concentrate)
|
|
841
|
+
else:
|
|
842
|
+
counts = concentrate.exp()
|
|
843
|
+
return counts
|
|
844
|
+
|
|
845
|
+
def _count_sample(self,concentrate):
|
|
846
|
+
if self.loss_func == 'bernoulli':
|
|
847
|
+
logits = concentrate
|
|
848
|
+
counts = dist.Bernoulli(logits=logits).to_event(1).sample()
|
|
849
|
+
else:
|
|
850
|
+
counts = self._count(concentrate=concentrate)
|
|
851
|
+
counts = dist.Poisson(rate=counts).to_event(1).sample()
|
|
852
|
+
return counts
|
|
853
|
+
|
|
854
|
+
def get_counts(self, zs,
|
|
855
|
+
batch_size: int = 1024,
|
|
856
|
+
use_sampler: bool = False):
|
|
857
|
+
|
|
858
|
+
zs = convert_to_tensor(zs, device=self.get_device())
|
|
859
|
+
dataset = CustomDataset(zs)
|
|
860
|
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
861
|
+
|
|
862
|
+
E = []
|
|
863
|
+
with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
|
|
864
|
+
for Z_batch, _ in dataloader:
|
|
865
|
+
concentrate = self._expression(Z_batch)
|
|
866
|
+
if use_sampler:
|
|
867
|
+
counts = self._count_sample(concentrate)
|
|
868
|
+
else:
|
|
869
|
+
counts = self._count(concentrate)
|
|
870
|
+
E.append(tensor_to_numpy(counts))
|
|
871
|
+
pbar.update(1)
|
|
872
|
+
|
|
873
|
+
E = np.concatenate(E)
|
|
874
|
+
return E
|
|
875
|
+
|
|
838
876
|
def preprocess(self, xs, threshold=0):
|
|
839
877
|
if self.loss_func == 'bernoulli':
|
|
840
878
|
ad = sc.AnnData(xs)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|