SURE-tools 2.4.25__tar.gz → 2.4.42__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (37) hide show
  1. {sure_tools-2.4.25 → sure_tools-2.4.42}/PKG-INFO +1 -1
  2. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/DensityFlow.py +152 -70
  3. sure_tools-2.4.25/SURE/DensityFlow2.py → sure_tools-2.4.42/SURE/DensityFlowLinear.py +91 -99
  4. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/__init__.py +3 -3
  5. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/SOURCES.txt +1 -1
  7. {sure_tools-2.4.25 → sure_tools-2.4.42}/setup.py +1 -1
  8. {sure_tools-2.4.25 → sure_tools-2.4.42}/LICENSE +0 -0
  9. {sure_tools-2.4.25 → sure_tools-2.4.42}/README.md +0 -0
  10. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/EfficientTranscriptomeDecoder.py +0 -0
  11. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/PerturbE.py +0 -0
  12. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/PerturbationAwareDecoder.py +0 -0
  13. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/SURE.py +0 -0
  14. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/SimpleTranscriptomeDecoder.py +0 -0
  15. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/TranscriptomeDecoder.py +0 -0
  16. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/VirtualCellDecoder.py +0 -0
  17. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/assembly/__init__.py +0 -0
  18. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/assembly/assembly.py +0 -0
  19. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/assembly/atlas.py +0 -0
  20. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/atac/__init__.py +0 -0
  21. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/atac/utils.py +0 -0
  22. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/codebook/__init__.py +0 -0
  23. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/codebook/codebook.py +0 -0
  24. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/flow/__init__.py +0 -0
  25. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/flow/flow_stats.py +0 -0
  26. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/flow/plot_quiver.py +0 -0
  27. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/perturb/__init__.py +0 -0
  28. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/perturb/perturb.py +0 -0
  29. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/__init__.py +0 -0
  30. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/custom_mlp.py +0 -0
  31. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/queue.py +0 -0
  32. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE/utils/utils.py +0 -0
  33. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/dependency_links.txt +0 -0
  34. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/entry_points.txt +0 -0
  35. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/requires.txt +0 -0
  36. {sure_tools-2.4.25 → sure_tools-2.4.42}/SURE_tools.egg-info/top_level.txt +0 -0
  37. {sure_tools-2.4.25 → sure_tools-2.4.42}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.25
3
+ Version: 2.4.42
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 30,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,10 +108,12 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
112
114
 
113
115
  print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Codebook size: {self.code_size}")
114
117
  print(f" - Latent Dimension: {self.latent_dim}")
115
118
  print(f" - Gene Dimension: {self.input_size}")
116
119
  print(f" - Hidden Dimensions: {self.hidden_layers}")
@@ -258,7 +261,7 @@ class DensityFlow(nn.Module):
258
261
  )
259
262
  )
260
263
 
261
- self.decoder_concentrate = MLP(
264
+ self.decoder_log_mu = MLP(
262
265
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
263
266
  activation=activate_fct,
264
267
  output_activation=None,
@@ -348,8 +351,8 @@ class DensityFlow(nn.Module):
348
351
  self.options = dict(dtype=xs.dtype, device=xs.device)
349
352
 
350
353
  if self.loss_func=='negbinomial':
351
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
352
- xs.new_ones(self.input_size), constraint=constraints.positive)
354
+ dispersion = pyro.param("dispersion", self.dispersion *
355
+ xs.new_ones(self.input_size), constraint=constraints.positive)
353
356
 
354
357
  if self.use_zeroinflate:
355
358
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -383,28 +386,32 @@ class DensityFlow(nn.Module):
383
386
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
384
387
 
385
388
  zs = zns
386
- concentrate = self.decoder_concentrate(zs)
389
+ log_mu = self.decoder_log_mu(zs)
387
390
  if self.loss_func in ['bernoulli']:
388
- log_theta = concentrate
391
+ log_theta = log_mu
392
+ elif self.loss_func == 'negbinomial':
393
+ mu = log_mu.exp()
389
394
  else:
390
- rate = concentrate.exp()
395
+ rate = log_mu.exp()
391
396
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
392
397
  if self.loss_func == 'poisson':
393
398
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
394
399
 
395
400
  if self.loss_func == 'negbinomial':
401
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
396
402
  if self.use_zeroinflate:
397
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
403
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
404
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
398
405
  else:
399
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
406
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
407
+ logits=logits).to_event(1), obs=xs)
400
408
  elif self.loss_func == 'poisson':
401
409
  if self.use_zeroinflate:
402
410
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
411
  else:
404
412
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
413
  elif self.loss_func == 'multinomial':
406
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
407
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
414
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
408
415
  elif self.loss_func == 'bernoulli':
409
416
  if self.use_zeroinflate:
410
417
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -428,8 +435,8 @@ class DensityFlow(nn.Module):
428
435
  self.options = dict(dtype=xs.dtype, device=xs.device)
429
436
 
430
437
  if self.loss_func=='negbinomial':
431
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
432
- xs.new_ones(self.input_size), constraint=constraints.positive)
438
+ dispersion = pyro.param("dispersion", self.dispersion *
439
+ xs.new_ones(self.input_size), constraint=constraints.positive)
433
440
 
434
441
  if self.use_zeroinflate:
435
442
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -468,28 +475,30 @@ class DensityFlow(nn.Module):
468
475
  else:
469
476
  zs = zns
470
477
 
471
- concentrate = self.decoder_concentrate(zs)
478
+ log_mu = self.decoder_log_mu(zs)
472
479
  if self.loss_func in ['bernoulli']:
473
- log_theta = concentrate
480
+ log_theta = log_mu
481
+ elif self.loss_func == 'negbinomial':
482
+ mu = log_mu.exp()
474
483
  else:
475
- rate = concentrate.exp()
484
+ rate = log_mu.exp()
476
485
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
477
486
  if self.loss_func == 'poisson':
478
487
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
479
488
 
480
489
  if self.loss_func == 'negbinomial':
490
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
481
491
  if self.use_zeroinflate:
482
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
492
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
483
493
  else:
484
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
494
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
485
495
  elif self.loss_func == 'poisson':
486
496
  if self.use_zeroinflate:
487
497
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
488
498
  else:
489
499
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
490
500
  elif self.loss_func == 'multinomial':
491
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
492
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
501
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
493
502
  elif self.loss_func == 'bernoulli':
494
503
  if self.use_zeroinflate:
495
504
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -513,8 +522,8 @@ class DensityFlow(nn.Module):
513
522
  self.options = dict(dtype=xs.dtype, device=xs.device)
514
523
 
515
524
  if self.loss_func=='negbinomial':
516
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
517
- xs.new_ones(self.input_size), constraint=constraints.positive)
525
+ dispersion = pyro.param("dispersion", self.dispersion *
526
+ xs.new_ones(self.input_size), constraint=constraints.positive)
518
527
 
519
528
  if self.use_zeroinflate:
520
529
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -565,28 +574,31 @@ class DensityFlow(nn.Module):
565
574
 
566
575
  zs = zns
567
576
 
568
- concentrate = self.decoder_concentrate(zs)
577
+ log_mu = self.decoder_log_mu(zs)
569
578
  if self.loss_func in ['bernoulli']:
570
- log_theta = concentrate
579
+ log_theta = log_mu
580
+ elif self.loss_func in ['negbinomial']:
581
+ mu = log_mu.exp()
571
582
  else:
572
- rate = concentrate.exp()
583
+ rate = log_mu.exp()
573
584
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
574
585
  if self.loss_func == 'poisson':
575
586
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
576
587
 
577
588
  if self.loss_func == 'negbinomial':
589
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
578
590
  if self.use_zeroinflate:
579
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
591
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
592
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
580
593
  else:
581
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
594
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
582
595
  elif self.loss_func == 'poisson':
583
596
  if self.use_zeroinflate:
584
597
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
585
598
  else:
586
599
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
587
600
  elif self.loss_func == 'multinomial':
588
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
589
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
601
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
590
602
  elif self.loss_func == 'bernoulli':
591
603
  if self.use_zeroinflate:
592
604
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -610,8 +622,8 @@ class DensityFlow(nn.Module):
610
622
  self.options = dict(dtype=xs.dtype, device=xs.device)
611
623
 
612
624
  if self.loss_func=='negbinomial':
613
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
614
- xs.new_ones(self.input_size), constraint=constraints.positive)
625
+ dispersion = pyro.param("dispersion", self.dispersion *
626
+ xs.new_ones(self.input_size), constraint=constraints.positive)
615
627
 
616
628
  if self.use_zeroinflate:
617
629
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -672,28 +684,31 @@ class DensityFlow(nn.Module):
672
684
  else:
673
685
  zs = zns
674
686
 
675
- concentrate = self.decoder_concentrate(zs)
687
+ log_mu = self.decoder_log_mu(zs)
676
688
  if self.loss_func in ['bernoulli']:
677
- log_theta = concentrate
689
+ log_theta = log_mu
690
+ elif self.loss_func in ['negbinomial']:
691
+ mu = log_mu.exp()
678
692
  else:
679
- rate = concentrate.exp()
693
+ rate = log_mu.exp()
680
694
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
681
695
  if self.loss_func == 'poisson':
682
696
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
683
697
 
684
698
  if self.loss_func == 'negbinomial':
699
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
685
700
  if self.use_zeroinflate:
686
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
701
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
702
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
687
703
  else:
688
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
704
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
689
705
  elif self.loss_func == 'poisson':
690
706
  if self.use_zeroinflate:
691
707
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
692
708
  else:
693
709
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
694
710
  elif self.loss_func == 'multinomial':
695
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
696
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
711
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
697
712
  elif self.loss_func == 'bernoulli':
698
713
  if self.use_zeroinflate:
699
714
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -717,13 +732,13 @@ class DensityFlow(nn.Module):
717
732
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
718
733
  #else:
719
734
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
720
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
735
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
721
736
  else:
722
737
  #if self.turn_off_cell_specific:
723
738
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
724
739
  #else:
725
740
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
726
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
741
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
727
742
  return zus
728
743
 
729
744
  def _get_codebook_identity(self):
@@ -744,6 +759,28 @@ class DensityFlow(nn.Module):
744
759
  cb = self._get_codebook()
745
760
  cb = tensor_to_numpy(cb)
746
761
  return cb
762
+
763
+ def _get_complete_embedding(self, xs, us):
764
+ basal,_ = self._get_basal_embedding(xs)
765
+ dzs = self._total_effects(basal, us)
766
+ return basal + dzs
767
+
768
+ def get_complete_embedding(self, xs, us, batch_size:int=1024):
769
+ xs = self.preprocess(xs)
770
+ xs = convert_to_tensor(xs, device=self.get_device())
771
+ us = convert_to_tensor(us, device=self.get_device())
772
+ dataset = CustomDataset2(xs, us)
773
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
774
+
775
+ Z = []
776
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
777
+ for X_batch, U_batch, _ in dataloader:
778
+ zns = self._get_complete_embedding(X_batch, U_batch)
779
+ Z.append(tensor_to_numpy(zns))
780
+ pbar.update(1)
781
+
782
+ Z = np.concatenate(Z)
783
+ return Z
747
784
 
748
785
  def _get_basal_embedding(self, xs):
749
786
  loc, scale = self.encoder_zn(xs)
@@ -865,12 +902,12 @@ class DensityFlow(nn.Module):
865
902
  us_i = us[:,pert_idx].reshape(-1,1)
866
903
 
867
904
  # factor effect of xs
868
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
905
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
869
906
 
870
907
  # perturbation effect
871
908
  ps = np.ones_like(us_i)
872
909
  if np.sum(np.abs(ps-us_i))>=1:
873
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
910
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
874
911
  zs = zs + dzs0 + dzs
875
912
  else:
876
913
  zs = zs + dzs0
@@ -884,10 +921,11 @@ class DensityFlow(nn.Module):
884
921
  library_sizes = library_sizes.reshape(-1,1)
885
922
 
886
923
  counts = self.get_counts(zs, library_sizes=library_sizes)
924
+ log_mu = self.get_log_mu(zs)
887
925
 
888
- return counts, zs
926
+ return counts, log_mu
889
927
 
890
- def _cell_response(self, zs, perturb_idx, perturb):
928
+ def _cell_shift(self, zs, perturb_idx, perturb):
891
929
  #zns,_ = self.encoder_zn(xs)
892
930
  #zns,_ = self._get_basal_embedding(xs)
893
931
  zns = zs
@@ -904,7 +942,7 @@ class DensityFlow(nn.Module):
904
942
 
905
943
  return ms
906
944
 
907
- def get_cell_response(self,
945
+ def get_cell_shift(self,
908
946
  zs,
909
947
  perturb_idx,
910
948
  perturb_us,
@@ -922,46 +960,43 @@ class DensityFlow(nn.Module):
922
960
  Z = []
923
961
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
924
962
  for Z_batch, P_batch, _ in dataloader:
925
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
963
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
926
964
  Z.append(tensor_to_numpy(zns))
927
965
  pbar.update(1)
928
966
 
929
967
  Z = np.concatenate(Z)
930
968
  return Z
931
969
 
932
- def _get_expression_response(self, delta_zs):
933
- return self.decoder_concentrate(delta_zs)
970
+ def _log_mu(self, zs):
971
+ return self.decoder_log_mu(zs)
934
972
 
935
- def get_expression_response(self,
936
- delta_zs,
937
- batch_size: int = 1024):
973
+ def get_log_mu(self, zs, batch_size: int = 1024):
938
974
  """
939
975
  Return cells' changes in the feature space induced by specific perturbation of a factor
940
976
 
941
977
  """
942
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
943
- dataset = CustomDataset(delta_zs)
978
+ zs = convert_to_tensor(zs, device=self.get_device())
979
+ dataset = CustomDataset(zs)
944
980
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
945
981
 
946
982
  R = []
947
983
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
948
- for delta_Z_batch, _ in dataloader:
949
- r = self._get_expression_response(delta_Z_batch)
984
+ for Z_batch, _ in dataloader:
985
+ r = self._log_mu(Z_batch)
950
986
  R.append(tensor_to_numpy(r))
951
987
  pbar.update(1)
952
988
 
953
989
  R = np.concatenate(R)
954
990
  return R
955
991
 
956
- def _count(self, concentrate, library_size=None):
992
+ def _count(self, log_mu, library_size=None):
957
993
  if self.loss_func == 'bernoulli':
958
- #counts = self.sigmoid(concentrate)
959
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
994
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
960
995
  elif self.loss_func == 'multinomial':
961
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
996
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
962
997
  counts = theta * library_size
963
998
  else:
964
- rate = concentrate.exp()
999
+ rate = log_mu.exp()
965
1000
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
966
1001
  counts = theta * library_size
967
1002
  return counts
@@ -983,8 +1018,8 @@ class DensityFlow(nn.Module):
983
1018
  E = []
984
1019
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
985
1020
  for Z_batch, L_batch, _ in dataloader:
986
- concentrate = self._get_expression_response(Z_batch)
987
- counts = self._count(concentrate, L_batch)
1021
+ log_mu = self._log_mu(Z_batch)
1022
+ counts = self._count(log_mu, L_batch)
988
1023
  E.append(tensor_to_numpy(counts))
989
1024
  pbar.update(1)
990
1025
 
@@ -996,7 +1031,7 @@ class DensityFlow(nn.Module):
996
1031
  ad = sc.AnnData(xs)
997
1032
  binarize(ad, threshold=threshold)
998
1033
  xs = ad.X.copy()
999
- else:
1034
+ elif self.loss_func == 'poisson':
1000
1035
  xs = np.round(xs)
1001
1036
 
1002
1037
  if sparse.issparse(xs):
@@ -1157,8 +1192,55 @@ class DensityFlow(nn.Module):
1157
1192
  else:
1158
1193
  with open(file_path, 'rb') as pickle_file:
1159
1194
  model = pickle.load(pickle_file)
1195
+
1196
+ print(f"🧬 DensityFlow Initialized:")
1197
+ print(f" - Codebook size: {model.code_size}")
1198
+ print(f" - Latent Dimension: {model.latent_dim}")
1199
+ print(f" - Gene Dimension: {model.input_size}")
1200
+ print(f" - Hidden Dimensions: {model.hidden_layers}")
1201
+ print(f" - Device: {model.get_device()}")
1202
+ print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
1160
1203
 
1161
1204
  return model
1205
+
1206
+ ''' def save(self, path):
1207
+ """Save model checkpoint"""
1208
+ torch.save({
1209
+ 'model_state_dict': self.state_dict(),
1210
+ 'model_config': {
1211
+ 'input_size': self.input_size,
1212
+ 'codebook_size': self.code_size,
1213
+ 'cell_factor_size': self.cell_factor_size,
1214
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1215
+ 'supervised_mode':self.supervised_mode,
1216
+ 'z_dim': self.latent_dim,
1217
+ 'z_dist': self.latent_dist,
1218
+ 'loss_func': self.loss_func,
1219
+ 'dispersion': self.dispersion,
1220
+ 'use_zeroinflate': self.use_zeroinflate,
1221
+ 'hidden_layers':self.hidden_layers,
1222
+ 'hidden_layer_activation':self.hidden_layer_activation,
1223
+ 'nn_dropout':self.nn_dropout,
1224
+ 'post_layer_fct':self.post_layer_fct,
1225
+ 'post_act_fct':self.post_act_fct,
1226
+ 'config_enum':self.config_enum,
1227
+ 'use_cuda':self.use_cuda,
1228
+ 'seed':self.seed,
1229
+ 'zero_bias':self.use_bias,
1230
+ 'dtype':self.dtype,
1231
+ }
1232
+ }, path)
1233
+
1234
+ @classmethod
1235
+ def load_model(cls, model_path: str):
1236
+ """Load pre-trained model"""
1237
+ checkpoint = torch.load(model_path)
1238
+ model = DensityFlow(**checkpoint.get('model_config'))
1239
+
1240
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1241
+ model.load_state_dict(checkpoint['model_state_dict'])
1242
+
1243
+ return model'''
1162
1244
 
1163
1245
 
1164
1246
  EXAMPLE_RUN = (
@@ -1357,7 +1439,7 @@ def main():
1357
1439
  df = DensityFlow(
1358
1440
  input_size=input_size,
1359
1441
  cell_factor_size=cell_factor_size,
1360
- inverse_dispersion=args.inverse_dispersion,
1442
+ dispersion=args.dispersion,
1361
1443
  z_dim=args.z_dim,
1362
1444
  hidden_layers=args.hidden_layers,
1363
1445
  hidden_layer_activation=args.hidden_layer_activation,
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class DensityFlow2(nn.Module):
57
+ class DensityFlowLinear(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
@@ -63,8 +63,8 @@ class DensityFlow2(nn.Module):
63
63
  supervised_mode: bool = False,
64
64
  z_dim: int = 10,
65
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
69
  hidden_layers: list = [500],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
@@ -81,7 +81,7 @@ class DensityFlow2(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
@@ -107,10 +107,11 @@ class DensityFlow2(nn.Module):
107
107
 
108
108
  self.codebook_weights = None
109
109
 
110
+ self.seed = seed
110
111
  set_random_seed(seed)
111
112
  self.setup_networks()
112
113
 
113
- print(f"🧬 DensityFlow2 Initialized:")
114
+ print(f"🧬 DensityFlowLinear Initialized:")
114
115
  print(f" - Latent Dimension: {self.latent_dim}")
115
116
  print(f" - Gene Dimension: {self.input_size}")
116
117
  print(f" - Hidden Dimensions: {self.hidden_layers}")
@@ -208,17 +209,6 @@ class DensityFlow2(nn.Module):
208
209
  use_cuda=self.use_cuda,
209
210
  )
210
211
 
211
- if self.loss_func == 'negbinomial':
212
- self.encoder_inverse_dispersion = MLP(
213
- [self.latent_dim] + hidden_sizes + [[self.input_size, self.input_size]],
214
- activation=activate_fct,
215
- output_activation=[Exp, Exp],
216
- post_layer_fct=post_layer_fct,
217
- post_act_fct=post_act_fct,
218
- allow_broadcast=self.allow_broadcast,
219
- use_cuda=self.use_cuda,
220
- )
221
-
222
212
  if self.cell_factor_size>0:
223
213
  self.cell_factor_effect = nn.ModuleList()
224
214
  for i in np.arange(self.cell_factor_size):
@@ -269,7 +259,7 @@ class DensityFlow2(nn.Module):
269
259
  )
270
260
  )
271
261
 
272
- self.decoder_concentrate = MLP(
262
+ self.decoder_log_mu = MLP(
273
263
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
274
264
  activation=activate_fct,
275
265
  output_activation=None,
@@ -352,15 +342,15 @@ class DensityFlow2(nn.Module):
352
342
  return xs
353
343
 
354
344
  def model1(self, xs):
355
- pyro.module('DensityFlow2', self)
345
+ pyro.module('DensityFlowLinear', self)
356
346
 
357
347
  eps = torch.finfo(xs.dtype).eps
358
348
  batch_size = xs.size(0)
359
349
  self.options = dict(dtype=xs.dtype, device=xs.device)
360
350
 
361
351
  if self.loss_func=='negbinomial':
362
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
363
- xs.new_ones(self.input_size), constraint=constraints.positive)
352
+ dispersion = pyro.param("dispersion", self.dispersion *
353
+ xs.new_ones(self.input_size), constraint=constraints.positive)
364
354
 
365
355
  if self.use_zeroinflate:
366
356
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -394,28 +384,31 @@ class DensityFlow2(nn.Module):
394
384
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
395
385
 
396
386
  zs = zns
397
- concentrate = self.decoder_concentrate(zs)
387
+ log_mu = self.decoder_log_mu(zs)
398
388
  if self.loss_func in ['bernoulli']:
399
- log_theta = concentrate
389
+ log_theta = log_mu
390
+ elif self.loss_func in ['negbinomial']:
391
+ mu = log_mu.exp()
400
392
  else:
401
- rate = concentrate.exp()
393
+ rate = log_mu.exp()
402
394
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
403
395
  if self.loss_func == 'poisson':
404
396
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
405
397
 
406
398
  if self.loss_func == 'negbinomial':
399
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
407
400
  if self.use_zeroinflate:
408
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
401
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
402
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
409
403
  else:
410
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
404
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
411
405
  elif self.loss_func == 'poisson':
412
406
  if self.use_zeroinflate:
413
407
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
414
408
  else:
415
409
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
416
410
  elif self.loss_func == 'multinomial':
417
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
418
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
411
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
419
412
  elif self.loss_func == 'bernoulli':
420
413
  if self.use_zeroinflate:
421
414
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -432,17 +425,15 @@ class DensityFlow2(nn.Module):
432
425
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
433
426
 
434
427
  def model2(self, xs, us=None):
435
- pyro.module('DensityFlow2', self)
428
+ pyro.module('DensityFlowLinear', self)
436
429
 
437
430
  eps = torch.finfo(xs.dtype).eps
438
431
  batch_size = xs.size(0)
439
432
  self.options = dict(dtype=xs.dtype, device=xs.device)
440
433
 
441
434
  if self.loss_func=='negbinomial':
442
- #total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
443
- # xs.new_ones(self.input_size), constraint=constraints.positive)
444
- with pyro.plate("genes", self.input_size):
445
- inverse_dispersion = pyro.sample("inverse_dispersion", dist.LogNormal(self.inverse_dispersion, 0.5).to_event(1))
435
+ dispersion = pyro.param("dispersion", self.dispersion *
436
+ xs.new_ones(self.input_size), constraint=constraints.positive)
446
437
 
447
438
  if self.use_zeroinflate:
448
439
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -482,28 +473,28 @@ class DensityFlow2(nn.Module):
482
473
  zs = zns'''
483
474
 
484
475
  zs = zns
485
- concentrate = self.decoder_concentrate(zs)
476
+ log_mu = self.decoder_log_mu(zs)
486
477
  for i in np.arange(self.cell_factor_size):
487
478
  zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
488
- concentrate += self.decoder_concentrate(zus)
479
+ log_mu += self.decoder_log_mu(zus)
489
480
 
490
481
  if self.loss_func in ['bernoulli']:
491
- log_theta = concentrate
482
+ log_theta = log_mu
492
483
  elif self.loss_func in ['negbinomial']:
493
- mu = concentrate.exp()
484
+ mu = log_mu.exp()
494
485
  else:
495
- rate = concentrate.exp()
486
+ rate = log_mu.exp()
496
487
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
497
488
  if self.loss_func == 'poisson':
498
489
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
499
490
 
500
491
  if self.loss_func == 'negbinomial':
501
- logits = (mu.log()-inverse_dispersion.log()).clamp(min=-10, max=10)
492
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
502
493
  if self.use_zeroinflate:
503
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=inverse_dispersion,
494
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
504
495
  logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
505
496
  else:
506
- pyro.sample('x', dist.NegativeBinomial(total_count=inverse_dispersion,
497
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
507
498
  logits=logits).to_event(1), obs=xs)
508
499
  elif self.loss_func == 'poisson':
509
500
  if self.use_zeroinflate:
@@ -511,8 +502,7 @@ class DensityFlow2(nn.Module):
511
502
  else:
512
503
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
513
504
  elif self.loss_func == 'multinomial':
514
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
515
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
505
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
516
506
  elif self.loss_func == 'bernoulli':
517
507
  if self.use_zeroinflate:
518
508
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -528,21 +518,16 @@ class DensityFlow2(nn.Module):
528
518
  alpha = self.encoder_n(zns)
529
519
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
530
520
 
531
- if self.loss_func == 'negbinomial':
532
- id_loc,id_scale = self.encoder_inverse_dispersion(zns)
533
- with pyro.plate("genes", self.input_size):
534
- pyro.sample("inverse_dispersion", dist.LogNormal(id_loc, id_scale).to_event(1))
535
-
536
521
  def model3(self, xs, ys, embeds=None):
537
- pyro.module('DensityFlow2', self)
522
+ pyro.module('DensityFlowLinear', self)
538
523
 
539
524
  eps = torch.finfo(xs.dtype).eps
540
525
  batch_size = xs.size(0)
541
526
  self.options = dict(dtype=xs.dtype, device=xs.device)
542
527
 
543
528
  if self.loss_func=='negbinomial':
544
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
545
- xs.new_ones(self.input_size), constraint=constraints.positive)
529
+ dispersion = pyro.param("dispersion", self.dispersion *
530
+ xs.new_ones(self.input_size), constraint=constraints.positive)
546
531
 
547
532
  if self.use_zeroinflate:
548
533
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -593,28 +578,31 @@ class DensityFlow2(nn.Module):
593
578
 
594
579
  zs = zns
595
580
 
596
- concentrate = self.decoder_concentrate(zs)
581
+ log_mu = self.decoder_log_mu(zs)
597
582
  if self.loss_func in ['bernoulli']:
598
- log_theta = concentrate
583
+ log_theta = log_mu
584
+ elif self.loss_func in ['negbinomial']:
585
+ mu = log_mu.exp()
599
586
  else:
600
- rate = concentrate.exp()
587
+ rate = log_mu.exp()
601
588
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
602
589
  if self.loss_func == 'poisson':
603
590
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
604
591
 
605
592
  if self.loss_func == 'negbinomial':
593
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
606
594
  if self.use_zeroinflate:
607
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
595
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
596
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
608
597
  else:
609
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
598
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
610
599
  elif self.loss_func == 'poisson':
611
600
  if self.use_zeroinflate:
612
601
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
613
602
  else:
614
603
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
615
604
  elif self.loss_func == 'multinomial':
616
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
617
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
605
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
618
606
  elif self.loss_func == 'bernoulli':
619
607
  if self.use_zeroinflate:
620
608
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -631,15 +619,15 @@ class DensityFlow2(nn.Module):
631
619
  zns = embeds
632
620
 
633
621
  def model4(self, xs, us, ys, embeds=None):
634
- pyro.module('DensityFlow2', self)
622
+ pyro.module('DensityFlowLinear', self)
635
623
 
636
624
  eps = torch.finfo(xs.dtype).eps
637
625
  batch_size = xs.size(0)
638
626
  self.options = dict(dtype=xs.dtype, device=xs.device)
639
627
 
640
628
  if self.loss_func=='negbinomial':
641
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
642
- xs.new_ones(self.input_size), constraint=constraints.positive)
629
+ dispersion = pyro.param("dispersion", self.dispersion *
630
+ xs.new_ones(self.input_size), constraint=constraints.positive)
643
631
 
644
632
  if self.use_zeroinflate:
645
633
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -695,32 +683,34 @@ class DensityFlow2(nn.Module):
695
683
  zs = zns'''
696
684
 
697
685
  zs = zns
698
- concentrate = self.decoder_concentrate(zs)
686
+ log_mu = self.decoder_log_mu(zs)
699
687
  for i in np.arange(self.cell_factor_size):
700
688
  zus = self._cell_shift(zs, i, us[:,i].reshape(-1,1))
701
- concentrate += self.decoder_concentrate(zus)
689
+ log_mu += self.decoder_log_mu(zus)
702
690
 
703
691
  if self.loss_func in ['bernoulli']:
704
- log_theta = concentrate
692
+ log_theta = log_mu
693
+ elif self.loss_func in ['negbinomial']:
694
+ mu = log_mu.exp()
705
695
  else:
706
- rate = concentrate.exp()
696
+ rate = log_mu.exp()
707
697
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
708
698
  if self.loss_func == 'poisson':
709
699
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
710
700
 
711
701
  if self.loss_func == 'negbinomial':
702
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
712
703
  if self.use_zeroinflate:
713
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
704
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
714
705
  else:
715
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
706
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
716
707
  elif self.loss_func == 'poisson':
717
708
  if self.use_zeroinflate:
718
709
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
719
710
  else:
720
711
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
721
712
  elif self.loss_func == 'multinomial':
722
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
723
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
713
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
724
714
  elif self.loss_func == 'bernoulli':
725
715
  if self.use_zeroinflate:
726
716
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -887,6 +877,8 @@ class DensityFlow2(nn.Module):
887
877
 
888
878
  # basal embedding
889
879
  zs = self.get_basal_embedding(xs)
880
+ log_mu = self.get_log_mu(zs)
881
+
890
882
  for pert in perturbs_predict:
891
883
  pert_idx = int(np.where(perturbs_reference==pert)[0])
892
884
  us_i = us[:,pert_idx].reshape(-1,1)
@@ -898,10 +890,12 @@ class DensityFlow2(nn.Module):
898
890
  ps = np.ones_like(us_i)
899
891
  if np.sum(np.abs(ps-us_i))>=1:
900
892
  dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
901
- zs = zs + dzs0 + dzs
893
+ delta = dzs0 + dzs
902
894
  else:
903
- zs = zs + dzs0
904
-
895
+ delta = dzs0
896
+
897
+ log_mu = log_mu + self.get_log_mu(delta)
898
+
905
899
  if library_sizes is None:
906
900
  library_sizes = np.sum(xs, axis=1, keepdims=True)
907
901
  elif type(library_sizes) == list:
@@ -910,9 +904,9 @@ class DensityFlow2(nn.Module):
910
904
  elif len(library_sizes.shape)==1:
911
905
  library_sizes = library_sizes.reshape(-1,1)
912
906
 
913
- counts = self.get_counts(zs, library_sizes=library_sizes)
907
+ counts = self.get_counts(log_mu, library_sizes=library_sizes)
914
908
 
915
- return counts, zs
909
+ return counts, log_mu
916
910
 
917
911
  def _cell_shift(self, zs, perturb_idx, perturb):
918
912
  #zns,_ = self.encoder_zn(xs)
@@ -956,48 +950,46 @@ class DensityFlow2(nn.Module):
956
950
  Z = np.concatenate(Z)
957
951
  return Z
958
952
 
959
- def _get_theta(self, delta_zs):
960
- return self.decoder_concentrate(delta_zs)
953
+ def _log_mu(self, zs):
954
+ return self.decoder_log_mu(zs)
961
955
 
962
- def get_theta(self,
963
- delta_zs,
964
- batch_size: int = 1024):
956
+ def get_log_mu(self, zs, batch_size: int = 1024):
965
957
  """
966
958
  Return cells' changes in the feature space induced by specific perturbation of a factor
967
959
 
968
960
  """
969
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
970
- dataset = CustomDataset(delta_zs)
961
+ zs = convert_to_tensor(zs, device=self.get_device())
962
+ dataset = CustomDataset(zs)
971
963
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
972
964
 
973
965
  R = []
974
966
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
975
- for delta_Z_batch, _ in dataloader:
976
- r = self._get_theta(delta_Z_batch)
967
+ for Z_batch, _ in dataloader:
968
+ r = self._log_mu(Z_batch)
977
969
  R.append(tensor_to_numpy(r))
978
970
  pbar.update(1)
979
971
 
980
972
  R = np.concatenate(R)
981
973
  return R
982
974
 
983
- def _count(self, concentrate, library_size=None):
975
+ def _count(self, log_mu, library_size=None):
984
976
  if self.loss_func == 'bernoulli':
985
977
  #counts = self.sigmoid(concentrate)
986
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
978
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
987
979
  elif self.loss_func == 'multinomial':
988
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
980
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
989
981
  counts = theta * library_size
990
982
  else:
991
- rate = concentrate.exp()
983
+ rate = log_mu.exp()
992
984
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
993
985
  counts = theta * library_size
994
986
  return counts
995
987
 
996
- def get_counts(self, concentrate,
988
+ def get_counts(self, log_mu,
997
989
  library_sizes,
998
990
  batch_size: int = 1024):
999
991
 
1000
- concentrate = convert_to_tensor(concentrate, device=self.get_device())
992
+ log_mu = convert_to_tensor(log_mu, device=self.get_device())
1001
993
 
1002
994
  if type(library_sizes) == list:
1003
995
  library_sizes = np.array(library_sizes).reshape(-1,1)
@@ -1005,13 +997,13 @@ class DensityFlow2(nn.Module):
1005
997
  library_sizes = library_sizes.reshape(-1,1)
1006
998
  ls = convert_to_tensor(library_sizes, device=self.get_device())
1007
999
 
1008
- dataset = CustomDataset2(concentrate,ls)
1000
+ dataset = CustomDataset2(log_mu,ls)
1009
1001
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
1010
1002
 
1011
1003
  E = []
1012
1004
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
1013
- for C_batch, L_batch, _ in dataloader:
1014
- counts = self._count(C_batch, L_batch)
1005
+ for Mu_batch, L_batch, _ in dataloader:
1006
+ counts = self._count(Mu_batch, L_batch)
1015
1007
  E.append(tensor_to_numpy(counts))
1016
1008
  pbar.update(1)
1017
1009
 
@@ -1045,7 +1037,7 @@ class DensityFlow2(nn.Module):
1045
1037
  threshold: int = 0,
1046
1038
  use_jax: bool = True):
1047
1039
  """
1048
- Train the DensityFlow2 model.
1040
+ Train the DensityFlowLinear model.
1049
1041
 
1050
1042
  Parameters
1051
1043
  ----------
@@ -1071,7 +1063,7 @@ class DensityFlow2(nn.Module):
1071
1063
  Parameter for optimization.
1072
1064
  use_jax
1073
1065
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1074
- the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow2 in the shell command.
1066
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlowLinear in the shell command.
1075
1067
  """
1076
1068
  xs = self.preprocess(xs, threshold=threshold)
1077
1069
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1189,12 +1181,12 @@ class DensityFlow2(nn.Module):
1189
1181
 
1190
1182
 
1191
1183
  EXAMPLE_RUN = (
1192
- "example run: DensityFlow2 --help"
1184
+ "example run: DensityFlowLinear --help"
1193
1185
  )
1194
1186
 
1195
1187
  def parse_args():
1196
1188
  parser = argparse.ArgumentParser(
1197
- description="DensityFlow2\n{}".format(EXAMPLE_RUN))
1189
+ description="DensityFlowLinear\n{}".format(EXAMPLE_RUN))
1198
1190
 
1199
1191
  parser.add_argument(
1200
1192
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1381,10 +1373,10 @@ def main():
1381
1373
  cell_factor_size = 0 if us is None else us.shape[1]
1382
1374
 
1383
1375
  ###########################################
1384
- df = DensityFlow2(
1376
+ df = DensityFlowLinear(
1385
1377
  input_size=input_size,
1386
1378
  cell_factor_size=cell_factor_size,
1387
- inverse_dispersion=args.inverse_dispersion,
1379
+ dispersion=args.dispersion,
1388
1380
  z_dim=args.z_dim,
1389
1381
  hidden_layers=args.hidden_layers,
1390
1382
  hidden_layer_activation=args.hidden_layer_activation,
@@ -1412,9 +1404,9 @@ def main():
1412
1404
 
1413
1405
  if args.save_model is not None:
1414
1406
  if args.save_model.endswith('gz'):
1415
- DensityFlow2.save_model(df, args.save_model, compression=True)
1407
+ DensityFlowLinear.save_model(df, args.save_model, compression=True)
1416
1408
  else:
1417
- DensityFlow2.save_model(df, args.save_model)
1409
+ DensityFlowLinear.save_model(df, args.save_model)
1418
1410
 
1419
1411
 
1420
1412
 
@@ -1,6 +1,6 @@
1
1
  from .SURE import SURE
2
2
  from .DensityFlow import DensityFlow
3
- from .DensityFlow2 import DensityFlow2
3
+ from .DensityFlowLinear import DensityFlowLinear
4
4
  from .PerturbE import PerturbE
5
5
  from .TranscriptomeDecoder import TranscriptomeDecoder
6
6
  from .SimpleTranscriptomeDecoder import SimpleTranscriptomeDecoder
@@ -12,7 +12,7 @@ from . import utils
12
12
  from . import codebook
13
13
  from . import SURE
14
14
  from . import DensityFlow
15
- from . import DensityFlow2
15
+ from . import DensityFlowLinear
16
16
  from . import atac
17
17
  from . import flow
18
18
  from . import perturb
@@ -23,6 +23,6 @@ from . import EfficientTranscriptomeDecoder
23
23
  from . import VirtualCellDecoder
24
24
  from . import PerturbationAwareDecoder
25
25
 
26
- __all__ = ['SURE', 'DensityFlow', 'DensityFlow2', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
26
+ __all__ = ['SURE', 'DensityFlow', 'DensityFlowLinear', 'PerturbE', 'TranscriptomeDecoder', 'SimpleTranscriptomeDecoder',
27
27
  'EfficientTranscriptomeDecoder', 'VirtualCellDecoder', 'PerturbationAwareDecoder',
28
28
  'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.4.25
3
+ Version: 2.4.42
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -2,7 +2,7 @@ LICENSE
2
2
  README.md
3
3
  setup.py
4
4
  SURE/DensityFlow.py
5
- SURE/DensityFlow2.py
5
+ SURE/DensityFlowLinear.py
6
6
  SURE/EfficientTranscriptomeDecoder.py
7
7
  SURE/PerturbE.py
8
8
  SURE/PerturbationAwareDecoder.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.4.25',
8
+ version='2.4.42',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes