SURE-tools 2.1.55__tar.gz → 2.1.57__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.1.55 → sure_tools-2.1.57}/PKG-INFO +1 -1
  2. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/PerturbFlow.py +33 -25
  3. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/PKG-INFO +1 -1
  4. {sure_tools-2.1.55 → sure_tools-2.1.57}/setup.py +1 -1
  5. {sure_tools-2.1.55 → sure_tools-2.1.57}/LICENSE +0 -0
  6. {sure_tools-2.1.55 → sure_tools-2.1.57}/README.md +0 -0
  7. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/SURE.py +0 -0
  8. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/__init__.py +0 -0
  9. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/assembly/__init__.py +0 -0
  10. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/assembly/assembly.py +0 -0
  11. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/assembly/atlas.py +0 -0
  12. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/atac/__init__.py +0 -0
  13. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/atac/utils.py +0 -0
  14. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/codebook/__init__.py +0 -0
  15. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/codebook/codebook.py +0 -0
  16. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/flow/__init__.py +0 -0
  17. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/flow/flow_stats.py +0 -0
  18. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/flow/plot_quiver.py +0 -0
  19. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/perturb/__init__.py +0 -0
  20. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/perturb/perturb.py +0 -0
  21. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/__init__.py +0 -0
  22. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/custom_mlp.py +0 -0
  23. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.55 → sure_tools-2.1.57}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.55 → sure_tools-2.1.57}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.55
3
+ Version: 2.1.57
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -244,7 +244,7 @@ class PerturbFlow(nn.Module):
244
244
  # allow_broadcast=self.allow_broadcast,
245
245
  # use_cuda=self.use_cuda,
246
246
  # )
247
- self.encoder_concentrate = self.decoder_concentrate
247
+ #self.encoder_concentrate = self.decoder_concentrate
248
248
  else:
249
249
  self.decoder_concentrate = MLP(
250
250
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -380,6 +380,8 @@ class PerturbFlow(nn.Module):
380
380
  concentrate = self.decoder_concentrate(zs)
381
381
  if self.loss_func == 'bernoulli':
382
382
  log_theta = concentrate
383
+ elif self.loss_func == 'negbinomial':
384
+ log_theta = concentrate
383
385
  else:
384
386
  rate = concentrate.exp()
385
387
  if self.loss_func != 'poisson':
@@ -387,9 +389,9 @@ class PerturbFlow(nn.Module):
387
389
 
388
390
  if self.loss_func == 'negbinomial':
389
391
  if self.use_zeroinflate:
390
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
392
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
391
393
  else:
392
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
394
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
393
395
  elif self.loss_func == 'poisson':
394
396
  if self.use_zeroinflate:
395
397
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -417,9 +419,9 @@ class PerturbFlow(nn.Module):
417
419
  alpha = self.encoder_n(zns)
418
420
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
419
421
 
420
- if self.loss_func == 'gamma-poisson':
421
- con_alpha,con_beta = self.encoder_concentrate(zns)
422
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
422
+ #if self.loss_func == 'gamma-poisson':
423
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
424
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
423
425
 
424
426
  def model2(self, xs, us=None):
425
427
  pyro.module('PerturbFlow', self)
@@ -484,6 +486,8 @@ class PerturbFlow(nn.Module):
484
486
  concentrate = self.decoder_concentrate(zs)
485
487
  if self.loss_func == 'bernoulli':
486
488
  log_theta = concentrate
489
+ elif self.loss_func == 'negbinomial':
490
+ log_theta = concentrate
487
491
  else:
488
492
  rate = concentrate.exp()
489
493
  if self.loss_func != 'poisson':
@@ -491,9 +495,9 @@ class PerturbFlow(nn.Module):
491
495
 
492
496
  if self.loss_func == 'negbinomial':
493
497
  if self.use_zeroinflate:
494
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
498
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
495
499
  else:
496
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
500
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
497
501
  elif self.loss_func == 'poisson':
498
502
  if self.use_zeroinflate:
499
503
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -521,9 +525,9 @@ class PerturbFlow(nn.Module):
521
525
  alpha = self.encoder_n(zns)
522
526
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
523
527
 
524
- if self.loss_func == 'gamma-poisson':
525
- con_alpha,con_beta = self.encoder_concentrate(zns)
526
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
528
+ #if self.loss_func == 'gamma-poisson':
529
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
530
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
527
531
 
528
532
  def model3(self, xs, ys, embeds=None):
529
533
  pyro.module('PerturbFlow', self)
@@ -594,6 +598,8 @@ class PerturbFlow(nn.Module):
594
598
  concentrate = self.decoder_concentrate(zs)
595
599
  if self.loss_func == 'bernoulli':
596
600
  log_theta = concentrate
601
+ elif self.loss_func == 'negbinomial':
602
+ log_theta = concentrate
597
603
  else:
598
604
  rate = concentrate.exp()
599
605
  if self.loss_func != 'poisson':
@@ -601,9 +607,9 @@ class PerturbFlow(nn.Module):
601
607
 
602
608
  if self.loss_func == 'negbinomial':
603
609
  if self.use_zeroinflate:
604
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
610
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
605
611
  else:
606
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
612
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
607
613
  elif self.loss_func == 'poisson':
608
614
  if self.use_zeroinflate:
609
615
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -631,9 +637,9 @@ class PerturbFlow(nn.Module):
631
637
  else:
632
638
  zns = embeds
633
639
 
634
- if self.loss_func == 'gamma-poisson':
635
- con_alpha,con_beta = self.encoder_concentrate(zns)
636
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
640
+ #if self.loss_func == 'gamma-poisson':
641
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
642
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
637
643
 
638
644
  def model4(self, xs, us, ys, embeds=None):
639
645
  pyro.module('PerturbFlow', self)
@@ -714,6 +720,8 @@ class PerturbFlow(nn.Module):
714
720
  concentrate = self.decoder_concentrate(zs)
715
721
  if self.loss_func == 'bernoulli':
716
722
  log_theta = concentrate
723
+ elif self.loss_func == 'negbinomial':
724
+ log_theta = concentrate
717
725
  else:
718
726
  rate = concentrate.exp()
719
727
  if self.loss_func != 'poisson':
@@ -721,9 +729,9 @@ class PerturbFlow(nn.Module):
721
729
 
722
730
  if self.loss_func == 'negbinomial':
723
731
  if self.use_zeroinflate:
724
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
732
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
725
733
  else:
726
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
734
+ pyro.sample('x', dist.NegativeBinomial(total_count=total_count, logits=log_theta).to_event(1), obs=xs)
727
735
  elif self.loss_func == 'poisson':
728
736
  if self.use_zeroinflate:
729
737
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -751,9 +759,9 @@ class PerturbFlow(nn.Module):
751
759
  else:
752
760
  zns = embeds
753
761
 
754
- if self.loss_func == 'gamma-poisson':
755
- con_alpha,con_beta = self.encoder_concentrate(zns)
756
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
762
+ #if self.loss_func == 'gamma-poisson':
763
+ # con_alpha,con_beta = self.encoder_concentrate(zns)
764
+ # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
757
765
 
758
766
  def _total_effects(self, zns, us):
759
767
  zus = None
@@ -933,7 +941,7 @@ class PerturbFlow(nn.Module):
933
941
 
934
942
  def _get_expression_response(self, delta_zs):
935
943
  if self.loss_func == 'gamma-poisson':
936
- alpha,beta = self.encoder_concentrate(delta_zs)
944
+ alpha,beta = self.decoder_concentrate(delta_zs)
937
945
  xs = dist.Gamma(alpha,beta).to_event(1).mean
938
946
  else:
939
947
  xs = self.decoder_concentrate(delta_zs)
@@ -966,11 +974,11 @@ class PerturbFlow(nn.Module):
966
974
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
967
975
  elif self.loss_func == 'negbinomial':
968
976
  #counts = concentrate.exp()
969
- rate = concentrate.exp()
970
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
977
+ #rate = concentrate.exp()
978
+ #theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
971
979
 
972
980
  total_count = pyro.param("inverse_dispersion")
973
- counts = dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1).mean
981
+ counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
974
982
  elif self.loss_func == 'poisson':
975
983
  rate = concentrate.exp()
976
984
  counts = dist.Poisson(rate=rate).to_event(1).mean
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.55
3
+ Version: 2.1.57
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.55',
8
+ version='2.1.57',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes