SURE-tools 2.1.84__tar.gz → 2.2.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

Files changed (30) hide show
  1. {sure_tools-2.1.84 → sure_tools-2.2.14}/PKG-INFO +1 -1
  2. sure_tools-2.1.84/SURE/PerturbFlow.py → sure_tools-2.2.14/SURE/DensityFlow.py +43 -54
  3. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/SURE.py +6 -6
  4. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/__init__.py +3 -3
  5. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/flow/flow_stats.py +12 -0
  6. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/perturb/perturb.py +27 -1
  7. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/PKG-INFO +1 -1
  8. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/SOURCES.txt +1 -1
  9. {sure_tools-2.1.84 → sure_tools-2.2.14}/setup.py +1 -1
  10. {sure_tools-2.1.84 → sure_tools-2.2.14}/LICENSE +0 -0
  11. {sure_tools-2.1.84 → sure_tools-2.2.14}/README.md +0 -0
  12. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/assembly/__init__.py +0 -0
  13. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/assembly/assembly.py +0 -0
  14. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/assembly/atlas.py +0 -0
  15. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/atac/__init__.py +0 -0
  16. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/atac/utils.py +0 -0
  17. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/codebook/__init__.py +0 -0
  18. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/codebook/codebook.py +0 -0
  19. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/flow/__init__.py +0 -0
  20. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/flow/plot_quiver.py +0 -0
  21. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/perturb/__init__.py +0 -0
  22. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/__init__.py +0 -0
  23. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/custom_mlp.py +0 -0
  24. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/queue.py +0 -0
  25. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE/utils/utils.py +0 -0
  26. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.1.84 → sure_tools-2.2.14}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.1.84 → sure_tools-2.2.14}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.84
3
+ Version: 2.2.14
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -54,7 +54,7 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
@@ -62,10 +62,10 @@ class PerturbFlow(nn.Module):
62
62
  supervised_mode: bool = False,
63
63
  z_dim: int = 10,
64
64
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
65
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
66
66
  inverse_dispersion: float = 10.0,
67
- use_zeroinflate: bool = False,
68
- hidden_layers: list = [300],
67
+ use_zeroinflate: bool = True,
68
+ hidden_layers: list = [500],
69
69
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
70
  nn_dropout: float = 0.1,
71
71
  post_layer_fct: list = ['layernorm'],
@@ -308,7 +308,7 @@ class PerturbFlow(nn.Module):
308
308
  return xs
309
309
 
310
310
  def model1(self, xs):
311
- pyro.module('PerturbFlow', self)
311
+ pyro.module('DensityFlow', self)
312
312
 
313
313
  eps = torch.finfo(xs.dtype).eps
314
314
  batch_size = xs.size(0)
@@ -387,7 +387,7 @@ class PerturbFlow(nn.Module):
387
387
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
388
388
 
389
389
  def model2(self, xs, us=None):
390
- pyro.module('PerturbFlow', self)
390
+ pyro.module('DensityFlow', self)
391
391
 
392
392
  eps = torch.finfo(xs.dtype).eps
393
393
  batch_size = xs.size(0)
@@ -429,7 +429,7 @@ class PerturbFlow(nn.Module):
429
429
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
430
430
 
431
431
  if self.cell_factor_size>0:
432
- zus = self._total_effects(zn_loc, us)
432
+ zus = self._total_effects(zns, us)
433
433
  zs = zns+zus
434
434
  else:
435
435
  zs = zns
@@ -471,7 +471,7 @@ class PerturbFlow(nn.Module):
471
471
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
472
472
 
473
473
  def model3(self, xs, ys, embeds=None):
474
- pyro.module('PerturbFlow', self)
474
+ pyro.module('DensityFlow', self)
475
475
 
476
476
  eps = torch.finfo(xs.dtype).eps
477
477
  batch_size = xs.size(0)
@@ -567,7 +567,7 @@ class PerturbFlow(nn.Module):
567
567
  zns = embeds
568
568
 
569
569
  def model4(self, xs, us, ys, embeds=None):
570
- pyro.module('PerturbFlow', self)
570
+ pyro.module('DensityFlow', self)
571
571
 
572
572
  eps = torch.finfo(xs.dtype).eps
573
573
  batch_size = xs.size(0)
@@ -631,7 +631,7 @@ class PerturbFlow(nn.Module):
631
631
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
632
632
  # else:
633
633
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
634
- zus = self._total_effects(zn_loc, us)
634
+ zus = self._total_effects(zns, us)
635
635
  zs = zns+zus
636
636
  else:
637
637
  zs = zns
@@ -824,9 +824,11 @@ class PerturbFlow(nn.Module):
824
824
 
825
825
  # perturbation effect
826
826
  ps = np.ones_like(us_i)
827
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
828
-
829
- zs = zs + dzs0 + dzs
827
+ if np.sum(np.abs(ps-us_i))>=1:
828
+ dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
829
+ zs = zs + dzs0 + dzs
830
+ else:
831
+ zs = zs + dzs0
830
832
 
831
833
  if library_sizes is None:
832
834
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -840,35 +842,36 @@ class PerturbFlow(nn.Module):
840
842
 
841
843
  return counts, zs
842
844
 
843
- def _cell_response(self, xs, factor_idx, perturb):
845
+ def _cell_response(self, zs, perturb_idx, perturb):
844
846
  #zns,_ = self.encoder_zn(xs)
845
- zns,_ = self._get_basal_embedding(xs)
847
+ #zns,_ = self._get_basal_embedding(xs)
848
+ zns = zs
846
849
  if perturb.ndim==2:
847
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
850
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
848
851
  else:
849
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
852
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
850
853
 
851
854
  return ms
852
855
 
853
856
  def get_cell_response(self,
854
- xs,
855
- factor_idx,
856
- perturb,
857
+ zs,
858
+ perturb_idx,
859
+ perturb_us,
857
860
  batch_size: int = 1024):
858
861
  """
859
862
  Return cells' changes in the latent space induced by specific perturbation of a factor
860
863
 
861
864
  """
862
- xs = self.preprocess(xs)
863
- xs = convert_to_tensor(xs, device=self.get_device())
864
- ps = convert_to_tensor(perturb, device=self.get_device())
865
- dataset = CustomDataset2(xs,ps)
865
+ #xs = self.preprocess(xs)
866
+ zs = convert_to_tensor(zs, device=self.get_device())
867
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
868
+ dataset = CustomDataset2(zs,ps)
866
869
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
867
870
 
868
871
  Z = []
869
872
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
870
- for X_batch, P_batch, _ in dataloader:
871
- zns = self._cell_response(X_batch, factor_idx, P_batch)
873
+ for Z_batch, P_batch, _ in dataloader:
874
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
872
875
  Z.append(tensor_to_numpy(zns))
873
876
  pbar.update(1)
874
877
 
@@ -905,7 +908,7 @@ class PerturbFlow(nn.Module):
905
908
  R = np.concatenate(R)
906
909
  return R
907
910
 
908
- def _count(self,concentrate, library_size=None):
911
+ def _count(self, concentrate, library_size=None):
909
912
  if self.loss_func == 'bernoulli':
910
913
  #counts = self.sigmoid(concentrate)
911
914
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -913,28 +916,17 @@ class PerturbFlow(nn.Module):
913
916
  rate = concentrate.exp()
914
917
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
915
918
  counts = theta * library_size
916
- #counts = dist.Poisson(rate=rate).to_event(1).mean
917
- return counts
918
-
919
- def _count_sample(self,concentrate):
920
- if self.loss_func == 'bernoulli':
921
- logits = concentrate
922
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
923
- else:
924
- counts = self._count(concentrate=concentrate)
925
- counts = dist.Poisson(rate=counts).to_event(1).sample()
926
919
  return counts
927
920
 
928
921
  def get_counts(self, zs, library_sizes,
929
- batch_size: int = 1024,
930
- use_sampler: bool = False):
922
+ batch_size: int = 1024):
931
923
 
932
924
  zs = convert_to_tensor(zs, device=self.get_device())
933
925
 
934
926
  if type(library_sizes) == list:
935
- library_sizes = np.array(library_sizes).view(-1,1)
927
+ library_sizes = np.array(library_sizes).reshape(-1,1)
936
928
  elif len(library_sizes.shape)==1:
937
- library_sizes = library_sizes.view(-1,1)
929
+ library_sizes = library_sizes.reshape(-1,1)
938
930
  ls = convert_to_tensor(library_sizes, device=self.get_device())
939
931
 
940
932
  dataset = CustomDataset2(zs,ls)
@@ -944,10 +936,7 @@ class PerturbFlow(nn.Module):
944
936
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
945
937
  for Z_batch, L_batch, _ in dataloader:
946
938
  concentrate = self._get_expression_response(Z_batch)
947
- if use_sampler:
948
- counts = self._count_sample(concentrate)
949
- else:
950
- counts = self._count(concentrate, L_batch)
939
+ counts = self._count(concentrate, L_batch)
951
940
  E.append(tensor_to_numpy(counts))
952
941
  pbar.update(1)
953
942
 
@@ -970,7 +959,7 @@ class PerturbFlow(nn.Module):
970
959
  us = None,
971
960
  ys = None,
972
961
  zs = None,
973
- num_epochs: int = 200,
962
+ num_epochs: int = 500,
974
963
  learning_rate: float = 0.0001,
975
964
  batch_size: int = 256,
976
965
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -981,7 +970,7 @@ class PerturbFlow(nn.Module):
981
970
  threshold: int = 0,
982
971
  use_jax: bool = True):
983
972
  """
984
- Train the PerturbFlow model.
973
+ Train the DensityFlow model.
985
974
 
986
975
  Parameters
987
976
  ----------
@@ -1007,7 +996,7 @@ class PerturbFlow(nn.Module):
1007
996
  Parameter for optimization.
1008
997
  use_jax
1009
998
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1010
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
999
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
1011
1000
  """
1012
1001
  xs = self.preprocess(xs, threshold=threshold)
1013
1002
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1125,12 +1114,12 @@ class PerturbFlow(nn.Module):
1125
1114
 
1126
1115
 
1127
1116
  EXAMPLE_RUN = (
1128
- "example run: PerturbFlow --help"
1117
+ "example run: DensityFlow --help"
1129
1118
  )
1130
1119
 
1131
1120
  def parse_args():
1132
1121
  parser = argparse.ArgumentParser(
1133
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1122
+ description="DensityFlow\n{}".format(EXAMPLE_RUN))
1134
1123
 
1135
1124
  parser.add_argument(
1136
1125
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1317,7 +1306,7 @@ def main():
1317
1306
  cell_factor_size = 0 if us is None else us.shape[1]
1318
1307
 
1319
1308
  ###########################################
1320
- perturbflow = PerturbFlow(
1309
+ DensityFlow = DensityFlow(
1321
1310
  input_size=input_size,
1322
1311
  cell_factor_size=cell_factor_size,
1323
1312
  inverse_dispersion=args.inverse_dispersion,
@@ -1336,7 +1325,7 @@ def main():
1336
1325
  dtype=dtype,
1337
1326
  )
1338
1327
 
1339
- perturbflow.fit(xs, us=us,
1328
+ DensityFlow.fit(xs, us=us,
1340
1329
  num_epochs=args.num_epochs,
1341
1330
  learning_rate=args.learning_rate,
1342
1331
  batch_size=args.batch_size,
@@ -1348,9 +1337,9 @@ def main():
1348
1337
 
1349
1338
  if args.save_model is not None:
1350
1339
  if args.save_model.endswith('gz'):
1351
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1340
+ DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1352
1341
  else:
1353
- PerturbFlow.save_model(perturbflow, args.save_model)
1342
+ DensityFlow.save_model(DensityFlow, args.save_model)
1354
1343
 
1355
1344
 
1356
1345
 
@@ -99,17 +99,17 @@ class SURE(nn.Module):
99
99
  cell_factor_size: int = 0,
100
100
  supervised_mode: bool = False,
101
101
  z_dim: int = 10,
102
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
102
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
103
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
104
104
  inverse_dispersion: float = 10.0,
105
105
  use_zeroinflate: bool = True,
106
- hidden_layers: list = [300],
106
+ hidden_layers: list = [500],
107
107
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
108
108
  nn_dropout: float = 0.1,
109
109
  post_layer_fct: list = ['layernorm'],
110
110
  post_act_fct: list = None,
111
111
  config_enum: str = 'parallel',
112
- use_cuda: bool = False,
112
+ use_cuda: bool = True,
113
113
  seed: int = 42,
114
114
  dtype = torch.float32, # type: ignore
115
115
  ):
@@ -817,7 +817,7 @@ class SURE(nn.Module):
817
817
  us = None,
818
818
  ys = None,
819
819
  zs = None,
820
- num_epochs: int = 200,
820
+ num_epochs: int = 500,
821
821
  learning_rate: float = 0.0001,
822
822
  batch_size: int = 256,
823
823
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -826,7 +826,7 @@ class SURE(nn.Module):
826
826
  decay_rate: float = 0.9,
827
827
  config_enum: str = 'parallel',
828
828
  threshold: int = 0,
829
- use_jax: bool = False):
829
+ use_jax: bool = True):
830
830
  """
831
831
  Train the SURE model.
832
832
 
@@ -1,12 +1,12 @@
1
1
  from .SURE import SURE
2
- from .PerturbFlow import PerturbFlow
2
+ from .DensityFlow import DensityFlow
3
3
 
4
4
  from . import utils
5
5
  from . import codebook
6
6
  from . import SURE
7
- from . import PerturbFlow
7
+ from . import DensityFlow
8
8
  from . import atac
9
9
  from . import flow
10
10
  from . import perturb
11
11
 
12
- __all__ = ['SURE', 'PerturbFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
12
+ __all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
@@ -41,6 +41,18 @@ class VectorFieldEval:
41
41
  divergence[np.isnan(divergence)] = 0
42
42
 
43
43
  return divergence
44
+
45
+ def movement_stats(self,vectors):
46
+ return calculate_movement_stats(vectors)
47
+
48
+ def direction_stats(self, vectors):
49
+ return calculate_direction_stats(vectors)
50
+
51
+ def movement_energy(self, vectors, masses=None):
52
+ return calculate_movement_energy(vectors, masses)
53
+
54
+ def movement_divergence(self, positions, vectors):
55
+ return calculate_movement_divergence(positions, vectors)
44
56
 
45
57
 
46
58
  def calculate_movement_stats(vectors):
@@ -1,5 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  from numba import njit
4
5
  from itertools import chain
5
6
  from joblib import Parallel, delayed
@@ -8,6 +9,8 @@ from typing import Literal
8
9
  class LabelMatrix:
9
10
  def __init__(self):
10
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
11
14
 
12
15
  def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
16
  if speedup=='none':
@@ -24,8 +27,31 @@ class LabelMatrix:
24
27
  mat = np.delete(mat, idx, axis=1)
25
28
  self.labels_ = np.delete(self.labels_, idx)
26
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
27
33
  return mat
28
-
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
29
55
  def inverse_transform(self, matrix):
30
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
31
57
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.84
3
+ Version: 2.2.14
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,7 +1,7 @@
1
1
  LICENSE
2
2
  README.md
3
3
  setup.py
4
- SURE/PerturbFlow.py
4
+ SURE/DensityFlow.py
5
5
  SURE/SURE.py
6
6
  SURE/__init__.py
7
7
  SURE/assembly/__init__.py
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.1.84',
8
+ version='2.2.14',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes