SURE-tools 2.4.20__tar.gz → 2.4.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {sure_tools-2.4.20 → sure_tools-2.4.35}/PKG-INFO +1 -1
  2. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/DensityFlow.py +122 -71
  3. sure_tools-2.4.35/SURE/DensityFlowLinear.py +1414 -0
  4. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/PerturbationAwareDecoder.py +217 -250
  5. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/__init__.py +3 -1
  6. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/PKG-INFO +1 -1
  7. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/SOURCES.txt +1 -0
  8. {sure_tools-2.4.20 → sure_tools-2.4.35}/setup.py +1 -1
  9. {sure_tools-2.4.20 → sure_tools-2.4.35}/LICENSE +0 -0
  10. {sure_tools-2.4.20 → sure_tools-2.4.35}/README.md +0 -0
  11. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/EfficientTranscriptomeDecoder.py +0 -0
  12. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/PerturbE.py +0 -0
  13. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/SURE.py +0 -0
  14. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/SimpleTranscriptomeDecoder.py +0 -0
  15. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/TranscriptomeDecoder.py +0 -0
  16. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/VirtualCellDecoder.py +0 -0
  17. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/assembly/__init__.py +0 -0
  18. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/assembly/assembly.py +0 -0
  19. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/assembly/atlas.py +0 -0
  20. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/atac/__init__.py +0 -0
  21. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/atac/utils.py +0 -0
  22. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/codebook/__init__.py +0 -0
  23. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/codebook/codebook.py +0 -0
  24. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/flow/__init__.py +0 -0
  25. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/flow/flow_stats.py +0 -0
  26. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/flow/plot_quiver.py +0 -0
  27. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/perturb/__init__.py +0 -0
  28. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/perturb/perturb.py +0 -0
  29. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/__init__.py +0 -0
  30. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/custom_mlp.py +0 -0
  31. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/queue.py +0 -0
  32. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE/utils/utils.py +0 -0
  33. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/dependency_links.txt +0 -0
  34. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/entry_points.txt +0 -0
  35. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/requires.txt +0 -0
  36. {sure_tools-2.4.20 → sure_tools-2.4.35}/SURE_tools.egg-info/top_level.txt +0 -0
  37. {sure_tools-2.4.20 → sure_tools-2.4.35}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.20
3
+ Version: 2.4.35
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 100,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,6 +108,7 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
112
114
 
@@ -115,7 +117,7 @@ class DensityFlow(nn.Module):
115
117
  print(f" - Gene Dimension: {self.input_size}")
116
118
  print(f" - Hidden Dimensions: {self.hidden_layers}")
117
119
  print(f" - Device: {self.get_device()}")
118
- print(f" - Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
120
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
119
121
 
120
122
  def setup_networks(self):
121
123
  latent_dim = self.latent_dim
@@ -258,7 +260,7 @@ class DensityFlow(nn.Module):
258
260
  )
259
261
  )
260
262
 
261
- self.decoder_concentrate = MLP(
263
+ self.decoder_log_mu = MLP(
262
264
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
263
265
  activation=activate_fct,
264
266
  output_activation=None,
@@ -348,8 +350,8 @@ class DensityFlow(nn.Module):
348
350
  self.options = dict(dtype=xs.dtype, device=xs.device)
349
351
 
350
352
  if self.loss_func=='negbinomial':
351
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
352
- xs.new_ones(self.input_size), constraint=constraints.positive)
353
+ dispersion = pyro.param("dispersion", self.dispersion *
354
+ xs.new_ones(self.input_size), constraint=constraints.positive)
353
355
 
354
356
  if self.use_zeroinflate:
355
357
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -383,28 +385,32 @@ class DensityFlow(nn.Module):
383
385
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
384
386
 
385
387
  zs = zns
386
- concentrate = self.decoder_concentrate(zs)
388
+ log_mu = self.decoder_log_mu(zs)
387
389
  if self.loss_func in ['bernoulli']:
388
- log_theta = concentrate
390
+ log_theta = log_mu
391
+ elif self.loss_func == 'negbinomial':
392
+ mu = log_mu.exp()
389
393
  else:
390
- rate = concentrate.exp()
394
+ rate = log_mu.exp()
391
395
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
392
396
  if self.loss_func == 'poisson':
393
397
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
394
398
 
395
399
  if self.loss_func == 'negbinomial':
400
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
396
401
  if self.use_zeroinflate:
397
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
402
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
403
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
398
404
  else:
399
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
405
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
406
+ logits=logits).to_event(1), obs=xs)
400
407
  elif self.loss_func == 'poisson':
401
408
  if self.use_zeroinflate:
402
409
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
410
  else:
404
411
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
412
  elif self.loss_func == 'multinomial':
406
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
407
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
413
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
408
414
  elif self.loss_func == 'bernoulli':
409
415
  if self.use_zeroinflate:
410
416
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -428,8 +434,8 @@ class DensityFlow(nn.Module):
428
434
  self.options = dict(dtype=xs.dtype, device=xs.device)
429
435
 
430
436
  if self.loss_func=='negbinomial':
431
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
432
- xs.new_ones(self.input_size), constraint=constraints.positive)
437
+ dispersion = pyro.param("dispersion", self.dispersion *
438
+ xs.new_ones(self.input_size), constraint=constraints.positive)
433
439
 
434
440
  if self.use_zeroinflate:
435
441
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -468,28 +474,30 @@ class DensityFlow(nn.Module):
468
474
  else:
469
475
  zs = zns
470
476
 
471
- concentrate = self.decoder_concentrate(zs)
477
+ log_mu = self.decoder_log_mu(zs)
472
478
  if self.loss_func in ['bernoulli']:
473
- log_theta = concentrate
479
+ log_theta = log_mu
480
+ elif self.loss_func == 'negbinomial':
481
+ mu = log_mu.exp()
474
482
  else:
475
- rate = concentrate.exp()
483
+ rate = log_mu.exp()
476
484
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
477
485
  if self.loss_func == 'poisson':
478
486
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
479
487
 
480
488
  if self.loss_func == 'negbinomial':
489
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
481
490
  if self.use_zeroinflate:
482
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
491
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
483
492
  else:
484
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
493
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
485
494
  elif self.loss_func == 'poisson':
486
495
  if self.use_zeroinflate:
487
496
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
488
497
  else:
489
498
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
490
499
  elif self.loss_func == 'multinomial':
491
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
492
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
500
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
493
501
  elif self.loss_func == 'bernoulli':
494
502
  if self.use_zeroinflate:
495
503
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -513,8 +521,8 @@ class DensityFlow(nn.Module):
513
521
  self.options = dict(dtype=xs.dtype, device=xs.device)
514
522
 
515
523
  if self.loss_func=='negbinomial':
516
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
517
- xs.new_ones(self.input_size), constraint=constraints.positive)
524
+ dispersion = pyro.param("dispersion", self.dispersion *
525
+ xs.new_ones(self.input_size), constraint=constraints.positive)
518
526
 
519
527
  if self.use_zeroinflate:
520
528
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -565,28 +573,31 @@ class DensityFlow(nn.Module):
565
573
 
566
574
  zs = zns
567
575
 
568
- concentrate = self.decoder_concentrate(zs)
576
+ log_mu = self.decoder_log_mu(zs)
569
577
  if self.loss_func in ['bernoulli']:
570
- log_theta = concentrate
578
+ log_theta = log_mu
579
+ elif self.loss_func in ['negbinomial']:
580
+ mu = log_mu.exp()
571
581
  else:
572
- rate = concentrate.exp()
582
+ rate = log_mu.exp()
573
583
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
574
584
  if self.loss_func == 'poisson':
575
585
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
576
586
 
577
587
  if self.loss_func == 'negbinomial':
588
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
578
589
  if self.use_zeroinflate:
579
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
590
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
591
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
580
592
  else:
581
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
593
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
582
594
  elif self.loss_func == 'poisson':
583
595
  if self.use_zeroinflate:
584
596
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
585
597
  else:
586
598
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
587
599
  elif self.loss_func == 'multinomial':
588
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
589
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
600
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
590
601
  elif self.loss_func == 'bernoulli':
591
602
  if self.use_zeroinflate:
592
603
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -610,8 +621,8 @@ class DensityFlow(nn.Module):
610
621
  self.options = dict(dtype=xs.dtype, device=xs.device)
611
622
 
612
623
  if self.loss_func=='negbinomial':
613
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
614
- xs.new_ones(self.input_size), constraint=constraints.positive)
624
+ dispersion = pyro.param("dispersion", self.dispersion *
625
+ xs.new_ones(self.input_size), constraint=constraints.positive)
615
626
 
616
627
  if self.use_zeroinflate:
617
628
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -672,28 +683,31 @@ class DensityFlow(nn.Module):
672
683
  else:
673
684
  zs = zns
674
685
 
675
- concentrate = self.decoder_concentrate(zs)
686
+ log_mu = self.decoder_log_mu(zs)
676
687
  if self.loss_func in ['bernoulli']:
677
- log_theta = concentrate
688
+ log_theta = log_mu
689
+ elif self.loss_func in ['negbinomial']:
690
+ mu = log_mu.exp()
678
691
  else:
679
- rate = concentrate.exp()
692
+ rate = log_mu.exp()
680
693
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
681
694
  if self.loss_func == 'poisson':
682
695
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
683
696
 
684
697
  if self.loss_func == 'negbinomial':
698
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
685
699
  if self.use_zeroinflate:
686
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
700
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
701
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
687
702
  else:
688
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
703
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
689
704
  elif self.loss_func == 'poisson':
690
705
  if self.use_zeroinflate:
691
706
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
692
707
  else:
693
708
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
694
709
  elif self.loss_func == 'multinomial':
695
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
696
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
710
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
697
711
  elif self.loss_func == 'bernoulli':
698
712
  if self.use_zeroinflate:
699
713
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -717,13 +731,13 @@ class DensityFlow(nn.Module):
717
731
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
718
732
  #else:
719
733
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
720
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
734
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
721
735
  else:
722
736
  #if self.turn_off_cell_specific:
723
737
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
724
738
  #else:
725
739
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
726
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
740
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
727
741
  return zus
728
742
 
729
743
  def _get_codebook_identity(self):
@@ -865,12 +879,12 @@ class DensityFlow(nn.Module):
865
879
  us_i = us[:,pert_idx].reshape(-1,1)
866
880
 
867
881
  # factor effect of xs
868
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
882
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
869
883
 
870
884
  # perturbation effect
871
885
  ps = np.ones_like(us_i)
872
886
  if np.sum(np.abs(ps-us_i))>=1:
873
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
887
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
874
888
  zs = zs + dzs0 + dzs
875
889
  else:
876
890
  zs = zs + dzs0
@@ -884,10 +898,11 @@ class DensityFlow(nn.Module):
884
898
  library_sizes = library_sizes.reshape(-1,1)
885
899
 
886
900
  counts = self.get_counts(zs, library_sizes=library_sizes)
901
+ log_mu = self.get_log_mu(zs)
887
902
 
888
- return counts, zs
903
+ return counts, log_mu
889
904
 
890
- def _cell_response(self, zs, perturb_idx, perturb):
905
+ def _cell_shift(self, zs, perturb_idx, perturb):
891
906
  #zns,_ = self.encoder_zn(xs)
892
907
  #zns,_ = self._get_basal_embedding(xs)
893
908
  zns = zs
@@ -904,7 +919,7 @@ class DensityFlow(nn.Module):
904
919
 
905
920
  return ms
906
921
 
907
- def get_cell_response(self,
922
+ def get_cell_shift(self,
908
923
  zs,
909
924
  perturb_idx,
910
925
  perturb_us,
@@ -922,46 +937,43 @@ class DensityFlow(nn.Module):
922
937
  Z = []
923
938
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
924
939
  for Z_batch, P_batch, _ in dataloader:
925
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
940
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
926
941
  Z.append(tensor_to_numpy(zns))
927
942
  pbar.update(1)
928
943
 
929
944
  Z = np.concatenate(Z)
930
945
  return Z
931
946
 
932
- def _get_expression_response(self, delta_zs):
933
- return self.decoder_concentrate(delta_zs)
947
+ def _log_mu(self, zs):
948
+ return self.decoder_log_mu(zs)
934
949
 
935
- def get_expression_response(self,
936
- delta_zs,
937
- batch_size: int = 1024):
950
+ def get_log_mu(self, zs, batch_size: int = 1024):
938
951
  """
939
952
  Return cells' changes in the feature space induced by specific perturbation of a factor
940
953
 
941
954
  """
942
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
943
- dataset = CustomDataset(delta_zs)
955
+ zs = convert_to_tensor(zs, device=self.get_device())
956
+ dataset = CustomDataset(zs)
944
957
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
945
958
 
946
959
  R = []
947
960
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
948
- for delta_Z_batch, _ in dataloader:
949
- r = self._get_expression_response(delta_Z_batch)
961
+ for Z_batch, _ in dataloader:
962
+ r = self._log_mu(Z_batch)
950
963
  R.append(tensor_to_numpy(r))
951
964
  pbar.update(1)
952
965
 
953
966
  R = np.concatenate(R)
954
967
  return R
955
968
 
956
- def _count(self, concentrate, library_size=None):
969
+ def _count(self, log_mu, library_size=None):
957
970
  if self.loss_func == 'bernoulli':
958
- #counts = self.sigmoid(concentrate)
959
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
971
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
960
972
  elif self.loss_func == 'multinomial':
961
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
973
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
962
974
  counts = theta * library_size
963
975
  else:
964
- rate = concentrate.exp()
976
+ rate = log_mu.exp()
965
977
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
966
978
  counts = theta * library_size
967
979
  return counts
@@ -983,8 +995,8 @@ class DensityFlow(nn.Module):
983
995
  E = []
984
996
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
985
997
  for Z_batch, L_batch, _ in dataloader:
986
- concentrate = self._get_expression_response(Z_batch)
987
- counts = self._count(concentrate, L_batch)
998
+ log_mu = self._log_mu(Z_batch)
999
+ counts = self._count(log_mu, L_batch)
988
1000
  E.append(tensor_to_numpy(counts))
989
1001
  pbar.update(1)
990
1002
 
@@ -1130,7 +1142,7 @@ class DensityFlow(nn.Module):
1130
1142
  pbar.set_postfix({'loss': str_loss})
1131
1143
  pbar.update(1)
1132
1144
 
1133
- @classmethod
1145
+ '''@classmethod
1134
1146
  def save_model(cls, model, file_path, compression=False):
1135
1147
  """Save the model to the specified file path."""
1136
1148
  file_path = os.path.abspath(file_path)
@@ -1158,6 +1170,45 @@ class DensityFlow(nn.Module):
1158
1170
  with open(file_path, 'rb') as pickle_file:
1159
1171
  model = pickle.load(pickle_file)
1160
1172
 
1173
+ return model'''
1174
+
1175
+ def save(self, path):
1176
+ """Save model checkpoint"""
1177
+ torch.save({
1178
+ 'model_state_dict': self.state_dict(),
1179
+ 'model_config': {
1180
+ 'input_size': self.input_size,
1181
+ 'codebook_size': self.code_size,
1182
+ 'cell_factor_size': self.cell_factor_size,
1183
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1184
+ 'supervised_mode':self.supervised_mode,
1185
+ 'z_dim': self.latent_dim,
1186
+ 'z_dist': self.latent_dist,
1187
+ 'loss_func': self.loss_func,
1188
+ 'dispersion': self.dispersion,
1189
+ 'use_zeroinflate': self.use_zeroinflate,
1190
+ 'hidden_layers':self.hidden_layers,
1191
+ 'hidden_layer_activation':self.hidden_layer_activation,
1192
+ 'nn_dropout':self.nn_dropout,
1193
+ 'post_layer_fct':self.post_layer_fct,
1194
+ 'post_act_fct':self.post_act_fct,
1195
+ 'config_enum':self.config_enum,
1196
+ 'use_cuda':self.use_cuda,
1197
+ 'seed':self.seed,
1198
+ 'zero_bias':self.use_bias,
1199
+ 'dtype':self.dtype,
1200
+ }
1201
+ }, path)
1202
+
1203
+ @classmethod
1204
+ def load_model(cls, model_path: str):
1205
+ """Load pre-trained model"""
1206
+ checkpoint = torch.load(model_path)
1207
+ model = DensityFlow(**checkpoint.get('model_config'))
1208
+
1209
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1210
+ model.load_state_dict(checkpoint['model_state_dict'])
1211
+
1161
1212
  return model
1162
1213
 
1163
1214
 
@@ -1357,7 +1408,7 @@ def main():
1357
1408
  df = DensityFlow(
1358
1409
  input_size=input_size,
1359
1410
  cell_factor_size=cell_factor_size,
1360
- inverse_dispersion=args.inverse_dispersion,
1411
+ dispersion=args.dispersion,
1361
1412
  z_dim=args.z_dim,
1362
1413
  hidden_layers=args.hidden_layers,
1363
1414
  hidden_layer_activation=args.hidden_layer_activation,