SURE-tools 2.2.2__py3-none-any.whl → 2.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

@@ -10,7 +10,7 @@ from torch.distributions.utils import logits_to_probs, probs_to_logits, clamp_pr
10
10
  from torch.distributions import constraints
11
11
  from torch.distributions.transforms import SoftmaxTransform
12
12
 
13
- from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP
13
+ from .utils.custom_mlp import MLP, Exp, ZeroBiasMLP2
14
14
  from .utils.utils import CustomDataset, CustomDataset2, CustomDataset4, tensor_to_numpy, convert_to_tensor
15
15
 
16
16
 
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class PerturbE(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
@@ -62,10 +62,10 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
66
  inverse_dispersion: float = 10.0,
67
67
  use_zeroinflate: bool = False,
68
- hidden_layers: list = [300],
68
+ hidden_layers: list = [500],
69
69
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
70
  nn_dropout: float = 0.1,
71
71
  post_layer_fct: list = ['layernorm'],
@@ -73,7 +73,6 @@ class PerturbFlow(nn.Module):
73
73
  config_enum: str = 'parallel',
74
74
  use_cuda: bool = True,
75
75
  seed: int = 42,
76
- zero_bias: bool|list = True,
77
76
  dtype = torch.float32, # type: ignore
78
77
  ):
79
78
  super().__init__()
@@ -97,11 +96,6 @@ class PerturbFlow(nn.Module):
97
96
  self.post_layer_fct = post_layer_fct
98
97
  self.post_act_fct = post_act_fct
99
98
  self.hidden_layer_activation = hidden_layer_activation
100
- if type(zero_bias) == list:
101
- self.use_bias = [not x for x in zero_bias]
102
- else:
103
- self.use_bias = [not zero_bias] * self.cell_factor_size
104
- #self.use_bias = not zero_bias
105
99
 
106
100
  self.codebook_weights = None
107
101
 
@@ -200,29 +194,14 @@ class PerturbFlow(nn.Module):
200
194
  )
201
195
 
202
196
  if self.cell_factor_size>0:
203
- self.cell_factor_effect = nn.ModuleList()
204
- for i in np.arange(self.cell_factor_size):
205
- if self.use_bias[i]:
206
- self.cell_factor_effect.append(MLP(
207
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
208
- activation=activate_fct,
209
- output_activation=None,
210
- post_layer_fct=post_layer_fct,
211
- post_act_fct=post_act_fct,
212
- allow_broadcast=self.allow_broadcast,
213
- use_cuda=self.use_cuda,
214
- )
215
- )
216
- else:
217
- self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
197
+ self.cell_factor_effect = ZeroBiasMLP2(
198
+ [self.cell_factor_size] + self.decoder_hidden_layers + [self.latent_dim],
219
199
  activation=activate_fct,
220
200
  output_activation=None,
221
201
  post_layer_fct=post_layer_fct,
222
202
  post_act_fct=post_act_fct,
223
203
  allow_broadcast=self.allow_broadcast,
224
204
  use_cuda=self.use_cuda,
225
- )
226
205
  )
227
206
 
228
207
  self.decoder_concentrate = MLP(
@@ -308,7 +287,7 @@ class PerturbFlow(nn.Module):
308
287
  return xs
309
288
 
310
289
  def model1(self, xs):
311
- pyro.module('PerturbFlow', self)
290
+ pyro.module('PerturbE', self)
312
291
 
313
292
  eps = torch.finfo(xs.dtype).eps
314
293
  batch_size = xs.size(0)
@@ -370,7 +349,8 @@ class PerturbFlow(nn.Module):
370
349
  else:
371
350
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
372
351
  elif self.loss_func == 'multinomial':
373
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
352
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
353
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
374
354
  elif self.loss_func == 'bernoulli':
375
355
  if self.use_zeroinflate:
376
356
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -387,7 +367,7 @@ class PerturbFlow(nn.Module):
387
367
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
388
368
 
389
369
  def model2(self, xs, us=None):
390
- pyro.module('PerturbFlow', self)
370
+ pyro.module('PerturbE', self)
391
371
 
392
372
  eps = torch.finfo(xs.dtype).eps
393
373
  batch_size = xs.size(0)
@@ -429,7 +409,7 @@ class PerturbFlow(nn.Module):
429
409
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
430
410
 
431
411
  if self.cell_factor_size>0:
432
- zus = self._total_effects(zns, us)
412
+ zus = self._perturb_effects(us)
433
413
  zs = zns+zus
434
414
  else:
435
415
  zs = zns
@@ -454,7 +434,8 @@ class PerturbFlow(nn.Module):
454
434
  else:
455
435
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
456
436
  elif self.loss_func == 'multinomial':
457
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
437
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
438
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
458
439
  elif self.loss_func == 'bernoulli':
459
440
  if self.use_zeroinflate:
460
441
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -471,7 +452,7 @@ class PerturbFlow(nn.Module):
471
452
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
472
453
 
473
454
  def model3(self, xs, ys, embeds=None):
474
- pyro.module('PerturbFlow', self)
455
+ pyro.module('PerturbE', self)
475
456
 
476
457
  eps = torch.finfo(xs.dtype).eps
477
458
  batch_size = xs.size(0)
@@ -550,7 +531,8 @@ class PerturbFlow(nn.Module):
550
531
  else:
551
532
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
552
533
  elif self.loss_func == 'multinomial':
553
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
534
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
535
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
554
536
  elif self.loss_func == 'bernoulli':
555
537
  if self.use_zeroinflate:
556
538
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -567,7 +549,7 @@ class PerturbFlow(nn.Module):
567
549
  zns = embeds
568
550
 
569
551
  def model4(self, xs, us, ys, embeds=None):
570
- pyro.module('PerturbFlow', self)
552
+ pyro.module('PerturbE', self)
571
553
 
572
554
  eps = torch.finfo(xs.dtype).eps
573
555
  batch_size = xs.size(0)
@@ -631,7 +613,7 @@ class PerturbFlow(nn.Module):
631
613
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
632
614
  # else:
633
615
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
634
- zus = self._total_effects(zns, us)
616
+ zus = self._perturb_effects(us)
635
617
  zs = zns+zus
636
618
  else:
637
619
  zs = zns
@@ -656,7 +638,8 @@ class PerturbFlow(nn.Module):
656
638
  else:
657
639
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
658
640
  elif self.loss_func == 'multinomial':
659
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
641
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
642
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
660
643
  elif self.loss_func == 'bernoulli':
661
644
  if self.use_zeroinflate:
662
645
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -672,13 +655,8 @@ class PerturbFlow(nn.Module):
672
655
  else:
673
656
  zns = embeds
674
657
 
675
- def _total_effects(self, zns, us):
676
- zus = None
677
- for i in np.arange(self.cell_factor_size):
678
- if i==0:
679
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
680
- else:
681
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
658
+ def _perturb_effects(self, us):
659
+ zus = self._cell_response(us)
682
660
  return zus
683
661
 
684
662
  def _get_codebook_identity(self):
@@ -696,7 +674,7 @@ class PerturbFlow(nn.Module):
696
674
  """
697
675
  Return the mean part of metacell codebook
698
676
  """
699
- cb = self._get_metacell_coordinates()
677
+ cb = self._get_codebook()
700
678
  cb = tensor_to_numpy(cb)
701
679
  return cb
702
680
 
@@ -810,23 +788,13 @@ class PerturbFlow(nn.Module):
810
788
  A = np.concatenate(A)
811
789
  return A
812
790
 
813
- def predict(self, xs, us, perturbs_predict:list, perturbs_reference:list, library_sizes=None):
791
+ def predict(self, xs, perturbs_us, library_sizes=None):
814
792
  perturbs_reference = np.array(perturbs_reference)
815
793
 
816
794
  # basal embedding
817
795
  zs = self.get_basal_embedding(xs)
818
- for pert in perturbs_predict:
819
- pert_idx = int(np.where(perturbs_reference==pert)[0])
820
- us_i = us[:,pert_idx].reshape(-1,1)
821
-
822
- # factor effect of xs
823
- dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
824
-
825
- # perturbation effect
826
- ps = np.ones_like(us_i)
827
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
828
-
829
- zs = zs + dzs0 + dzs
796
+ dzs = self.get_cell_response(perturbs_us)
797
+ zs = zs + dzs
830
798
 
831
799
  if library_sizes is None:
832
800
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -840,47 +808,32 @@ class PerturbFlow(nn.Module):
840
808
 
841
809
  return counts, zs
842
810
 
843
- def _cell_response(self, xs, factor_idx, perturb):
844
- #zns,_ = self.encoder_zn(xs)
845
- zns,_ = self._get_basal_embedding(xs)
846
- if perturb.ndim==2:
847
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
848
- else:
849
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
850
-
811
+ def _cell_response(self, perturb):
812
+ ms = self.cell_factor_effect(perturb)
851
813
  return ms
852
814
 
853
815
  def get_cell_response(self,
854
- xs,
855
- factor_idx,
856
- perturb,
816
+ perturb_us,
857
817
  batch_size: int = 1024):
858
818
  """
859
819
  Return cells' changes in the latent space induced by specific perturbation of a factor
860
820
 
861
821
  """
862
- xs = self.preprocess(xs)
863
- xs = convert_to_tensor(xs, device=self.get_device())
864
- ps = convert_to_tensor(perturb, device=self.get_device())
865
- dataset = CustomDataset2(xs,ps)
822
+ #xs = self.preprocess(xs)
823
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
824
+ dataset = CustomDataset(ps)
866
825
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
867
826
 
868
827
  Z = []
869
828
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
870
- for X_batch, P_batch, _ in dataloader:
871
- zns = self._cell_response(X_batch, factor_idx, P_batch)
829
+ for P_batch, _ in dataloader:
830
+ zns = self._cell_response(P_batch)
872
831
  Z.append(tensor_to_numpy(zns))
873
832
  pbar.update(1)
874
833
 
875
834
  Z = np.concatenate(Z)
876
835
  return Z
877
836
 
878
- def get_metacell_response(self, factor_idx, perturb):
879
- zs = self._get_codebook()
880
- ps = convert_to_tensor(perturb, device=self.get_device())
881
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
882
- return tensor_to_numpy(ms)
883
-
884
837
  def _get_expression_response(self, delta_zs):
885
838
  return self.decoder_concentrate(delta_zs)
886
839
 
@@ -905,36 +858,28 @@ class PerturbFlow(nn.Module):
905
858
  R = np.concatenate(R)
906
859
  return R
907
860
 
908
- def _count(self,concentrate, library_size=None):
861
+ def _count(self, concentrate, library_size=None):
909
862
  if self.loss_func == 'bernoulli':
910
863
  #counts = self.sigmoid(concentrate)
911
864
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
865
+ elif self.loss_func == 'multinomial':
866
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
867
+ counts = theta * library_size
912
868
  else:
913
869
  rate = concentrate.exp()
914
870
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
915
871
  counts = theta * library_size
916
- #counts = dist.Poisson(rate=rate).to_event(1).mean
917
- return counts
918
-
919
- def _count_sample(self,concentrate):
920
- if self.loss_func == 'bernoulli':
921
- logits = concentrate
922
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
923
- else:
924
- counts = self._count(concentrate=concentrate)
925
- counts = dist.Poisson(rate=counts).to_event(1).sample()
926
872
  return counts
927
873
 
928
874
  def get_counts(self, zs, library_sizes,
929
- batch_size: int = 1024,
930
- use_sampler: bool = False):
875
+ batch_size: int = 1024):
931
876
 
932
877
  zs = convert_to_tensor(zs, device=self.get_device())
933
878
 
934
879
  if type(library_sizes) == list:
935
- library_sizes = np.array(library_sizes).view(-1,1)
880
+ library_sizes = np.array(library_sizes).reshape(-1,1)
936
881
  elif len(library_sizes.shape)==1:
937
- library_sizes = library_sizes.view(-1,1)
882
+ library_sizes = library_sizes.reshape(-1,1)
938
883
  ls = convert_to_tensor(library_sizes, device=self.get_device())
939
884
 
940
885
  dataset = CustomDataset2(zs,ls)
@@ -944,10 +889,7 @@ class PerturbFlow(nn.Module):
944
889
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
945
890
  for Z_batch, L_batch, _ in dataloader:
946
891
  concentrate = self._get_expression_response(Z_batch)
947
- if use_sampler:
948
- counts = self._count_sample(concentrate)
949
- else:
950
- counts = self._count(concentrate, L_batch)
892
+ counts = self._count(concentrate, L_batch)
951
893
  E.append(tensor_to_numpy(counts))
952
894
  pbar.update(1)
953
895
 
@@ -970,7 +912,7 @@ class PerturbFlow(nn.Module):
970
912
  us = None,
971
913
  ys = None,
972
914
  zs = None,
973
- num_epochs: int = 200,
915
+ num_epochs: int = 500,
974
916
  learning_rate: float = 0.0001,
975
917
  batch_size: int = 256,
976
918
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -981,7 +923,7 @@ class PerturbFlow(nn.Module):
981
923
  threshold: int = 0,
982
924
  use_jax: bool = True):
983
925
  """
984
- Train the PerturbFlow model.
926
+ Train the PerturbE model.
985
927
 
986
928
  Parameters
987
929
  ----------
@@ -1007,7 +949,7 @@ class PerturbFlow(nn.Module):
1007
949
  Parameter for optimization.
1008
950
  use_jax
1009
951
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1010
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
952
+ the Python script or Jupyter notebook. It is OK if it is used when runing PerturbE in the shell command.
1011
953
  """
1012
954
  xs = self.preprocess(xs, threshold=threshold)
1013
955
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1125,12 +1067,12 @@ class PerturbFlow(nn.Module):
1125
1067
 
1126
1068
 
1127
1069
  EXAMPLE_RUN = (
1128
- "example run: PerturbFlow --help"
1070
+ "example run: PerturbE --help"
1129
1071
  )
1130
1072
 
1131
1073
  def parse_args():
1132
1074
  parser = argparse.ArgumentParser(
1133
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1075
+ description="PerturbE\n{}".format(EXAMPLE_RUN))
1134
1076
 
1135
1077
  parser.add_argument(
1136
1078
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1317,7 +1259,7 @@ def main():
1317
1259
  cell_factor_size = 0 if us is None else us.shape[1]
1318
1260
 
1319
1261
  ###########################################
1320
- perturbflow = PerturbFlow(
1262
+ perturbe = PerturbE(
1321
1263
  input_size=input_size,
1322
1264
  cell_factor_size=cell_factor_size,
1323
1265
  inverse_dispersion=args.inverse_dispersion,
@@ -1336,7 +1278,7 @@ def main():
1336
1278
  dtype=dtype,
1337
1279
  )
1338
1280
 
1339
- perturbflow.fit(xs, us=us,
1281
+ perturbe.fit(xs, us=us,
1340
1282
  num_epochs=args.num_epochs,
1341
1283
  learning_rate=args.learning_rate,
1342
1284
  batch_size=args.batch_size,
@@ -1348,12 +1290,11 @@ def main():
1348
1290
 
1349
1291
  if args.save_model is not None:
1350
1292
  if args.save_model.endswith('gz'):
1351
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1293
+ PerturbE.save_model(perturbe, args.save_model, compression=True)
1352
1294
  else:
1353
- PerturbFlow.save_model(perturbflow, args.save_model)
1295
+ PerturbE.save_model(perturbe, args.save_model)
1354
1296
 
1355
1297
 
1356
1298
 
1357
1299
  if __name__ == "__main__":
1358
-
1359
1300
  main()