SURE-tools 2.2.10__tar.gz → 2.2.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.2.10 → sure_tools-2.2.28}/PKG-INFO +1 -1
  2. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/DensityFlow.py +91 -71
  3. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/perturb/perturb.py +27 -1
  4. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/utils/custom_mlp.py +8 -2
  5. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE_tools.egg-info/PKG-INFO +1 -1
  6. {sure_tools-2.2.10 → sure_tools-2.2.28}/setup.py +1 -1
  7. {sure_tools-2.2.10 → sure_tools-2.2.28}/LICENSE +0 -0
  8. {sure_tools-2.2.10 → sure_tools-2.2.28}/README.md +0 -0
  9. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/SURE.py +0 -0
  10. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/__init__.py +0 -0
  11. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/assembly/__init__.py +0 -0
  12. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/assembly/assembly.py +0 -0
  13. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/assembly/atlas.py +0 -0
  14. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/atac/__init__.py +0 -0
  15. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/atac/utils.py +0 -0
  16. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/codebook/__init__.py +0 -0
  17. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/codebook/codebook.py +0 -0
  18. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/flow/__init__.py +0 -0
  19. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/flow/flow_stats.py +0 -0
  20. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/flow/plot_quiver.py +0 -0
  21. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/perturb/__init__.py +0 -0
  22. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.2.10 → sure_tools-2.2.28}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.2.10 → sure_tools-2.2.28}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.10
3
+ Version: 2.2.28
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
+ turn_off_cell_specific: bool = False,
62
63
  supervised_mode: bool = False,
63
64
  z_dim: int = 10,
64
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
67
  inverse_dispersion: float = 10.0,
67
- use_zeroinflate: bool = True,
68
+ use_zeroinflate: bool = False,
68
69
  hidden_layers: list = [500],
69
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
71
  nn_dropout: float = 0.1,
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
102
103
  else:
103
104
  self.use_bias = [not zero_bias] * self.cell_factor_size
104
105
  #self.use_bias = not zero_bias
106
+ self.turn_off_cell_specific = turn_off_cell_specific
105
107
 
106
108
  self.codebook_weights = None
107
109
 
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
203
205
  self.cell_factor_effect = nn.ModuleList()
204
206
  for i in np.arange(self.cell_factor_size):
205
207
  if self.use_bias[i]:
206
- self.cell_factor_effect.append(MLP(
207
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
208
- activation=activate_fct,
209
- output_activation=None,
210
- post_layer_fct=post_layer_fct,
211
- post_act_fct=post_act_fct,
212
- allow_broadcast=self.allow_broadcast,
213
- use_cuda=self.use_cuda,
208
+ if self.turn_off_cell_specific:
209
+ self.cell_factor_effect.append(MLP(
210
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
211
+ activation=activate_fct,
212
+ output_activation=None,
213
+ post_layer_fct=post_layer_fct,
214
+ post_act_fct=post_act_fct,
215
+ allow_broadcast=self.allow_broadcast,
216
+ use_cuda=self.use_cuda,
217
+ )
218
+ )
219
+ else:
220
+ self.cell_factor_effect.append(MLP(
221
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
222
+ activation=activate_fct,
223
+ output_activation=None,
224
+ post_layer_fct=post_layer_fct,
225
+ post_act_fct=post_act_fct,
226
+ allow_broadcast=self.allow_broadcast,
227
+ use_cuda=self.use_cuda,
228
+ )
214
229
  )
215
- )
216
230
  else:
217
- self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
219
- activation=activate_fct,
220
- output_activation=None,
221
- post_layer_fct=post_layer_fct,
222
- post_act_fct=post_act_fct,
223
- allow_broadcast=self.allow_broadcast,
224
- use_cuda=self.use_cuda,
231
+ if self.turn_off_cell_specific:
232
+ self.cell_factor_effect.append(ZeroBiasMLP(
233
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
234
+ activation=activate_fct,
235
+ output_activation=None,
236
+ post_layer_fct=post_layer_fct,
237
+ post_act_fct=post_act_fct,
238
+ allow_broadcast=self.allow_broadcast,
239
+ use_cuda=self.use_cuda,
240
+ )
241
+ )
242
+ else:
243
+ self.cell_factor_effect.append(ZeroBiasMLP(
244
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
245
+ activation=activate_fct,
246
+ output_activation=None,
247
+ post_layer_fct=post_layer_fct,
248
+ post_act_fct=post_act_fct,
249
+ allow_broadcast=self.allow_broadcast,
250
+ use_cuda=self.use_cuda,
251
+ )
225
252
  )
226
- )
227
253
 
228
254
  self.decoder_concentrate = MLP(
229
255
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -234,16 +260,6 @@ class DensityFlow(nn.Module):
234
260
  allow_broadcast=self.allow_broadcast,
235
261
  use_cuda=self.use_cuda,
236
262
  )
237
- if self.loss_func == 'negbinomial':
238
- self.decoder_total_count = MLP(
239
- [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
240
- activation=activate_fct,
241
- output_activation=Exp,
242
- post_layer_fct=post_layer_fct,
243
- post_act_fct=post_act_fct,
244
- allow_broadcast=self.allow_broadcast,
245
- use_cuda=self.use_cuda,
246
- )
247
263
 
248
264
  if self.latent_dist == 'studentt':
249
265
  self.codebook = MLP(
@@ -324,9 +340,9 @@ class DensityFlow(nn.Module):
324
340
  batch_size = xs.size(0)
325
341
  self.options = dict(dtype=xs.dtype, device=xs.device)
326
342
 
327
- #if self.loss_func=='negbinomial':
328
- # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
329
- # xs.new_ones(self.input_size), constraint=constraints.positive)
343
+ if self.loss_func=='negbinomial':
344
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
345
+ xs.new_ones(self.input_size), constraint=constraints.positive)
330
346
 
331
347
  if self.use_zeroinflate:
332
348
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -370,7 +386,6 @@ class DensityFlow(nn.Module):
370
386
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
371
387
 
372
388
  if self.loss_func == 'negbinomial':
373
- total_count = self.decoder_total_count(zs)
374
389
  if self.use_zeroinflate:
375
390
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
376
391
  else:
@@ -404,9 +419,9 @@ class DensityFlow(nn.Module):
404
419
  batch_size = xs.size(0)
405
420
  self.options = dict(dtype=xs.dtype, device=xs.device)
406
421
 
407
- #if self.loss_func=='negbinomial':
408
- # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
409
- # xs.new_ones(self.input_size), constraint=constraints.positive)
422
+ if self.loss_func=='negbinomial':
423
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
424
+ xs.new_ones(self.input_size), constraint=constraints.positive)
410
425
 
411
426
  if self.use_zeroinflate:
412
427
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -455,7 +470,6 @@ class DensityFlow(nn.Module):
455
470
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
456
471
 
457
472
  if self.loss_func == 'negbinomial':
458
- total_count = self.decoder_total_count(zs)
459
473
  if self.use_zeroinflate:
460
474
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
461
475
  else:
@@ -489,9 +503,9 @@ class DensityFlow(nn.Module):
489
503
  batch_size = xs.size(0)
490
504
  self.options = dict(dtype=xs.dtype, device=xs.device)
491
505
 
492
- #if self.loss_func=='negbinomial':
493
- # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
494
- # xs.new_ones(self.input_size), constraint=constraints.positive)
506
+ if self.loss_func=='negbinomial':
507
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
508
+ xs.new_ones(self.input_size), constraint=constraints.positive)
495
509
 
496
510
  if self.use_zeroinflate:
497
511
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -552,7 +566,6 @@ class DensityFlow(nn.Module):
552
566
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
553
567
 
554
568
  if self.loss_func == 'negbinomial':
555
- total_count = self.decoder_total_count(zs)
556
569
  if self.use_zeroinflate:
557
570
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
558
571
  else:
@@ -586,9 +599,9 @@ class DensityFlow(nn.Module):
586
599
  batch_size = xs.size(0)
587
600
  self.options = dict(dtype=xs.dtype, device=xs.device)
588
601
 
589
- #if self.loss_func=='negbinomial':
590
- # total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
591
- # xs.new_ones(self.input_size), constraint=constraints.positive)
602
+ if self.loss_func=='negbinomial':
603
+ total_count = pyro.param("inverse_dispersion", self.inverse_dispersion *
604
+ xs.new_ones(self.input_size), constraint=constraints.positive)
592
605
 
593
606
  if self.use_zeroinflate:
594
607
  gate_logits = pyro.param("dropout_rate", xs.new_zeros(self.input_size))
@@ -659,7 +672,6 @@ class DensityFlow(nn.Module):
659
672
  rate = theta * torch.sum(xs, dim=1, keepdim=True)
660
673
 
661
674
  if self.loss_func == 'negbinomial':
662
- total_count = self.decoder_total_count(zs)
663
675
  if self.use_zeroinflate:
664
676
  pyro.sample('x', dist.ZeroInflatedDistribution(dist.NegativeBinomial(total_count=total_count, probs=theta),gate_logits=gate_logits).to_event(1), obs=xs)
665
677
  else:
@@ -690,9 +702,17 @@ class DensityFlow(nn.Module):
690
702
  zus = None
691
703
  for i in np.arange(self.cell_factor_size):
692
704
  if i==0:
693
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
705
+ #if self.turn_off_cell_specific:
706
+ # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
707
+ #else:
708
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
709
+ zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
694
710
  else:
695
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
711
+ #if self.turn_off_cell_specific:
712
+ # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
713
+ #else:
714
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
715
+ zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
696
716
  return zus
697
717
 
698
718
  def _get_codebook_identity(self):
@@ -710,7 +730,7 @@ class DensityFlow(nn.Module):
710
730
  """
711
731
  Return the mean part of metacell codebook
712
732
  """
713
- cb = self._get_metacell_coordinates()
733
+ cb = self._get_codebook()
714
734
  cb = tensor_to_numpy(cb)
715
735
  return cb
716
736
 
@@ -834,12 +854,12 @@ class DensityFlow(nn.Module):
834
854
  us_i = us[:,pert_idx].reshape(-1,1)
835
855
 
836
856
  # factor effect of xs
837
- dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
857
+ dzs0 = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=us_i)
838
858
 
839
859
  # perturbation effect
840
860
  ps = np.ones_like(us_i)
841
861
  if np.sum(np.abs(ps-us_i))>=1:
842
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
862
+ dzs = self.get_cell_response(zs, perturb_idx=pert_idx, perturb_us=ps)
843
863
  zs = zs + dzs0 + dzs
844
864
  else:
845
865
  zs = zs + dzs0
@@ -856,47 +876,48 @@ class DensityFlow(nn.Module):
856
876
 
857
877
  return counts, zs
858
878
 
859
- def _cell_response(self, xs, factor_idx, perturb):
879
+ def _cell_response(self, zs, perturb_idx, perturb):
860
880
  #zns,_ = self.encoder_zn(xs)
861
- zns,_ = self._get_basal_embedding(xs)
881
+ #zns,_ = self._get_basal_embedding(xs)
882
+ zns = zs
862
883
  if perturb.ndim==2:
863
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
884
+ if self.turn_off_cell_specific:
885
+ ms = self.cell_factor_effect[perturb_idx](perturb)
886
+ else:
887
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
864
888
  else:
865
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
889
+ if self.turn_off_cell_specific:
890
+ ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
891
+ else:
892
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
866
893
 
867
894
  return ms
868
895
 
869
896
  def get_cell_response(self,
870
- xs,
871
- factor_idx,
872
- perturb,
897
+ zs,
898
+ perturb_idx,
899
+ perturb_us,
873
900
  batch_size: int = 1024):
874
901
  """
875
902
  Return cells' changes in the latent space induced by specific perturbation of a factor
876
903
 
877
904
  """
878
- xs = self.preprocess(xs)
879
- xs = convert_to_tensor(xs, device=self.get_device())
880
- ps = convert_to_tensor(perturb, device=self.get_device())
881
- dataset = CustomDataset2(xs,ps)
905
+ #xs = self.preprocess(xs)
906
+ zs = convert_to_tensor(zs, device=self.get_device())
907
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
908
+ dataset = CustomDataset2(zs,ps)
882
909
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
883
910
 
884
911
  Z = []
885
912
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
886
- for X_batch, P_batch, _ in dataloader:
887
- zns = self._cell_response(X_batch, factor_idx, P_batch)
913
+ for Z_batch, P_batch, _ in dataloader:
914
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
888
915
  Z.append(tensor_to_numpy(zns))
889
916
  pbar.update(1)
890
917
 
891
918
  Z = np.concatenate(Z)
892
919
  return Z
893
920
 
894
- def get_metacell_response(self, factor_idx, perturb):
895
- zs = self._get_codebook()
896
- ps = convert_to_tensor(perturb, device=self.get_device())
897
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
898
- return tensor_to_numpy(ms)
899
-
900
921
  def _get_expression_response(self, delta_zs):
901
922
  return self.decoder_concentrate(delta_zs)
902
923
 
@@ -1357,5 +1378,4 @@ def main():
1357
1378
 
1358
1379
 
1359
1380
  if __name__ == "__main__":
1360
-
1361
1381
  main()
@@ -1,5 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  from numba import njit
4
5
  from itertools import chain
5
6
  from joblib import Parallel, delayed
@@ -8,6 +9,8 @@ from typing import Literal
8
9
  class LabelMatrix:
9
10
  def __init__(self):
10
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
11
14
 
12
15
  def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
16
  if speedup=='none':
@@ -24,8 +27,31 @@ class LabelMatrix:
24
27
  mat = np.delete(mat, idx, axis=1)
25
28
  self.labels_ = np.delete(self.labels_, idx)
26
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
27
33
  return mat
28
-
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
29
55
  def inverse_transform(self, matrix):
30
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
31
57
 
@@ -240,9 +240,15 @@ class ZeroBiasMLP(nn.Module):
240
240
  y = self.mlp(x)
241
241
  mask = torch.zeros_like(y)
242
242
  if len(y.shape)==2:
243
- mask[x[1][:,0]>0,:] = 1
243
+ if type(x)==list:
244
+ mask[x[1][:,0]>0,:] = 1
245
+ else:
246
+ mask[x[:,0]>0,:] = 1
244
247
  elif len(y.shape)==3:
245
- mask[:,x[1][:,0]>0,:] = 1
248
+ if type(x)==list:
249
+ mask[:,x[1][:,0]>0,:] = 1
250
+ else:
251
+ mask[:,x[:,0]>0,:] = 1
246
252
  return y*mask
247
253
 
248
254
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.10
3
+ Version: 2.2.28
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.2.10',
8
+ version='2.2.28',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes