SURE-tools 2.2.24__py3-none-any.whl → 2.4.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SURE/DensityFlow.py CHANGED
@@ -57,16 +57,16 @@ def set_random_seed(seed):
57
57
  class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
- codebook_size: int = 200,
60
+ codebook_size: int = 100,
61
61
  cell_factor_size: int = 0,
62
62
  turn_off_cell_specific: bool = False,
63
63
  supervised_mode: bool = False,
64
- z_dim: int = 10,
65
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
67
- inverse_dispersion: float = 10.0,
64
+ z_dim: int = 50,
65
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'studentt',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
+ dispersion: float = 8.0,
68
68
  use_zeroinflate: bool = False,
69
- hidden_layers: list = [500],
69
+ hidden_layers: list = [1024],
70
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
71
  nn_dropout: float = 0.1,
72
72
  post_layer_fct: list = ['layernorm'],
@@ -81,10 +81,11 @@ class DensityFlow(nn.Module):
81
81
 
82
82
  self.input_size = input_size
83
83
  self.cell_factor_size = cell_factor_size
84
- self.inverse_dispersion = inverse_dispersion
84
+ self.dispersion = dispersion
85
85
  self.latent_dim = z_dim
86
86
  self.hidden_layers = hidden_layers
87
87
  self.decoder_hidden_layers = hidden_layers[::-1]
88
+ self.config_enum = config_enum
88
89
  self.allow_broadcast = config_enum == 'parallel'
89
90
  self.use_cuda = use_cuda
90
91
  self.loss_func = loss_func
@@ -107,8 +108,16 @@ class DensityFlow(nn.Module):
107
108
 
108
109
  self.codebook_weights = None
109
110
 
111
+ self.seed = seed
110
112
  set_random_seed(seed)
111
113
  self.setup_networks()
114
+
115
+ print(f"🧬 DensityFlow Initialized:")
116
+ print(f" - Latent Dimension: {self.latent_dim}")
117
+ print(f" - Gene Dimension: {self.input_size}")
118
+ print(f" - Hidden Dimensions: {self.hidden_layers}")
119
+ print(f" - Device: {self.get_device()}")
120
+ print(f" - Parameters: {sum(p.numel() for p in self.parameters()):,}")
112
121
 
113
122
  def setup_networks(self):
114
123
  latent_dim = self.latent_dim
@@ -251,7 +260,7 @@ class DensityFlow(nn.Module):
251
260
  )
252
261
  )
253
262
 
254
- self.decoder_concentrate = MLP(
263
+ self.decoder_log_mu = MLP(
255
264
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
256
265
  activation=activate_fct,
257
266
  output_activation=None,
@@ -341,8 +350,8 @@ class DensityFlow(nn.Module):
341
350
  self.options = dict(dtype=xs.dtype, device=xs.device)
342
351
 
343
352
  if self.loss_func=='negbinomial':
344
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
345
- xs.new_ones(self.input_size), constraint=constraints.positive)
353
+ dispersion = pyro.param("dispersion", self.dispersion *
354
+ xs.new_ones(self.input_size), constraint=constraints.positive)
346
355
 
347
356
  if self.use_zeroinflate:
348
357
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -376,20 +385,25 @@ class DensityFlow(nn.Module):
376
385
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
377
386
 
378
387
  zs = zns
379
- concentrate = self.decoder_concentrate(zs)
388
+ log_mu = self.decoder_log_mu(zs)
380
389
  if self.loss_func in ['bernoulli']:
381
- log_theta = concentrate
390
+ log_theta = log_mu
391
+ elif self.loss_func == 'negbinomial':
392
+ mu = log_mu.exp()
382
393
  else:
383
- rate = concentrate.exp()
394
+ rate = log_mu.exp()
384
395
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
385
396
  if self.loss_func == 'poisson':
386
397
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
387
398
 
388
399
  if self.loss_func == 'negbinomial':
400
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
389
401
  if self.use_zeroinflate:
390
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
402
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
403
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
391
404
  else:
392
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
405
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion,
406
+ logits=logits).to_event(1), obs=xs)
393
407
  elif self.loss_func == 'poisson':
394
408
  if self.use_zeroinflate:
395
409
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -420,8 +434,8 @@ class DensityFlow(nn.Module):
420
434
  self.options = dict(dtype=xs.dtype, device=xs.device)
421
435
 
422
436
  if self.loss_func=='negbinomial':
423
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
424
- xs.new_ones(self.input_size), constraint=constraints.positive)
437
+ dispersion = pyro.param("dispersion", self.dispersion *
438
+ xs.new_ones(self.input_size), constraint=constraints.positive)
425
439
 
426
440
  if self.use_zeroinflate:
427
441
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -460,20 +474,23 @@ class DensityFlow(nn.Module):
460
474
  else:
461
475
  zs = zns
462
476
 
463
- concentrate = self.decoder_concentrate(zs)
477
+ log_mu = self.decoder_log_mu(zs)
464
478
  if self.loss_func in ['bernoulli']:
465
- log_theta = concentrate
479
+ log_theta = log_mu
480
+ elif self.loss_func == 'negbinomial':
481
+ mu = log_mu.exp()
466
482
  else:
467
- rate = concentrate.exp()
483
+ rate = log_mu.exp()
468
484
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
469
485
  if self.loss_func == 'poisson':
470
486
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
471
487
 
472
488
  if self.loss_func == 'negbinomial':
489
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
473
490
  if self.use_zeroinflate:
474
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
491
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion, logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
475
492
  else:
476
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
493
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
477
494
  elif self.loss_func == 'poisson':
478
495
  if self.use_zeroinflate:
479
496
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -504,8 +521,8 @@ class DensityFlow(nn.Module):
504
521
  self.options = dict(dtype=xs.dtype, device=xs.device)
505
522
 
506
523
  if self.loss_func=='negbinomial':
507
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
508
- xs.new_ones(self.input_size), constraint=constraints.positive)
524
+ dispersion = pyro.param("dispersion", self.dispersion *
525
+ xs.new_ones(self.input_size), constraint=constraints.positive)
509
526
 
510
527
  if self.use_zeroinflate:
511
528
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -556,20 +573,24 @@ class DensityFlow(nn.Module):
556
573
 
557
574
  zs = zns
558
575
 
559
- concentrate = self.decoder_concentrate(zs)
576
+ log_mu = self.decoder_log_mu(zs)
560
577
  if self.loss_func in ['bernoulli']:
561
- log_theta = concentrate
578
+ log_theta = log_mu
579
+ elif self.loss_func in ['negbinomial']:
580
+ mu = log_mu.exp()
562
581
  else:
563
- rate = concentrate.exp()
582
+ rate = log_mu.exp()
564
583
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
565
584
  if self.loss_func == 'poisson':
566
585
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
567
586
 
568
587
  if self.loss_func == 'negbinomial':
588
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
569
589
  if self.use_zeroinflate:
570
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
590
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
591
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
571
592
  else:
572
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
593
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
573
594
  elif self.loss_func == 'poisson':
574
595
  if self.use_zeroinflate:
575
596
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -600,8 +621,8 @@ class DensityFlow(nn.Module):
600
621
  self.options = dict(dtype=xs.dtype, device=xs.device)
601
622
 
602
623
  if self.loss_func=='negbinomial':
603
- total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
604
- xs.new_ones(self.input_size), constraint=constraints.positive)
624
+ dispersion = pyro.param("dispersion", self.dispersion *
625
+ xs.new_ones(self.input_size), constraint=constraints.positive)
605
626
 
606
627
  if self.use_zeroinflate:
607
628
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -662,20 +683,24 @@ class DensityFlow(nn.Module):
662
683
  else:
663
684
  zs = zns
664
685
 
665
- concentrate = self.decoder_concentrate(zs)
686
+ log_mu = self.decoder_log_mu(zs)
666
687
  if self.loss_func in ['bernoulli']:
667
- log_theta = concentrate
688
+ log_theta = log_mu
689
+ elif self.loss_func in ['negbinomial']:
690
+ mu = log_mu.exp()
668
691
  else:
669
- rate = concentrate.exp()
692
+ rate = log_mu.exp()
670
693
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
671
694
  if self.loss_func == 'poisson':
672
695
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
673
696
 
674
697
  if self.loss_func == 'negbinomial':
698
+ logits = (mu.log()-dispersion.log()).clamp(min=-15, max=15)
675
699
  if self.use_zeroinflate:
676
- pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
700
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=dispersion,
701
+ logits=logits),gate_logits=gate_logits).to_event(1), obs=xs)
677
702
  else:
678
- pyro.sample('x', dist.NegativeBinomial(total_count=total_count, probs=theta).to_event(1), obs=xs)
703
+ pyro.sample('x', dist.NegativeBinomial(total_count=dispersion, logits=logits).to_event(1), obs=xs)
679
704
  elif self.loss_func == 'poisson':
680
705
  if self.use_zeroinflate:
681
706
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Poisson(rate=rate),gate_logits=gate_logits).to_event(1), obs=xs.round())
@@ -706,13 +731,13 @@ class DensityFlow(nn.Module):
706
731
  # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
707
732
  #else:
708
733
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
709
- zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
734
+ zus = self._cell_shift(zns, i, us[:,i].reshape(-1,1))
710
735
  else:
711
736
  #if self.turn_off_cell_specific:
712
737
  # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
713
738
  #else:
714
739
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
715
- zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
740
+ zus = zus + self._cell_shift(zns, i, us[:,i].reshape(-1,1))
716
741
  return zus
717
742
 
718
743
  def _get_codebook_identity(self):
@@ -854,12 +879,12 @@ class DensityFlow(nn.Module):
854
879
  us_i = us[:,pert_idx].reshape(-1,1)
855
880
 
856
881
  # factor effect of xs
857
- dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
882
+ dzs0 = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=us_i)
858
883
 
859
884
  # perturbation effect
860
885
  ps = np.ones_like(us_i)
861
886
  if np.sum(np.abs(ps-us_i))>=1:
862
- dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
887
+ dzs = self.get_cell_shift(zs, perturb_idx=pert_idx, perturb_us=ps)
863
888
  zs = zs + dzs0 + dzs
864
889
  else:
865
890
  zs = zs + dzs0
@@ -873,10 +898,11 @@ class DensityFlow(nn.Module):
873
898
  library_sizes = library_sizes.reshape(-1,1)
874
899
 
875
900
  counts = self.get_counts(zs, library_sizes=library_sizes)
901
+ log_mu = self.get_log_mu(zs)
876
902
 
877
- return counts, zs
903
+ return counts, log_mu
878
904
 
879
- def _cell_response(self, zs, perturb_idx, perturb):
905
+ def _cell_shift(self, zs, perturb_idx, perturb):
880
906
  #zns,_ = self.encoder_zn(xs)
881
907
  #zns,_ = self._get_basal_embedding(xs)
882
908
  zns = zs
@@ -893,7 +919,7 @@ class DensityFlow(nn.Module):
893
919
 
894
920
  return ms
895
921
 
896
- def get_cell_response(self,
922
+ def get_cell_shift(self,
897
923
  zs,
898
924
  perturb_idx,
899
925
  perturb_us,
@@ -911,43 +937,43 @@ class DensityFlow(nn.Module):
911
937
  Z = []
912
938
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
913
939
  for Z_batch, P_batch, _ in dataloader:
914
- zns = self._cell_response(Z_batch, perturb_idx, P_batch)
940
+ zns = self._cell_shift(Z_batch, perturb_idx, P_batch)
915
941
  Z.append(tensor_to_numpy(zns))
916
942
  pbar.update(1)
917
943
 
918
944
  Z = np.concatenate(Z)
919
945
  return Z
920
946
 
921
- def _get_expression_response(self, delta_zs):
922
- return self.decoder_concentrate(delta_zs)
947
+ def _log_mu(self, zs):
948
+ return self.decoder_log_mu(zs)
923
949
 
924
- def get_expression_response(self,
925
- delta_zs,
926
- batch_size: int = 1024):
950
+ def get_log_mu(self, zs, batch_size: int = 1024):
927
951
  """
928
952
  Return cells' changes in the feature space induced by specific perturbation of a factor
929
953
 
930
954
  """
931
- delta_zs = convert_to_tensor(delta_zs, device=self.get_device())
932
- dataset = CustomDataset(delta_zs)
955
+ zs = convert_to_tensor(zs, device=self.get_device())
956
+ dataset = CustomDataset(zs)
933
957
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
934
958
 
935
959
  R = []
936
960
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
937
- for delta_Z_batch, _ in dataloader:
938
- r = self._get_expression_response(delta_Z_batch)
961
+ for Z_batch, _ in dataloader:
962
+ r = self._log_mu(Z_batch)
939
963
  R.append(tensor_to_numpy(r))
940
964
  pbar.update(1)
941
965
 
942
966
  R = np.concatenate(R)
943
967
  return R
944
968
 
945
- def _count(self, concentrate, library_size=None):
969
+ def _count(self, log_mu, library_size=None):
946
970
  if self.loss_func == 'bernoulli':
947
- #counts = self.sigmoid(concentrate)
948
- counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
971
+ counts = dist.Bernoulli(logits=log_mu).to_event(1).mean
972
+ elif self.loss_func == 'multinomial':
973
+ theta = dist.Multinomial(total_count=int(1e8), logits=log_mu).mean
974
+ counts = theta * library_size
949
975
  else:
950
- rate = concentrate.exp()
976
+ rate = log_mu.exp()
951
977
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
952
978
  counts = theta * library_size
953
979
  return counts
@@ -969,8 +995,8 @@ class DensityFlow(nn.Module):
969
995
  E = []
970
996
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
971
997
  for Z_batch, L_batch, _ in dataloader:
972
- concentrate = self._get_expression_response(Z_batch)
973
- counts = self._count(concentrate, L_batch)
998
+ log_mu = self._log_mu(Z_batch)
999
+ counts = self._count(log_mu, L_batch)
974
1000
  E.append(tensor_to_numpy(counts))
975
1001
  pbar.update(1)
976
1002
 
@@ -1116,7 +1142,7 @@ class DensityFlow(nn.Module):
1116
1142
  pbar.set_postfix({'loss': str_loss})
1117
1143
  pbar.update(1)
1118
1144
 
1119
- @classmethod
1145
+ '''@classmethod
1120
1146
  def save_model(cls, model, file_path, compression=False):
1121
1147
  """Save the model to the specified file path."""
1122
1148
  file_path = os.path.abspath(file_path)
@@ -1144,6 +1170,45 @@ class DensityFlow(nn.Module):
1144
1170
  with open(file_path, 'rb') as pickle_file:
1145
1171
  model = pickle.load(pickle_file)
1146
1172
 
1173
+ return model'''
1174
+
1175
+ def save_model(self, path):
1176
+ """Save model checkpoint"""
1177
+ torch.save({
1178
+ 'model_state_dict': self.state_dict(),
1179
+ 'model_config': {
1180
+ 'input_size': self.input_size,
1181
+ 'codebook_size': self.code_size,
1182
+ 'cell_factor_size': self.cell_factor_size,
1183
+ 'turn_off_cell_specific':self.turn_off_cell_specific,
1184
+ 'supervised_mode':self.supervised_mode,
1185
+ 'z_dim': self.latent_dim,
1186
+ 'z_dist': self.latent_dist,
1187
+ 'loss_func': self.loss_func,
1188
+ 'dispersion': self.dispersion,
1189
+ 'use_zeroinflate': self.use_zeroinflate,
1190
+ 'hidden_layers':self.hidden_layers,
1191
+ 'hidden_layer_activation':self.hidden_layer_activation,
1192
+ 'nn_dropout':self.nn_dropout,
1193
+ 'post_layer_fct':self.post_layer_fct,
1194
+ 'post_act_fct':self.post_act_fct,
1195
+ 'config_enum':self.config_enum,
1196
+ 'use_cuda':self.use_cuda,
1197
+ 'seed':self.seed,
1198
+ 'zero_bias':self.use_bias,
1199
+ 'dtype':self.dtype,
1200
+ }
1201
+ }, path)
1202
+
1203
+ @classmethod
1204
+ def load_model(cls, model_path: str):
1205
+ """Load pre-trained model"""
1206
+ checkpoint = torch.load(model_path)
1207
+ model = DensityFlow(**checkpoint.get('model_config'))
1208
+
1209
+ checkpoint = torch.load(model_path, map_location=model.get_device())
1210
+ model.load_state_dict(checkpoint['model_state_dict'])
1211
+
1147
1212
  return model
1148
1213
 
1149
1214
 
@@ -1340,10 +1405,10 @@ def main():
1340
1405
  cell_factor_size = 0 if us is None else us.shape[1]
1341
1406
 
1342
1407
  ###########################################
1343
- DensityFlow = DensityFlow(
1408
+ df = DensityFlow(
1344
1409
  input_size=input_size,
1345
1410
  cell_factor_size=cell_factor_size,
1346
- inverse_dispersion=args.inverse_dispersion,
1411
+ dispersion=args.dispersion,
1347
1412
  z_dim=args.z_dim,
1348
1413
  hidden_layers=args.hidden_layers,
1349
1414
  hidden_layer_activation=args.hidden_layer_activation,
@@ -1359,7 +1424,7 @@ def main():
1359
1424
  dtype=dtype,
1360
1425
  )
1361
1426
 
1362
- DensityFlow.fit(xs, us=us,
1427
+ df.fit(xs, us=us,
1363
1428
  num_epochs=args.num_epochs,
1364
1429
  learning_rate=args.learning_rate,
1365
1430
  batch_size=args.batch_size,
@@ -1371,9 +1436,9 @@ def main():
1371
1436
 
1372
1437
  if args.save_model is not None:
1373
1438
  if args.save_model.endswith('gz'):
1374
- DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1439
+ DensityFlow.save_model(df, args.save_model, compression=True)
1375
1440
  else:
1376
- DensityFlow.save_model(DensityFlow, args.save_model)
1441
+ DensityFlow.save_model(df, args.save_model)
1377
1442
 
1378
1443
 
1379
1444