SURE-tools 2.1.57__tar.gz → 2.1.59__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {sure_tools-2.1.57 → sure_tools-2.1.59}/PKG-INFO +1 -1
  2. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/PerturbFlow.py +38 -134
  3. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE_tools.egg-info/PKG-INFO +1 -1
  4. {sure_tools-2.1.57 → sure_tools-2.1.59}/setup.py +1 -1
  5. {sure_tools-2.1.57 → sure_tools-2.1.59}/LICENSE +0 -0
  6. {sure_tools-2.1.57 → sure_tools-2.1.59}/README.md +0 -0
  7. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/SURE.py +0 -0
  8. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/__init__.py +0 -0
  9. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/assembly/__init__.py +0 -0
  10. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/assembly/assembly.py +0 -0
  11. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/assembly/atlas.py +0 -0
  12. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/atac/__init__.py +0 -0
  13. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/atac/utils.py +0 -0
  14. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/codebook/__init__.py +0 -0
  15. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/codebook/codebook.py +0 -0
  16. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/flow/__init__.py +0 -0
  17. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/flow/flow_stats.py +0 -0
  18. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/flow/plot_quiver.py +0 -0
  19. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/perturb/__init__.py +0 -0
  20. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/perturb/perturb.py +0 -0
  21. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/utils/__init__.py +0 -0
  22. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/utils/custom_mlp.py +0 -0
  23. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.57 → sure_tools-2.1.59}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.57 → sure_tools-2.1.59}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.57
3
+ Version: 2.1.59
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -62,7 +62,7 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli','gamma-poisson'] = 'negbinomial',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
68
  hidden_layers: list = [300],
@@ -225,28 +225,7 @@ class PerturbFlow(nn.Module):
225
225
  )
226
226
  )
227
227
 
228
- if self.loss_func == 'gamma-poisson':
229
- self.decoder_concentrate = MLP(
230
- [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
231
- activation=activate_fct,
232
- output_activation=[Exp,Exp],
233
- post_layer_fct=post_layer_fct,
234
- post_act_fct=post_act_fct,
235
- allow_broadcast=self.allow_broadcast,
236
- use_cuda=self.use_cuda,
237
- )
238
- #self.encoder_concentrate = MLP(
239
- # [self.latent_dim] + self.decoder_hidden_layers + [[self.input_size,self.input_size]],
240
- # activation=activate_fct,
241
- # output_activation=[Exp,Exp],
242
- # post_layer_fct=post_layer_fct,
243
- # post_act_fct=post_act_fct,
244
- # allow_broadcast=self.allow_broadcast,
245
- # use_cuda=self.use_cuda,
246
- # )
247
- #self.encoder_concentrate = self.decoder_concentrate
248
- else:
249
- self.decoder_concentrate = MLP(
228
+ self.decoder_concentrate = MLP(
250
229
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
251
230
  activation=activate_fct,
252
231
  output_activation=None,
@@ -338,8 +317,6 @@ class PerturbFlow(nn.Module):
338
317
  if self.loss_func=='negbinomial':
339
318
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
340
319
  xs.new_ones(self.input_size), constraint=constraints.positive)
341
- elif self.loss_func == 'multinomial':
342
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
343
320
 
344
321
  if self.use_zeroinflate:
345
322
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -373,19 +350,14 @@ class PerturbFlow(nn.Module):
373
350
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
374
351
 
375
352
  zs = zns
376
- if self.loss_func == 'gamma-poisson':
377
- con_alpha,con_beta = self.decoder_concentrate(zs)
378
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
353
+ concentrate = self.decoder_concentrate(zs)
354
+ if self.loss_func in ['bernoulli','negbinomial']:
355
+ log_theta = concentrate
379
356
  else:
380
- concentrate = self.decoder_concentrate(zs)
381
- if self.loss_func == 'bernoulli':
382
- log_theta = concentrate
383
- elif self.loss_func == 'negbinomial':
384
- log_theta = concentrate
385
- else:
386
- rate = concentrate.exp()
387
- if self.loss_func != 'poisson':
388
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
357
+ rate = concentrate.exp()
358
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
359
+ if self.loss_func == 'poisson':
360
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
389
361
 
390
362
  if self.loss_func == 'negbinomial':
391
363
  if self.use_zeroinflate:
@@ -397,13 +369,8 @@ class PerturbFlow(nn.Module):
397
369
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
398
370
  else:
399
371
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
400
- elif self.loss_func == 'gamma-poisson':
401
- if self.use_zeroinflate:
402
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
- else:
404
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
372
  elif self.loss_func == 'multinomial':
406
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
373
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
407
374
  elif self.loss_func == 'bernoulli':
408
375
  if self.use_zeroinflate:
409
376
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -418,10 +385,6 @@ class PerturbFlow(nn.Module):
418
385
 
419
386
  alpha = self.encoder_n(zns)
420
387
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
421
-
422
- #if self.loss_func == 'gamma-poisson':
423
- # con_alpha,con_beta = self.encoder_concentrate(zns)
424
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
425
388
 
426
389
  def model2(self, xs, us=None):
427
390
  pyro.module('PerturbFlow', self)
@@ -433,8 +396,6 @@ class PerturbFlow(nn.Module):
433
396
  if self.loss_func=='negbinomial':
434
397
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
435
398
  xs.new_ones(self.input_size), constraint=constraints.positive)
436
- elif self.loss_func == 'multinomial':
437
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
438
399
 
439
400
  if self.use_zeroinflate:
440
401
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -468,30 +429,19 @@ class PerturbFlow(nn.Module):
468
429
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
469
430
 
470
431
  if self.cell_factor_size>0:
471
- #zus = None
472
- #for i in np.arange(self.cell_factor_size):
473
- # if i==0:
474
- # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
475
- # else:
476
- # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
477
432
  zus = self._total_effects(zns, us)
478
433
  zs = zns+zus
479
434
  else:
480
435
  zs = zns
481
436
 
482
- if self.loss_func == 'gamma-poisson':
483
- con_alpha,con_beta = self.decoder_concentrate(zs)
484
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
437
+ concentrate = self.decoder_concentrate(zs)
438
+ if self.loss_func in ['bernoulli','negbinomial']:
439
+ log_theta = concentrate
485
440
  else:
486
- concentrate = self.decoder_concentrate(zs)
487
- if self.loss_func == 'bernoulli':
488
- log_theta = concentrate
489
- elif self.loss_func == 'negbinomial':
490
- log_theta = concentrate
491
- else:
492
- rate = concentrate.exp()
493
- if self.loss_func != 'poisson':
494
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
441
+ rate = concentrate.exp()
442
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
443
+ if self.loss_func == 'poisson':
444
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
495
445
 
496
446
  if self.loss_func == 'negbinomial':
497
447
  if self.use_zeroinflate:
@@ -503,13 +453,8 @@ class PerturbFlow(nn.Module):
503
453
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
504
454
  else:
505
455
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
506
- elif self.loss_func == 'gamma-poisson':
507
- if self.use_zeroinflate:
508
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
509
- else:
510
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
511
456
  elif self.loss_func == 'multinomial':
512
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
457
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
513
458
  elif self.loss_func == 'bernoulli':
514
459
  if self.use_zeroinflate:
515
460
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -524,10 +469,6 @@ class PerturbFlow(nn.Module):
524
469
 
525
470
  alpha = self.encoder_n(zns)
526
471
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
527
-
528
- #if self.loss_func == 'gamma-poisson':
529
- # con_alpha,con_beta = self.encoder_concentrate(zns)
530
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
531
472
 
532
473
  def model3(self, xs, ys, embeds=None):
533
474
  pyro.module('PerturbFlow', self)
@@ -539,8 +480,6 @@ class PerturbFlow(nn.Module):
539
480
  if self.loss_func=='negbinomial':
540
481
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
541
482
  xs.new_ones(self.input_size), constraint=constraints.positive)
542
- elif self.loss_func == 'multinomial':
543
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
544
483
 
545
484
  if self.use_zeroinflate:
546
485
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -591,19 +530,14 @@ class PerturbFlow(nn.Module):
591
530
 
592
531
  zs = zns
593
532
 
594
- if self.loss_func == 'gamma-poisson':
595
- con_alpha,con_beta = self.decoder_concentrate(zs)
596
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
533
+ concentrate = self.decoder_concentrate(zs)
534
+ if self.loss_func in ['bernoulli','negbinomial']:
535
+ log_theta = concentrate
597
536
  else:
598
- concentrate = self.decoder_concentrate(zs)
599
- if self.loss_func == 'bernoulli':
600
- log_theta = concentrate
601
- elif self.loss_func == 'negbinomial':
602
- log_theta = concentrate
603
- else:
604
- rate = concentrate.exp()
605
- if self.loss_func != 'poisson':
606
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
537
+ rate = concentrate.exp()
538
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
539
+ if self.loss_func == 'poisson':
540
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
607
541
 
608
542
  if self.loss_func == 'negbinomial':
609
543
  if self.use_zeroinflate:
@@ -615,13 +549,8 @@ class PerturbFlow(nn.Module):
615
549
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
616
550
  else:
617
551
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
618
- elif self.loss_func == 'gamma-poisson':
619
- if self.use_zeroinflate:
620
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
621
- else:
622
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
623
552
  elif self.loss_func == 'multinomial':
624
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
553
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
625
554
  elif self.loss_func == 'bernoulli':
626
555
  if self.use_zeroinflate:
627
556
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -636,10 +565,6 @@ class PerturbFlow(nn.Module):
636
565
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
637
566
  else:
638
567
  zns = embeds
639
-
640
- #if self.loss_func == 'gamma-poisson':
641
- # con_alpha,con_beta = self.encoder_concentrate(zns)
642
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
643
568
 
644
569
  def model4(self, xs, us, ys, embeds=None):
645
570
  pyro.module('PerturbFlow', self)
@@ -651,8 +576,6 @@ class PerturbFlow(nn.Module):
651
576
  if self.loss_func=='negbinomial':
652
577
  total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
653
578
  xs.new_ones(self.input_size), constraint=constraints.positive)
654
- elif self.loss_func == 'multinomial':
655
- total_count = pyro.param('total_count', int(1e8), constraint=constraints.positive_integer)
656
579
 
657
580
  if self.use_zeroinflate:
658
581
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -713,19 +636,14 @@ class PerturbFlow(nn.Module):
713
636
  else:
714
637
  zs = zns
715
638
 
716
- if self.loss_func == 'gamma-poisson':
717
- con_alpha,con_beta = self.decoder_concentrate(zs)
718
- rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
639
+ concentrate = self.decoder_concentrate(zs)
640
+ if self.loss_func in ['bernoulli','negbinomial']:
641
+ log_theta = concentrate
719
642
  else:
720
- concentrate = self.decoder_concentrate(zs)
721
- if self.loss_func == 'bernoulli':
722
- log_theta = concentrate
723
- elif self.loss_func == 'negbinomial':
724
- log_theta = concentrate
725
- else:
726
- rate = concentrate.exp()
727
- if self.loss_func != 'poisson':
728
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
643
+ rate = concentrate.exp()
644
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
645
+ if self.loss_func == 'poisson':
646
+ rate = theta * torch.sum(xs, dim=1, keepdim=True)
729
647
 
730
648
  if self.loss_func == 'negbinomial':
731
649
  if self.use_zeroinflate:
@@ -737,13 +655,8 @@ class PerturbFlow(nn.Module):
737
655
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
738
656
  else:
739
657
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
740
- elif self.loss_func == 'gamma-poisson':
741
- if self.use_zeroinflate:
742
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
743
- else:
744
- pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
745
658
  elif self.loss_func == 'multinomial':
746
- pyro.sample('x', dist.Multinomial(total_count=total_count, probs=theta), obs=xs)
659
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
747
660
  elif self.loss_func == 'bernoulli':
748
661
  if self.use_zeroinflate:
749
662
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -758,10 +671,6 @@ class PerturbFlow(nn.Module):
758
671
  zns = pyro.sample('zn', dist.Normal(zn_loc, zn_scale).to_event(1))
759
672
  else:
760
673
  zns = embeds
761
-
762
- #if self.loss_func == 'gamma-poisson':
763
- # con_alpha,con_beta = self.encoder_concentrate(zns)
764
- # rate = pyro.sample('cs', dist.Gamma(con_alpha, con_beta).to_event(1))
765
674
 
766
675
  def _total_effects(self, zns, us):
767
676
  zus = None
@@ -940,12 +849,7 @@ class PerturbFlow(nn.Module):
940
849
  return tensor_to_numpy(ms)
941
850
 
942
851
  def _get_expression_response(self, delta_zs):
943
- if self.loss_func == 'gamma-poisson':
944
- alpha,beta = self.decoder_concentrate(delta_zs)
945
- xs = dist.Gamma(alpha,beta).to_event(1).mean
946
- else:
947
- xs = self.decoder_concentrate(delta_zs)
948
- return xs
852
+ return self.decoder_concentrate(delta_zs)
949
853
 
950
854
  def get_expression_response(self,
951
855
  delta_zs,
@@ -981,9 +885,9 @@ class PerturbFlow(nn.Module):
981
885
  counts = dist.NegativeBinomial(total_count=total_count, logits=concentrate).to_event(1).mean
982
886
  elif self.loss_func == 'poisson':
983
887
  rate = concentrate.exp()
888
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
889
+ rate = theta * library_size
984
890
  counts = dist.Poisson(rate=rate).to_event(1).mean
985
- elif self.loss_func == 'gamma-poisson':
986
- counts = dist.Poisson(rate=concentrate).to_event(1).mean
987
891
  elif self.loss_func == 'multinomial':
988
892
  rate = concentrate.exp()
989
893
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
@@ -1007,7 +911,7 @@ class PerturbFlow(nn.Module):
1007
911
  zs = convert_to_tensor(zs, device=self.get_device())
1008
912
  ls = zs
1009
913
 
1010
- if self.loss_func == 'multinomial':
914
+ if self.loss_func in ['multinomial','poisson']:
1011
915
  assert library_sizes!=None, 'Library sizes are required for multinomial!'
1012
916
 
1013
917
  if type(library_sizes) == list:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.57
3
+ Version: 2.1.59
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.57',
8
+ version='2.1.59',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes