SURE-tools 2.4.22__py3-none-any.whl → 2.4.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py CHANGED
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 30,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,10 +108,12 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
112
114
 
113
115
  print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Codebook size: {self.code_size}")
114
117
  print(f" - Latent Dimension: {self.latent_dim}")
115
118
  print(f" - Gene Dimension: {self.input_size}")
116
119
  print(f" - Hidden Dimensions: {self.hidden_layers}")
@@ -258,7 +261,7 @@ class DensityFlow(nn.Module):
258
261
  )
259
262
  )
260
263
 
261
- self.decoder_concentrate = MLP(
264
+ self.decoder_log_mu = MLP(
262
265
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
263
266
  activation=activate_fct,
264
267
  output_activation=None,
@@ -348,8 +351,8 @@ class DensityFlow(nn.Module):
348
351
  self.options = dict(dtype=xs.dtype, device=xs.device)
349
352
 
350
353
  if self.loss_func=='negbinomial':
351
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
352
- xs.new_ones(self.input_size), constraint=constraints.positive)
354
+ dispersion = pyro.param("dispersion", self.dispersion *
355
+ xs.new_ones(self.input_size), constraint=constraints.positive)
353
356
 
354
357
  if self.use_zeroinflate:
355
358
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -383,28 +386,32 @@ class DensityFlow(nn.Module):
383
386
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
384
387
 
385
388
  zs = zns
386
- concentrate = self.decoder_concentrate(zs)
389
+ log_mu = self.decoder_log_mu(zs)
387
390
  if self.loss_func in ['bernoulli']:
388
- log_theta = concentrate
391
+ log_theta = log_mu
392
+ elif self.loss_func == 'negbinomial':
393
+ mu = log_mu.exp()
389
394
  else:
390
- rate = concentrate.exp()
395
+ rate = log_mu.exp()
391
396
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
392
397
  if self.loss_func == 'poisson':
393
398
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
394
399
 
395
400
  if self.loss_func == 'negbinomial':
401
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
396
402
  if self.use_zeroinflate:
397
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
403
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
404
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
398
405
  else:
399
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
406
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
407
+ logits=logits).to_event(1), obs=xs)
400
408
  elif self.loss_func == 'poisson':
401
409
  if self.use_zeroinflate:
402
410
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
411
  else:
404
412
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
413
  elif self.loss_func == 'multinomial':
406
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
407
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
414
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
408
415
  elif self.loss_func == 'bernoulli':
409
416
  if self.use_zeroinflate:
410
417
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -428,8 +435,8 @@ class DensityFlow(nn.Module):
428
435
  self.options = dict(dtype=xs.dtype, device=xs.device)
429
436
 
430
437
  if self.loss_func=='negbinomial':
431
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
432
- xs.new_ones(self.input_size), constraint=constraints.positive)
438
+ dispersion = pyro.param("dispersion", self.dispersion *
439
+ xs.new_ones(self.input_size), constraint=constraints.positive)
433
440
 
434
441
  if self.use_zeroinflate:
435
442
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -468,28 +475,30 @@ class DensityFlow(nn.Module):
468
475
  else:
469
476
  zs = zns
470
477
 
471
- concentrate = self.decoder_concentrate(zs)
478
+ log_mu = self.decoder_log_mu(zs)
472
479
  if self.loss_func in ['bernoulli']:
473
- log_theta = concentrate
480
+ log_theta = log_mu
481
+ elif self.loss_func == 'negbinomial':
482
+ mu = log_mu.exp()
474
483
  else:
475
- rate = concentrate.exp()
484
+ rate = log_mu.exp()
476
485
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
477
486
  if self.loss_func == 'poisson':
478
487
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
479
488
 
480
489
  if self.loss_func == 'negbinomial':
490
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
481
491
  if self.use_zeroinflate:
482
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
492
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
483
493
  else:
484
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
494
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
485
495
  elif self.loss_func == 'poisson':
486
496
  if self.use_zeroinflate:
487
497
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
488
498
  else:
489
499
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
490
500
  elif self.loss_func == 'multinomial':
491
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
492
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
501
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
493
502
  elif self.loss_func == 'bernoulli':
494
503
  if self.use_zeroinflate:
495
504
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -513,8 +522,8 @@ class DensityFlow(nn.Module):
513
522
  self.options = dict(dtype=xs.dtype, device=xs.device)
514
523
 
515
524
  if self.loss_func=='negbinomial':
516
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
517
- xs.new_ones(self.input_size), constraint=constraints.positive)
525
+ dispersion = pyro.param("dispersion", self.dispersion *
526
+ xs.new_ones(self.input_size), constraint=constraints.positive)
518
527
 
519
528
  if self.use_zeroinflate:
520
529
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -565,28 +574,31 @@ class DensityFlow(nn.Module):
565
574
 
566
575
  zs = zns
567
576
 
568
- concentrate = self.decoder_concentrate(zs)
577
+ log_mu = self.decoder_log_mu(zs)
569
578
  if self.loss_func in ['bernoulli']:
570
- log_theta = concentrate
579
+ log_theta = log_mu
580
+ elif self.loss_func in ['negbinomial']:
581
+ mu = log_mu.exp()
571
582
  else:
572
- rate = concentrate.exp()
583
+ rate = log_mu.exp()
573
584
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
574
585
  if self.loss_func == 'poisson':
575
586
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
576
587
 
577
588
  if self.loss_func == 'negbinomial':
589
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
578
590
  if self.use_zeroinflate:
579
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
591
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
592
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
580
593
  else:
581
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
594
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
582
595
  elif self.loss_func == 'poisson':
583
596
  if self.use_zeroinflate:
584
597
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
585
598
  else:
586
599
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
587
600
  elif self.loss_func == 'multinomial':
588
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
589
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
601
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
590
602
  elif self.loss_func == 'bernoulli':
591
603
  if self.use_zeroinflate:
592
604
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -610,8 +622,8 @@ class DensityFlow(nn.Module):
610
622
  self.options = dict(dtype=xs.dtype, device=xs.device)
611
623
 
612
624
  if self.loss_func=='negbinomial':
613
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
614
- xs.new_ones(self.input_size), constraint=constraints.positive)
625
+ dispersion = pyro.param("dispersion", self.dispersion *
626
+ xs.new_ones(self.input_size), constraint=constraints.positive)
615
627
 
616
628
  if self.use_zeroinflate:
617
629
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -672,28 +684,31 @@ class DensityFlow(nn.Module):
672
684
  else:
673
685
  zs = zns
674
686
 
675
- concentrate = self.decoder_concentrate(zs)
687
+ log_mu = self.decoder_log_mu(zs)
676
688
  if self.loss_func in ['bernoulli']:
677
- log_theta = concentrate
689
+ log_theta = log_mu
690
+ elif self.loss_func in ['negbinomial']:
691
+ mu = log_mu.exp()
678
692
  else:
679
- rate = concentrate.exp()
693
+ rate = log_mu.exp()
680
694
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
681
695
  if self.loss_func == 'poisson':
682
696
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
683
697
 
684
698
  if self.loss_func == 'negbinomial':
699
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
685
700
  if self.use_zeroinflate:
686
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
701
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
702
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
687
703
  else:
688
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
704
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
689
705
  elif self.loss_func == 'poisson':
690
706
  if self.use_zeroinflate:
691
707
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
692
708
  else:
693
709
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
694
710
  elif self.loss_func == 'multinomial':
695
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
696
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
711
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
697
712
  elif self.loss_func == 'bernoulli':
698
713
  if self.use_zeroinflate:
699
714
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -717,13 +732,13 @@ class DensityFlow(nn.Module):
717
732
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
718
733
  #else:
719
734
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
720
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
735
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
721
736
  else:
722
737
  #if self.turn_off_cell_specific:
723
738
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
724
739
  #else:
725
740
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
726
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
741
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
727
742
  return zus
728
743
 
729
744
  def _get_codebook_identity(self):
@@ -744,6 +759,28 @@ class DensityFlow(nn.Module):
744
759
  cb = self._get_codebook()
745
760
  cb = tensor_to_numpy(cb)
746
761
  return cb
762
+
763
+ def _get_complete_embedding(self, xs, us):
764
+ basal,_ = self._get_basal_embedding(xs)
765
+ dzs = self._total_effects(basal, us)
766
+ return basal + dzs
767
+
768
+ def get_complete_embedding(self, xs, us, batch_size:int=1024):
769
+ xs = self.preprocess(xs)
770
+ xs = convert_to_tensor(xs, device=self.get_device())
771
+ us = convert_to_tensor(us, device=self.get_device())
772
+ dataset = CustomDataset2(xs, us)
773
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
774
+
775
+ Z = []
776
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
777
+ for X_batch, U_batch, _ in dataloader:
778
+ zns = self._get_complete_embedding(X_batch, U_batch)
779
+ Z.append(tensor_to_numpy(zns))
780
+ pbar.update(1)
781
+
782
+ Z = np.concatenate(Z)
783
+ return Z
747
784
 
748
785
  def _get_basal_embedding(self, xs):
749
786
  loc, scale = self.encoder_zn(xs)
@@ -865,12 +902,12 @@ class DensityFlow(nn.Module):
865
902
  us_i = us[:,pert_idx].reshape(-1,1)
866
903
 
867
904
  # factor effect of xs
868
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
905
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
869
906
 
870
907
  # perturbation effect
871
908
  ps = np.ones_like(us_i)
872
909
  if np.sum(np.abs(ps-us_i))>=1:
873
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
910
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
874
911
  zs = zs + dzs0 + dzs
875
912
  else:
876
913
  zs = zs + dzs0
@@ -884,10 +921,11 @@ class DensityFlow(nn.Module):
884
921
  library_sizes = library_sizes.reshape(-1,1)
885
922
 
886
923
  counts = self.get_counts(zs, library_sizes=library_sizes)
924
+ log_mu = self.get_log_mu(zs)
887
925
 
888
- return counts, zs
926
+ return counts, log_mu
889
927
 
890
- def _cell_response(self, zs, perturb_idx, perturb):
928
+ def _cell_shift(self, zs, perturb_idx, perturb):
891
929
  #zns,_ = self.encoder_zn(xs)
892
930
  #zns,_ = self._get_basal_embedding(xs)
893
931
  zns = zs
@@ -904,7 +942,7 @@ class DensityFlow(nn.Module):
904
942
 
905
943
  return ms
906
944
 
907
- def get_cell_response(self,
945
+ def get_cell_shift(self,
908
946
  zs,
909
947
  perturb_idx,
910
948
  perturb_us,
@@ -922,46 +960,43 @@ class DensityFlow(nn.Module):
922
960
  Z = []
923
961
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
924
962
  for Z_batch, P_batch, _ in dataloader:
925
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
963
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
926
964
  Z.append(tensor_to_numpy(zns))
927
965
  pbar.update(1)
928
966
 
929
967
  Z = np.concatenate(Z)
930
968
  return Z
931
969
 
932
- def _get_expression_response(self, delta_zs):
933
- return self.decoder_concentrate(delta_zs)
970
+ def _log_mu(self, zs):
971
+ return self.decoder_log_mu(zs)
934
972
 
935
- def get_expression_response(self,
936
- delta_zs,
937
- batch_size: int = 1024):
973
+ def get_log_mu(self, zs, batch_size: int = 1024):
938
974
  """
939
975
  Return cells' changes in the feature space induced by specific perturbation of a factor
940
976
 
941
977
  """
942
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
943
- dataset = CustomDataset(delta_zs)
978
+ zs = convert_to_tensor(zs, device=self.get_device())
979
+ dataset = CustomDataset(zs)
944
980
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
945
981
 
946
982
  R = []
947
983
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
948
- for delta_Z_batch, _ in dataloader:
949
- r = self._get_expression_response(delta_Z_batch)
984
+ for Z_batch, _ in dataloader:
985
+ r = self._log_mu(Z_batch)
950
986
  R.append(tensor_to_numpy(r))
951
987
  pbar.update(1)
952
988
 
953
989
  R = np.concatenate(R)
954
990
  return R
955
991
 
956
- def _count(self, concentrate, library_size=None):
992
+ def _count(self, log_mu, library_size=None):
957
993
  if self.loss_func == 'bernoulli':
958
- #counts = self.sigmoid(concentrate)
959
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
994
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
960
995
  elif self.loss_func == 'multinomial':
961
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
996
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
962
997
  counts = theta * library_size
963
998
  else:
964
- rate = concentrate.exp()
999
+ rate = log_mu.exp()
965
1000
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
966
1001
  counts = theta * library_size
967
1002
  return counts
@@ -983,8 +1018,8 @@ class DensityFlow(nn.Module):
983
1018
  E = []
984
1019
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
985
1020
  for Z_batch, L_batch, _ in dataloader:
986
- concentrate = self._get_expression_response(Z_batch)
987
- counts = self._count(concentrate, L_batch)
1021
+ log_mu = self._log_mu(Z_batch)
1022
+ counts = self._count(log_mu, L_batch)
988
1023
  E.append(tensor_to_numpy(counts))
989
1024
  pbar.update(1)
990
1025
 
@@ -1157,8 +1192,55 @@ class DensityFlow(nn.Module):
1157
1192
  else:
1158
1193
  with open(file_path, 'rb') as pickle_file:
1159
1194
  model = pickle.load(pickle_file)
1195
+
1196
+ print(f"🧬 DensityFlow Initialized:")
1197
+ print(f" - Codebook size: {model.code_size}")
1198
+ print(f" - Latent Dimension: {model.latent_dim}")
1199
+ print(f" - Gene Dimension: {model.input_size}")
1200
+ print(f" - Hidden Dimensions: {model.hidden_layers}")
1201
+ print(f" - Device: {model.get_device()}")
1202
+ print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
1160
1203
 
1161
1204
  return model
1205
+
1206
+ ''' def save(self, path):
1207
+ """Save model checkpoint"""
1208
+ torch.save({
1209
+ 'model_state_dict': self.state_dict(),
1210
+ 'model_config': {
1211
+ 'input_size': self.input_size,
1212
+ 'codebook_size': self.code_size,
1213
+ 'cell_factor_size': self.cell_factor_size,
1214
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1215
+ 'supervised_mode':self.supervised_mode,
1216
+ 'z_dim': self.latent_dim,
1217
+ 'z_dist': self.latent_dist,
1218
+ 'loss_func': self.loss_func,
1219
+ 'dispersion': self.dispersion,
1220
+ 'use_zeroinflate': self.use_zeroinflate,
1221
+ 'hidden_layers':self.hidden_layers,
1222
+ 'hidden_layer_activation':self.hidden_layer_activation,
1223
+ 'nn_dropout':self.nn_dropout,
1224
+ 'post_layer_fct':self.post_layer_fct,
1225
+ 'post_act_fct':self.post_act_fct,
1226
+ 'config_enum':self.config_enum,
1227
+ 'use_cuda':self.use_cuda,
1228
+ 'seed':self.seed,
1229
+ 'zero_bias':self.use_bias,
1230
+ 'dtype':self.dtype,
1231
+ }
1232
+ }, path)
1233
+
1234
+ @classmethod
1235
+ def load_model(cls, model_path: str):
1236
+ """Load pre-trained model"""
1237
+ checkpoint = torch.load(model_path)
1238
+ model = DensityFlow(**checkpoint.get('model_config'))
1239
+
1240
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1241
+ model.load_state_dict(checkpoint['model_state_dict'])
1242
+
1243
+ return model'''
1162
1244
 
1163
1245
 
1164
1246
  EXAMPLE_RUN = (
@@ -1357,7 +1439,7 @@ def main():
1357
1439
  df = DensityFlow(
1358
1440
  input_size=input_size,
1359
1441
  cell_factor_size=cell_factor_size,
1360
- inverse_dispersion=args.inverse_dispersion,
1442
+ dispersion=args.dispersion,
1361
1443
  z_dim=args.z_dim,
1362
1444
  hidden_layers=args.hidden_layers,
1363
1445
  hidden_layer_activation=args.hidden_layer_activation,