SURE-tools 2.4.7__py3-none-any.whl → 2.4.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/DensityFlow.py CHANGED
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 30,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,8 +108,17 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
114
+
115
+ print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Codebook size: {self.code_size}")
117
+ print(f" - Latent Dimension: {self.latent_dim}")
118
+ print(f" - Gene Dimension: {self.input_size}")
119
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
120
+ print(f" - Device: {self.get_device()}")
121
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
112
122
 
113
123
  def setup_networks(self):
114
124
  latent_dim = self.latent_dim
@@ -251,7 +261,7 @@ class DensityFlow(nn.Module):
251
261
  )
252
262
  )
253
263
 
254
- self.decoder_concentrate = MLP(
264
+ self.decoder_log_mu = MLP(
255
265
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
256
266
  activation=activate_fct,
257
267
  output_activation=None,
@@ -341,8 +351,8 @@ class DensityFlow(nn.Module):
341
351
  self.options = dict(dtype=xs.dtype, device=xs.device)
342
352
 
343
353
  if self.loss_func=='negbinomial':
344
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
345
- xs.new_ones(self.input_size), constraint=constraints.positive)
354
+ dispersion = pyro.param("dispersion", self.dispersion *
355
+ xs.new_ones(self.input_size), constraint=constraints.positive)
346
356
 
347
357
  if self.use_zeroinflate:
348
358
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -376,28 +386,32 @@ class DensityFlow(nn.Module):
376
386
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
377
387
 
378
388
  zs = zns
379
- concentrate = self.decoder_concentrate(zs)
389
+ log_mu = self.decoder_log_mu(zs)
380
390
  if self.loss_func in ['bernoulli']:
381
- log_theta = concentrate
391
+ log_theta = log_mu
392
+ elif self.loss_func == 'negbinomial':
393
+ mu = log_mu.exp()
382
394
  else:
383
- rate = concentrate.exp()
395
+ rate = log_mu.exp()
384
396
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
385
397
  if self.loss_func == 'poisson':
386
398
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
387
399
 
388
400
  if self.loss_func == 'negbinomial':
401
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
389
402
  if self.use_zeroinflate:
390
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
403
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
404
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
391
405
  else:
392
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
406
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
407
+ logits=logits).to_event(1), obs=xs)
393
408
  elif self.loss_func == 'poisson':
394
409
  if self.use_zeroinflate:
395
410
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
396
411
  else:
397
412
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
413
  elif self.loss_func == 'multinomial':
399
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
400
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
414
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
401
415
  elif self.loss_func == 'bernoulli':
402
416
  if self.use_zeroinflate:
403
417
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -421,8 +435,8 @@ class DensityFlow(nn.Module):
421
435
  self.options = dict(dtype=xs.dtype, device=xs.device)
422
436
 
423
437
  if self.loss_func=='negbinomial':
424
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
425
- xs.new_ones(self.input_size), constraint=constraints.positive)
438
+ dispersion = pyro.param("dispersion", self.dispersion *
439
+ xs.new_ones(self.input_size), constraint=constraints.positive)
426
440
 
427
441
  if self.use_zeroinflate:
428
442
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -461,28 +475,30 @@ class DensityFlow(nn.Module):
461
475
  else:
462
476
  zs = zns
463
477
 
464
- concentrate = self.decoder_concentrate(zs)
478
+ log_mu = self.decoder_log_mu(zs)
465
479
  if self.loss_func in ['bernoulli']:
466
- log_theta = concentrate
480
+ log_theta = log_mu
481
+ elif self.loss_func == 'negbinomial':
482
+ mu = log_mu.exp()
467
483
  else:
468
- rate = concentrate.exp()
484
+ rate = log_mu.exp()
469
485
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
470
486
  if self.loss_func == 'poisson':
471
487
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
472
488
 
473
489
  if self.loss_func == 'negbinomial':
490
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
474
491
  if self.use_zeroinflate:
475
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
492
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
476
493
  else:
477
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
494
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
478
495
  elif self.loss_func == 'poisson':
479
496
  if self.use_zeroinflate:
480
497
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
481
498
  else:
482
499
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
483
500
  elif self.loss_func == 'multinomial':
484
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
485
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
501
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
486
502
  elif self.loss_func == 'bernoulli':
487
503
  if self.use_zeroinflate:
488
504
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -506,8 +522,8 @@ class DensityFlow(nn.Module):
506
522
  self.options = dict(dtype=xs.dtype, device=xs.device)
507
523
 
508
524
  if self.loss_func=='negbinomial':
509
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
510
- xs.new_ones(self.input_size), constraint=constraints.positive)
525
+ dispersion = pyro.param("dispersion", self.dispersion *
526
+ xs.new_ones(self.input_size), constraint=constraints.positive)
511
527
 
512
528
  if self.use_zeroinflate:
513
529
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -558,28 +574,31 @@ class DensityFlow(nn.Module):
558
574
 
559
575
  zs = zns
560
576
 
561
- concentrate = self.decoder_concentrate(zs)
577
+ log_mu = self.decoder_log_mu(zs)
562
578
  if self.loss_func in ['bernoulli']:
563
- log_theta = concentrate
579
+ log_theta = log_mu
580
+ elif self.loss_func in ['negbinomial']:
581
+ mu = log_mu.exp()
564
582
  else:
565
- rate = concentrate.exp()
583
+ rate = log_mu.exp()
566
584
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
567
585
  if self.loss_func == 'poisson':
568
586
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
569
587
 
570
588
  if self.loss_func == 'negbinomial':
589
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
571
590
  if self.use_zeroinflate:
572
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
591
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
592
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
573
593
  else:
574
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
594
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
575
595
  elif self.loss_func == 'poisson':
576
596
  if self.use_zeroinflate:
577
597
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
578
598
  else:
579
599
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
580
600
  elif self.loss_func == 'multinomial':
581
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
582
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
601
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
583
602
  elif self.loss_func == 'bernoulli':
584
603
  if self.use_zeroinflate:
585
604
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -603,8 +622,8 @@ class DensityFlow(nn.Module):
603
622
  self.options = dict(dtype=xs.dtype, device=xs.device)
604
623
 
605
624
  if self.loss_func=='negbinomial':
606
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
607
- xs.new_ones(self.input_size), constraint=constraints.positive)
625
+ dispersion = pyro.param("dispersion", self.dispersion *
626
+ xs.new_ones(self.input_size), constraint=constraints.positive)
608
627
 
609
628
  if self.use_zeroinflate:
610
629
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -665,28 +684,31 @@ class DensityFlow(nn.Module):
665
684
  else:
666
685
  zs = zns
667
686
 
668
- concentrate = self.decoder_concentrate(zs)
687
+ log_mu = self.decoder_log_mu(zs)
669
688
  if self.loss_func in ['bernoulli']:
670
- log_theta = concentrate
689
+ log_theta = log_mu
690
+ elif self.loss_func in ['negbinomial']:
691
+ mu = log_mu.exp()
671
692
  else:
672
- rate = concentrate.exp()
693
+ rate = log_mu.exp()
673
694
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
674
695
  if self.loss_func == 'poisson':
675
696
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
676
697
 
677
698
  if self.loss_func == 'negbinomial':
699
+ logits = (mu.log()-dispersion.log()).clamp(min=-10, max=10)
678
700
  if self.use_zeroinflate:
679
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
701
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
702
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
680
703
  else:
681
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
704
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
682
705
  elif self.loss_func == 'poisson':
683
706
  if self.use_zeroinflate:
684
707
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
685
708
  else:
686
709
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
687
710
  elif self.loss_func == 'multinomial':
688
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
689
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
711
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
690
712
  elif self.loss_func == 'bernoulli':
691
713
  if self.use_zeroinflate:
692
714
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -710,13 +732,13 @@ class DensityFlow(nn.Module):
710
732
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
711
733
  #else:
712
734
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
713
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
735
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
714
736
  else:
715
737
  #if self.turn_off_cell_specific:
716
738
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
717
739
  #else:
718
740
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
719
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
741
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
720
742
  return zus
721
743
 
722
744
  def _get_codebook_identity(self):
@@ -737,6 +759,28 @@ class DensityFlow(nn.Module):
737
759
  cb = self._get_codebook()
738
760
  cb = tensor_to_numpy(cb)
739
761
  return cb
762
+
763
+ def _get_complete_embedding(self, xs, us):
764
+ basal,_ = self._get_basal_embedding(xs)
765
+ dzs = self._total_effects(basal, us)
766
+ return basal + dzs
767
+
768
+ def get_complete_embedding(self, xs, us, batch_size:int=1024):
769
+ xs = self.preprocess(xs)
770
+ xs = convert_to_tensor(xs, device=self.get_device())
771
+ us = convert_to_tensor(us, device=self.get_device())
772
+ dataset = CustomDataset2(xs, us)
773
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
774
+
775
+ Z = []
776
+ with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
777
+ for X_batch, U_batch, _ in dataloader:
778
+ zns = self._get_complete_embedding(X_batch, U_batch)
779
+ Z.append(tensor_to_numpy(zns))
780
+ pbar.update(1)
781
+
782
+ Z = np.concatenate(Z)
783
+ return Z
740
784
 
741
785
  def _get_basal_embedding(self, xs):
742
786
  loc, scale = self.encoder_zn(xs)
@@ -858,12 +902,12 @@ class DensityFlow(nn.Module):
858
902
  us_i = us[:,pert_idx].reshape(-1,1)
859
903
 
860
904
  # factor effect of xs
861
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
905
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
862
906
 
863
907
  # perturbation effect
864
908
  ps = np.ones_like(us_i)
865
909
  if np.sum(np.abs(ps-us_i))>=1:
866
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
910
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
867
911
  zs = zs + dzs0 + dzs
868
912
  else:
869
913
  zs = zs + dzs0
@@ -877,10 +921,11 @@ class DensityFlow(nn.Module):
877
921
  library_sizes = library_sizes.reshape(-1,1)
878
922
 
879
923
  counts = self.get_counts(zs, library_sizes=library_sizes)
924
+ log_mu = self.get_log_mu(zs)
880
925
 
881
- return counts, zs
926
+ return counts, log_mu
882
927
 
883
- def _cell_response(self, zs, perturb_idx, perturb):
928
+ def _cell_shift(self, zs, perturb_idx, perturb):
884
929
  #zns,_ = self.encoder_zn(xs)
885
930
  #zns,_ = self._get_basal_embedding(xs)
886
931
  zns = zs
@@ -897,7 +942,7 @@ class DensityFlow(nn.Module):
897
942
 
898
943
  return ms
899
944
 
900
- def get_cell_response(self,
945
+ def get_cell_shift(self,
901
946
  zs,
902
947
  perturb_idx,
903
948
  perturb_us,
@@ -915,46 +960,43 @@ class DensityFlow(nn.Module):
915
960
  Z = []
916
961
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
917
962
  for Z_batch, P_batch, _ in dataloader:
918
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
963
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
919
964
  Z.append(tensor_to_numpy(zns))
920
965
  pbar.update(1)
921
966
 
922
967
  Z = np.concatenate(Z)
923
968
  return Z
924
969
 
925
- def _get_expression_response(self, delta_zs):
926
- return self.decoder_concentrate(delta_zs)
970
+ def _log_mu(self, zs):
971
+ return self.decoder_log_mu(zs)
927
972
 
928
- def get_expression_response(self,
929
- delta_zs,
930
- batch_size: int = 1024):
973
+ def get_log_mu(self, zs, batch_size: int = 1024):
931
974
  """
932
975
  Return cells' changes in the feature space induced by specific perturbation of a factor
933
976
 
934
977
  """
935
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
936
- dataset = CustomDataset(delta_zs)
978
+ zs = convert_to_tensor(zs, device=self.get_device())
979
+ dataset = CustomDataset(zs)
937
980
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
938
981
 
939
982
  R = []
940
983
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
941
- for delta_Z_batch, _ in dataloader:
942
- r = self._get_expression_response(delta_Z_batch)
984
+ for Z_batch, _ in dataloader:
985
+ r = self._log_mu(Z_batch)
943
986
  R.append(tensor_to_numpy(r))
944
987
  pbar.update(1)
945
988
 
946
989
  R = np.concatenate(R)
947
990
  return R
948
991
 
949
- def _count(self, concentrate, library_size=None):
992
+ def _count(self, log_mu, library_size=None):
950
993
  if self.loss_func == 'bernoulli':
951
- #counts = self.sigmoid(concentrate)
952
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
994
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
953
995
  elif self.loss_func == 'multinomial':
954
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
996
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
955
997
  counts = theta * library_size
956
998
  else:
957
- rate = concentrate.exp()
999
+ rate = log_mu.exp()
958
1000
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
959
1001
  counts = theta * library_size
960
1002
  return counts
@@ -976,8 +1018,8 @@ class DensityFlow(nn.Module):
976
1018
  E = []
977
1019
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
978
1020
  for Z_batch, L_batch, _ in dataloader:
979
- concentrate = self._get_expression_response(Z_batch)
980
- counts = self._count(concentrate, L_batch)
1021
+ log_mu = self._log_mu(Z_batch)
1022
+ counts = self._count(log_mu, L_batch)
981
1023
  E.append(tensor_to_numpy(counts))
982
1024
  pbar.update(1)
983
1025
 
@@ -989,7 +1031,7 @@ class DensityFlow(nn.Module):
989
1031
  ad = sc.AnnData(xs)
990
1032
  binarize(ad, threshold=threshold)
991
1033
  xs = ad.X.copy()
992
- else:
1034
+ elif self.loss_func == 'poisson':
993
1035
  xs = np.round(xs)
994
1036
 
995
1037
  if sparse.issparse(xs):
@@ -1150,8 +1192,55 @@ class DensityFlow(nn.Module):
1150
1192
  else:
1151
1193
  with open(file_path, 'rb') as pickle_file:
1152
1194
  model = pickle.load(pickle_file)
1195
+
1196
+ print(f"🧬 DensityFlow Initialized:")
1197
+ print(f" - Codebook size: {model.code_size}")
1198
+ print(f" - Latent Dimension: {model.latent_dim}")
1199
+ print(f" - Gene Dimension: {model.input_size}")
1200
+ print(f" - Hidden Dimensions: {model.hidden_layers}")
1201
+ print(f" - Device: {model.get_device()}")
1202
+ print(f" - Parameters: {sum(p.numel() for p in model.parameters()):,}")
1153
1203
 
1154
1204
  return model
1205
+
1206
+ ''' def save(self, path):
1207
+ """Save model checkpoint"""
1208
+ torch.save({
1209
+ 'model_state_dict': self.state_dict(),
1210
+ 'model_config': {
1211
+ 'input_size': self.input_size,
1212
+ 'codebook_size': self.code_size,
1213
+ 'cell_factor_size': self.cell_factor_size,
1214
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1215
+ 'supervised_mode':self.supervised_mode,
1216
+ 'z_dim': self.latent_dim,
1217
+ 'z_dist': self.latent_dist,
1218
+ 'loss_func': self.loss_func,
1219
+ 'dispersion': self.dispersion,
1220
+ 'use_zeroinflate': self.use_zeroinflate,
1221
+ 'hidden_layers':self.hidden_layers,
1222
+ 'hidden_layer_activation':self.hidden_layer_activation,
1223
+ 'nn_dropout':self.nn_dropout,
1224
+ 'post_layer_fct':self.post_layer_fct,
1225
+ 'post_act_fct':self.post_act_fct,
1226
+ 'config_enum':self.config_enum,
1227
+ 'use_cuda':self.use_cuda,
1228
+ 'seed':self.seed,
1229
+ 'zero_bias':self.use_bias,
1230
+ 'dtype':self.dtype,
1231
+ }
1232
+ }, path)
1233
+
1234
+ @classmethod
1235
+ def load_model(cls, model_path: str):
1236
+ """Load pre-trained model"""
1237
+ checkpoint = torch.load(model_path)
1238
+ model = DensityFlow(**checkpoint.get('model_config'))
1239
+
1240
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1241
+ model.load_state_dict(checkpoint['model_state_dict'])
1242
+
1243
+ return model'''
1155
1244
 
1156
1245
 
1157
1246
  EXAMPLE_RUN = (
@@ -1350,7 +1439,7 @@ def main():
1350
1439
  df = DensityFlow(
1351
1440
  input_size=input_size,
1352
1441
  cell_factor_size=cell_factor_size,
1353
- inverse_dispersion=args.inverse_dispersion,
1442
+ dispersion=args.dispersion,
1354
1443
  z_dim=args.z_dim,
1355
1444
  hidden_layers=args.hidden_layers,
1356
1445
  hidden_layer_activation=args.hidden_layer_activation,