SURE-tools 2.2.2__py3-none-any.whl → 2.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py CHANGED
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
+ turn_off_cell_specific: bool = False,
62
63
  supervised_mode: bool = False,
63
64
  z_dim: int = 10,
64
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
67
  inverse_dispersion: float = 10.0,
67
- use_zeroinflate: bool = True,
68
+ use_zeroinflate: bool = False,
68
69
  hidden_layers: list = [500],
69
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
71
  nn_dropout: float = 0.1,
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
102
103
  else:
103
104
  self.use_bias = [not zero_bias] * self.cell_factor_size
104
105
  #self.use_bias = not zero_bias
106
+ self.turn_off_cell_specific = turn_off_cell_specific
105
107
 
106
108
  self.codebook_weights = None
107
109
 
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
203
205
  self.cell_factor_effect = nn.ModuleList()
204
206
  for i in np.arange(self.cell_factor_size):
205
207
  if self.use_bias[i]:
206
- self.cell_factor_effect.append(MLP(
207
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
208
- activation=activate_fct,
209
- output_activation=None,
210
- post_layer_fct=post_layer_fct,
211
- post_act_fct=post_act_fct,
212
- allow_broadcast=self.allow_broadcast,
213
- use_cuda=self.use_cuda,
208
+ if self.turn_off_cell_specific:
209
+ self.cell_factor_effect.append(MLP(
210
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
211
+ activation=activate_fct,
212
+ output_activation=None,
213
+ post_layer_fct=post_layer_fct,
214
+ post_act_fct=post_act_fct,
215
+ allow_broadcast=self.allow_broadcast,
216
+ use_cuda=self.use_cuda,
217
+ )
218
+ )
219
+ else:
220
+ self.cell_factor_effect.append(MLP(
221
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
222
+ activation=activate_fct,
223
+ output_activation=None,
224
+ post_layer_fct=post_layer_fct,
225
+ post_act_fct=post_act_fct,
226
+ allow_broadcast=self.allow_broadcast,
227
+ use_cuda=self.use_cuda,
228
+ )
214
229
  )
215
- )
216
230
  else:
217
- self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
219
- activation=activate_fct,
220
- output_activation=None,
221
- post_layer_fct=post_layer_fct,
222
- post_act_fct=post_act_fct,
223
- allow_broadcast=self.allow_broadcast,
224
- use_cuda=self.use_cuda,
231
+ if self.turn_off_cell_specific:
232
+ self.cell_factor_effect.append(ZeroBiasMLP(
233
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
234
+ activation=activate_fct,
235
+ output_activation=None,
236
+ post_layer_fct=post_layer_fct,
237
+ post_act_fct=post_act_fct,
238
+ allow_broadcast=self.allow_broadcast,
239
+ use_cuda=self.use_cuda,
240
+ )
241
+ )
242
+ else:
243
+ self.cell_factor_effect.append(ZeroBiasMLP(
244
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
245
+ activation=activate_fct,
246
+ output_activation=None,
247
+ post_layer_fct=post_layer_fct,
248
+ post_act_fct=post_act_fct,
249
+ allow_broadcast=self.allow_broadcast,
250
+ use_cuda=self.use_cuda,
251
+ )
225
252
  )
226
- )
227
253
 
228
254
  self.decoder_concentrate = MLP(
229
255
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -370,7 +396,8 @@ class DensityFlow(nn.Module):
370
396
  else:
371
397
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
372
398
  elif self.loss_func == 'multinomial':
373
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
399
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
400
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
374
401
  elif self.loss_func == 'bernoulli':
375
402
  if self.use_zeroinflate:
376
403
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -454,7 +481,8 @@ class DensityFlow(nn.Module):
454
481
  else:
455
482
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
456
483
  elif self.loss_func == 'multinomial':
457
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
484
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
485
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
458
486
  elif self.loss_func == 'bernoulli':
459
487
  if self.use_zeroinflate:
460
488
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -550,7 +578,8 @@ class DensityFlow(nn.Module):
550
578
  else:
551
579
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
552
580
  elif self.loss_func == 'multinomial':
553
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
581
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
582
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
554
583
  elif self.loss_func == 'bernoulli':
555
584
  if self.use_zeroinflate:
556
585
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -656,7 +685,8 @@ class DensityFlow(nn.Module):
656
685
  else:
657
686
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
658
687
  elif self.loss_func == 'multinomial':
659
- pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
688
+ #pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
689
+ pyro.sample('x', dist.Multinomial(total_count=int(1e8), logits=concentrate), obs=xs)
660
690
  elif self.loss_func == 'bernoulli':
661
691
  if self.use_zeroinflate:
662
692
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
@@ -676,9 +706,17 @@ class DensityFlow(nn.Module):
676
706
  zus = None
677
707
  for i in np.arange(self.cell_factor_size):
678
708
  if i==0:
679
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
709
+ #if self.turn_off_cell_specific:
710
+ # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
711
+ #else:
712
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
713
+ zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
680
714
  else:
681
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
715
+ #if self.turn_off_cell_specific:
716
+ # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
717
+ #else:
718
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
719
+ zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
682
720
  return zus
683
721
 
684
722
  def _get_codebook_identity(self):
@@ -696,7 +734,7 @@ class DensityFlow(nn.Module):
696
734
  """
697
735
  Return the mean part of metacell codebook
698
736
  """
699
- cb = self._get_metacell_coordinates()
737
+ cb = self._get_codebook()
700
738
  cb = tensor_to_numpy(cb)
701
739
  return cb
702
740
 
@@ -820,13 +858,15 @@ class DensityFlow(nn.Module):
820
858
  us_i = us[:,pert_idx].reshape(-1,1)
821
859
 
822
860
  # factor effect of xs
823
- dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
861
+ dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
824
862
 
825
863
  # perturbation effect
826
864
  ps = np.ones_like(us_i)
827
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
828
-
829
- zs = zs + dzs0 + dzs
865
+ if np.sum(np.abs(ps-us_i))>=1:
866
+ dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
867
+ zs = zs + dzs0 + dzs
868
+ else:
869
+ zs = zs + dzs0
830
870
 
831
871
  if library_sizes is None:
832
872
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -840,47 +880,48 @@ class DensityFlow(nn.Module):
840
880
 
841
881
  return counts, zs
842
882
 
843
- def _cell_response(self, xs, factor_idx, perturb):
883
+ def _cell_response(self, zs, perturb_idx, perturb):
844
884
  #zns,_ = self.encoder_zn(xs)
845
- zns,_ = self._get_basal_embedding(xs)
885
+ #zns,_ = self._get_basal_embedding(xs)
886
+ zns = zs
846
887
  if perturb.ndim==2:
847
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
888
+ if self.turn_off_cell_specific:
889
+ ms = self.cell_factor_effect[perturb_idx](perturb)
890
+ else:
891
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
848
892
  else:
849
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
893
+ if self.turn_off_cell_specific:
894
+ ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
895
+ else:
896
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
850
897
 
851
898
  return ms
852
899
 
853
900
  def get_cell_response(self,
854
- xs,
855
- factor_idx,
856
- perturb,
901
+ zs,
902
+ perturb_idx,
903
+ perturb_us,
857
904
  batch_size: int = 1024):
858
905
  """
859
906
  Return cells' changes in the latent space induced by specific perturbation of a factor
860
907
 
861
908
  """
862
- xs = self.preprocess(xs)
863
- xs = convert_to_tensor(xs, device=self.get_device())
864
- ps = convert_to_tensor(perturb, device=self.get_device())
865
- dataset = CustomDataset2(xs,ps)
909
+ #xs = self.preprocess(xs)
910
+ zs = convert_to_tensor(zs, device=self.get_device())
911
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
912
+ dataset = CustomDataset2(zs,ps)
866
913
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
867
914
 
868
915
  Z = []
869
916
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
870
- for X_batch, P_batch, _ in dataloader:
871
- zns = self._cell_response(X_batch, factor_idx, P_batch)
917
+ for Z_batch, P_batch, _ in dataloader:
918
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
872
919
  Z.append(tensor_to_numpy(zns))
873
920
  pbar.update(1)
874
921
 
875
922
  Z = np.concatenate(Z)
876
923
  return Z
877
924
 
878
- def get_metacell_response(self, factor_idx, perturb):
879
- zs = self._get_codebook()
880
- ps = convert_to_tensor(perturb, device=self.get_device())
881
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
882
- return tensor_to_numpy(ms)
883
-
884
925
  def _get_expression_response(self, delta_zs):
885
926
  return self.decoder_concentrate(delta_zs)
886
927
 
@@ -905,36 +946,28 @@ class DensityFlow(nn.Module):
905
946
  R = np.concatenate(R)
906
947
  return R
907
948
 
908
- def _count(self,concentrate, library_size=None):
949
+ def _count(self, concentrate, library_size=None):
909
950
  if self.loss_func == 'bernoulli':
910
951
  #counts = self.sigmoid(concentrate)
911
952
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
953
+ elif self.loss_func == 'multinomial':
954
+ theta = dist.Multinomial(total_count=int(1e8), logits=concentrate).mean
955
+ counts = theta * library_size
912
956
  else:
913
957
  rate = concentrate.exp()
914
958
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
915
959
  counts = theta * library_size
916
- #counts = dist.Poisson(rate=rate).to_event(1).mean
917
- return counts
918
-
919
- def _count_sample(self,concentrate):
920
- if self.loss_func == 'bernoulli':
921
- logits = concentrate
922
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
923
- else:
924
- counts = self._count(concentrate=concentrate)
925
- counts = dist.Poisson(rate=counts).to_event(1).sample()
926
960
  return counts
927
961
 
928
962
  def get_counts(self, zs, library_sizes,
929
- batch_size: int = 1024,
930
- use_sampler: bool = False):
963
+ batch_size: int = 1024):
931
964
 
932
965
  zs = convert_to_tensor(zs, device=self.get_device())
933
966
 
934
967
  if type(library_sizes) == list:
935
- library_sizes = np.array(library_sizes).view(-1,1)
968
+ library_sizes = np.array(library_sizes).reshape(-1,1)
936
969
  elif len(library_sizes.shape)==1:
937
- library_sizes = library_sizes.view(-1,1)
970
+ library_sizes = library_sizes.reshape(-1,1)
938
971
  ls = convert_to_tensor(library_sizes, device=self.get_device())
939
972
 
940
973
  dataset = CustomDataset2(zs,ls)
@@ -944,10 +977,7 @@ class DensityFlow(nn.Module):
944
977
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
945
978
  for Z_batch, L_batch, _ in dataloader:
946
979
  concentrate = self._get_expression_response(Z_batch)
947
- if use_sampler:
948
- counts = self._count_sample(concentrate)
949
- else:
950
- counts = self._count(concentrate, L_batch)
980
+ counts = self._count(concentrate, L_batch)
951
981
  E.append(tensor_to_numpy(counts))
952
982
  pbar.update(1)
953
983
 
@@ -1317,7 +1347,7 @@ def main():
1317
1347
  cell_factor_size = 0 if us is None else us.shape[1]
1318
1348
 
1319
1349
  ###########################################
1320
- DensityFlow = DensityFlow(
1350
+ df = DensityFlow(
1321
1351
  input_size=input_size,
1322
1352
  cell_factor_size=cell_factor_size,
1323
1353
  inverse_dispersion=args.inverse_dispersion,
@@ -1336,7 +1366,7 @@ def main():
1336
1366
  dtype=dtype,
1337
1367
  )
1338
1368
 
1339
- DensityFlow.fit(xs, us=us,
1369
+ df.fit(xs, us=us,
1340
1370
  num_epochs=args.num_epochs,
1341
1371
  learning_rate=args.learning_rate,
1342
1372
  batch_size=args.batch_size,
@@ -1348,12 +1378,11 @@ def main():
1348
1378
 
1349
1379
  if args.save_model is not None:
1350
1380
  if args.save_model.endswith('gz'):
1351
- DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1381
+ DensityFlow.save_model(df, args.save_model, compression=True)
1352
1382
  else:
1353
- DensityFlow.save_model(DensityFlow, args.save_model)
1383
+ DensityFlow.save_model(df, args.save_model)
1354
1384
 
1355
1385
 
1356
1386
 
1357
1387
  if __name__ == "__main__":
1358
-
1359
1388
  main()