SURE-tools 2.2.7__py3-none-any.whl → 2.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/DensityFlow.py CHANGED
@@ -59,12 +59,13 @@ class DensityFlow(nn.Module):
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
+ turn_off_cell_specific: bool = False,
62
63
  supervised_mode: bool = False,
63
64
  z_dim: int = 10,
64
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
67
  inverse_dispersion: float = 10.0,
67
- use_zeroinflate: bool = True,
68
+ use_zeroinflate: bool = False,
68
69
  hidden_layers: list = [500],
69
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
71
  nn_dropout: float = 0.1,
@@ -102,6 +103,7 @@ class DensityFlow(nn.Module):
102
103
  else:
103
104
  self.use_bias = [not zero_bias] * self.cell_factor_size
104
105
  #self.use_bias = not zero_bias
106
+ self.turn_off_cell_specific = turn_off_cell_specific
105
107
 
106
108
  self.codebook_weights = None
107
109
 
@@ -203,27 +205,51 @@ class DensityFlow(nn.Module):
203
205
  self.cell_factor_effect = nn.ModuleList()
204
206
  for i in np.arange(self.cell_factor_size):
205
207
  if self.use_bias[i]:
206
- self.cell_factor_effect.append(MLP(
207
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
208
- activation=activate_fct,
209
- output_activation=None,
210
- post_layer_fct=post_layer_fct,
211
- post_act_fct=post_act_fct,
212
- allow_broadcast=self.allow_broadcast,
213
- use_cuda=self.use_cuda,
208
+ if self.turn_off_cell_specific:
209
+ self.cell_factor_effect.append(MLP(
210
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
211
+ activation=activate_fct,
212
+ output_activation=None,
213
+ post_layer_fct=post_layer_fct,
214
+ post_act_fct=post_act_fct,
215
+ allow_broadcast=self.allow_broadcast,
216
+ use_cuda=self.use_cuda,
217
+ )
218
+ )
219
+ else:
220
+ self.cell_factor_effect.append(MLP(
221
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
222
+ activation=activate_fct,
223
+ output_activation=None,
224
+ post_layer_fct=post_layer_fct,
225
+ post_act_fct=post_act_fct,
226
+ allow_broadcast=self.allow_broadcast,
227
+ use_cuda=self.use_cuda,
228
+ )
214
229
  )
215
- )
216
230
  else:
217
- self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
219
- activation=activate_fct,
220
- output_activation=None,
221
- post_layer_fct=post_layer_fct,
222
- post_act_fct=post_act_fct,
223
- allow_broadcast=self.allow_broadcast,
224
- use_cuda=self.use_cuda,
231
+ if self.turn_off_cell_specific:
232
+ self.cell_factor_effect.append(ZeroBiasMLP(
233
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
234
+ activation=activate_fct,
235
+ output_activation=None,
236
+ post_layer_fct=post_layer_fct,
237
+ post_act_fct=post_act_fct,
238
+ allow_broadcast=self.allow_broadcast,
239
+ use_cuda=self.use_cuda,
240
+ )
241
+ )
242
+ else:
243
+ self.cell_factor_effect.append(ZeroBiasMLP(
244
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
245
+ activation=activate_fct,
246
+ output_activation=None,
247
+ post_layer_fct=post_layer_fct,
248
+ post_act_fct=post_act_fct,
249
+ allow_broadcast=self.allow_broadcast,
250
+ use_cuda=self.use_cuda,
251
+ )
225
252
  )
226
- )
227
253
 
228
254
  self.decoder_concentrate = MLP(
229
255
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -676,9 +702,17 @@ class DensityFlow(nn.Module):
676
702
  zus = None
677
703
  for i in np.arange(self.cell_factor_size):
678
704
  if i==0:
679
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
705
+ #if self.turn_off_cell_specific:
706
+ # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
707
+ #else:
708
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
709
+ zus = self._cell_response(zns, i, us)
680
710
  else:
681
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
711
+ #if self.turn_off_cell_specific:
712
+ # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
713
+ #else:
714
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
715
+ zus = zus + self._cell_response(zns, i, us)
682
716
  return zus
683
717
 
684
718
  def _get_codebook_identity(self):
@@ -696,7 +730,7 @@ class DensityFlow(nn.Module):
696
730
  """
697
731
  Return the mean part of metacell codebook
698
732
  """
699
- cb = self._get_metacell_coordinates()
733
+ cb = self._get_codebook()
700
734
  cb = tensor_to_numpy(cb)
701
735
  return cb
702
736
 
@@ -820,12 +854,12 @@ class DensityFlow(nn.Module):
820
854
  us_i = us[:,pert_idx].reshape(-1,1)
821
855
 
822
856
  # factor effect of xs
823
- dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
857
+ dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
824
858
 
825
859
  # perturbation effect
826
860
  ps = np.ones_like(us_i)
827
861
  if np.sum(np.abs(ps-us_i))>=1:
828
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
862
+ dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
829
863
  zs = zs + dzs0 + dzs
830
864
  else:
831
865
  zs = zs + dzs0
@@ -842,47 +876,48 @@ class DensityFlow(nn.Module):
842
876
 
843
877
  return counts, zs
844
878
 
845
- def _cell_response(self, xs, factor_idx, perturb):
879
+ def _cell_response(self, zs, perturb_idx, perturb):
846
880
  #zns,_ = self.encoder_zn(xs)
847
- zns,_ = self._get_basal_embedding(xs)
881
+ #zns,_ = self._get_basal_embedding(xs)
882
+ zns = zs
848
883
  if perturb.ndim==2:
849
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
884
+ if self.turn_off_cell_specific:
885
+ ms = self.cell_factor_effect[perturb_idx](perturb)
886
+ else:
887
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
850
888
  else:
851
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
889
+ if self.turn_off_cell_specific:
890
+ ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
891
+ else:
892
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
852
893
 
853
894
  return ms
854
895
 
855
896
  def get_cell_response(self,
856
- xs,
857
- factor_idx,
858
- perturb,
897
+ zs,
898
+ perturb_idx,
899
+ perturb_us,
859
900
  batch_size: int = 1024):
860
901
  """
861
902
  Return cells' changes in the latent space induced by specific perturbation of a factor
862
903
 
863
904
  """
864
- xs = self.preprocess(xs)
865
- xs = convert_to_tensor(xs, device=self.get_device())
866
- ps = convert_to_tensor(perturb, device=self.get_device())
867
- dataset = CustomDataset2(xs,ps)
905
+ #xs = self.preprocess(xs)
906
+ zs = convert_to_tensor(zs, device=self.get_device())
907
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
908
+ dataset = CustomDataset2(zs,ps)
868
909
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
869
910
 
870
911
  Z = []
871
912
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
872
- for X_batch, P_batch, _ in dataloader:
873
- zns = self._cell_response(X_batch, factor_idx, P_batch)
913
+ for Z_batch, P_batch, _ in dataloader:
914
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
874
915
  Z.append(tensor_to_numpy(zns))
875
916
  pbar.update(1)
876
917
 
877
918
  Z = np.concatenate(Z)
878
919
  return Z
879
920
 
880
- def get_metacell_response(self, factor_idx, perturb):
881
- zs = self._get_codebook()
882
- ps = convert_to_tensor(perturb, device=self.get_device())
883
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
884
- return tensor_to_numpy(ms)
885
-
886
921
  def _get_expression_response(self, delta_zs):
887
922
  return self.decoder_concentrate(delta_zs)
888
923
 
@@ -907,7 +942,7 @@ class DensityFlow(nn.Module):
907
942
  R = np.concatenate(R)
908
943
  return R
909
944
 
910
- def _count(self,concentrate, library_size=None):
945
+ def _count(self, concentrate, library_size=None):
911
946
  if self.loss_func == 'bernoulli':
912
947
  #counts = self.sigmoid(concentrate)
913
948
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -917,18 +952,8 @@ class DensityFlow(nn.Module):
917
952
  counts = theta * library_size
918
953
  return counts
919
954
 
920
- def _count_sample(self,concentrate):
921
- if self.loss_func == 'bernoulli':
922
- logits = concentrate
923
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
924
- else:
925
- counts = self._count(concentrate=concentrate)
926
- counts = dist.Poisson(rate=counts).to_event(1).sample()
927
- return counts
928
-
929
955
  def get_counts(self, zs, library_sizes,
930
- batch_size: int = 1024,
931
- use_sampler: bool = False):
956
+ batch_size: int = 1024):
932
957
 
933
958
  zs = convert_to_tensor(zs, device=self.get_device())
934
959
 
@@ -945,10 +970,7 @@ class DensityFlow(nn.Module):
945
970
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
946
971
  for Z_batch, L_batch, _ in dataloader:
947
972
  concentrate = self._get_expression_response(Z_batch)
948
- if use_sampler:
949
- counts = self._count_sample(concentrate)
950
- else:
951
- counts = self._count(concentrate, L_batch)
973
+ counts = self._count(concentrate, L_batch)
952
974
  E.append(tensor_to_numpy(counts))
953
975
  pbar.update(1)
954
976
 
@@ -1093,9 +1115,6 @@ class DensityFlow(nn.Module):
1093
1115
  # Update progress bar
1094
1116
  pbar.set_postfix({'loss': str_loss})
1095
1117
  pbar.update(1)
1096
-
1097
- if self.loss_func == 'negbinomial':
1098
- self.inverse_dispersion = pyro.param("inverse_dispersion")
1099
1118
 
1100
1119
  @classmethod
1101
1120
  def save_model(cls, model, file_path, compression=False):
@@ -1359,5 +1378,4 @@ def main():
1359
1378
 
1360
1379
 
1361
1380
  if __name__ == "__main__":
1362
-
1363
1381
  main()
SURE/perturb/perturb.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  from numba import njit
4
5
  from itertools import chain
5
6
  from joblib import Parallel, delayed
@@ -8,6 +9,8 @@ from typing import Literal
8
9
  class LabelMatrix:
9
10
  def __init__(self):
10
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
11
14
 
12
15
  def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
16
  if speedup=='none':
@@ -24,8 +27,31 @@ class LabelMatrix:
24
27
  mat = np.delete(mat, idx, axis=1)
25
28
  self.labels_ = np.delete(self.labels_, idx)
26
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
27
33
  return mat
28
-
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
29
55
  def inverse_transform(self, matrix):
30
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
31
57
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.2.7
3
+ Version: 2.2.18
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,4 +1,4 @@
1
- SURE/DensityFlow.py,sha256=JE0Cd5jdNPuB5mMQHcOdaUPg8j9lvg20vHqXjxv1dlI,54844
1
+ SURE/DensityFlow.py,sha256=FQ5LT-5xo_O3Qa5I0hxYBMq8f7HvapxfwTQj-oc3iyI,56086
2
2
  SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
3
3
  SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
@@ -12,14 +12,14 @@ SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
12
12
  SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
13
13
  SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
14
14
  SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
15
- SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
15
+ SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
16
16
  SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
17
  SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.2.7.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.2.7.dist-info/METADATA,sha256=JShESEmAZ-N7lzOqnjcfmkwM1s7rXA3ZxcWypCn1f1w,2677
22
- sure_tools-2.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.2.7.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.2.7.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.2.7.dist-info/RECORD,,
20
+ sure_tools-2.2.18.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.2.18.dist-info/METADATA,sha256=QIr1OsgtniZmJyGGBnA-Pfmm385gGJl6W6vMF1KxgRY,2678
22
+ sure_tools-2.2.18.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.2.18.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.2.18.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.2.18.dist-info/RECORD,,