SURE-tools 2.4.25__tar.gz → 2.4.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (37) hide show
  1. {sure_tools-2.4.25 → sure_tools-2.4.38}/PKG-INFO +1 -1
  2. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/DensityFlow.py +149 -69
  3. sure_tools-2.4.25/SURE/DensityFlow2.py → sure_tools-2.4.38/SURE/DensityFlowLinear.py +91 -99
  4. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/__init__.py +3 -3
  5. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/SOURCES.txt +1 -1
  7. {sure_tools-2.4.25 → sure_tools-2.4.38}/setup.py +1 -1
  8. {sure_tools-2.4.25 → sure_tools-2.4.38}/LICENSE +0 -0
  9. {sure_tools-2.4.25 → sure_tools-2.4.38}/README.md +0 -0
  10. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/EfficientTranscriptomeDecoder.py +0 -0
  11. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/PerturbE.py +0 -0
  12. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/PerturbationAwareDecoder.py +0 -0
  13. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/SURE.py +0 -0
  14. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/SimpleTranscriptomeDecoder.py +0 -0
  15. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/TranscriptomeDecoder.py +0 -0
  16. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/VirtualCellDecoder.py +0 -0
  17. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/assembly/__init__.py +0 -0
  18. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/assembly/assembly.py +0 -0
  19. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/assembly/atlas.py +0 -0
  20. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/atac/__init__.py +0 -0
  21. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/atac/utils.py +0 -0
  22. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/codebook/__init__.py +0 -0
  23. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/codebook/codebook.py +0 -0
  24. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/flow/__init__.py +0 -0
  25. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/flow/flow_stats.py +0 -0
  26. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/flow/plot_quiver.py +0 -0
  27. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/perturb/__init__.py +0 -0
  28. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/perturb/perturb.py +0 -0
  29. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/__init__.py +0 -0
  30. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/custom_mlp.py +0 -0
  31. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/queue.py +0 -0
  32. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE/utils/utils.py +0 -0
  33. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/dependency_links.txt +0 -0
  34. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/entry_points.txt +0 -0
  35. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/requires.txt +0 -0
  36. {sure_tools-2.4.25 → sure_tools-2.4.38}/SURE_tools.egg-info/top_level.txt +0 -0
  37. {sure_tools-2.4.25 → sure_tools-2.4.38}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.25
3
+ Version: 2.4.38
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 100,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,6 +108,7 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
112
114
 
@@ -258,7 +260,7 @@ class DensityFlow(nn.Module):
258
260
  )
259
261
  )
260
262
 
261
- self.decoder_concentrate = MLP(
263
+ self.decoder_log_mu = MLP(
262
264
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
263
265
  activation=activate_fct,
264
266
  output_activation=None,
@@ -348,8 +350,8 @@ class DensityFlow(nn.Module):
348
350
  self.options = dict(dtype=xs.dtype, device=xs.device)
349
351
 
350
352
  if self.loss_func=='negbinomial':
351
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
352
- xs.new_ones(self.input_size), constraint=constraints.positive)
353
+ dispersion = pyro.param("dispersion", self.dispersion *
354
+ xs.new_ones(self.input_size), constraint=constraints.positive)
353
355
 
354
356
  if self.use_zeroinflate:
355
357
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -383,28 +385,32 @@ class DensityFlow(nn.Module):
383
385
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
384
386
 
385
387
  zs = zns
386
- concentrate = self.decoder_concentrate(zs)
388
+ log_mu = self.decoder_log_mu(zs)
387
389
  if self.loss_func in ['bernoulli']:
388
- log_theta = concentrate
390
+ log_theta = log_mu
391
+ elif self.loss_func == 'negbinomial':
392
+ mu = log_mu.exp()
389
393
  else:
390
- rate = concentrate.exp()
394
+ rate = log_mu.exp()
391
395
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
392
396
  if self.loss_func == 'poisson':
393
397
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
394
398
 
395
399
  if self.loss_func == 'negbinomial':
400
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
396
401
  if self.use_zeroinflate:
397
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
402
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
403
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
398
404
  else:
399
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
405
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
406
+ logits=logits).to_event(1), obs=xs)
400
407
  elif self.loss_func == 'poisson':
401
408
  if self.use_zeroinflate:
402
409
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
410
  else:
404
411
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
412
  elif self.loss_func == 'multinomial':
406
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
407
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
413
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
408
414
  elif self.loss_func == 'bernoulli':
409
415
  if self.use_zeroinflate:
410
416
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -428,8 +434,8 @@ class DensityFlow(nn.Module):
428
434
  self.options = dict(dtype=xs.dtype, device=xs.device)
429
435
 
430
436
  if self.loss_func=='negbinomial':
431
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
432
- xs.new_ones(self.input_size), constraint=constraints.positive)
437
+ dispersion = pyro.param("dispersion", self.dispersion *
438
+ xs.new_ones(self.input_size), constraint=constraints.positive)
433
439
 
434
440
  if self.use_zeroinflate:
435
441
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -468,28 +474,30 @@ class DensityFlow(nn.Module):
468
474
  else:
469
475
  zs = zns
470
476
 
471
- concentrate = self.decoder_concentrate(zs)
477
+ log_mu = self.decoder_log_mu(zs)
472
478
  if self.loss_func in ['bernoulli']:
473
- log_theta = concentrate
479
+ log_theta = log_mu
480
+ elif self.loss_func == 'negbinomial':
481
+ mu = log_mu.exp()
474
482
  else:
475
- rate = concentrate.exp()
483
+ rate = log_mu.exp()
476
484
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
477
485
  if self.loss_func == 'poisson':
478
486
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
479
487
 
480
488
  if self.loss_func == 'negbinomial':
489
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
481
490
  if self.use_zeroinflate:
482
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
491
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
483
492
  else:
484
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
493
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
485
494
  elif self.loss_func == 'poisson':
486
495
  if self.use_zeroinflate:
487
496
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
488
497
  else:
489
498
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
490
499
  elif self.loss_func == 'multinomial':
491
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
492
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
500
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
493
501
  elif self.loss_func == 'bernoulli':
494
502
  if self.use_zeroinflate:
495
503
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -513,8 +521,8 @@ class DensityFlow(nn.Module):
513
521
  self.options = dict(dtype=xs.dtype, device=xs.device)
514
522
 
515
523
  if self.loss_func=='negbinomial':
516
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
517
- xs.new_ones(self.input_size), constraint=constraints.positive)
524
+ dispersion = pyro.param("dispersion", self.dispersion *
525
+ xs.new_ones(self.input_size), constraint=constraints.positive)
518
526
 
519
527
  if self.use_zeroinflate:
520
528
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -565,28 +573,31 @@ class DensityFlow(nn.Module):
565
573
 
566
574
  zs = zns
567
575
 
568
- concentrate = self.decoder_concentrate(zs)
576
+ log_mu = self.decoder_log_mu(zs)
569
577
  if self.loss_func in ['bernoulli']:
570
- log_theta = concentrate
578
+ log_theta = log_mu
579
+ elif self.loss_func in ['negbinomial']:
580
+ mu = log_mu.exp()
571
581
  else:
572
- rate = concentrate.exp()
582
+ rate = log_mu.exp()
573
583
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
574
584
  if self.loss_func == 'poisson':
575
585
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
576
586
 
577
587
  if self.loss_func == 'negbinomial':
588
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
578
589
  if self.use_zeroinflate:
579
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
590
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
591
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
580
592
  else:
581
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
593
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
582
594
  elif self.loss_func == 'poisson':
583
595
  if self.use_zeroinflate:
584
596
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
585
597
  else:
586
598
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
587
599
  elif self.loss_func == 'multinomial':
588
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
589
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
600
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
590
601
  elif self.loss_func == 'bernoulli':
591
602
  if self.use_zeroinflate:
592
603
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -610,8 +621,8 @@ class DensityFlow(nn.Module):
610
621
  self.options = dict(dtype=xs.dtype, device=xs.device)
611
622
 
612
623
  if self.loss_func=='negbinomial':
613
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
614
- xs.new_ones(self.input_size), constraint=constraints.positive)
624
+ dispersion = pyro.param("dispersion", self.dispersion *
625
+ xs.new_ones(self.input_size), constraint=constraints.positive)
615
626
 
616
627
  if self.use_zeroinflate:
617
628
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -672,28 +683,31 @@ class DensityFlow(nn.Module):
672
683
  else:
673
684
  zs = zns
674
685
 
675
- concentrate = self.decoder_concentrate(zs)
686
+ log_mu = self.decoder_log_mu(zs)
676
687
  if self.loss_func in ['bernoulli']:
677
- log_theta = concentrate
688
+ log_theta = log_mu
689
+ elif self.loss_func in ['negbinomial']:
690
+ mu = log_mu.exp()
678
691
  else:
679
- rate = concentrate.exp()
692
+ rate = log_mu.exp()
680
693
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
681
694
  if self.loss_func == 'poisson':
682
695
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
683
696
 
684
697
  if self.loss_func == 'negbinomial':
698
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
685
699
  if self.use_zeroinflate:
686
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
700
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
701
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
687
702
  else:
688
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
703
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
689
704
  elif self.loss_func == 'poisson':
690
705
  if self.use_zeroinflate:
691
706
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
692
707
  else:
693
708
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
694
709
  elif self.loss_func == 'multinomial':
695
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
696
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
710
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
697
711
  elif self.loss_func == 'bernoulli':
698
712
  if self.use_zeroinflate:
699
713
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -717,13 +731,13 @@ class DensityFlow(nn.Module):
717
731
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
718
732
  #else:
719
733
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
720
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
734
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
721
735
  else:
722
736
  #if self.turn_off_cell_specific:
723
737
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
724
738
  #else:
725
739
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
726
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
740
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
727
741
  return zus
728
742
 
729
743
  def _get_codebook_identity(self):
@@ -744,6 +758,28 @@ class DensityFlow(nn.Module):
744
758
  cb = self._get_codebook()
745
759
  cb = tensor_to_numpy(cb)
746
760
  return cb
761
+
762
+ def _get_complete_embedding(self, xs, us):
763
+ basal,_ = self._get_basal_embedding(xs)
764
+ dzs = self._total_effects(basal, us)
765
+ return basal + dzs
766
+
767
+ def get_complete_embedding(self, xs, us, batch_size:int=1024):
768
+ xs = self.preprocess(xs)
769
+ xs = convert_to_tensor(xs, device=self.get_device())
770
+ us = convert_to_tensor(us, device=self.get_device())
771
+ dataset = CustomDataset2(xs, us)
772
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
773
+
774
+ Z = []
775
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
776
+ for X_batch, U_batch, _ in dataloader:
777
+ zns = self._get_basal_embedding(X_batch, U_batch)
778
+ Z.append(tensor_to_numpy(zns))
779
+ pbar.update(1)
780
+
781
+ Z = np.concatenate(Z)
782
+ return Z
747
783
 
748
784
  def _get_basal_embedding(self, xs):
749
785
  loc, scale = self.encoder_zn(xs)
@@ -865,12 +901,12 @@ class DensityFlow(nn.Module):
865
901
  us_i = us[:,pert_idx].reshape(-1,1)
866
902
 
867
903
  # factor effect of xs
868
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
904
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
869
905
 
870
906
  # perturbation effect
871
907
  ps = np.ones_like(us_i)
872
908
  if np.sum(np.abs(ps-us_i))>=1:
873
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
909
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
874
910
  zs = zs + dzs0 + dzs
875
911
  else:
876
912
  zs = zs + dzs0
@@ -884,10 +920,11 @@ class DensityFlow(nn.Module):
884
920
  library_sizes = library_sizes.reshape(-1,1)
885
921
 
886
922
  counts = self.get_counts(zs, library_sizes=library_sizes)
923
+ log_mu = self.get_log_mu(zs)
887
924
 
888
- return counts, zs
925
+ return counts, log_mu
889
926
 
890
- def _cell_response(self, zs, perturb_idx, perturb):
927
+ def _cell_shift(self, zs, perturb_idx, perturb):
891
928
  #zns,_ = self.encoder_zn(xs)
892
929
  #zns,_ = self._get_basal_embedding(xs)
893
930
  zns = zs
@@ -904,7 +941,7 @@ class DensityFlow(nn.Module):
904
941
 
905
942
  return ms
906
943
 
907
- def get_cell_response(self,
944
+ def get_cell_shift(self,
908
945
  zs,
909
946
  perturb_idx,
910
947
  perturb_us,
@@ -922,46 +959,43 @@ class DensityFlow(nn.Module):
922
959
  Z = []
923
960
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
924
961
  for Z_batch, P_batch, _ in dataloader:
925
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
962
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
926
963
  Z.append(tensor_to_numpy(zns))
927
964
  pbar.update(1)
928
965
 
929
966
  Z = np.concatenate(Z)
930
967
  return Z
931
968
 
932
- def _get_expression_response(self, delta_zs):
933
- return self.decoder_concentrate(delta_zs)
969
+ def _log_mu(self, zs):
970
+ return self.decoder_log_mu(zs)
934
971
 
935
- def get_expression_response(self,
936
- delta_zs,
937
- batch_size: int = 1024):
972
+ def get_log_mu(self, zs, batch_size: int = 1024):
938
973
  """
939
974
  Return cells' changes in the feature space induced by specific perturbation of a factor
940
975
 
941
976
  """
942
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
943
- dataset = CustomDataset(delta_zs)
977
+ zs = convert_to_tensor(zs, device=self.get_device())
978
+ dataset = CustomDataset(zs)
944
979
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
945
980
 
946
981
  R = []
947
982
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
948
- for delta_Z_batch, _ in dataloader:
949
- r = self._get_expression_response(delta_Z_batch)
983
+ for Z_batch, _ in dataloader:
984
+ r = self._log_mu(Z_batch)
950
985
  R.append(tensor_to_numpy(r))
951
986
  pbar.update(1)
952
987
 
953
988
  R = np.concatenate(R)
954
989
  return R
955
990
 
956
- def _count(self, concentrate, library_size=None):
991
+ def _count(self, log_mu, library_size=None):
957
992
  if self.loss_func == 'bernoulli':
958
- #counts = self.sigmoid(concentrate)
959
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
993
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
960
994
  elif self.loss_func == 'multinomial':
961
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
995
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
962
996
  counts = theta * library_size
963
997
  else:
964
- rate = concentrate.exp()
998
+ rate = log_mu.exp()
965
999
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
966
1000
  counts = theta * library_size
967
1001
  return counts
@@ -983,8 +1017,8 @@ class DensityFlow(nn.Module):
983
1017
  E = []
984
1018
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
985
1019
  for Z_batch, L_batch, _ in dataloader:
986
- concentrate = self._get_expression_response(Z_batch)
987
- counts = self._count(concentrate, L_batch)
1020
+ log_mu = self._log_mu(Z_batch)
1021
+ counts = self._count(log_mu, L_batch)
988
1022
  E.append(tensor_to_numpy(counts))
989
1023
  pbar.update(1)
990
1024
 
@@ -1157,8 +1191,54 @@ class DensityFlow(nn.Module):
1157
1191
  else:
1158
1192
  with open(file_path, 'rb') as pickle_file:
1159
1193
  model = pickle.load(pickle_file)
1194
+
1195
+ print(f"🧬 DensityFlow Initialized:")
1196
+ print(f" - Latent Dimension: {model.latent_dim}")
1197
+ print(f" - Gene Dimension: {model.input_size}")
1198
+ print(f" - Hidden Dimensions: {model.hidden_layers}")
1199
+ print(f" - Device: {model.get_device()}")
1200
+ print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
1160
1201
 
1161
1202
  return model
1203
+
1204
+ ''' def save(self, path):
1205
+ """Save model checkpoint"""
1206
+ torch.save({
1207
+ 'model_state_dict': self.state_dict(),
1208
+ 'model_config': {
1209
+ 'input_size': self.input_size,
1210
+ 'codebook_size': self.code_size,
1211
+ 'cell_factor_size': self.cell_factor_size,
1212
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1213
+ 'supervised_mode':self.supervised_mode,
1214
+ 'z_dim': self.latent_dim,
1215
+ 'z_dist': self.latent_dist,
1216
+ 'loss_func': self.loss_func,
1217
+ 'dispersion': self.dispersion,
1218
+ 'use_zeroinflate': self.use_zeroinflate,
1219
+ 'hidden_layers':self.hidden_layers,
1220
+ 'hidden_layer_activation':self.hidden_layer_activation,
1221
+ 'nn_dropout':self.nn_dropout,
1222
+ 'post_layer_fct':self.post_layer_fct,
1223
+ 'post_act_fct':self.post_act_fct,
1224
+ 'config_enum':self.config_enum,
1225
+ 'use_cuda':self.use_cuda,
1226
+ 'seed':self.seed,
1227
+ 'zero_bias':self.use_bias,
1228
+ 'dtype':self.dtype,
1229
+ }
1230
+ }, path)
1231
+
1232
+ @classmethod
1233
+ def load_model(cls, model_path: str):
1234
+ """Load pre-trained model"""
1235
+ checkpoint = torch.load(model_path)
1236
+ model = DensityFlow(**checkpoint.get('model_config'))
1237
+
1238
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1239
+ model.load_state_dict(checkpoint['model_state_dict'])
1240
+
1241
+ return model'''
1162
1242
 
1163
1243
 
1164
1244
  EXAMPLE_RUN = (
@@ -1357,7 +1437,7 @@ def main():
1357
1437
  df = DensityFlow(
1358
1438
  input_size=input_size,
1359
1439
  cell_factor_size=cell_factor_size,
1360
- inverse_dispersion=args.inverse_dispersion,
1440
+ dispersion=args.dispersion,
1361
1441
  z_dim=args.z_dim,
1362
1442
  hidden_layers=args.hidden_layers,
1363
1443
  hidden_layer_activation=args.hidden_layer_activation,
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class DensityFlow2(nn.Module):
57
+ class DensityFlowLinear(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
@@ -63,8 +63,8 @@ class DensityFlow2(nn.Module):
63
63
  supervised_mode: bool = False,
64
64
  z_dim: int = 10,
65
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
69
  hidden_layers: list = [500],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
@@ -81,7 +81,7 @@ class DensityFlow2(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
@@ -107,10 +107,11 @@ class DensityFlow2(nn.Module):
107
107
 
108
108
  self.codebook_weights = None
109
109
 
110
+ self.seed = seed
110
111
  set_random_seed(seed)
111
112
  self.setup_networks()
112
113
 
113
- print(f"🧬 DensityFlow2 Initialized:")
114
+ print(f"🧬 DensityFlowLinear Initialized:")
114
115
  print(f" - Latent Dimension: {self.latent_dim}")
115
116
  print(f" - Gene Dimension: {self.input_size}")
116
117
  print(f" - Hidden Dimensions: {self.hidden_layers}")
@@ -208,17 +209,6 @@ class DensityFlow2(nn.Module):
208
209
  use_cuda=self.use_cuda,
209
210
  )
210
211
 
211
- if self.loss_func == 'negbinomial':
212
- self.encoder_inverse_dispersion = MLP(
213
- [self.latent_dim] + hidden_sizes + [[self.input_size, self.input_size]],
214
- activation=activate_fct,
215
- output_activation=[Exp, Exp],
216
- post_layer_fct=post_layer_fct,
217
- post_act_fct=post_act_fct,
218
- allow_broadcast=self.allow_broadcast,
219
- use_cuda=self.use_cuda,
220
- )
221
-
222
212
  if self.cell_factor_size>0:
223
213
  self.cell_factor_effect = nn.ModuleList()
224
214
  for i in np.arange(self.cell_factor_size):
@@ -269,7 +259,7 @@ class DensityFlow2(nn.Module):
269
259
  )
270
260
  )
271
261
 
272
- self.decoder_concentrate = MLP(
262
+ self.decoder_log_mu = MLP(
273
263
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
274
264
  activation=activate_fct,
275
265
  output_activation=None,
@@ -352,15 +342,15 @@ class DensityFlow2(nn.Module):
352
342
  return xs
353
343
 
354
344
  def model1(self, xs):
355
- pyro.module('DensityFlow2', self)
345
+ pyro.module('DensityFlowLinear', self)
356
346
 
357
347
  eps = torch.finfo(xs.dtype).eps
358
348
  batch_size = xs.size(0)
359
349
  self.options = dict(dtype=xs.dtype, device=xs.device)
360
350
 
361
351
  if self.loss_func=='negbinomial':
362
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
363
- xs.new_ones(self.input_size), constraint=constraints.positive)
352
+ dispersion = pyro.param("dispersion", self.dispersion *
353
+ xs.new_ones(self.input_size), constraint=constraints.positive)
364
354
 
365
355
  if self.use_zeroinflate:
366
356
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -394,28 +384,31 @@ class DensityFlow2(nn.Module):
394
384
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
395
385
 
396
386
  zs = zns
397
- concentrate = self.decoder_concentrate(zs)
387
+ log_mu = self.decoder_log_mu(zs)
398
388
  if self.loss_func in ['bernoulli']:
399
- log_theta = concentrate
389
+ log_theta = log_mu
390
+ elif self.loss_func in ['negbinomial']:
391
+ mu = log_mu.exp()
400
392
  else:
401
- rate = concentrate.exp()
393
+ rate = log_mu.exp()
402
394
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
403
395
  if self.loss_func == 'poisson':
404
396
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
405
397
 
406
398
  if self.loss_func == 'negbinomial':
399
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
407
400
  if self.use_zeroinflate:
408
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
401
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
402
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
409
403
  else:
410
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
404
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
411
405
  elif self.loss_func == 'poisson':
412
406
  if self.use_zeroinflate:
413
407
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
414
408
  else:
415
409
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
416
410
  elif self.loss_func == 'multinomial':
417
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
418
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
411
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
419
412
  elif self.loss_func == 'bernoulli':
420
413
  if self.use_zeroinflate:
421
414
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -432,17 +425,15 @@ class DensityFlow2(nn.Module):
432
425
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
433
426
 
434
427
  def model2(self, xs, us=None):
435
- pyro.module('DensityFlow2', self)
428
+ pyro.module('DensityFlowLinear', self)
436
429
 
437
430
  eps = torch.finfo(xs.dtype).eps
438
431
  batch_size = xs.size(0)
439
432
  self.options = dict(dtype=xs.dtype, device=xs.device)
440
433
 
441
434
  if self.loss_func=='negbinomial':
442
- #total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
443
- # xs.new_ones(self.input_size), constraint=constraints.positive)
444
- with pyro.plate("genes", self.input_size):
445
- inverse_dispersion = pyro.sample("inverse_dispersion", dist.LogNormal(self.inverse_dispersion, 0.5).to_event(1))
435
+ dispersion = pyro.param("dispersion", self.dispersion *
436
+ xs.new_ones(self.input_size), constraint=constraints.positive)
446
437
 
447
438
  if self.use_zeroinflate:
448
439
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -482,28 +473,28 @@ class DensityFlow2(nn.Module):
482
473
  zs = zns'''
483
474
 
484
475
  zs = zns
485
- concentrate = self.decoder_concentrate(zs)
476
+ log_mu = self.decoder_log_mu(zs)
486
477
  for i in np.arange(self.cell_factor_size):
487
478
  zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
488
- concentrate += self.decoder_concentrate(zus)
479
+ log_mu += self.decoder_log_mu(zus)
489
480
 
490
481
  if self.loss_func in ['bernoulli']:
491
- log_theta = concentrate
482
+ log_theta = log_mu
492
483
  elif self.loss_func in ['negbinomial']:
493
- mu = concentrate.exp()
484
+ mu = log_mu.exp()
494
485
  else:
495
- rate = concentrate.exp()
486
+ rate = log_mu.exp()
496
487
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
497
488
  if self.loss_func == 'poisson':
498
489
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
499
490
 
500
491
  if self.loss_func == 'negbinomial':
501
- logits = (mu.log()-inverse_dispersion.log()).clamp(min=-10, max=10)
492
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
502
493
  if self.use_zeroinflate:
503
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=inverse_dispersion,
494
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
504
495
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
505
496
  else:
506
- pyro.sample('x', dist.NegativeBinomial(total_count=inverse_dispersion,
497
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
507
498
  logits=logits).to_event(1), obs=xs)
508
499
  elif self.loss_func == 'poisson':
509
500
  if self.use_zeroinflate:
@@ -511,8 +502,7 @@ class DensityFlow2(nn.Module):
511
502
  else:
512
503
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
513
504
  elif self.loss_func == 'multinomial':
514
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
515
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
505
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
516
506
  elif self.loss_func == 'bernoulli':
517
507
  if self.use_zeroinflate:
518
508
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -528,21 +518,16 @@ class DensityFlow2(nn.Module):
528
518
  alpha = self.encoder_n(zns)
529
519
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
530
520
 
531
- if self.loss_func == 'negbinomial':
532
- id_loc,id_scale = self.encoder_inverse_dispersion(zns)
533
- with pyro.plate("genes", self.input_size):
534
- pyro.sample("inverse_dispersion", dist.LogNormal(id_loc, id_scale).to_event(1))
535
-
536
521
  def model3(self, xs, ys, embeds=None):
537
- pyro.module('DensityFlow2', self)
522
+ pyro.module('DensityFlowLinear', self)
538
523
 
539
524
  eps = torch.finfo(xs.dtype).eps
540
525
  batch_size = xs.size(0)
541
526
  self.options = dict(dtype=xs.dtype, device=xs.device)
542
527
 
543
528
  if self.loss_func=='negbinomial':
544
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
545
- xs.new_ones(self.input_size), constraint=constraints.positive)
529
+ dispersion = pyro.param("dispersion", self.dispersion *
530
+ xs.new_ones(self.input_size), constraint=constraints.positive)
546
531
 
547
532
  if self.use_zeroinflate:
548
533
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -593,28 +578,31 @@ class DensityFlow2(nn.Module):
593
578
 
594
579
  zs = zns
595
580
 
596
- concentrate = self.decoder_concentrate(zs)
581
+ log_mu = self.decoder_log_mu(zs)
597
582
  if self.loss_func in ['bernoulli']:
598
- log_theta = concentrate
583
+ log_theta = log_mu
584
+ elif self.loss_func in ['negbinomial']:
585
+ mu = log_mu.exp()
599
586
  else:
600
- rate = concentrate.exp()
587
+ rate = log_mu.exp()
601
588
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
602
589
  if self.loss_func == 'poisson':
603
590
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
604
591
 
605
592
  if self.loss_func == 'negbinomial':
593
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
606
594
  if self.use_zeroinflate:
607
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
595
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
596
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
608
597
  else:
609
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
598
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
610
599
  elif self.loss_func == 'poisson':
611
600
  if self.use_zeroinflate:
612
601
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
613
602
  else:
614
603
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
615
604
  elif self.loss_func == 'multinomial':
616
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
617
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
605
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
618
606
  elif self.loss_func == 'bernoulli':
619
607
  if self.use_zeroinflate:
620
608
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -631,15 +619,15 @@ class DensityFlow2(nn.Module):
631
619
  zns = embeds
632
620
 
633
621
  def model4(self, xs, us, ys, embeds=None):
634
- pyro.module('DensityFlow2', self)
622
+ pyro.module('DensityFlowLinear', self)
635
623
 
636
624
  eps = torch.finfo(xs.dtype).eps
637
625
  batch_size = xs.size(0)
638
626
  self.options = dict(dtype=xs.dtype, device=xs.device)
639
627
 
640
628
  if self.loss_func=='negbinomial':
641
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
642
- xs.new_ones(self.input_size), constraint=constraints.positive)
629
+ dispersion = pyro.param("dispersion", self.dispersion *
630
+ xs.new_ones(self.input_size), constraint=constraints.positive)
643
631
 
644
632
  if self.use_zeroinflate:
645
633
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -695,32 +683,34 @@ class DensityFlow2(nn.Module):
695
683
  zs = zns'''
696
684
 
697
685
  zs = zns
698
- concentrate = self.decoder_concentrate(zs)
686
+ log_mu = self.decoder_log_mu(zs)
699
687
  for i in np.arange(self.cell_factor_size):
700
688
  zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
701
- concentrate += self.decoder_concentrate(zus)
689
+ log_mu += self.decoder_log_mu(zus)
702
690
 
703
691
  if self.loss_func in ['bernoulli']:
704
- log_theta = concentrate
692
+ log_theta = log_mu
693
+ elif self.loss_func in ['negbinomial']:
694
+ mu = log_mu.exp()
705
695
  else:
706
- rate = concentrate.exp()
696
+ rate = log_mu.exp()
707
697
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
708
698
  if self.loss_func == 'poisson':
709
699
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
710
700
 
711
701
  if self.loss_func == 'negbinomial':
702
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
712
703
  if self.use_zeroinflate:
713
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
704
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
714
705
  else:
715
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
706
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
716
707
  elif self.loss_func == 'poisson':
717
708
  if self.use_zeroinflate:
718
709
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
719
710
  else:
720
711
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
721
712
  elif self.loss_func == 'multinomial':
722
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
723
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
713
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
724
714
  elif self.loss_func == 'bernoulli':
725
715
  if self.use_zeroinflate:
726
716
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -887,6 +877,8 @@ class DensityFlow2(nn.Module):
887
877
 
888
878
  # basal embedding
889
879
  zs = self.get_basal_embedding(xs)
880
+ log_mu = self.get_log_mu(zs)
881
+
890
882
  for pert in perturbs_predict:
891
883
  pert_idx = int(np.where(perturbs_reference==pert)[0])
892
884
  us_i = us[:,pert_idx].reshape(-1,1)
@@ -898,10 +890,12 @@ class DensityFlow2(nn.Module):
898
890
  ps = np.ones_like(us_i)
899
891
  if np.sum(np.abs(ps-us_i))>=1:
900
892
  dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
901
- zs = zs + dzs0 + dzs
893
+ delta = dzs0 + dzs
902
894
  else:
903
- zs = zs + dzs0
904
-
895
+ delta = dzs0
896
+
897
+ log_mu = log_mu + self.get_log_mu(delta)
898
+
905
899
  if library_sizes is None:
906
900
  library_sizes = np.sum(xs, axis=1, keepdims=True)
907
901
  elif type(library_sizes) == list:
@@ -910,9 +904,9 @@ class DensityFlow2(nn.Module):
910
904
  elif len(library_sizes.shape)==1:
911
905
  library_sizes = library_sizes.reshape(-1,1)
912
906
 
913
- counts = self.get_counts(zs, library_sizes=library_sizes)
907
+ counts = self.get_counts(log_mu, library_sizes=library_sizes)
914
908
 
915
- return counts, zs
909
+ return counts, log_mu
916
910
 
917
911
  def _cell_shift(self, zs, perturb_idx, perturb):
918
912
  #zns,_ = self.encoder_zn(xs)
@@ -956,48 +950,46 @@ class DensityFlow2(nn.Module):
956
950
  Z = np.concatenate(Z)
957
951
  return Z
958
952
 
959
- def _get_theta(self, delta_zs):
960
- return self.decoder_concentrate(delta_zs)
953
+ def _log_mu(self, zs):
954
+ return self.decoder_log_mu(zs)
961
955
 
962
- def get_theta(self,
963
- delta_zs,
964
- batch_size: int = 1024):
956
+ def get_log_mu(self, zs, batch_size: int = 1024):
965
957
  """
966
958
  Return cells' changes in the feature space induced by specific perturbation of a factor
967
959
 
968
960
  """
969
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
970
- dataset = CustomDataset(delta_zs)
961
+ zs = convert_to_tensor(zs, device=self.get_device())
962
+ dataset = CustomDataset(zs)
971
963
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
972
964
 
973
965
  R = []
974
966
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
975
- for delta_Z_batch, _ in dataloader:
976
- r = self._get_theta(delta_Z_batch)
967
+ for Z_batch, _ in dataloader:
968
+ r = self._log_mu(Z_batch)
977
969
  R.append(tensor_to_numpy(r))
978
970
  pbar.update(1)
979
971
 
980
972
  R = np.concatenate(R)
981
973
  return R
982
974
 
983
- def _count(self, concentrate, library_size=None):
975
+ def _count(self, log_mu, library_size=None):
984
976
  if self.loss_func == 'bernoulli':
985
977
  #counts = self.sigmoid(concentrate)
986
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
978
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
987
979
  elif self.loss_func == 'multinomial':
988
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
980
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
989
981
  counts = theta * library_size
990
982
  else:
991
- rate = concentrate.exp()
983
+ rate = log_mu.exp()
992
984
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
993
985
  counts = theta * library_size
994
986
  return counts
995
987
 
996
- def get_counts(self, concentrate,
988
+ def get_counts(self, log_mu,
997
989
  library_sizes,
998
990
  batch_size: int = 1024):
999
991
 
1000
- concentrate = convert_to_tensor(concentrate, device=self.get_device())
992
+ log_mu = convert_to_tensor(log_mu, device=self.get_device())
1001
993
 
1002
994
  if type(library_sizes) == list:
1003
995
  library_sizes = np.array(library_sizes).reshape(-1,1)
@@ -1005,13 +997,13 @@ class DensityFlow2(nn.Module):
1005
997
  library_sizes = library_sizes.reshape(-1,1)
1006
998
  ls = convert_to_tensor(library_sizes, device=self.get_device())
1007
999
 
1008
- dataset = CustomDataset2(concentrate,ls)
1000
+ dataset = CustomDataset2(log_mu,ls)
1009
1001
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
1010
1002
 
1011
1003
  E = []
1012
1004
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
1013
- for C_batch, L_batch, _ in dataloader:
1014
- counts = self._count(C_batch, L_batch)
1005
+ for Mu_batch, L_batch, _ in dataloader:
1006
+ counts = self._count(Mu_batch, L_batch)
1015
1007
  E.append(tensor_to_numpy(counts))
1016
1008
  pbar.update(1)
1017
1009
 
@@ -1045,7 +1037,7 @@ class DensityFlow2(nn.Module):
1045
1037
  threshold: int = 0,
1046
1038
  use_jax: bool = True):
1047
1039
  """
1048
- Train the DensityFlow2 model.
1040
+ Train the DensityFlowLinear model.
1049
1041
 
1050
1042
  Parameters
1051
1043
  ----------
@@ -1071,7 +1063,7 @@ class DensityFlow2(nn.Module):
1071
1063
  Parameter for optimization.
1072
1064
  use_jax
1073
1065
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1074
- the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow2 in the shell command.
1066
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlowLinear in the shell command.
1075
1067
  """
1076
1068
  xs = self.preprocess(xs, threshold=threshold)
1077
1069
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1189,12 +1181,12 @@ class DensityFlow2(nn.Module):
1189
1181
 
1190
1182
 
1191
1183
  EXAMPLE_RUN = (
1192
- "example run: DensityFlow2 --help"
1184
+ "example run: DensityFlowLinear --help"
1193
1185
  )
1194
1186
 
1195
1187
  def parse_args():
1196
1188
  parser = argparse.ArgumentParser(
1197
- description="DensityFlow2\n{}".format(EXAMPLE_RUN))
1189
+ description="DensityFlowLinear\n{}".format(EXAMPLE_RUN))
1198
1190
 
1199
1191
  parser.add_argument(
1200
1192
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1381,10 +1373,10 @@ def main():
1381
1373
  cell_factor_size = 0 if us is None else us.shape[1]
1382
1374
 
1383
1375
  ###########################################
1384
- df = DensityFlow2(
1376
+ df = DensityFlowLinear(
1385
1377
  input_size=input_size,
1386
1378
  cell_factor_size=cell_factor_size,
1387
- inverse_dispersion=args.inverse_dispersion,
1379
+ dispersion=args.dispersion,
1388
1380
  z_dim=args.z_dim,
1389
1381
  hidden_layers=args.hidden_layers,
1390
1382
  hidden_layer_activation=args.hidden_layer_activation,
@@ -1412,9 +1404,9 @@ def main():
1412
1404
 
1413
1405
  if args.save_model is not None:
1414
1406
  if args.save_model.endswith('gz'):
1415
- DensityFlow2.save_model(df, args.save_model, compression=True)
1407
+ DensityFlowLinear.save_model(df, args.save_model, compression=True)
1416
1408
  else:
1417
- DensityFlow2.save_model(df, args.save_model)
1409
+ DensityFlowLinear.save_model(df, args.save_model)
1418
1410
 
1419
1411
 
1420
1412
 
@@ -1,6 +1,6 @@
1
1
  from .SURE import SURE
2
2
  from .DensityFlow import DensityFlow
3
- from .DensityFlow2 import DensityFlow2
3
+ from .DensityFlowLinear import DensityFlowLinear
4
4
  from .PerturbE import PerturbE
5
5
  from .TranscriptomeDecoder import TranscriptomeDecoder
6
6
  from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
@@ -12,7 +12,7 @@ from . import utils
12
12
  from . import codebook
13
13
  from . import SURE
14
14
  from . import DensityFlow
15
- from . import DensityFlow2
15
+ from . import DensityFlowLinear
16
16
  from . import atac
17
17
  from . import flow
18
18
  from . import perturb
@@ -23,6 +23,6 @@ from . import EfficientTranscriptomeDecoder
23
23
  from . import VirtualCellDecoder
24
24
  from . import PerturbationAwareDecoder
25
25
 
26
- __all__ = ['SURE', 'DensityFlow', 'DensityFlow2', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
26
+ __all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
27
27
  'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
28
28
  'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.25
3
+ Version: 2.4.38
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -2,7 +2,7 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  SURE/DensityFlow.py
5
- SURE/DensityFlow2.py
5
+ SURE/DensityFlowLinear.py
6
6
  SURE/EfficientTranscriptomeDecoder.py
7
7
  SURE/PerturbE.py
8
8
  SURE/PerturbationAwareDecoder.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.4.25',
8
+ version='2.4.38',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes