SURE-tools 2.1.83__py3-none-any.whl → 2.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,18 +54,19 @@ def set_random_seed(seed):
54
54
  # Set seed for Pyro
55
55
  pyro.set_rng_seed(seed)
56
56
 
57
- class PerturbFlow(nn.Module):
57
+ class DensityFlow(nn.Module):
58
58
  def __init__(self,
59
59
  input_size: int,
60
60
  codebook_size: int = 200,
61
61
  cell_factor_size: int = 0,
62
+ turn_off_cell_specific: bool = False,
62
63
  supervised_mode: bool = False,
63
64
  z_dim: int = 10,
64
65
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
65
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
66
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'multinomial',
66
67
  inverse_dispersion: float = 10.0,
67
68
  use_zeroinflate: bool = False,
68
- hidden_layers: list = [300],
69
+ hidden_layers: list = [500],
69
70
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
70
71
  nn_dropout: float = 0.1,
71
72
  post_layer_fct: list = ['layernorm'],
@@ -102,6 +103,7 @@ class PerturbFlow(nn.Module):
102
103
  else:
103
104
  self.use_bias = [not zero_bias] * self.cell_factor_size
104
105
  #self.use_bias = not zero_bias
106
+ self.turn_off_cell_specific = turn_off_cell_specific
105
107
 
106
108
  self.codebook_weights = None
107
109
 
@@ -203,27 +205,51 @@ class PerturbFlow(nn.Module):
203
205
  self.cell_factor_effect = nn.ModuleList()
204
206
  for i in np.arange(self.cell_factor_size):
205
207
  if self.use_bias[i]:
206
- self.cell_factor_effect.append(MLP(
207
- [self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
208
- activation=activate_fct,
209
- output_activation=None,
210
- post_layer_fct=post_layer_fct,
211
- post_act_fct=post_act_fct,
212
- allow_broadcast=self.allow_broadcast,
213
- use_cuda=self.use_cuda,
208
+ if self.turn_off_cell_specific:
209
+ self.cell_factor_effect.append(MLP(
210
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
211
+ activation=activate_fct,
212
+ output_activation=None,
213
+ post_layer_fct=post_layer_fct,
214
+ post_act_fct=post_act_fct,
215
+ allow_broadcast=self.allow_broadcast,
216
+ use_cuda=self.use_cuda,
217
+ )
218
+ )
219
+ else:
220
+ self.cell_factor_effect.append(MLP(
221
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
222
+ activation=activate_fct,
223
+ output_activation=None,
224
+ post_layer_fct=post_layer_fct,
225
+ post_act_fct=post_act_fct,
226
+ allow_broadcast=self.allow_broadcast,
227
+ use_cuda=self.use_cuda,
228
+ )
214
229
  )
215
- )
216
230
  else:
217
- self.cell_factor_effect.append(ZeroBiasMLP(
218
- [self.code_size+1] + self.decoder_hidden_layers + [self.latent_dim],
219
- activation=activate_fct,
220
- output_activation=None,
221
- post_layer_fct=post_layer_fct,
222
- post_act_fct=post_act_fct,
223
- allow_broadcast=self.allow_broadcast,
224
- use_cuda=self.use_cuda,
231
+ if self.turn_off_cell_specific:
232
+ self.cell_factor_effect.append(ZeroBiasMLP(
233
+ [1] + self.decoder_hidden_layers + [self.latent_dim],
234
+ activation=activate_fct,
235
+ output_activation=None,
236
+ post_layer_fct=post_layer_fct,
237
+ post_act_fct=post_act_fct,
238
+ allow_broadcast=self.allow_broadcast,
239
+ use_cuda=self.use_cuda,
240
+ )
241
+ )
242
+ else:
243
+ self.cell_factor_effect.append(ZeroBiasMLP(
244
+ [self.latent_dim+1] + self.decoder_hidden_layers + [self.latent_dim],
245
+ activation=activate_fct,
246
+ output_activation=None,
247
+ post_layer_fct=post_layer_fct,
248
+ post_act_fct=post_act_fct,
249
+ allow_broadcast=self.allow_broadcast,
250
+ use_cuda=self.use_cuda,
251
+ )
225
252
  )
226
- )
227
253
 
228
254
  self.decoder_concentrate = MLP(
229
255
  [self.latent_dim] + self.decoder_hidden_layers + [self.input_size],
@@ -308,7 +334,7 @@ class PerturbFlow(nn.Module):
308
334
  return xs
309
335
 
310
336
  def model1(self, xs):
311
- pyro.module('PerturbFlow', self)
337
+ pyro.module('DensityFlow', self)
312
338
 
313
339
  eps = torch.finfo(xs.dtype).eps
314
340
  batch_size = xs.size(0)
@@ -387,7 +413,7 @@ class PerturbFlow(nn.Module):
387
413
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
388
414
 
389
415
  def model2(self, xs, us=None):
390
- pyro.module('PerturbFlow', self)
416
+ pyro.module('DensityFlow', self)
391
417
 
392
418
  eps = torch.finfo(xs.dtype).eps
393
419
  batch_size = xs.size(0)
@@ -429,7 +455,7 @@ class PerturbFlow(nn.Module):
429
455
  zns = pyro.sample('zn', dist.Gumbel(zn_loc, zn_scale).to_event(1))
430
456
 
431
457
  if self.cell_factor_size>0:
432
- zus = self._total_effects(ns, us)
458
+ zus = self._total_effects(zns, us)
433
459
  zs = zns+zus
434
460
  else:
435
461
  zs = zns
@@ -471,7 +497,7 @@ class PerturbFlow(nn.Module):
471
497
  ns = pyro.sample('n', dist.OneHotCategorical(logits=alpha))
472
498
 
473
499
  def model3(self, xs, ys, embeds=None):
474
- pyro.module('PerturbFlow', self)
500
+ pyro.module('DensityFlow', self)
475
501
 
476
502
  eps = torch.finfo(xs.dtype).eps
477
503
  batch_size = xs.size(0)
@@ -567,7 +593,7 @@ class PerturbFlow(nn.Module):
567
593
  zns = embeds
568
594
 
569
595
  def model4(self, xs, us, ys, embeds=None):
570
- pyro.module('PerturbFlow', self)
596
+ pyro.module('DensityFlow', self)
571
597
 
572
598
  eps = torch.finfo(xs.dtype).eps
573
599
  batch_size = xs.size(0)
@@ -631,7 +657,7 @@ class PerturbFlow(nn.Module):
631
657
  # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
632
658
  # else:
633
659
  # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
634
- zus = self._total_effects(ns, us)
660
+ zus = self._total_effects(zns, us)
635
661
  zs = zns+zus
636
662
  else:
637
663
  zs = zns
@@ -676,9 +702,17 @@ class PerturbFlow(nn.Module):
676
702
  zus = None
677
703
  for i in np.arange(self.cell_factor_size):
678
704
  if i==0:
679
- zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
705
+ #if self.turn_off_cell_specific:
706
+ # zus = self.cell_factor_effect[i](us[:,i].reshape(-1,1))
707
+ #else:
708
+ # zus = self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
709
+ zus = self._cell_response(zns, i, us[:,i].reshape(-1,1))
680
710
  else:
681
- zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
711
+ #if self.turn_off_cell_specific:
712
+ # zus = zus + self.cell_factor_effect[i](us[:,i].reshape(-1,1))
713
+ #else:
714
+ # zus = zus + self.cell_factor_effect[i]([zns,us[:,i].reshape(-1,1)])
715
+ zus = zus + self._cell_response(zns, i, us[:,i].reshape(-1,1))
682
716
  return zus
683
717
 
684
718
  def _get_codebook_identity(self):
@@ -696,7 +730,7 @@ class PerturbFlow(nn.Module):
696
730
  """
697
731
  Return the mean part of metacell codebook
698
732
  """
699
- cb = self._get_metacell_coordinates()
733
+ cb = self._get_codebook()
700
734
  cb = tensor_to_numpy(cb)
701
735
  return cb
702
736
 
@@ -820,13 +854,15 @@ class PerturbFlow(nn.Module):
820
854
  us_i = us[:,pert_idx].reshape(-1,1)
821
855
 
822
856
  # factor effect of xs
823
- dzs0 = self.get_cell_response(xs, factor_idx=pert_idx, perturb=us_i)
857
+ dzs0 = self.get_cell_response(zs, factor_idx=pert_idx, perturb=us_i)
824
858
 
825
859
  # perturbation effect
826
860
  ps = np.ones_like(us_i)
827
- dzs = self.get_cell_response(xs, factor_idx=pert_idx, perturb=ps)
828
-
829
- zs = zs + dzs0 + dzs
861
+ if np.sum(np.abs(ps-us_i))>=1:
862
+ dzs = self.get_cell_response(zs, factor_idx=pert_idx, perturb=ps)
863
+ zs = zs + dzs0 + dzs
864
+ else:
865
+ zs = zs + dzs0
830
866
 
831
867
  if library_sizes is None:
832
868
  library_sizes = np.sum(xs, axis=1, keepdims=True)
@@ -840,49 +876,48 @@ class PerturbFlow(nn.Module):
840
876
 
841
877
  return counts, zs
842
878
 
843
- def _cell_response(self, xs, factor_idx, perturb):
879
+ def _cell_response(self, zs, perturb_idx, perturb):
844
880
  #zns,_ = self.encoder_zn(xs)
845
881
  #zns,_ = self._get_basal_embedding(xs)
846
- zns = self._soft_assignments(xs)
882
+ zns = zs
847
883
  if perturb.ndim==2:
848
- ms = self.cell_factor_effect[factor_idx]([zns, perturb])
884
+ if self.turn_off_cell_specific:
885
+ ms = self.cell_factor_effect[perturb_idx](perturb)
886
+ else:
887
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb])
849
888
  else:
850
- ms = self.cell_factor_effect[factor_idx]([zns, perturb.reshape(-1,1)])
889
+ if self.turn_off_cell_specific:
890
+ ms = self.cell_factor_effect[perturb_idx](perturb.reshape(-1,1))
891
+ else:
892
+ ms = self.cell_factor_effect[perturb_idx]([zns, perturb.reshape(-1,1)])
851
893
 
852
894
  return ms
853
895
 
854
896
  def get_cell_response(self,
855
- xs,
856
- factor_idx,
857
- perturb,
897
+ zs,
898
+ perturb_idx,
899
+ perturb_us,
858
900
  batch_size: int = 1024):
859
901
  """
860
902
  Return cells' changes in the latent space induced by specific perturbation of a factor
861
903
 
862
904
  """
863
- xs = self.preprocess(xs)
864
- xs = convert_to_tensor(xs, device=self.get_device())
865
- ps = convert_to_tensor(perturb, device=self.get_device())
866
- dataset = CustomDataset2(xs,ps)
905
+ #xs = self.preprocess(xs)
906
+ zs = convert_to_tensor(zs, device=self.get_device())
907
+ ps = convert_to_tensor(perturb_us, device=self.get_device())
908
+ dataset = CustomDataset2(zs,ps)
867
909
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
868
910
 
869
911
  Z = []
870
912
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
871
- for X_batch, P_batch, _ in dataloader:
872
- zns = self._cell_response(X_batch, factor_idx, P_batch)
913
+ for Z_batch, P_batch, _ in dataloader:
914
+ zns = self._cell_response(Z_batch, perturb_idx, P_batch)
873
915
  Z.append(tensor_to_numpy(zns))
874
916
  pbar.update(1)
875
917
 
876
918
  Z = np.concatenate(Z)
877
919
  return Z
878
920
 
879
- def get_metacell_response(self, factor_idx, perturb):
880
- #zs = self._get_codebook()
881
- zs = self._get_codebook_identity()
882
- ps = convert_to_tensor(perturb, device=self.get_device())
883
- ms = self.cell_factor_effect[factor_idx]([zs,ps])
884
- return tensor_to_numpy(ms)
885
-
886
921
  def _get_expression_response(self, delta_zs):
887
922
  return self.decoder_concentrate(delta_zs)
888
923
 
@@ -907,7 +942,7 @@ class PerturbFlow(nn.Module):
907
942
  R = np.concatenate(R)
908
943
  return R
909
944
 
910
- def _count(self,concentrate, library_size=None):
945
+ def _count(self, concentrate, library_size=None):
911
946
  if self.loss_func == 'bernoulli':
912
947
  #counts = self.sigmoid(concentrate)
913
948
  counts = dist.Bernoulli(logits=concentrate).to_event(1).mean
@@ -915,28 +950,17 @@ class PerturbFlow(nn.Module):
915
950
  rate = concentrate.exp()
916
951
  theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
917
952
  counts = theta * library_size
918
- #counts = dist.Poisson(rate=rate).to_event(1).mean
919
- return counts
920
-
921
- def _count_sample(self,concentrate):
922
- if self.loss_func == 'bernoulli':
923
- logits = concentrate
924
- counts = dist.Bernoulli(logits=logits).to_event(1).sample()
925
- else:
926
- counts = self._count(concentrate=concentrate)
927
- counts = dist.Poisson(rate=counts).to_event(1).sample()
928
953
  return counts
929
954
 
930
955
  def get_counts(self, zs, library_sizes,
931
- batch_size: int = 1024,
932
- use_sampler: bool = False):
956
+ batch_size: int = 1024):
933
957
 
934
958
  zs = convert_to_tensor(zs, device=self.get_device())
935
959
 
936
960
  if type(library_sizes) == list:
937
- library_sizes = np.array(library_sizes).view(-1,1)
961
+ library_sizes = np.array(library_sizes).reshape(-1,1)
938
962
  elif len(library_sizes.shape)==1:
939
- library_sizes = library_sizes.view(-1,1)
963
+ library_sizes = library_sizes.reshape(-1,1)
940
964
  ls = convert_to_tensor(library_sizes, device=self.get_device())
941
965
 
942
966
  dataset = CustomDataset2(zs,ls)
@@ -946,10 +970,7 @@ class PerturbFlow(nn.Module):
946
970
  with tqdm(total=len(dataloader), desc='', unit='batch') as pbar:
947
971
  for Z_batch, L_batch, _ in dataloader:
948
972
  concentrate = self._get_expression_response(Z_batch)
949
- if use_sampler:
950
- counts = self._count_sample(concentrate)
951
- else:
952
- counts = self._count(concentrate, L_batch)
973
+ counts = self._count(concentrate, L_batch)
953
974
  E.append(tensor_to_numpy(counts))
954
975
  pbar.update(1)
955
976
 
@@ -972,7 +993,7 @@ class PerturbFlow(nn.Module):
972
993
  us = None,
973
994
  ys = None,
974
995
  zs = None,
975
- num_epochs: int = 200,
996
+ num_epochs: int = 500,
976
997
  learning_rate: float = 0.0001,
977
998
  batch_size: int = 256,
978
999
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -983,7 +1004,7 @@ class PerturbFlow(nn.Module):
983
1004
  threshold: int = 0,
984
1005
  use_jax: bool = True):
985
1006
  """
986
- Train the PerturbFlow model.
1007
+ Train the DensityFlow model.
987
1008
 
988
1009
  Parameters
989
1010
  ----------
@@ -1009,7 +1030,7 @@ class PerturbFlow(nn.Module):
1009
1030
  Parameter for optimization.
1010
1031
  use_jax
1011
1032
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
1012
- the Python script or Jupyter notebook. It is OK if it is used when runing PerturbFlow in the shell command.
1033
+ the Python script or Jupyter notebook. It is OK if it is used when runing DensityFlow in the shell command.
1013
1034
  """
1014
1035
  xs = self.preprocess(xs, threshold=threshold)
1015
1036
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
@@ -1127,12 +1148,12 @@ class PerturbFlow(nn.Module):
1127
1148
 
1128
1149
 
1129
1150
  EXAMPLE_RUN = (
1130
- "example run: PerturbFlow --help"
1151
+ "example run: DensityFlow --help"
1131
1152
  )
1132
1153
 
1133
1154
  def parse_args():
1134
1155
  parser = argparse.ArgumentParser(
1135
- description="PerturbFlow\n{}".format(EXAMPLE_RUN))
1156
+ description="DensityFlow\n{}".format(EXAMPLE_RUN))
1136
1157
 
1137
1158
  parser.add_argument(
1138
1159
  "--cuda", action="store_true", help="use GPU(s) to speed up training"
@@ -1319,7 +1340,7 @@ def main():
1319
1340
  cell_factor_size = 0 if us is None else us.shape[1]
1320
1341
 
1321
1342
  ###########################################
1322
- perturbflow = PerturbFlow(
1343
+ DensityFlow = DensityFlow(
1323
1344
  input_size=input_size,
1324
1345
  cell_factor_size=cell_factor_size,
1325
1346
  inverse_dispersion=args.inverse_dispersion,
@@ -1338,7 +1359,7 @@ def main():
1338
1359
  dtype=dtype,
1339
1360
  )
1340
1361
 
1341
- perturbflow.fit(xs, us=us,
1362
+ DensityFlow.fit(xs, us=us,
1342
1363
  num_epochs=args.num_epochs,
1343
1364
  learning_rate=args.learning_rate,
1344
1365
  batch_size=args.batch_size,
@@ -1350,12 +1371,11 @@ def main():
1350
1371
 
1351
1372
  if args.save_model is not None:
1352
1373
  if args.save_model.endswith('gz'):
1353
- PerturbFlow.save_model(perturbflow, args.save_model, compression=True)
1374
+ DensityFlow.save_model(DensityFlow, args.save_model, compression=True)
1354
1375
  else:
1355
- PerturbFlow.save_model(perturbflow, args.save_model)
1376
+ DensityFlow.save_model(DensityFlow, args.save_model)
1356
1377
 
1357
1378
 
1358
1379
 
1359
1380
  if __name__ == "__main__":
1360
-
1361
1381
  main()
SURE/SURE.py CHANGED
@@ -99,17 +99,17 @@ class SURE(nn.Module):
99
99
  cell_factor_size: int = 0,
100
100
  supervised_mode: bool = False,
101
101
  z_dim: int = 10,
102
- z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
- loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
102
+ z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'gumbel',
103
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'poisson',
104
104
  inverse_dispersion: float = 10.0,
105
105
  use_zeroinflate: bool = True,
106
- hidden_layers: list = [300],
106
+ hidden_layers: list = [500],
107
107
  hidden_layer_activation: Literal['relu','softplus','leakyrelu','linear'] = 'relu',
108
108
  nn_dropout: float = 0.1,
109
109
  post_layer_fct: list = ['layernorm'],
110
110
  post_act_fct: list = None,
111
111
  config_enum: str = 'parallel',
112
- use_cuda: bool = False,
112
+ use_cuda: bool = True,
113
113
  seed: int = 42,
114
114
  dtype = torch.float32, # type: ignore
115
115
  ):
@@ -817,7 +817,7 @@ class SURE(nn.Module):
817
817
  us = None,
818
818
  ys = None,
819
819
  zs = None,
820
- num_epochs: int = 200,
820
+ num_epochs: int = 500,
821
821
  learning_rate: float = 0.0001,
822
822
  batch_size: int = 256,
823
823
  algo: Literal['adam','rmsprop','adamw'] = 'adam',
@@ -826,7 +826,7 @@ class SURE(nn.Module):
826
826
  decay_rate: float = 0.9,
827
827
  config_enum: str = 'parallel',
828
828
  threshold: int = 0,
829
- use_jax: bool = False):
829
+ use_jax: bool = True):
830
830
  """
831
831
  Train the SURE model.
832
832
 
SURE/__init__.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from .SURE import SURE
2
- from .PerturbFlow import PerturbFlow
2
+ from .DensityFlow import DensityFlow
3
3
 
4
4
  from . import utils
5
5
  from . import codebook
6
6
  from . import SURE
7
- from . import PerturbFlow
7
+ from . import DensityFlow
8
8
  from . import atac
9
9
  from . import flow
10
10
  from . import perturb
11
11
 
12
- __all__ = ['SURE', 'PerturbFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
12
+ __all__ = ['SURE', 'DensityFlow', 'flow', 'perturb', 'atac', 'utils', 'codebook']
SURE/flow/flow_stats.py CHANGED
@@ -41,6 +41,18 @@ class VectorFieldEval:
41
41
  divergence[np.isnan(divergence)] = 0
42
42
 
43
43
  return divergence
44
+
45
+ def movement_stats(self,vectors):
46
+ return calculate_movement_stats(vectors)
47
+
48
+ def direction_stats(self, vectors):
49
+ return calculate_direction_stats(vectors)
50
+
51
+ def movement_energy(self, vectors, masses=None):
52
+ return calculate_movement_energy(vectors, masses)
53
+
54
+ def movement_divergence(self, positions, vectors):
55
+ return calculate_movement_divergence(positions, vectors)
44
56
 
45
57
 
46
58
  def calculate_movement_stats(vectors):
SURE/perturb/perturb.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import re
2
2
  import numpy as np
3
+ import pandas as pd
3
4
  from numba import njit
4
5
  from itertools import chain
5
6
  from joblib import Parallel, delayed
@@ -8,6 +9,8 @@ from typing import Literal
8
9
  class LabelMatrix:
9
10
  def __init__(self):
10
11
  self.labels_ = None
12
+ self.control_label = None
13
+ self.sep_pattern = None
11
14
 
12
15
  def fit_transform(self, labels, control_label=None, sep_pattern=r'[,;_\s]', speedup: Literal['none','vectorize','parallel']='none'):
13
16
  if speedup=='none':
@@ -24,8 +27,31 @@ class LabelMatrix:
24
27
  mat = np.delete(mat, idx, axis=1)
25
28
  self.labels_ = np.delete(self.labels_, idx)
26
29
 
30
+ self.control_label = control_label
31
+ self.sep_pattern=sep_pattern
32
+
27
33
  return mat
28
-
34
+
35
+ def transform(self, labels, speedup: Literal['none','vectorize','parallel']='none'):
36
+ sep_pattern = self.sep_pattern
37
+ if speedup=='none':
38
+ mat, labels_ = label_to_matrix(labels=labels, sep_pattern=sep_pattern)
39
+ elif speedup=='vectorize':
40
+ mat, labels_ = vectorized_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
41
+ elif speedup=='parallel':
42
+ mat, labels_ = parallel_label_to_matrix(labels=labels, sep_pattern=sep_pattern)
43
+
44
+ mat_df = pd.DataFrame(mat, columns=labels_)
45
+
46
+ labels_valid = [x for x in labels_ if x in self.labels_]
47
+ mat_df = mat_df[labels_valid]
48
+
49
+ mat_valid = np.zeros([mat.shape[0], len(self.labels_)])
50
+ mat_valid_df = pd.DataFrame(mat_valid, columns=self.labels_)
51
+ mat_valid_df[labels_valid] = mat_df
52
+
53
+ return mat_valid_df.values
54
+
29
55
  def inverse_transform(self, matrix):
30
56
  return matrix_to_labels(matrix=matrix, unique_labels=self.labels_)
31
57
 
SURE/utils/custom_mlp.py CHANGED
@@ -239,7 +239,10 @@ class ZeroBiasMLP(nn.Module):
239
239
  def forward(self, x):
240
240
  y = self.mlp(x)
241
241
  mask = torch.zeros_like(y)
242
- mask[x[1][:,0]>0,:] = 1
242
+ if len(y.shape)==2:
243
+ mask[x[1][:,0]>0,:] = 1
244
+ elif len(y.shape)==3:
245
+ mask[:,x[1][:,0]>0,:] = 1
243
246
  return y*mask
244
247
 
245
248
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 2.1.83
3
+ Version: 2.2.23
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -1,6 +1,6 @@
1
- SURE/PerturbFlow.py,sha256=5HzS8oB06iSR3JM5AalGfYi-quxbjkTZeTypjih-VBI,54759
2
- SURE/SURE.py,sha256=g8EhovBxjfpbVJA0AkmVkQ_ZW_JFc8TtkTCg8FCybV4,47750
3
- SURE/__init__.py,sha256=NOJI_K-eCqPgStXXvgl3wIEMp6d8saMTDYLJ7Ga9MqE,293
1
+ SURE/DensityFlow.py,sha256=p5Pt3KrsdF_NTLFx0p1cUPuXkIac6wQED1LsLJRG7mI,56124
2
+ SURE/SURE.py,sha256=MXs7iuvcj-lU4dJ_MwKegpL2Rqk2HB4eFfAgHRA3RtA,47744
3
+ SURE/__init__.py,sha256=NVp22RCHrhSwHNMomABC-eftoCYvt7vV1XOzim-UZHE,293
4
4
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
5
5
  SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
6
6
  SURE/assembly/atlas.py,sha256=ALjmVWutm_tOHTcT1aqOxmuCEQw-XzrtDoMCV_8oXLk,21794
@@ -9,17 +9,17 @@ SURE/atac/utils.py,sha256=m4NYwpy9O5T1pXTzgCOCcmlwrC6GTi-cQ5sm2wZu2O8,4354
9
9
  SURE/codebook/__init__.py,sha256=2T5gjp8JIaBayrXAnOJYSebQHsWprOs87difpR1OPNw,243
10
10
  SURE/codebook/codebook.py,sha256=ZlN6gRX9Gj2D2u3P5KeOsbZri0MoMAiJo9lNeL-MK-I,17117
11
11
  SURE/flow/__init__.py,sha256=rsAjYsh1xVIrxBCuwOE0Q_6N5th1wBgjJceV0ABPG3c,183
12
- SURE/flow/flow_stats.py,sha256=_pF7m4-87SKlCHVtVmx3LG2bAGVXOnAfEgMzLhLx4Io,10910
12
+ SURE/flow/flow_stats.py,sha256=6SzNMT59WRFRP1nC6bvpBPF7BugWnkIS_DSlr4S-Ez0,11338
13
13
  SURE/flow/plot_quiver.py,sha256=UbmuScUcgbQHeMmjKmgqxjrIjHhiHx0VWct16UMMwuE,8110
14
14
  SURE/perturb/__init__.py,sha256=8TP1dSUhXiZzKpFebHZmm8XMMGbUz_OfQ10xu-6uPPY,43
15
- SURE/perturb/perturb.py,sha256=1iSsCePcwkA2CyM1nCdq_G8gogUNjhMH0BfhhvhpJQk,5037
15
+ SURE/perturb/perturb.py,sha256=ey7cxsM1tO1MW4UaE_MLpLHK87CjvXzn2CBPtvv1VZ0,6116
16
16
  SURE/utils/__init__.py,sha256=YF5jB-PAHJQ40OlcZ7BCZbsN2q1JKuPT6EppilRXQqM,680
17
- SURE/utils/custom_mlp.py,sha256=C0EXLGYsWkUQpEL49AyBFPSzKmasb2hdvtnJfxbF-YU,9282
17
+ SURE/utils/custom_mlp.py,sha256=HuNb7f8-6RFjsvfEu1XOuNpLrHZkGYHgf8TpJfPSNO0,9382
18
18
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
19
19
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
20
- sure_tools-2.1.83.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
- sure_tools-2.1.83.dist-info/METADATA,sha256=H-q3GA7c-UxJp8C3OfR-f7YpSkhqaSQD3oZ_qcg9OJo,2678
22
- sure_tools-2.1.83.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
- sure_tools-2.1.83.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
- sure_tools-2.1.83.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
- sure_tools-2.1.83.dist-info/RECORD,,
20
+ sure_tools-2.2.23.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
21
+ sure_tools-2.2.23.dist-info/METADATA,sha256=ckAOsGL19y8unUmL2zYK4yeTRGFyALbaN_3hM18u0tw,2678
22
+ sure_tools-2.2.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
23
+ sure_tools-2.2.23.dist-info/entry_points.txt,sha256=-nJI8rVe_qqrR0HmfAODzj-JNfEqCcSsyVh6okSqyHk,83
24
+ sure_tools-2.2.23.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
25
+ sure_tools-2.2.23.dist-info/RECORD,,