SURE-tools 2.1.59__tar.gz → 2.1.64__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- {sure_tools-2.1.59 → sure_tools-2.1.64}/PKG-INFO +1 -1
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/PerturbFlow.py +15 -14
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE_tools.egg-info/PKG-INFO +1 -1
- {sure_tools-2.1.59 → sure_tools-2.1.64}/setup.py +1 -1
- {sure_tools-2.1.59 → sure_tools-2.1.64}/LICENSE +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/README.md +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/SURE.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/assembly/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/assembly/assembly.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/assembly/atlas.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/atac/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/atac/utils.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/codebook/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/codebook/codebook.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/flow/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/flow/flow_stats.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/flow/plot_quiver.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/perturb/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/perturb/perturb.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/utils/__init__.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/utils/custom_mlp.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/utils/queue.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE/utils/utils.py +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE_tools.egg-info/SOURCES.txt +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE_tools.egg-info/dependency_links.txt +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE_tools.egg-info/entry_points.txt +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE_tools.egg-info/requires.txt +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/SURE_tools.egg-info/top_level.txt +0 -0
- {sure_tools-2.1.59 → sure_tools-2.1.64}/setup.cfg +0 -0
|
@@ -351,7 +351,7 @@ class PerturbFlow(nn.Module):
|
|
|
351
351
|
|
|
352
352
|
zs = zns
|
|
353
353
|
concentrate = self.decoder_concentrate(zs)
|
|
354
|
-
if self.loss_func in ['bernoulli'
|
|
354
|
+
if self.loss_func in ['bernoulli']:
|
|
355
355
|
log_theta = concentrate
|
|
356
356
|
else:
|
|
357
357
|
rate = concentrate.exp()
|
|
@@ -361,9 +361,9 @@ class PerturbFlow(nn.Module):
|
|
|
361
361
|
|
|
362
362
|
if self.loss_func == 'negbinomial':
|
|
363
363
|
if self.use_zeroinflate:
|
|
364
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
364
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
365
365
|
else:
|
|
366
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
366
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
367
367
|
elif self.loss_func == 'poisson':
|
|
368
368
|
if self.use_zeroinflate:
|
|
369
369
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -435,7 +435,7 @@ class PerturbFlow(nn.Module):
|
|
|
435
435
|
zs = zns
|
|
436
436
|
|
|
437
437
|
concentrate = self.decoder_concentrate(zs)
|
|
438
|
-
if self.loss_func in ['bernoulli'
|
|
438
|
+
if self.loss_func in ['bernoulli']:
|
|
439
439
|
log_theta = concentrate
|
|
440
440
|
else:
|
|
441
441
|
rate = concentrate.exp()
|
|
@@ -445,9 +445,9 @@ class PerturbFlow(nn.Module):
|
|
|
445
445
|
|
|
446
446
|
if self.loss_func == 'negbinomial':
|
|
447
447
|
if self.use_zeroinflate:
|
|
448
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
448
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
449
449
|
else:
|
|
450
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
450
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
451
451
|
elif self.loss_func == 'poisson':
|
|
452
452
|
if self.use_zeroinflate:
|
|
453
453
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -531,7 +531,7 @@ class PerturbFlow(nn.Module):
|
|
|
531
531
|
zs = zns
|
|
532
532
|
|
|
533
533
|
concentrate = self.decoder_concentrate(zs)
|
|
534
|
-
if self.loss_func in ['bernoulli'
|
|
534
|
+
if self.loss_func in ['bernoulli']:
|
|
535
535
|
log_theta = concentrate
|
|
536
536
|
else:
|
|
537
537
|
rate = concentrate.exp()
|
|
@@ -541,9 +541,9 @@ class PerturbFlow(nn.Module):
|
|
|
541
541
|
|
|
542
542
|
if self.loss_func == 'negbinomial':
|
|
543
543
|
if self.use_zeroinflate:
|
|
544
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
544
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
545
545
|
else:
|
|
546
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
546
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
547
547
|
elif self.loss_func == 'poisson':
|
|
548
548
|
if self.use_zeroinflate:
|
|
549
549
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -637,7 +637,7 @@ class PerturbFlow(nn.Module):
|
|
|
637
637
|
zs = zns
|
|
638
638
|
|
|
639
639
|
concentrate = self.decoder_concentrate(zs)
|
|
640
|
-
if self.loss_func in ['bernoulli'
|
|
640
|
+
if self.loss_func in ['bernoulli']:
|
|
641
641
|
log_theta = concentrate
|
|
642
642
|
else:
|
|
643
643
|
rate = concentrate.exp()
|
|
@@ -647,9 +647,9 @@ class PerturbFlow(nn.Module):
|
|
|
647
647
|
|
|
648
648
|
if self.loss_func == 'negbinomial':
|
|
649
649
|
if self.use_zeroinflate:
|
|
650
|
-
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count,
|
|
650
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
651
651
|
else:
|
|
652
|
-
pyro.sample('x', dist.NegativeBinomial(total_count=total_count,
|
|
652
|
+
pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
|
|
653
653
|
elif self.loss_func == 'poisson':
|
|
654
654
|
if self.use_zeroinflate:
|
|
655
655
|
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
|
|
@@ -909,16 +909,17 @@ class PerturbFlow(nn.Module):
|
|
|
909
909
|
use_sampler: bool = False):
|
|
910
910
|
|
|
911
911
|
zs = convert_to_tensor(zs, device=self.get_device())
|
|
912
|
-
ls = zs
|
|
913
912
|
|
|
914
913
|
if self.loss_func in ['multinomial','poisson']:
|
|
915
|
-
assert library_sizes
|
|
914
|
+
assert library_sizes is not None, 'Library sizes are required for multinomial!'
|
|
916
915
|
|
|
917
916
|
if type(library_sizes) == list:
|
|
918
917
|
library_sizes = np.array(library_sizes).view(-1,1)
|
|
919
918
|
elif len(library_sizes.shape)==1:
|
|
920
919
|
library_sizes = library_sizes.view(-1,1)
|
|
921
920
|
ls = convert_to_tensor(library_sizes, device=self.get_device)
|
|
921
|
+
else:
|
|
922
|
+
ls = zs
|
|
922
923
|
|
|
923
924
|
dataset = CustomDataset2(zs,ls)
|
|
924
925
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|