SURE-tools 2.4.22__py3-none-any.whl → 2.4.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py CHANGED
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 100,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
64
+ z_dim: int = 50,
65
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,7 +81,7 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
@@ -258,7 +258,7 @@ class DensityFlow(nn.Module):
258
258
  )
259
259
  )
260
260
 
261
- self.decoder_concentrate = MLP(
261
+ self.decoder_log_mu = MLP(
262
262
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
263
263
  activation=activate_fct,
264
264
  output_activation=None,
@@ -348,8 +348,8 @@ class DensityFlow(nn.Module):
348
348
  self.options = dict(dtype=xs.dtype, device=xs.device)
349
349
 
350
350
  if self.loss_func=='negbinomial':
351
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
352
- xs.new_ones(self.input_size), constraint=constraints.positive)
351
+ dispersion = pyro.param("dispersion", self.dispersion *
352
+ xs.new_ones(self.input_size), constraint=constraints.positive)
353
353
 
354
354
  if self.use_zeroinflate:
355
355
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -383,28 +383,32 @@ class DensityFlow(nn.Module):
383
383
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
384
384
 
385
385
  zs = zns
386
- concentrate = self.decoder_concentrate(zs)
386
+ log_mu = self.decoder_log_mu(zs)
387
387
  if self.loss_func in ['bernoulli']:
388
- log_theta = concentrate
388
+ log_theta = log_mu
389
+ elif self.loss_func == 'negbinomial':
390
+ mu = log_mu.exp()
389
391
  else:
390
- rate = concentrate.exp()
392
+ rate = log_mu.exp()
391
393
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
392
394
  if self.loss_func == 'poisson':
393
395
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
394
396
 
395
397
  if self.loss_func == 'negbinomial':
398
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
396
399
  if self.use_zeroinflate:
397
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
400
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
401
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
398
402
  else:
399
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
403
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
404
+ logits=logits).to_event(1), obs=xs)
400
405
  elif self.loss_func == 'poisson':
401
406
  if self.use_zeroinflate:
402
407
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
403
408
  else:
404
409
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
405
410
  elif self.loss_func == 'multinomial':
406
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
407
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
411
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
408
412
  elif self.loss_func == 'bernoulli':
409
413
  if self.use_zeroinflate:
410
414
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -428,8 +432,8 @@ class DensityFlow(nn.Module):
428
432
  self.options = dict(dtype=xs.dtype, device=xs.device)
429
433
 
430
434
  if self.loss_func=='negbinomial':
431
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
432
- xs.new_ones(self.input_size), constraint=constraints.positive)
435
+ dispersion = pyro.param("dispersion", self.dispersion *
436
+ xs.new_ones(self.input_size), constraint=constraints.positive)
433
437
 
434
438
  if self.use_zeroinflate:
435
439
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -468,28 +472,30 @@ class DensityFlow(nn.Module):
468
472
  else:
469
473
  zs = zns
470
474
 
471
- concentrate = self.decoder_concentrate(zs)
475
+ log_mu = self.decoder_log_mu(zs)
472
476
  if self.loss_func in ['bernoulli']:
473
- log_theta = concentrate
477
+ log_theta = log_mu
478
+ elif self.loss_func == 'negbinomial':
479
+ mu = log_mu.exp()
474
480
  else:
475
- rate = concentrate.exp()
481
+ rate = log_mu.exp()
476
482
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
477
483
  if self.loss_func == 'poisson':
478
484
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
479
485
 
480
486
  if self.loss_func == 'negbinomial':
487
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
481
488
  if self.use_zeroinflate:
482
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
489
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
483
490
  else:
484
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
491
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
485
492
  elif self.loss_func == 'poisson':
486
493
  if self.use_zeroinflate:
487
494
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
488
495
  else:
489
496
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
490
497
  elif self.loss_func == 'multinomial':
491
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
492
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
498
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
493
499
  elif self.loss_func == 'bernoulli':
494
500
  if self.use_zeroinflate:
495
501
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -513,8 +519,8 @@ class DensityFlow(nn.Module):
513
519
  self.options = dict(dtype=xs.dtype, device=xs.device)
514
520
 
515
521
  if self.loss_func=='negbinomial':
516
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
517
- xs.new_ones(self.input_size), constraint=constraints.positive)
522
+ dispersion = pyro.param("dispersion", self.dispersion *
523
+ xs.new_ones(self.input_size), constraint=constraints.positive)
518
524
 
519
525
  if self.use_zeroinflate:
520
526
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -565,28 +571,31 @@ class DensityFlow(nn.Module):
565
571
 
566
572
  zs = zns
567
573
 
568
- concentrate = self.decoder_concentrate(zs)
574
+ log_mu = self.decoder_log_mu(zs)
569
575
  if self.loss_func in ['bernoulli']:
570
- log_theta = concentrate
576
+ log_theta = log_mu
577
+ elif self.loss_func in ['negbinomial']:
578
+ mu = log_mu.exp()
571
579
  else:
572
- rate = concentrate.exp()
580
+ rate = log_mu.exp()
573
581
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
574
582
  if self.loss_func == 'poisson':
575
583
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
576
584
 
577
585
  if self.loss_func == 'negbinomial':
586
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
578
587
  if self.use_zeroinflate:
579
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
588
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
589
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
580
590
  else:
581
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
591
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
582
592
  elif self.loss_func == 'poisson':
583
593
  if self.use_zeroinflate:
584
594
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
585
595
  else:
586
596
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
587
597
  elif self.loss_func == 'multinomial':
588
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
589
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
598
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
590
599
  elif self.loss_func == 'bernoulli':
591
600
  if self.use_zeroinflate:
592
601
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -610,8 +619,8 @@ class DensityFlow(nn.Module):
610
619
  self.options = dict(dtype=xs.dtype, device=xs.device)
611
620
 
612
621
  if self.loss_func=='negbinomial':
613
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
614
- xs.new_ones(self.input_size), constraint=constraints.positive)
622
+ dispersion = pyro.param("dispersion", self.dispersion *
623
+ xs.new_ones(self.input_size), constraint=constraints.positive)
615
624
 
616
625
  if self.use_zeroinflate:
617
626
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -672,28 +681,31 @@ class DensityFlow(nn.Module):
672
681
  else:
673
682
  zs = zns
674
683
 
675
- concentrate = self.decoder_concentrate(zs)
684
+ log_mu = self.decoder_log_mu(zs)
676
685
  if self.loss_func in ['bernoulli']:
677
- log_theta = concentrate
686
+ log_theta = log_mu
687
+ elif self.loss_func in ['negbinomial']:
688
+ mu = log_mu.exp()
678
689
  else:
679
- rate = concentrate.exp()
690
+ rate = log_mu.exp()
680
691
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
681
692
  if self.loss_func == 'poisson':
682
693
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
683
694
 
684
695
  if self.loss_func == 'negbinomial':
696
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
685
697
  if self.use_zeroinflate:
686
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
698
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
699
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
687
700
  else:
688
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
701
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
689
702
  elif self.loss_func == 'poisson':
690
703
  if self.use_zeroinflate:
691
704
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
692
705
  else:
693
706
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
694
707
  elif self.loss_func == 'multinomial':
695
- #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
696
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
708
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
697
709
  elif self.loss_func == 'bernoulli':
698
710
  if self.use_zeroinflate:
699
711
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -717,13 +729,13 @@ class DensityFlow(nn.Module):
717
729
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
718
730
  #else:
719
731
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
720
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
732
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
721
733
  else:
722
734
  #if self.turn_off_cell_specific:
723
735
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
724
736
  #else:
725
737
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
726
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
738
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
727
739
  return zus
728
740
 
729
741
  def _get_codebook_identity(self):
@@ -865,12 +877,12 @@ class DensityFlow(nn.Module):
865
877
  us_i = us[:,pert_idx].reshape(-1,1)
866
878
 
867
879
  # factor effect of xs
868
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
880
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
869
881
 
870
882
  # perturbation effect
871
883
  ps = np.ones_like(us_i)
872
884
  if np.sum(np.abs(ps-us_i))>=1:
873
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
885
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
874
886
  zs = zs + dzs0 + dzs
875
887
  else:
876
888
  zs = zs + dzs0
@@ -884,10 +896,11 @@ class DensityFlow(nn.Module):
884
896
  library_sizes = library_sizes.reshape(-1,1)
885
897
 
886
898
  counts = self.get_counts(zs, library_sizes=library_sizes)
899
+ log_mu = self.get_log_mu(zs)
887
900
 
888
- return counts, zs
901
+ return counts, log_mu
889
902
 
890
- def _cell_response(self, zs, perturb_idx, perturb):
903
+ def _cell_shift(self, zs, perturb_idx, perturb):
891
904
  #zns,_ = self.encoder_zn(xs)
892
905
  #zns,_ = self._get_basal_embedding(xs)
893
906
  zns = zs
@@ -904,7 +917,7 @@ class DensityFlow(nn.Module):
904
917
 
905
918
  return ms
906
919
 
907
- def get_cell_response(self,
920
+ def get_cell_shift(self,
908
921
  zs,
909
922
  perturb_idx,
910
923
  perturb_us,
@@ -922,46 +935,43 @@ class DensityFlow(nn.Module):
922
935
  Z = []
923
936
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
924
937
  for Z_batch, P_batch, _ in dataloader:
925
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
938
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
926
939
  Z.append(tensor_to_numpy(zns))
927
940
  pbar.update(1)
928
941
 
929
942
  Z = np.concatenate(Z)
930
943
  return Z
931
944
 
932
- def _get_expression_response(self, delta_zs):
933
- return self.decoder_concentrate(delta_zs)
945
+ def _log_mu(self, zs):
946
+ return self.decoder_log_mu(zs)
934
947
 
935
- def get_expression_response(self,
936
- delta_zs,
937
- batch_size: int = 1024):
948
+ def get_log_mu(self, zs, batch_size: int = 1024):
938
949
  """
939
950
  Return cells' changes in the feature space induced by specific perturbation of a factor
940
951
 
941
952
  """
942
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
943
- dataset = CustomDataset(delta_zs)
953
+ zs = convert_to_tensor(zs, device=self.get_device())
954
+ dataset = CustomDataset(zs)
944
955
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
945
956
 
946
957
  R = []
947
958
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
948
- for delta_Z_batch, _ in dataloader:
949
- r = self._get_expression_response(delta_Z_batch)
959
+ for Z_batch, _ in dataloader:
960
+ r = self._log_mu(Z_batch)
950
961
  R.append(tensor_to_numpy(r))
951
962
  pbar.update(1)
952
963
 
953
964
  R = np.concatenate(R)
954
965
  return R
955
966
 
956
- def _count(self, concentrate, library_size=None):
967
+ def _count(self, log_mu, library_size=None):
957
968
  if self.loss_func == 'bernoulli':
958
- #counts = self.sigmoid(concentrate)
959
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
969
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
960
970
  elif self.loss_func == 'multinomial':
961
- theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
971
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
962
972
  counts = theta * library_size
963
973
  else:
964
- rate = concentrate.exp()
974
+ rate = log_mu.exp()
965
975
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
966
976
  counts = theta * library_size
967
977
  return counts
@@ -983,8 +993,8 @@ class DensityFlow(nn.Module):
983
993
  E = []
984
994
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
985
995
  for Z_batch, L_batch, _ in dataloader:
986
- concentrate = self._get_expression_response(Z_batch)
987
- counts = self._count(concentrate, L_batch)
996
+ log_mu = self._log_mu(Z_batch)
997
+ counts = self._count(log_mu, L_batch)
988
998
  E.append(tensor_to_numpy(counts))
989
999
  pbar.update(1)
990
1000
 
@@ -1357,7 +1367,7 @@ def main():
1357
1367
  df = DensityFlow(
1358
1368
  input_size=input_size,
1359
1369
  cell_factor_size=cell_factor_size,
1360
- inverse_dispersion=args.inverse_dispersion,
1370
+ dispersion=args.dispersion,
1361
1371
  z_dim=args.z_dim,
1362
1372
  hidden_layers=args.hidden_layers,
1363
1373
  hidden_layer_activation=args.hidden_layer_activation,