SURE-tools 2.4.17__py3-none-any.whl → 2.4.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py CHANGED
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 100,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
64
+ z_dim: int = 50,
65
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,7 +81,7 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
@@ -109,6 +109,13 @@ class DensityFlow(nn.Module):
109
109
 
110
110
  set_random_seed(seed)
111
111
  self.setup_networks()
112
+
113
+ print(f"🧬 DensityFlow Initialized:")
114
+ print(f" - Latent Dimension: {self.latent_dim}")
115
+ print(f" - Gene Dimension: {self.input_size}")
116
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
117
+ print(f" - Device: {self.get_device()}")
118
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
112
119
 
113
120
  def setup_networks(self):
114
121
  latent_dim = self.latent_dim
@@ -251,7 +258,7 @@ class DensityFlow(nn.Module):
251
258
  )
252
259
  )
253
260
 
254
- self.decoder_concentrate = MLP(
261
+ self.decoder_log_mu = MLP(
255
262
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
256
263
  activation=activate_fct,
257
264
  output_activation=None,
@@ -341,8 +348,8 @@ class DensityFlow(nn.Module):
341
348
  self.options = dict(dtype=xs.dtype, device=xs.device)
342
349
 
343
350
  if self.loss_func=='negbinomial':
344
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
345
- xs.new_ones(self.input_size), constraint=constraints.positive)
351
+ dispersion = pyro.param("dispersion", self.dispersion *
352
+ xs.new_ones(self.input_size), constraint=constraints.positive)
346
353
 
347
354
  if self.use_zeroinflate:
348
355
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -376,28 +383,32 @@ class DensityFlow(nn.Module):
376
383
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
377
384
 
378
385
  zs = zns
379
- concentrate = self.decoder_concentrate(zs)
386
+ log_mu = self.decoder_log_mu(zs)
380
387
  if self.loss_func in ['bernoulli']:
381
- log_theta = concentrate
388
+ log_theta = log_mu
389
+ elif self.loss_func == 'negbinomial':
390
+ mu = log_mu.exp()
382
391
  else:
383
- rate = concentrate.exp()
392
+ rate = log_mu.exp()
384
393
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
385
394
  if self.loss_func == 'poisson':
386
395
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
387
396
 
388
397
  if self.loss_func == 'negbinomial':
398
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
389
399
  if self.use_zeroinflate:
390
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
400
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
401
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
391
402
  else:
392
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
403
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
404
+ logits=logits).to_event(1), obs=xs)
393
405
  elif self.loss_func == 'poisson':
394
406
  if self.use_zeroinflate:
395
407
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
396
408
  else:
397
409
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
398
410
  elif self.loss_func == 'multinomial':
399
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
400
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
411
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
401
412
  elif self.loss_func == 'bernoulli':
402
413
  if self.use_zeroinflate:
403
414
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -421,8 +432,8 @@ class DensityFlow(nn.Module):
421
432
  self.options = dict(dtype=xs.dtype, device=xs.device)
422
433
 
423
434
  if self.loss_func=='negbinomial':
424
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
425
- xs.new_ones(self.input_size), constraint=constraints.positive)
435
+ dispersion = pyro.param("dispersion", self.dispersion *
436
+ xs.new_ones(self.input_size), constraint=constraints.positive)
426
437
 
427
438
  if self.use_zeroinflate:
428
439
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -461,28 +472,30 @@ class DensityFlow(nn.Module):
461
472
  else:
462
473
  zs = zns
463
474
 
464
- concentrate = self.decoder_concentrate(zs)
475
+ log_mu = self.decoder_log_mu(zs)
465
476
  if self.loss_func in ['bernoulli']:
466
- log_theta = concentrate
477
+ log_theta = log_mu
478
+ elif self.loss_func == 'negbinomial':
479
+ mu = log_mu.exp()
467
480
  else:
468
- rate = concentrate.exp()
481
+ rate = log_mu.exp()
469
482
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
470
483
  if self.loss_func == 'poisson':
471
484
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
472
485
 
473
486
  if self.loss_func == 'negbinomial':
487
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
474
488
  if self.use_zeroinflate:
475
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
489
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
476
490
  else:
477
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
491
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
478
492
  elif self.loss_func == 'poisson':
479
493
  if self.use_zeroinflate:
480
494
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
481
495
  else:
482
496
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
483
497
  elif self.loss_func == 'multinomial':
484
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
485
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
498
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
486
499
  elif self.loss_func == 'bernoulli':
487
500
  if self.use_zeroinflate:
488
501
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -506,8 +519,8 @@ class DensityFlow(nn.Module):
506
519
  self.options = dict(dtype=xs.dtype, device=xs.device)
507
520
 
508
521
  if self.loss_func=='negbinomial':
509
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
510
- xs.new_ones(self.input_size), constraint=constraints.positive)
522
+ dispersion = pyro.param("dispersion", self.dispersion *
523
+ xs.new_ones(self.input_size), constraint=constraints.positive)
511
524
 
512
525
  if self.use_zeroinflate:
513
526
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -558,28 +571,31 @@ class DensityFlow(nn.Module):
558
571
 
559
572
  zs = zns
560
573
 
561
- concentrate = self.decoder_concentrate(zs)
574
+ log_mu = self.decoder_log_mu(zs)
562
575
  if self.loss_func in ['bernoulli']:
563
- log_theta = concentrate
576
+ log_theta = log_mu
577
+ elif self.loss_func in ['negbinomial']:
578
+ mu = log_mu.exp()
564
579
  else:
565
- rate = concentrate.exp()
580
+ rate = log_mu.exp()
566
581
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
567
582
  if self.loss_func == 'poisson':
568
583
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
569
584
 
570
585
  if self.loss_func == 'negbinomial':
586
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
571
587
  if self.use_zeroinflate:
572
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
588
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
589
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
573
590
  else:
574
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
591
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
575
592
  elif self.loss_func == 'poisson':
576
593
  if self.use_zeroinflate:
577
594
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
578
595
  else:
579
596
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
580
597
  elif self.loss_func == 'multinomial':
581
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
582
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
598
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
583
599
  elif self.loss_func == 'bernoulli':
584
600
  if self.use_zeroinflate:
585
601
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -603,8 +619,8 @@ class DensityFlow(nn.Module):
603
619
  self.options = dict(dtype=xs.dtype, device=xs.device)
604
620
 
605
621
  if self.loss_func=='negbinomial':
606
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
607
- xs.new_ones(self.input_size), constraint=constraints.positive)
622
+ dispersion = pyro.param("dispersion", self.dispersion *
623
+ xs.new_ones(self.input_size), constraint=constraints.positive)
608
624
 
609
625
  if self.use_zeroinflate:
610
626
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -665,28 +681,31 @@ class DensityFlow(nn.Module):
665
681
  else:
666
682
  zs = zns
667
683
 
668
- concentrate = self.decoder_concentrate(zs)
684
+ log_mu = self.decoder_log_mu(zs)
669
685
  if self.loss_func in ['bernoulli']:
670
- log_theta = concentrate
686
+ log_theta = log_mu
687
+ elif self.loss_func in ['negbinomial']:
688
+ mu = log_mu.exp()
671
689
  else:
672
- rate = concentrate.exp()
690
+ rate = log_mu.exp()
673
691
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
674
692
  if self.loss_func == 'poisson':
675
693
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
676
694
 
677
695
  if self.loss_func == 'negbinomial':
696
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
678
697
  if self.use_zeroinflate:
679
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
698
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
699
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
680
700
  else:
681
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
701
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
682
702
  elif self.loss_func == 'poisson':
683
703
  if self.use_zeroinflate:
684
704
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
685
705
  else:
686
706
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
687
707
  elif self.loss_func == 'multinomial':
688
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
689
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
708
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
690
709
  elif self.loss_func == 'bernoulli':
691
710
  if self.use_zeroinflate:
692
711
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -710,13 +729,13 @@ class DensityFlow(nn.Module):
710
729
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
711
730
  #else:
712
731
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
713
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
732
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
714
733
  else:
715
734
  #if self.turn_off_cell_specific:
716
735
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
717
736
  #else:
718
737
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
719
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
738
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
720
739
  return zus
721
740
 
722
741
  def _get_codebook_identity(self):
@@ -858,12 +877,12 @@ class DensityFlow(nn.Module):
858
877
  us_i = us[:,pert_idx].reshape(-1,1)
859
878
 
860
879
  # factor effect of xs
861
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
880
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
862
881
 
863
882
  # perturbation effect
864
883
  ps = np.ones_like(us_i)
865
884
  if np.sum(np.abs(ps-us_i))>=1:
866
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
885
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
867
886
  zs = zs + dzs0 + dzs
868
887
  else:
869
888
  zs = zs + dzs0
@@ -877,10 +896,11 @@ class DensityFlow(nn.Module):
877
896
  library_sizes = library_sizes.reshape(-1,1)
878
897
 
879
898
  counts = self.get_counts(zs, library_sizes=library_sizes)
899
+ log_mu = self.get_log_mu(zs)
880
900
 
881
- return counts, zs
901
+ return counts, log_mu
882
902
 
883
- def _cell_response(self, zs, perturb_idx, perturb):
903
+ def _cell_shift(self, zs, perturb_idx, perturb):
884
904
  #zns,_ = self.encoder_zn(xs)
885
905
  #zns,_ = self._get_basal_embedding(xs)
886
906
  zns = zs
@@ -897,7 +917,7 @@ class DensityFlow(nn.Module):
897
917
 
898
918
  return ms
899
919
 
900
- def get_cell_response(self,
920
+ def get_cell_shift(self,
901
921
  zs,
902
922
  perturb_idx,
903
923
  perturb_us,
@@ -915,46 +935,43 @@ class DensityFlow(nn.Module):
915
935
  Z = []
916
936
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
917
937
  for Z_batch, P_batch, _ in dataloader:
918
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
938
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
919
939
  Z.append(tensor_to_numpy(zns))
920
940
  pbar.update(1)
921
941
 
922
942
  Z = np.concatenate(Z)
923
943
  return Z
924
944
 
925
- def _get_expression_response(self, delta_zs):
926
- return self.decoder_concentrate(delta_zs)
945
+ def _log_mu(self, zs):
946
+ return self.decoder_log_mu(zs)
927
947
 
928
- def get_expression_response(self,
929
- delta_zs,
930
- batch_size: int = 1024):
948
+ def get_log_mu(self, zs, batch_size: int = 1024):
931
949
  """
932
950
  Return cells' changes in the feature space induced by specific perturbation of a factor
933
951
 
934
952
  """
935
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
936
- dataset = CustomDataset(delta_zs)
953
+ zs = convert_to_tensor(zs, device=self.get_device())
954
+ dataset = CustomDataset(zs)
937
955
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
938
956
 
939
957
  R = []
940
958
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
941
- for delta_Z_batch, _ in dataloader:
942
- r = self._get_expression_response(delta_Z_batch)
959
+ for Z_batch, _ in dataloader:
960
+ r = self._log_mu(Z_batch)
943
961
  R.append(tensor_to_numpy(r))
944
962
  pbar.update(1)
945
963
 
946
964
  R = np.concatenate(R)
947
965
  return R
948
966
 
949
- def _count(self, concentrate, library_size=None):
967
+ def _count(self, log_mu, library_size=None):
950
968
  if self.loss_func == 'bernoulli':
951
- #counts = self.sigmoid(concentrate)
952
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
969
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
953
970
  elif self.loss_func == 'multinomial':
954
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
971
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
955
972
  counts = theta * library_size
956
973
  else:
957
- rate = concentrate.exp()
974
+ rate = log_mu.exp()
958
975
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
959
976
  counts = theta * library_size
960
977
  return counts
@@ -976,8 +993,8 @@ class DensityFlow(nn.Module):
976
993
  E = []
977
994
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
978
995
  for Z_batch, L_batch, _ in dataloader:
979
- concentrate = self._get_expression_response(Z_batch)
980
- counts = self._count(concentrate, L_batch)
996
+ log_mu = self._log_mu(Z_batch)
997
+ counts = self._count(log_mu, L_batch)
981
998
  E.append(tensor_to_numpy(counts))
982
999
  pbar.update(1)
983
1000
 
@@ -1350,7 +1367,7 @@ def main():
1350
1367
  df = DensityFlow(
1351
1368
  input_size=input_size,
1352
1369
  cell_factor_size=cell_factor_size,
1353
- inverse_dispersion=args.inverse_dispersion,
1370
+ dispersion=args.dispersion,
1354
1371
  z_dim=args.z_dim,
1355
1372
  hidden_layers=args.hidden_layers,
1356
1373
  hidden_layer_activation=args.hidden_layer_activation,