SURE-tools 2.2.7__tar.gz → 2.2.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {sure_tools-2.2.7 → sure_tools-2.2.15}/PKG-INFO +1 -1
  2. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/DensityFlow.py +18 -39
  3. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/perturb/perturb.py +27 -1
  4. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/PKG-INFO +1 -1
  5. {sure_tools-2.2.7 → sure_tools-2.2.15}/setup.py +1 -1
  6. {sure_tools-2.2.7 → sure_tools-2.2.15}/LICENSE +0 -0
  7. {sure_tools-2.2.7 → sure_tools-2.2.15}/README.md +0 -0
  8. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/SURE.py +0 -0
  9. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/__init__.py +0 -0
  10. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/assembly/__init__.py +0 -0
  11. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/assembly/assembly.py +0 -0
  12. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/assembly/atlas.py +0 -0
  13. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/atac/__init__.py +0 -0
  14. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/atac/utils.py +0 -0
  15. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/codebook/__init__.py +0 -0
  16. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/codebook/codebook.py +0 -0
  17. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/flow/__init__.py +0 -0
  18. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/flow/flow_stats.py +0 -0
  19. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/flow/plot_quiver.py +0 -0
  20. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/perturb/__init__.py +0 -0
  21. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/__init__.py +0 -0
  22. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/custom_mlp.py +0 -0
  23. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/queue.py +0 -0
  24. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE/utils/utils.py +0 -0
  25. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/SOURCES.txt +0 -0
  26. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/dependency_links.txt +0 -0
  27. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/entry_points.txt +0 -0
  28. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/requires.txt +0 -0
  29. {sure_tools-2.2.7 → sure_tools-2.2.15}/SURE_tools.egg-info/top_level.txt +0 -0
  30. {sure_tools-2.2.7 → sure_tools-2.2.15}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.7
3
+ Version: 2.2.15
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -696,7 +696,7 @@ class DensityFlow(nn.Module):
696
696
  """
697
697
  Return the mean part of metacell codebook
698
698
  """
699
- cb = self._get_metacell_coordinates()
699
+ cb = self._get_codebook()
700
700
  cb = tensor_to_numpy(cb)
701
701
  return cb
702
702
 
@@ -842,47 +842,42 @@ class DensityFlow(nn.Module):
842
842
 
843
843
  return counts, zs
844
844
 
845
- def _cell_response(self, xs, factor_idx, perturb):
845
+ def _cell_response(self, zs, perturb_idx, perturb):
846
846
  #zns,_ = self.encoder_zn(xs)
847
- zns,_ = self._get_basal_embedding(xs)
847
+ #zns,_ = self._get_basal_embedding(xs)
848
+ zns = zs
848
849
  if perturb.ndim==2:
849
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
850
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
850
851
  else:
851
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
852
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
852
853
 
853
854
  return ms
854
855
 
855
856
  def get_cell_response(self,
856
- xs,
857
- factor_idx,
858
- perturb,
857
+ zs,
858
+ perturb_idx,
859
+ perturb_us,
859
860
  batch_size: int = 1024):
860
861
  """
861
862
  Return cells' changes in the latent space induced by specific perturbation of a factor
862
863
 
863
864
  """
864
- xs = self.preprocess(xs)
865
- xs = convert_to_tensor(xs, device=self.get_device())
866
- ps = convert_to_tensor(perturb, device=self.get_device())
867
- dataset = CustomDataset2(xs,ps)
865
+ #xs = self.preprocess(xs)
866
+ zs = convert_to_tensor(zs, device=self.get_device())
867
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
868
+ dataset = CustomDataset2(zs,ps)
868
869
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
869
870
 
870
871
  Z = []
871
872
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
872
- for X_batch, P_batch, _ in dataloader:
873
- zns = self._cell_response(X_batch, factor_idx, P_batch)
873
+ for Z_batch, P_batch, _ in dataloader:
874
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
874
875
  Z.append(tensor_to_numpy(zns))
875
876
  pbar.update(1)
876
877
 
877
878
  Z = np.concatenate(Z)
878
879
  return Z
879
880
 
880
- def get_metacell_response(self, factor_idx, perturb):
881
- zs = self._get_codebook()
882
- ps = convert_to_tensor(perturb, device=self.get_device())
883
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
884
- return tensor_to_numpy(ms)
885
-
886
881
  def _get_expression_response(self, delta_zs):
887
882
  return self.decoder_concentrate(delta_zs)
888
883
 
@@ -907,7 +902,7 @@ class DensityFlow(nn.Module):
907
902
  R = np.concatenate(R)
908
903
  return R
909
904
 
910
- def _count(self,concentrate, library_size=None):
905
+ def _count(self, concentrate, library_size=None):
911
906
  if self.loss_func == 'bernoulli':
912
907
  #counts = self.sigmoid(concentrate)
913
908
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -917,18 +912,8 @@ class DensityFlow(nn.Module):
917
912
  counts = theta * library_size
918
913
  return counts
919
914
 
920
- def _count_sample(self,concentrate):
921
- if self.loss_func == 'bernoulli':
922
- logits = concentrate
923
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
924
- else:
925
- counts = self._count(concentrate=concentrate)
926
- counts = dist.Poisson(rate=counts).to_event(1).sample()
927
- return counts
928
-
929
915
  def get_counts(self, zs, library_sizes,
930
- batch_size: int = 1024,
931
- use_sampler: bool = False):
916
+ batch_size: int = 1024):
932
917
 
933
918
  zs = convert_to_tensor(zs, device=self.get_device())
934
919
 
@@ -945,10 +930,7 @@ class DensityFlow(nn.Module):
945
930
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
946
931
  for Z_batch, L_batch, _ in dataloader:
947
932
  concentrate = self._get_expression_response(Z_batch)
948
- if use_sampler:
949
- counts = self._count_sample(concentrate)
950
- else:
951
- counts = self._count(concentrate, L_batch)
933
+ counts = self._count(concentrate, L_batch)
952
934
  E.append(tensor_to_numpy(counts))
953
935
  pbar.update(1)
954
936
 
@@ -1093,9 +1075,6 @@ class DensityFlow(nn.Module):
1093
1075
  # Update progress bar
1094
1076
  pbar.set_postfix({'loss': str_loss})
1095
1077
  pbar.update(1)
1096
-
1097
- if self.loss_func == 'negbinomial':
1098
- self.inverse_dispersion = pyro.param("inverse_dispersion")
1099
1078
 
1100
1079
  @classmethod
1101
1080
  def save_model(cls, model, file_path, compression=False):
@@ -1,5 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  from numba import njit
4
5
  from itertools import chain
5
6
  from joblib import Parallel, delayed
@@ -8,6 +9,8 @@ from typing import Literal
8
9
  class LabelMatrix:
9
10
  def __init__(self):
10
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
11
14
 
12
15
  def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
16
  if speedup=='none':
@@ -24,8 +27,31 @@ class LabelMatrix:
24
27
  mat = np.delete(mat, idx, axis=1)
25
28
  self.labels_ = np.delete(self.labels_, idx)
26
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
27
33
  return mat
28
-
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
29
55
  def inverse_transform(self, matrix):
30
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
31
57
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.7
3
+ Version: 2.2.15
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
5
 
6
6
  setup(
7
7
  name='SURE-tools',
8
- version='2.2.7',
8
+ version='2.2.15',
9
9
  description='Succinct Representation of Single Cells',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
File without changes
File without changes
File without changes
File without changes
File without changes