SURE-tools 2.1.91__tar.gz → 2.2.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.1.91 → sure_tools-2.2.14}/PKG-INFO +1 -1
  2. sure_tools-2.1.91/SURE/PerturbFlow.py → sure_tools-2.2.14/SURE/DensityFlow.py +43 -62
  3. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/SURE.py +6 -6
  4. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/__init__.py +3 -3
  5. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/flow/flow_stats.py +12 -0
  6. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/perturb/perturb.py +27 -1
  7. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/PKG-INFO +1 -1
  8. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/SOURCES.txt +1 -1
  9. {sure_tools-2.1.91 → sure_tools-2.2.14}/setup.py +1 -1
  10. {sure_tools-2.1.91 → sure_tools-2.2.14}/LICENSE +0 -0
  11. {sure_tools-2.1.91 → sure_tools-2.2.14}/README.md +0 -0
  12. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/assembly/__init__.py +0 -0
  13. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/assembly/assembly.py +0 -0
  14. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/assembly/atlas.py +0 -0
  15. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/atac/__init__.py +0 -0
  16. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/atac/utils.py +0 -0
  17. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/codebook/__init__.py +0 -0
  18. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/codebook/codebook.py +0 -0
  19. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/flow/__init__.py +0 -0
  20. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/flow/plot_quiver.py +0 -0
  21. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/perturb/__init__.py +0 -0
  22. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/custom_mlp.py +0 -0
  24. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/queue.py +0 -0
  25. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE/utils/utils.py +0 -0
  26. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.91 → sure_tools-2.2.14}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.91 → sure_tools-2.2.14}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.91
3
+ Version: 2.2.14
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -54,19 +54,18 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
- cell_factor_effect_discrete: bool = False,
63
62
  supervised_mode: bool = False,
64
63
  z_dim: int = 10,
65
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
66
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
67
66
  inverse_dispersion: float = 10.0,
68
- use_zeroinflate: bool = False,
69
- hidden_layers: list = [300],
67
+ use_zeroinflate: bool = True,
68
+ hidden_layers: list = [500],
70
69
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
71
70
  nn_dropout: float = 0.1,
72
71
  post_layer_fct: list = ['layernorm'],
@@ -103,7 +102,6 @@ class PerturbFlow(nn.Module):
103
102
  else:
104
103
  self.use_bias = [not zero_bias] * self.cell_factor_size
105
104
  #self.use_bias = not zero_bias
106
- self.enumrate = cell_factor_effect_discrete
107
105
 
108
106
  self.codebook_weights = None
109
107
 
@@ -310,7 +308,7 @@ class PerturbFlow(nn.Module):
310
308
  return xs
311
309
 
312
310
  def model1(self, xs):
313
- pyro.module('PerturbFlow', self)
311
+ pyro.module('DensityFlow', self)
314
312
 
315
313
  eps = torch.finfo(xs.dtype).eps
316
314
  batch_size = xs.size(0)
@@ -389,7 +387,7 @@ class PerturbFlow(nn.Module):
389
387
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
390
388
 
391
389
  def model2(self, xs, us=None):
392
- pyro.module('PerturbFlow', self)
390
+ pyro.module('DensityFlow', self)
393
391
 
394
392
  eps = torch.finfo(xs.dtype).eps
395
393
  batch_size = xs.size(0)
@@ -431,10 +429,7 @@ class PerturbFlow(nn.Module):
431
429
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
432
430
 
433
431
  if self.cell_factor_size>0:
434
- if self.enumrate:
435
- zus = self._total_effects(zn_loc, us)
436
- else:
437
- zus = self._total_effects(zns, us)
432
+ zus = self._total_effects(zns, us)
438
433
  zs = zns+zus
439
434
  else:
440
435
  zs = zns
@@ -476,7 +471,7 @@ class PerturbFlow(nn.Module):
476
471
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
477
472
 
478
473
  def model3(self, xs, ys, embeds=None):
479
- pyro.module('PerturbFlow', self)
474
+ pyro.module('DensityFlow', self)
480
475
 
481
476
  eps = torch.finfo(xs.dtype).eps
482
477
  batch_size = xs.size(0)
@@ -572,7 +567,7 @@ class PerturbFlow(nn.Module):
572
567
  zns = embeds
573
568
 
574
569
  def model4(self, xs, us, ys, embeds=None):
575
- pyro.module('PerturbFlow', self)
570
+ pyro.module('DensityFlow', self)
576
571
 
577
572
  eps = torch.finfo(xs.dtype).eps
578
573
  batch_size = xs.size(0)
@@ -636,10 +631,7 @@ class PerturbFlow(nn.Module):
636
631
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
637
632
  # else:
638
633
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
639
- if self.enumrate:
640
- zus = self._total_effects(zn_loc, us)
641
- else:
642
- zus = self._total_effects(zns, us)
634
+ zus = self._total_effects(zns, us)
643
635
  zs = zns+zus
644
636
  else:
645
637
  zs = zns
@@ -832,9 +824,11 @@ class PerturbFlow(nn.Module):
832
824
 
833
825
  # perturbation effect
834
826
  ps = np.ones_like(us_i)
835
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
836
-
837
- zs = zs + dzs0 + dzs
827
+ if np.sum(np.abs(ps-us_i))>=1:
828
+ dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
829
+ zs = zs + dzs0 + dzs
830
+ else:
831
+ zs = zs + dzs0
838
832
 
839
833
  if library_sizes is None:
840
834
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -848,35 +842,36 @@ class PerturbFlow(nn.Module):
848
842
 
849
843
  return counts, zs
850
844
 
851
- def _cell_response(self, xs, factor_idx, perturb):
845
+ def _cell_response(self, zs, perturb_idx, perturb):
852
846
  #zns,_ = self.encoder_zn(xs)
853
- zns,_ = self._get_basal_embedding(xs)
847
+ #zns,_ = self._get_basal_embedding(xs)
848
+ zns = zs
854
849
  if perturb.ndim==2:
855
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
850
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
856
851
  else:
857
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
852
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
858
853
 
859
854
  return ms
860
855
 
861
856
  def get_cell_response(self,
862
- xs,
863
- factor_idx,
864
- perturb,
857
+ zs,
858
+ perturb_idx,
859
+ perturb_us,
865
860
  batch_size: int = 1024):
866
861
  """
867
862
  Return cells' changes in the latent space induced by specific perturbation of a factor
868
863
 
869
864
  """
870
- xs = self.preprocess(xs)
871
- xs = convert_to_tensor(xs, device=self.get_device())
872
- ps = convert_to_tensor(perturb, device=self.get_device())
873
- dataset = CustomDataset2(xs,ps)
865
+ #xs = self.preprocess(xs)
866
+ zs = convert_to_tensor(zs, device=self.get_device())
867
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
868
+ dataset = CustomDataset2(zs,ps)
874
869
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
875
870
 
876
871
  Z = []
877
872
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
878
- for X_batch, P_batch, _ in dataloader:
879
- zns = self._cell_response(X_batch, factor_idx, P_batch)
873
+ for Z_batch, P_batch, _ in dataloader:
874
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
880
875
  Z.append(tensor_to_numpy(zns))
881
876
  pbar.update(1)
882
877
 
@@ -913,7 +908,7 @@ class PerturbFlow(nn.Module):
913
908
  R = np.concatenate(R)
914
909
  return R
915
910
 
916
- def _count(self,concentrate, library_size=None):
911
+ def _count(self, concentrate, library_size=None):
917
912
  if self.loss_func == 'bernoulli':
918
913
  #counts = self.sigmoid(concentrate)
919
914
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -921,28 +916,17 @@ class PerturbFlow(nn.Module):
921
916
  rate = concentrate.exp()
922
917
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
923
918
  counts = theta * library_size
924
- #counts = dist.Poisson(rate=rate).to_event(1).mean
925
- return counts
926
-
927
- def _count_sample(self,concentrate):
928
- if self.loss_func == 'bernoulli':
929
- logits = concentrate
930
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
931
- else:
932
- counts = self._count(concentrate=concentrate)
933
- counts = dist.Poisson(rate=counts).to_event(1).sample()
934
919
  return counts
935
920
 
936
921
  def get_counts(self, zs, library_sizes,
937
- batch_size: int = 1024,
938
- use_sampler: bool = False):
922
+ batch_size: int = 1024):
939
923
 
940
924
  zs = convert_to_tensor(zs, device=self.get_device())
941
925
 
942
926
  if type(library_sizes) == list:
943
- library_sizes = np.array(library_sizes).view(-1,1)
927
+ library_sizes = np.array(library_sizes).reshape(-1,1)
944
928
  elif len(library_sizes.shape)==1:
945
- library_sizes = library_sizes.view(-1,1)
929
+ library_sizes = library_sizes.reshape(-1,1)
946
930
  ls = convert_to_tensor(library_sizes, device=self.get_device())
947
931
 
948
932
  dataset = CustomDataset2(zs,ls)
@@ -952,10 +936,7 @@ class PerturbFlow(nn.Module):
952
936
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
953
937
  for Z_batch, L_batch, _ in dataloader:
954
938
  concentrate = self._get_expression_response(Z_batch)
955
- if use_sampler:
956
- counts = self._count_sample(concentrate)
957
- else:
958
- counts = self._count(concentrate, L_batch)
939
+ counts = self._count(concentrate, L_batch)
959
940
  E.append(tensor_to_numpy(counts))
960
941
  pbar.update(1)
961
942
 
@@ -978,7 +959,7 @@ class PerturbFlow(nn.Module):
978
959
  us = None,
979
960
  ys = None,
980
961
  zs = None,
981
- num_epochs: int = 200,
962
+ num_epochs: int = 500,
982
963
  learning_rate: float = 0.0001,
983
964
  batch_size: int = 256,
984
965
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -989,7 +970,7 @@ class PerturbFlow(nn.Module):
989
970
  threshold: int = 0,
990
971
  use_jax: bool = True):
991
972
  """
992
- Train the PerturbFlow model.
973
+ Train the DensityFlow model.
993
974
 
994
975
  Parameters
995
976
  ----------
@@ -1015,7 +996,7 @@ class PerturbFlow(nn.Module):
1015
996
  Parameter for optimization.
1016
997
  use_jax
1017
998
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1018
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
999
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
1019
1000
  """
1020
1001
  xs = self.preprocess(xs, threshold=threshold)
1021
1002
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1133,12 +1114,12 @@ class PerturbFlow(nn.Module):
1133
1114
 
1134
1115
 
1135
1116
  EXAMPLE_RUN = (
1136
- "example run: PerturbFlow --help"
1117
+ "example run: DensityFlow --help"
1137
1118
  )
1138
1119
 
1139
1120
  def parse_args():
1140
1121
  parser = argparse.ArgumentParser(
1141
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1122
+ description="DensityFlow\n{}".format(EXAMPLE_RUN))
1142
1123
 
1143
1124
  parser.add_argument(
1144
1125
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1325,7 +1306,7 @@ def main():
1325
1306
  cell_factor_size = 0 if us is None else us.shape[1]
1326
1307
 
1327
1308
  ###########################################
1328
- perturbflow = PerturbFlow(
1309
+ DensityFlow = DensityFlow(
1329
1310
  input_size=input_size,
1330
1311
  cell_factor_size=cell_factor_size,
1331
1312
  inverse_dispersion=args.inverse_dispersion,
@@ -1344,7 +1325,7 @@ def main():
1344
1325
  dtype=dtype,
1345
1326
  )
1346
1327
 
1347
- perturbflow.fit(xs, us=us,
1328
+ DensityFlow.fit(xs, us=us,
1348
1329
  num_epochs=args.num_epochs,
1349
1330
  learning_rate=args.learning_rate,
1350
1331
  batch_size=args.batch_size,
@@ -1356,9 +1337,9 @@ def main():
1356
1337
 
1357
1338
  if args.save_model is not None:
1358
1339
  if args.save_model.endswith('gz'):
1359
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1340
+ DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1360
1341
  else:
1361
- PerturbFlow.save_model(perturbflow, args.save_model)
1342
+ DensityFlow.save_model(DensityFlow, args.save_model)
1362
1343
 
1363
1344
 
1364
1345
 
@@ -99,17 +99,17 @@ class SURE(nn.Module):
99
99
  cell_factor_size: int = 0,
100
100
  supervised_mode: bool = False,
101
101
  z_dim: int = 10,
102
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
102
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
103
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
104
104
  inverse_dispersion: float = 10.0,
105
105
  use_zeroinflate: bool = True,
106
- hidden_layers: list = [300],
106
+ hidden_layers: list = [500],
107
107
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
108
108
  nn_dropout: float = 0.1,
109
109
  post_layer_fct: list = ['layernorm'],
110
110
  post_act_fct: list = None,
111
111
  config_enum: str = 'parallel',
112
- use_cuda: bool = False,
112
+ use_cuda: bool = True,
113
113
  seed: int = 42,
114
114
  dtype = torch.float32, # type: ignore
115
115
  ):
@@ -817,7 +817,7 @@ class SURE(nn.Module):
817
817
  us = None,
818
818
  ys = None,
819
819
  zs = None,
820
- num_epochs: int = 200,
820
+ num_epochs: int = 500,
821
821
  learning_rate: float = 0.0001,
822
822
  batch_size: int = 256,
823
823
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -826,7 +826,7 @@ class SURE(nn.Module):
826
826
  decay_rate: float = 0.9,
827
827
  config_enum: str = 'parallel',
828
828
  threshold: int = 0,
829
- use_jax: bool = False):
829
+ use_jax: bool = True):
830
830
  """
831
831
  Train the SURE model.
832
832
 
@@ -1,12 +1,12 @@
1
1
  from .SURE import SURE
2
- from .PerturbFlow import PerturbFlow
2
+ from .DensityFlow import DensityFlow
3
3
 
4
4
  from . import utils
5
5
  from . import codebook
6
6
  from . import SURE
7
- from . import PerturbFlow
7
+ from . import DensityFlow
8
8
  from . import atac
9
9
  from . import flow
10
10
  from . import perturb
11
11
 
12
- __all__ = ['SURE', 'PerturbFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
12
+ __all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -41,6 +41,18 @@ class VectorFieldEval:
41
41
  divergence[np.isnan(divergence)] = 0
42
42
 
43
43
  return divergence
44
+
45
+ def movement_stats(self,vectors):
46
+ return calculate_movement_stats(vectors)
47
+
48
+ def direction_stats(self, vectors):
49
+ return calculate_direction_stats(vectors)
50
+
51
+ def movement_energy(self, vectors, masses=None):
52
+ return calculate_movement_energy(vectors, masses)
53
+
54
+ def movement_divergence(self, positions, vectors):
55
+ return calculate_movement_divergence(positions, vectors)
44
56
 
45
57
 
46
58
  def calculate_movement_stats(vectors):
@@ -1,5 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  from numba import njit
4
5
  from itertools import chain
5
6
  from joblib import Parallel, delayed
@@ -8,6 +9,8 @@ from typing import Literal
8
9
  class LabelMatrix:
9
10
  def __init__(self):
10
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
11
14
 
12
15
  def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
16
  if speedup=='none':
@@ -24,8 +27,31 @@ class LabelMatrix:
24
27
  mat = np.delete(mat, idx, axis=1)
25
28
  self.labels_ = np.delete(self.labels_, idx)
26
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
27
33
  return mat
28
-
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
29
55
  def inverse_transform(self, matrix):
30
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
31
57
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.91
3
+ Version: 2.2.14
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,7 +1,7 @@
1
1
  LICENSE
2
2
  README.md
3
3
  setup.py
4
- SURE/PerturbFlow.py
4
+ SURE/DensityFlow.py
5
5
  SURE/SURE.py
6
6
  SURE/__init__.py
7
7
  SURE/assembly/__init__.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.91',
8
+ version='2.2.14',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes