SURE-tools 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/SURE.py CHANGED
@@ -58,11 +58,11 @@ class SURE(nn.Module):
58
58
  ----------
59
59
  inpute_size
60
60
  Number of features (e.g., genes, peaks, proteins, etc.) per cell.
61
- undesired_size
62
- Number of undesired factors. It would be used to adjust for undesired variations like batch effect.
63
61
  codebook_size
64
62
  Number of metacells.
65
- latent_dim
63
+ cell_factor_size
64
+ Number of cell-level factors.
65
+ z_dim
66
66
  Dimensionality of latent states and metacells.
67
67
  hidden_layers
68
68
  A list give the numbers of neurons for each hidden layer.
@@ -73,10 +73,7 @@ class SURE(nn.Module):
73
73
  * ``'negbinomial'`` - negative binomial distribution (default)
74
74
  * ``'poisson'`` - poisson distribution
75
75
  * ``'multinomial'`` - multinomial distribution
76
- user_dirichlet
77
- A boolean option. If toggled on, SURE characterizes single-cell data using a hierarchical model, such as
78
- dirichlet-negative binomial.
79
- latent_dist
76
+ z_dist
80
77
  The distribution model for latent states.
81
78
 
82
79
  One of the following:
@@ -100,7 +97,7 @@ class SURE(nn.Module):
100
97
  supervised_mode: bool = False,
101
98
  z_dim: int = 10,
102
99
  z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
103
- loss_func: Literal['negbinomial','poisson','multinomial'] = 'negbinomial',
100
+ loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
104
101
  inverse_dispersion: float = 10.0,
105
102
  use_zeroinflate: bool = True,
106
103
  hidden_layers: list = [500],
@@ -368,9 +365,12 @@ class SURE(nn.Module):
368
365
 
369
366
  zs = zns
370
367
  concentrate = self.decoder_concentrate(zs)
371
- rate = concentrate.exp()
372
- if self.loss_func != 'poisson':
373
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
368
+ if self.loss_func == 'bernoulli':
369
+ log_theta = concentrate
370
+ else:
371
+ rate = concentrate.exp()
372
+ if self.loss_func == 'negbinomial':
373
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
374
374
 
375
375
  if self.loss_func == 'negbinomial':
376
376
  if self.use_zeroinflate:
@@ -384,6 +384,11 @@ class SURE(nn.Module):
384
384
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
385
385
  elif self.loss_func == 'multinomial':
386
386
  pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
387
+ elif self.loss_func == 'bernoulli':
388
+ if self.use_zeroinflate:
389
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
390
+ else:
391
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
387
392
 
388
393
  def guide1(self, xs):
389
394
  with pyro.plate('data'):
@@ -442,9 +447,12 @@ class SURE(nn.Module):
442
447
  zs = zns
443
448
 
444
449
  concentrate = self.decoder_concentrate(zs)
445
- rate = concentrate.exp()
446
- if self.loss_func != 'poisson':
447
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
450
+ if self.loss_func == 'bernoulli':
451
+ log_theta = concentrate
452
+ else:
453
+ rate = concentrate.exp()
454
+ if self.loss_func == 'negbinomial':
455
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
448
456
 
449
457
  if self.loss_func == 'negbinomial':
450
458
  if self.use_zeroinflate:
@@ -458,6 +466,11 @@ class SURE(nn.Module):
458
466
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
459
467
  elif self.loss_func == 'multinomial':
460
468
  pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
469
+ elif self.loss_func == 'bernoulli':
470
+ if self.use_zeroinflate:
471
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
472
+ else:
473
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
461
474
 
462
475
  def guide2(self, xs, us=None):
463
476
  with pyro.plate('data'):
@@ -528,9 +541,12 @@ class SURE(nn.Module):
528
541
  zs = zns
529
542
 
530
543
  concentrate = self.decoder_concentrate(zs)
531
- rate = concentrate.exp()
532
- if self.loss_func != 'poisson':
533
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
544
+ if self.loss_func == 'bernoulli':
545
+ log_theta = concentrate
546
+ else:
547
+ rate = concentrate.exp()
548
+ if self.loss_func == 'negbinomial':
549
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
534
550
 
535
551
  if self.loss_func == 'negbinomial':
536
552
  if self.use_zeroinflate:
@@ -544,6 +560,11 @@ class SURE(nn.Module):
544
560
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
545
561
  elif self.loss_func == 'multinomial':
546
562
  pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
563
+ elif self.loss_func == 'bernoulli':
564
+ if self.use_zeroinflate:
565
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
566
+ else:
567
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
547
568
 
548
569
  def guide3(self, xs, ys, embeds=None):
549
570
  with pyro.plate('data'):
@@ -616,9 +637,12 @@ class SURE(nn.Module):
616
637
  zs = zns
617
638
 
618
639
  concentrate = self.decoder_concentrate(zs)
619
- rate = concentrate.exp()
620
- if self.loss_func != 'poisson':
621
- theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
640
+ if self.loss_func == 'bernoulli':
641
+ log_theta = concentrate
642
+ else:
643
+ rate = concentrate.exp()
644
+ if self.loss_func == 'negbinomial':
645
+ theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
622
646
 
623
647
  if self.loss_func == 'negbinomial':
624
648
  if self.use_zeroinflate:
@@ -632,6 +656,11 @@ class SURE(nn.Module):
632
656
  pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
633
657
  elif self.loss_func == 'multinomial':
634
658
  pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
659
+ elif self.loss_func == 'bernoulli':
660
+ if self.use_zeroinflate:
661
+ pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
662
+ else:
663
+ pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
635
664
 
636
665
  def guide4(self, xs, us, ys, embeds=None):
637
666
  with pyro.plate('data'):
@@ -764,7 +793,14 @@ class SURE(nn.Module):
764
793
  A = np.concatenate(A)
765
794
  return A
766
795
 
767
- def preprocess(self, xs):
796
+ def preprocess(self, xs, threshold=0):
797
+ if self.loss_func == 'bernoulli':
798
+ ad = sc.AnnData(xs)
799
+ binarize(ad, threshold=threshold)
800
+ xs = ad.X.copy()
801
+ else:
802
+ xs = np.round(xs)
803
+
768
804
  if sparse.issparse(xs):
769
805
  xs = xs.toarray()
770
806
  return xs
@@ -781,6 +817,7 @@ class SURE(nn.Module):
781
817
  weight_decay: float = 0.005,
782
818
  decay_rate: float = 0.9,
783
819
  config_enum: str = 'parallel',
820
+ threshold: int = 0,
784
821
  use_jax: bool = False):
785
822
  """
786
823
  Train the SURE model.
@@ -790,7 +827,7 @@ class SURE(nn.Module):
790
827
  xs
791
828
  Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
792
829
  us
793
- Undesired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are undesired factors.
830
+ cell-level factor matrix.
794
831
  ys
795
832
  Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
796
833
  num_epochs
@@ -811,7 +848,7 @@ class SURE(nn.Module):
811
848
  If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
812
849
  the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
813
850
  """
814
- xs = self.preprocess(xs)
851
+ xs = self.preprocess(xs, threshold=threshold)
815
852
  xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
816
853
  if us is not None:
817
854
  us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
@@ -964,11 +1001,11 @@ def parse_args():
964
1001
  help="the data file",
965
1002
  )
966
1003
  parser.add_argument(
967
- "-undesired",
968
- "--undesired-factor-file",
1004
+ "-cf",
1005
+ "--cell-factor-file",
969
1006
  default=None,
970
1007
  type=str,
971
- help="the file for the record of undesired factors",
1008
+ help="the file for the record of cell-level factors",
972
1009
  )
973
1010
  parser.add_argument(
974
1011
  "-delta",
@@ -1148,18 +1185,18 @@ def main():
1148
1185
 
1149
1186
  xs = dt.fread(file=args.data_file, header=True).to_numpy()
1150
1187
  us = None
1151
- if args.undesired_factor_file is not None:
1152
- us = dt.fread(file=args.undesired_factor_file, header=True).to_numpy()
1188
+ if args.cell_factor_file is not None:
1189
+ us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
1153
1190
 
1154
1191
  input_size = xs.shape[1]
1155
- undesired_size = 0 if us is None else us.shape[1]
1192
+ cell_factor_size = 0 if us is None else us.shape[1]
1156
1193
 
1157
1194
  latent_dist = args.z_dist
1158
1195
 
1159
1196
  ###########################################
1160
1197
  sure = SURE(
1161
1198
  input_size=input_size,
1162
- undesired_size=undesired_size,
1199
+ cell_factor_size=cell_factor_size,
1163
1200
  inverse_dispersion=args.inverse_dispersion,
1164
1201
  latent_dim=args.latent_dim,
1165
1202
  hidden_layers=args.hidden_layers,
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: SURE-tools
3
- Version: 1.0.2
3
+ Version: 1.0.4
4
4
  Summary: Succinct Representation of Single Cells
5
5
  Home-page: https://github.com/ZengFLab/SURE
6
6
  Author: Feng Zeng
@@ -26,6 +26,16 @@ Requires-Dist: networkx
26
26
  Requires-Dist: matplotlib
27
27
  Requires-Dist: seaborn
28
28
  Requires-Dist: fa2-modified
29
+ Dynamic: author
30
+ Dynamic: author-email
31
+ Dynamic: classifier
32
+ Dynamic: description
33
+ Dynamic: description-content-type
34
+ Dynamic: home-page
35
+ Dynamic: license-file
36
+ Dynamic: requires-dist
37
+ Dynamic: requires-python
38
+ Dynamic: summary
29
39
 
30
40
  # SURE: SUccinct REpresentation of cells
31
41
  SURE introduces a vector quantization-based probabilistic generative model for calling metacells and use them as landmarks that form a coordinate system for cell ID. Analyzing single-cell omics data in a manner analogous to reference genome-based genomic analysis.
@@ -1,4 +1,4 @@
1
- SURE/SURE.py,sha256=AMU2EZKJVIE_n4W_K4crVvcpvFOtiiplczoGOCPvvPY,46959
1
+ SURE/SURE.py,sha256=YhsWt3ndKpiIngKTjKOt58_jNyaBDF0wsoeIVNg2Di0,48758
2
2
  SURE/__init__.py,sha256=SbIRwAVBnNhza9vbsUH4N04atr0q_Abp04pCUTBhNio,127
3
3
  SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
4
4
  SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
@@ -9,9 +9,9 @@ SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
9
9
  SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
10
10
  SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
11
11
  SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
12
- SURE_tools-1.0.2.dist-info/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
13
- SURE_tools-1.0.2.dist-info/METADATA,sha256=Fy0Bc3luEPlFIfMRunpTkb18J_tPoqKK4jsfPFnRAlo,2431
14
- SURE_tools-1.0.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
15
- SURE_tools-1.0.2.dist-info/entry_points.txt,sha256=u12payZYgCBy5FCwRHP6AlSQhKCiWSEDwj68r1DVdn8,40
16
- SURE_tools-1.0.2.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
17
- SURE_tools-1.0.2.dist-info/RECORD,,
12
+ sure_tools-1.0.4.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
13
+ sure_tools-1.0.4.dist-info/METADATA,sha256=2dPXR-pUr_8fNXewgDEo4oueSuRE0If05GQU94wypEo,2650
14
+ sure_tools-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
+ sure_tools-1.0.4.dist-info/entry_points.txt,sha256=u12payZYgCBy5FCwRHP6AlSQhKCiWSEDwj68r1DVdn8,40
16
+ sure_tools-1.0.4.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
17
+ sure_tools-1.0.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5