SURE-tools 2.1.87__py3-none-any.whl → 2.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -10,7 +10,7 @@ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_pr
10
10
  from torch.distributions import constraints
11
11
  from torch.distributions.transforms import SoftmaxTransform
12
12
 
13
- from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP
13
+ from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP2
14
14
  from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
15
 
16
16
 
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class PerturbE(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
@@ -62,10 +62,10 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
- hidden_layers: list = [300],
68
+ hidden_layers: list = [500],
69
69
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
70
  nn_dropout: float = 0.1,
71
71
  post_layer_fct: list = ['layernorm'],
@@ -73,8 +73,6 @@ class PerturbFlow(nn.Module):
73
73
  config_enum: str = 'parallel',
74
74
  use_cuda: bool = True,
75
75
  seed: int = 42,
76
- zero_bias: bool|list = True,
77
- enumrate: bool = False,
78
76
  dtype = torch.float32, # type: ignore
79
77
  ):
80
78
  super().__init__()
@@ -98,12 +96,6 @@ class PerturbFlow(nn.Module):
98
96
  self.post_layer_fct = post_layer_fct
99
97
  self.post_act_fct = post_act_fct
100
98
  self.hidden_layer_activation = hidden_layer_activation
101
- if type(zero_bias) == list:
102
- self.use_bias = [not x for x in zero_bias]
103
- else:
104
- self.use_bias = [not zero_bias] * self.cell_factor_size
105
- #self.use_bias = not zero_bias
106
- self.enumrate = enumrate
107
99
 
108
100
  self.codebook_weights = None
109
101
 
@@ -202,29 +194,14 @@ class PerturbFlow(nn.Module):
202
194
  )
203
195
 
204
196
  if self.cell_factor_size>0:
205
- self.cell_factor_effect = nn.ModuleList()
206
- for i in np.arange(self.cell_factor_size):
207
- if self.use_bias[i]:
208
- self.cell_factor_effect.append(MLP(
209
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
197
+ self.cell_factor_effect = ZeroBiasMLP2(
198
+ [self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
210
199
  activation=activate_fct,
211
200
  output_activation=None,
212
201
  post_layer_fct=post_layer_fct,
213
202
  post_act_fct=post_act_fct,
214
203
  allow_broadcast=self.allow_broadcast,
215
204
  use_cuda=self.use_cuda,
216
- )
217
- )
218
- else:
219
- self.cell_factor_effect.append(ZeroBiasMLP(
220
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
221
- activation=activate_fct,
222
- output_activation=None,
223
- post_layer_fct=post_layer_fct,
224
- post_act_fct=post_act_fct,
225
- allow_broadcast=self.allow_broadcast,
226
- use_cuda=self.use_cuda,
227
- )
228
205
  )
229
206
 
230
207
  self.decoder_concentrate = MLP(
@@ -310,7 +287,7 @@ class PerturbFlow(nn.Module):
310
287
  return xs
311
288
 
312
289
  def model1(self, xs):
313
- pyro.module('PerturbFlow', self)
290
+ pyro.module('PerturbE', self)
314
291
 
315
292
  eps = torch.finfo(xs.dtype).eps
316
293
  batch_size = xs.size(0)
@@ -372,7 +349,8 @@ class PerturbFlow(nn.Module):
372
349
  else:
373
350
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
374
351
  elif self.loss_func == 'multinomial':
375
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
352
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
353
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
376
354
  elif self.loss_func == 'bernoulli':
377
355
  if self.use_zeroinflate:
378
356
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -389,7 +367,7 @@ class PerturbFlow(nn.Module):
389
367
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
390
368
 
391
369
  def model2(self, xs, us=None):
392
- pyro.module('PerturbFlow', self)
370
+ pyro.module('PerturbE', self)
393
371
 
394
372
  eps = torch.finfo(xs.dtype).eps
395
373
  batch_size = xs.size(0)
@@ -431,12 +409,7 @@ class PerturbFlow(nn.Module):
431
409
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
432
410
 
433
411
  if self.cell_factor_size>0:
434
- if self.enumrate:
435
- idx = torch.argmax(ns, dim=1)
436
- zn_loc = acs_loc[idx]
437
- zus = self._total_effects(zn_loc, us)
438
- else:
439
- zus = self._total_effects(zns, us)
412
+ zus = self._perturb_effects(us)
440
413
  zs = zns+zus
441
414
  else:
442
415
  zs = zns
@@ -461,7 +434,8 @@ class PerturbFlow(nn.Module):
461
434
  else:
462
435
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
463
436
  elif self.loss_func == 'multinomial':
464
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
437
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
438
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
465
439
  elif self.loss_func == 'bernoulli':
466
440
  if self.use_zeroinflate:
467
441
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -478,7 +452,7 @@ class PerturbFlow(nn.Module):
478
452
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
479
453
 
480
454
  def model3(self, xs, ys, embeds=None):
481
- pyro.module('PerturbFlow', self)
455
+ pyro.module('PerturbE', self)
482
456
 
483
457
  eps = torch.finfo(xs.dtype).eps
484
458
  batch_size = xs.size(0)
@@ -557,7 +531,8 @@ class PerturbFlow(nn.Module):
557
531
  else:
558
532
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
559
533
  elif self.loss_func == 'multinomial':
560
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
534
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
535
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
561
536
  elif self.loss_func == 'bernoulli':
562
537
  if self.use_zeroinflate:
563
538
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -574,7 +549,7 @@ class PerturbFlow(nn.Module):
574
549
  zns = embeds
575
550
 
576
551
  def model4(self, xs, us, ys, embeds=None):
577
- pyro.module('PerturbFlow', self)
552
+ pyro.module('PerturbE', self)
578
553
 
579
554
  eps = torch.finfo(xs.dtype).eps
580
555
  batch_size = xs.size(0)
@@ -638,12 +613,7 @@ class PerturbFlow(nn.Module):
638
613
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
639
614
  # else:
640
615
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
641
- if self.enumrate:
642
- idx = torch.argmax(ns, dim=1)
643
- zn_loc = acs_loc[idx]
644
- zus = self._total_effects(zn_loc, us)
645
- else:
646
- zus = self._total_effects(zns, us)
616
+ zus = self._perturb_effects(us)
647
617
  zs = zns+zus
648
618
  else:
649
619
  zs = zns
@@ -668,7 +638,8 @@ class PerturbFlow(nn.Module):
668
638
  else:
669
639
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
670
640
  elif self.loss_func == 'multinomial':
671
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
641
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
642
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
672
643
  elif self.loss_func == 'bernoulli':
673
644
  if self.use_zeroinflate:
674
645
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -684,13 +655,8 @@ class PerturbFlow(nn.Module):
684
655
  else:
685
656
  zns = embeds
686
657
 
687
- def _total_effects(self, zns, us):
688
- zus = None
689
- for i in np.arange(self.cell_factor_size):
690
- if i==0:
691
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
692
- else:
693
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
658
+ def _perturb_effects(self, us):
659
+ zus = self._cell_response(us)
694
660
  return zus
695
661
 
696
662
  def _get_codebook_identity(self):
@@ -708,7 +674,7 @@ class PerturbFlow(nn.Module):
708
674
  """
709
675
  Return the mean part of metacell codebook
710
676
  """
711
- cb = self._get_metacell_coordinates()
677
+ cb = self._get_codebook()
712
678
  cb = tensor_to_numpy(cb)
713
679
  return cb
714
680
 
@@ -822,23 +788,13 @@ class PerturbFlow(nn.Module):
822
788
  A = np.concatenate(A)
823
789
  return A
824
790
 
825
- def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
791
+ def predict(self, xs, perturbs_us, library_sizes=None):
826
792
  perturbs_reference = np.array(perturbs_reference)
827
793
 
828
794
  # basal embedding
829
795
  zs = self.get_basal_embedding(xs)
830
- for pert in perturbs_predict:
831
- pert_idx = int(np.where(perturbs_reference==pert)[0])
832
- us_i = us[:,pert_idx].reshape(-1,1)
833
-
834
- # factor effect of xs
835
- dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
836
-
837
- # perturbation effect
838
- ps = np.ones_like(us_i)
839
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
840
-
841
- zs = zs + dzs0 + dzs
796
+ dzs = self.get_cell_response(perturbs_us)
797
+ zs = zs + dzs
842
798
 
843
799
  if library_sizes is None:
844
800
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -852,47 +808,32 @@ class PerturbFlow(nn.Module):
852
808
 
853
809
  return counts, zs
854
810
 
855
- def _cell_response(self, xs, factor_idx, perturb):
856
- #zns,_ = self.encoder_zn(xs)
857
- zns,_ = self._get_basal_embedding(xs)
858
- if perturb.ndim==2:
859
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
860
- else:
861
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
862
-
811
+ def _cell_response(self, perturb):
812
+ ms = self.cell_factor_effect(perturb)
863
813
  return ms
864
814
 
865
815
  def get_cell_response(self,
866
- xs,
867
- factor_idx,
868
- perturb,
816
+ perturb_us,
869
817
  batch_size: int = 1024):
870
818
  """
871
819
  Return cells' changes in the latent space induced by specific perturbation of a factor
872
820
 
873
821
  """
874
- xs = self.preprocess(xs)
875
- xs = convert_to_tensor(xs, device=self.get_device())
876
- ps = convert_to_tensor(perturb, device=self.get_device())
877
- dataset = CustomDataset2(xs,ps)
822
+ #xs = self.preprocess(xs)
823
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
824
+ dataset = CustomDataset(ps)
878
825
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
879
826
 
880
827
  Z = []
881
828
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
882
- for X_batch, P_batch, _ in dataloader:
883
- zns = self._cell_response(X_batch, factor_idx, P_batch)
829
+ for P_batch, _ in dataloader:
830
+ zns = self._cell_response(P_batch)
884
831
  Z.append(tensor_to_numpy(zns))
885
832
  pbar.update(1)
886
833
 
887
834
  Z = np.concatenate(Z)
888
835
  return Z
889
836
 
890
- def get_metacell_response(self, factor_idx, perturb):
891
- zs = self._get_codebook()
892
- ps = convert_to_tensor(perturb, device=self.get_device())
893
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
894
- return tensor_to_numpy(ms)
895
-
896
837
  def _get_expression_response(self, delta_zs):
897
838
  return self.decoder_concentrate(delta_zs)
898
839
 
@@ -917,36 +858,28 @@ class PerturbFlow(nn.Module):
917
858
  R = np.concatenate(R)
918
859
  return R
919
860
 
920
- def _count(self,concentrate, library_size=None):
861
+ def _count(self, concentrate, library_size=None):
921
862
  if self.loss_func == 'bernoulli':
922
863
  #counts = self.sigmoid(concentrate)
923
864
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
865
+ elif self.loss_func == 'multinomial':
866
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
867
+ counts = theta * library_size
924
868
  else:
925
869
  rate = concentrate.exp()
926
870
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
927
871
  counts = theta * library_size
928
- #counts = dist.Poisson(rate=rate).to_event(1).mean
929
- return counts
930
-
931
- def _count_sample(self,concentrate):
932
- if self.loss_func == 'bernoulli':
933
- logits = concentrate
934
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
935
- else:
936
- counts = self._count(concentrate=concentrate)
937
- counts = dist.Poisson(rate=counts).to_event(1).sample()
938
872
  return counts
939
873
 
940
874
  def get_counts(self, zs, library_sizes,
941
- batch_size: int = 1024,
942
- use_sampler: bool = False):
875
+ batch_size: int = 1024):
943
876
 
944
877
  zs = convert_to_tensor(zs, device=self.get_device())
945
878
 
946
879
  if type(library_sizes) == list:
947
- library_sizes = np.array(library_sizes).view(-1,1)
880
+ library_sizes = np.array(library_sizes).reshape(-1,1)
948
881
  elif len(library_sizes.shape)==1:
949
- library_sizes = library_sizes.view(-1,1)
882
+ library_sizes = library_sizes.reshape(-1,1)
950
883
  ls = convert_to_tensor(library_sizes, device=self.get_device())
951
884
 
952
885
  dataset = CustomDataset2(zs,ls)
@@ -956,10 +889,7 @@ class PerturbFlow(nn.Module):
956
889
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
957
890
  for Z_batch, L_batch, _ in dataloader:
958
891
  concentrate = self._get_expression_response(Z_batch)
959
- if use_sampler:
960
- counts = self._count_sample(concentrate)
961
- else:
962
- counts = self._count(concentrate, L_batch)
892
+ counts = self._count(concentrate, L_batch)
963
893
  E.append(tensor_to_numpy(counts))
964
894
  pbar.update(1)
965
895
 
@@ -982,7 +912,7 @@ class PerturbFlow(nn.Module):
982
912
  us = None,
983
913
  ys = None,
984
914
  zs = None,
985
- num_epochs: int = 200,
915
+ num_epochs: int = 500,
986
916
  learning_rate: float = 0.0001,
987
917
  batch_size: int = 256,
988
918
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -993,7 +923,7 @@ class PerturbFlow(nn.Module):
993
923
  threshold: int = 0,
994
924
  use_jax: bool = True):
995
925
  """
996
- Train the PerturbFlow model.
926
+ Train the PerturbE model.
997
927
 
998
928
  Parameters
999
929
  ----------
@@ -1019,7 +949,7 @@ class PerturbFlow(nn.Module):
1019
949
  Parameter for optimization.
1020
950
  use_jax
1021
951
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1022
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
952
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbE in the shell command.
1023
953
  """
1024
954
  xs = self.preprocess(xs, threshold=threshold)
1025
955
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1137,12 +1067,12 @@ class PerturbFlow(nn.Module):
1137
1067
 
1138
1068
 
1139
1069
  EXAMPLE_RUN = (
1140
- "example run: PerturbFlow --help"
1070
+ "example run: PerturbE --help"
1141
1071
  )
1142
1072
 
1143
1073
  def parse_args():
1144
1074
  parser = argparse.ArgumentParser(
1145
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1075
+ description="PerturbE\n{}".format(EXAMPLE_RUN))
1146
1076
 
1147
1077
  parser.add_argument(
1148
1078
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1329,7 +1259,7 @@ def main():
1329
1259
  cell_factor_size = 0 if us is None else us.shape[1]
1330
1260
 
1331
1261
  ###########################################
1332
- perturbflow = PerturbFlow(
1262
+ perturbe = PerturbE(
1333
1263
  input_size=input_size,
1334
1264
  cell_factor_size=cell_factor_size,
1335
1265
  inverse_dispersion=args.inverse_dispersion,
@@ -1348,7 +1278,7 @@ def main():
1348
1278
  dtype=dtype,
1349
1279
  )
1350
1280
 
1351
- perturbflow.fit(xs, us=us,
1281
+ perturbe.fit(xs, us=us,
1352
1282
  num_epochs=args.num_epochs,
1353
1283
  learning_rate=args.learning_rate,
1354
1284
  batch_size=args.batch_size,
@@ -1360,12 +1290,11 @@ def main():
1360
1290
 
1361
1291
  if args.save_model is not None:
1362
1292
  if args.save_model.endswith('gz'):
1363
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1293
+ PerturbE.save_model(perturbe, args.save_model, compression=True)
1364
1294
  else:
1365
- PerturbFlow.save_model(perturbflow, args.save_model)
1295
+ PerturbE.save_model(perturbe, args.save_model)
1366
1296
 
1367
1297
 
1368
1298
 
1369
1299
  if __name__ == "__main__":
1370
-
1371
1300
  main()
SURE/SURE.py CHANGED
@@ -99,17 +99,17 @@ class SURE(nn.Module):
99
99
  cell_factor_size: int = 0,
100
100
  supervised_mode: bool = False,
101
101
  z_dim: int = 10,
102
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
102
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
103
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
104
104
  inverse_dispersion: float = 10.0,
105
105
  use_zeroinflate: bool = True,
106
- hidden_layers: list = [300],
106
+ hidden_layers: list = [500],
107
107
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
108
108
  nn_dropout: float = 0.1,
109
109
  post_layer_fct: list = ['layernorm'],
110
110
  post_act_fct: list = None,
111
111
  config_enum: str = 'parallel',
112
- use_cuda: bool = False,
112
+ use_cuda: bool = True,
113
113
  seed: int = 42,
114
114
  dtype = torch.float32, # type: ignore
115
115
  ):
@@ -817,7 +817,7 @@ class SURE(nn.Module):
817
817
  us = None,
818
818
  ys = None,
819
819
  zs = None,
820
- num_epochs: int = 200,
820
+ num_epochs: int = 500,
821
821
  learning_rate: float = 0.0001,
822
822
  batch_size: int = 256,
823
823
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -826,7 +826,7 @@ class SURE(nn.Module):
826
826
  decay_rate: float = 0.9,
827
827
  config_enum: str = 'parallel',
828
828
  threshold: int = 0,
829
- use_jax: bool = False):
829
+ use_jax: bool = True):
830
830
  """
831
831
  Train the SURE model.
832
832