SURE-tools 2.4.7__py3-none-any.whl → 2.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/DensityFlow.py CHANGED
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 100,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,8 +108,16 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
114
+
115
+ print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Latent Dimension: {self.latent_dim}")
117
+ print(f" - Gene Dimension: {self.input_size}")
118
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
119
+ print(f" - Device: {self.get_device()}")
120
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
112
121
 
113
122
  def setup_networks(self):
114
123
  latent_dim = self.latent_dim
@@ -251,7 +260,7 @@ class DensityFlow(nn.Module):
251
260
  )
252
261
  )
253
262
 
254
- self.decoder_concentrate = MLP(
263
+ self.decoder_log_mu = MLP(
255
264
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
256
265
  activation=activate_fct,
257
266
  output_activation=None,
@@ -341,8 +350,8 @@ class DensityFlow(nn.Module):
341
350
  self.options = dict(dtype=xs.dtype, device=xs.device)
342
351
 
343
352
  if self.loss_func=='negbinomial':
344
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
345
- xs.new_ones(self.input_size), constraint=constraints.positive)
353
+ dispersion = pyro.param("dispersion", self.dispersion *
354
+ xs.new_ones(self.input_size), constraint=constraints.positive)
346
355
 
347
356
  if self.use_zeroinflate:
348
357
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -376,28 +385,32 @@ class DensityFlow(nn.Module):
376
385
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
377
386
 
378
387
  zs = zns
379
- concentrate = self.decoder_concentrate(zs)
388
+ log_mu = self.decoder_log_mu(zs)
380
389
  if self.loss_func in ['bernoulli']:
381
- log_theta = concentrate
390
+ log_theta = log_mu
391
+ elif self.loss_func == 'negbinomial':
392
+ mu = log_mu.exp()
382
393
  else:
383
- rate = concentrate.exp()
394
+ rate = log_mu.exp()
384
395
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
385
396
  if self.loss_func == 'poisson':
386
397
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
387
398
 
388
399
  if self.loss_func == 'negbinomial':
400
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
389
401
  if self.use_zeroinflate:
390
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
402
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
403
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
391
404
  else:
392
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
405
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
406
+ logits=logits).to_event(1), obs=xs)
393
407
  elif self.loss_func == 'poisson':
394
408
  if self.use_zeroinflate:
395
409
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
396
410
  else:
397
411
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
412
  elif self.loss_func == 'multinomial':
399
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
400
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
413
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
401
414
  elif self.loss_func == 'bernoulli':
402
415
  if self.use_zeroinflate:
403
416
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -421,8 +434,8 @@ class DensityFlow(nn.Module):
421
434
  self.options = dict(dtype=xs.dtype, device=xs.device)
422
435
 
423
436
  if self.loss_func=='negbinomial':
424
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
425
- xs.new_ones(self.input_size), constraint=constraints.positive)
437
+ dispersion = pyro.param("dispersion", self.dispersion *
438
+ xs.new_ones(self.input_size), constraint=constraints.positive)
426
439
 
427
440
  if self.use_zeroinflate:
428
441
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -461,28 +474,30 @@ class DensityFlow(nn.Module):
461
474
  else:
462
475
  zs = zns
463
476
 
464
- concentrate = self.decoder_concentrate(zs)
477
+ log_mu = self.decoder_log_mu(zs)
465
478
  if self.loss_func in ['bernoulli']:
466
- log_theta = concentrate
479
+ log_theta = log_mu
480
+ elif self.loss_func == 'negbinomial':
481
+ mu = log_mu.exp()
467
482
  else:
468
- rate = concentrate.exp()
483
+ rate = log_mu.exp()
469
484
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
470
485
  if self.loss_func == 'poisson':
471
486
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
472
487
 
473
488
  if self.loss_func == 'negbinomial':
489
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
474
490
  if self.use_zeroinflate:
475
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
491
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
476
492
  else:
477
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
493
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
478
494
  elif self.loss_func == 'poisson':
479
495
  if self.use_zeroinflate:
480
496
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
481
497
  else:
482
498
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
483
499
  elif self.loss_func == 'multinomial':
484
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
485
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
500
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
486
501
  elif self.loss_func == 'bernoulli':
487
502
  if self.use_zeroinflate:
488
503
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -506,8 +521,8 @@ class DensityFlow(nn.Module):
506
521
  self.options = dict(dtype=xs.dtype, device=xs.device)
507
522
 
508
523
  if self.loss_func=='negbinomial':
509
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
510
- xs.new_ones(self.input_size), constraint=constraints.positive)
524
+ dispersion = pyro.param("dispersion", self.dispersion *
525
+ xs.new_ones(self.input_size), constraint=constraints.positive)
511
526
 
512
527
  if self.use_zeroinflate:
513
528
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -558,28 +573,31 @@ class DensityFlow(nn.Module):
558
573
 
559
574
  zs = zns
560
575
 
561
- concentrate = self.decoder_concentrate(zs)
576
+ log_mu = self.decoder_log_mu(zs)
562
577
  if self.loss_func in ['bernoulli']:
563
- log_theta = concentrate
578
+ log_theta = log_mu
579
+ elif self.loss_func in ['negbinomial']:
580
+ mu = log_mu.exp()
564
581
  else:
565
- rate = concentrate.exp()
582
+ rate = log_mu.exp()
566
583
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
567
584
  if self.loss_func == 'poisson':
568
585
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
569
586
 
570
587
  if self.loss_func == 'negbinomial':
588
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
571
589
  if self.use_zeroinflate:
572
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
590
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
591
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
573
592
  else:
574
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
593
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
575
594
  elif self.loss_func == 'poisson':
576
595
  if self.use_zeroinflate:
577
596
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
578
597
  else:
579
598
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
580
599
  elif self.loss_func == 'multinomial':
581
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
582
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
600
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
583
601
  elif self.loss_func == 'bernoulli':
584
602
  if self.use_zeroinflate:
585
603
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -603,8 +621,8 @@ class DensityFlow(nn.Module):
603
621
  self.options = dict(dtype=xs.dtype, device=xs.device)
604
622
 
605
623
  if self.loss_func=='negbinomial':
606
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
607
- xs.new_ones(self.input_size), constraint=constraints.positive)
624
+ dispersion = pyro.param("dispersion", self.dispersion *
625
+ xs.new_ones(self.input_size), constraint=constraints.positive)
608
626
 
609
627
  if self.use_zeroinflate:
610
628
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -665,28 +683,31 @@ class DensityFlow(nn.Module):
665
683
  else:
666
684
  zs = zns
667
685
 
668
- concentrate = self.decoder_concentrate(zs)
686
+ log_mu = self.decoder_log_mu(zs)
669
687
  if self.loss_func in ['bernoulli']:
670
- log_theta = concentrate
688
+ log_theta = log_mu
689
+ elif self.loss_func in ['negbinomial']:
690
+ mu = log_mu.exp()
671
691
  else:
672
- rate = concentrate.exp()
692
+ rate = log_mu.exp()
673
693
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
674
694
  if self.loss_func == 'poisson':
675
695
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
676
696
 
677
697
  if self.loss_func == 'negbinomial':
698
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
678
699
  if self.use_zeroinflate:
679
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
700
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
701
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
680
702
  else:
681
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
703
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
682
704
  elif self.loss_func == 'poisson':
683
705
  if self.use_zeroinflate:
684
706
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
685
707
  else:
686
708
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
687
709
  elif self.loss_func == 'multinomial':
688
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
689
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
710
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
690
711
  elif self.loss_func == 'bernoulli':
691
712
  if self.use_zeroinflate:
692
713
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -710,13 +731,13 @@ class DensityFlow(nn.Module):
710
731
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
711
732
  #else:
712
733
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
713
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
734
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
714
735
  else:
715
736
  #if self.turn_off_cell_specific:
716
737
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
717
738
  #else:
718
739
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
719
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
740
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
720
741
  return zus
721
742
 
722
743
  def _get_codebook_identity(self):
@@ -858,12 +879,12 @@ class DensityFlow(nn.Module):
858
879
  us_i = us[:,pert_idx].reshape(-1,1)
859
880
 
860
881
  # factor effect of xs
861
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
882
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
862
883
 
863
884
  # perturbation effect
864
885
  ps = np.ones_like(us_i)
865
886
  if np.sum(np.abs(ps-us_i))>=1:
866
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
887
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
867
888
  zs = zs + dzs0 + dzs
868
889
  else:
869
890
  zs = zs + dzs0
@@ -877,10 +898,11 @@ class DensityFlow(nn.Module):
877
898
  library_sizes = library_sizes.reshape(-1,1)
878
899
 
879
900
  counts = self.get_counts(zs, library_sizes=library_sizes)
901
+ log_mu = self.get_log_mu(zs)
880
902
 
881
- return counts, zs
903
+ return counts, log_mu
882
904
 
883
- def _cell_response(self, zs, perturb_idx, perturb):
905
+ def _cell_shift(self, zs, perturb_idx, perturb):
884
906
  #zns,_ = self.encoder_zn(xs)
885
907
  #zns,_ = self._get_basal_embedding(xs)
886
908
  zns = zs
@@ -897,7 +919,7 @@ class DensityFlow(nn.Module):
897
919
 
898
920
  return ms
899
921
 
900
- def get_cell_response(self,
922
+ def get_cell_shift(self,
901
923
  zs,
902
924
  perturb_idx,
903
925
  perturb_us,
@@ -915,46 +937,43 @@ class DensityFlow(nn.Module):
915
937
  Z = []
916
938
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
917
939
  for Z_batch, P_batch, _ in dataloader:
918
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
940
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
919
941
  Z.append(tensor_to_numpy(zns))
920
942
  pbar.update(1)
921
943
 
922
944
  Z = np.concatenate(Z)
923
945
  return Z
924
946
 
925
- def _get_expression_response(self, delta_zs):
926
- return self.decoder_concentrate(delta_zs)
947
+ def _log_mu(self, zs):
948
+ return self.decoder_log_mu(zs)
927
949
 
928
- def get_expression_response(self,
929
- delta_zs,
930
- batch_size: int = 1024):
950
+ def get_log_mu(self, zs, batch_size: int = 1024):
931
951
  """
932
952
  Return cells' changes in the feature space induced by specific perturbation of a factor
933
953
 
934
954
  """
935
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
936
- dataset = CustomDataset(delta_zs)
955
+ zs = convert_to_tensor(zs, device=self.get_device())
956
+ dataset = CustomDataset(zs)
937
957
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
938
958
 
939
959
  R = []
940
960
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
941
- for delta_Z_batch, _ in dataloader:
942
- r = self._get_expression_response(delta_Z_batch)
961
+ for Z_batch, _ in dataloader:
962
+ r = self._log_mu(Z_batch)
943
963
  R.append(tensor_to_numpy(r))
944
964
  pbar.update(1)
945
965
 
946
966
  R = np.concatenate(R)
947
967
  return R
948
968
 
949
- def _count(self, concentrate, library_size=None):
969
+ def _count(self, log_mu, library_size=None):
950
970
  if self.loss_func == 'bernoulli':
951
- #counts = self.sigmoid(concentrate)
952
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
971
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
953
972
  elif self.loss_func == 'multinomial':
954
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
973
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
955
974
  counts = theta * library_size
956
975
  else:
957
- rate = concentrate.exp()
976
+ rate = log_mu.exp()
958
977
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
959
978
  counts = theta * library_size
960
979
  return counts
@@ -976,8 +995,8 @@ class DensityFlow(nn.Module):
976
995
  E = []
977
996
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
978
997
  for Z_batch, L_batch, _ in dataloader:
979
- concentrate = self._get_expression_response(Z_batch)
980
- counts = self._count(concentrate, L_batch)
998
+ log_mu = self._log_mu(Z_batch)
999
+ counts = self._count(log_mu, L_batch)
981
1000
  E.append(tensor_to_numpy(counts))
982
1001
  pbar.update(1)
983
1002
 
@@ -1123,7 +1142,7 @@ class DensityFlow(nn.Module):
1123
1142
  pbar.set_postfix({'loss': str_loss})
1124
1143
  pbar.update(1)
1125
1144
 
1126
- @classmethod
1145
+ '''@classmethod
1127
1146
  def save_model(cls, model, file_path, compression=False):
1128
1147
  """Save the model to the specified file path."""
1129
1148
  file_path = os.path.abspath(file_path)
@@ -1151,6 +1170,45 @@ class DensityFlow(nn.Module):
1151
1170
  with open(file_path, 'rb') as pickle_file:
1152
1171
  model = pickle.load(pickle_file)
1153
1172
 
1173
+ return model'''
1174
+
1175
+ def save_model(self, path):
1176
+ """Save model checkpoint"""
1177
+ torch.save({
1178
+ 'model_state_dict': self.state_dict(),
1179
+ 'model_config': {
1180
+ 'input_size': self.input_size,
1181
+ 'codebook_size': self.code_size,
1182
+ 'cell_factor_size': self.cell_factor_size,
1183
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1184
+ 'supervised_mode':self.supervised_mode,
1185
+ 'z_dim': self.latent_dim,
1186
+ 'z_dist': self.latent_dist,
1187
+ 'loss_func': self.loss_func,
1188
+ 'dispersion': self.dispersion,
1189
+ 'use_zeroinflate': self.use_zeroinflate,
1190
+ 'hidden_layers':self.hidden_layers,
1191
+ 'hidden_layer_activation':self.hidden_layer_activation,
1192
+ 'nn_dropout':self.nn_dropout,
1193
+ 'post_layer_fct':self.post_layer_fct,
1194
+ 'post_act_fct':self.post_act_fct,
1195
+ 'config_enum':self.config_enum,
1196
+ 'use_cuda':self.use_cuda,
1197
+ 'seed':self.seed,
1198
+ 'zero_bias':self.use_bias,
1199
+ 'dtype':self.dtype,
1200
+ }
1201
+ }, path)
1202
+
1203
+ @classmethod
1204
+ def load_model(cls, model_path: str):
1205
+ """Load pre-trained model"""
1206
+ checkpoint = torch.load(model_path)
1207
+ model = DensityFlow(**checkpoint.get('model_config'))
1208
+
1209
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1210
+ model.load_state_dict(checkpoint['model_state_dict'])
1211
+
1154
1212
  return model
1155
1213
 
1156
1214
 
@@ -1350,7 +1408,7 @@ def main():
1350
1408
  df = DensityFlow(
1351
1409
  input_size=input_size,
1352
1410
  cell_factor_size=cell_factor_size,
1353
- inverse_dispersion=args.inverse_dispersion,
1411
+ dispersion=args.dispersion,
1354
1412
  z_dim=args.z_dim,
1355
1413
  hidden_layers=args.hidden_layers,
1356
1414
  hidden_layer_activation=args.hidden_layer_activation,