SURE-tools 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/SURE.py +67 -30
- {SURE_tools-1.0.2.dist-info → sure_tools-1.0.4.dist-info}/METADATA +12 -2
- {SURE_tools-1.0.2.dist-info → sure_tools-1.0.4.dist-info}/RECORD +7 -7
- {SURE_tools-1.0.2.dist-info → sure_tools-1.0.4.dist-info}/WHEEL +1 -1
- {SURE_tools-1.0.2.dist-info → sure_tools-1.0.4.dist-info}/entry_points.txt +0 -0
- {SURE_tools-1.0.2.dist-info → sure_tools-1.0.4.dist-info/licenses}/LICENSE +0 -0
- {SURE_tools-1.0.2.dist-info → sure_tools-1.0.4.dist-info}/top_level.txt +0 -0
SURE/SURE.py
CHANGED
|
@@ -58,11 +58,11 @@ class SURE(nn.Module):
|
|
|
58
58
|
----------
|
|
59
59
|
inpute_size
|
|
60
60
|
Number of features (e.g., genes, peaks, proteins, etc.) per cell.
|
|
61
|
-
undesired_size
|
|
62
|
-
Number of undesired factors. It would be used to adjust for undesired variations like batch effect.
|
|
63
61
|
codebook_size
|
|
64
62
|
Number of metacells.
|
|
65
|
-
|
|
63
|
+
cell_factor_size
|
|
64
|
+
Number of cell-level factors.
|
|
65
|
+
z_dim
|
|
66
66
|
Dimensionality of latent states and metacells.
|
|
67
67
|
hidden_layers
|
|
68
68
|
A list give the numbers of neurons for each hidden layer.
|
|
@@ -73,10 +73,7 @@ class SURE(nn.Module):
|
|
|
73
73
|
* ``'negbinomial'`` - negative binomial distribution (default)
|
|
74
74
|
* ``'poisson'`` - poisson distribution
|
|
75
75
|
* ``'multinomial'`` - multinomial distribution
|
|
76
|
-
|
|
77
|
-
A boolean option. If toggled on, SURE characterizes single-cell data using a hierarchical model, such as
|
|
78
|
-
dirichlet-negative binomial.
|
|
79
|
-
latent_dist
|
|
76
|
+
z_dist
|
|
80
77
|
The distribution model for latent states.
|
|
81
78
|
|
|
82
79
|
One of the following:
|
|
@@ -100,7 +97,7 @@ class SURE(nn.Module):
|
|
|
100
97
|
supervised_mode: bool = False,
|
|
101
98
|
z_dim: int = 10,
|
|
102
99
|
z_dist: Literal['normal','studentt','laplacian','cauchy','gumbel'] = 'normal',
|
|
103
|
-
loss_func: Literal['negbinomial','poisson','multinomial'] = 'negbinomial',
|
|
100
|
+
loss_func: Literal['negbinomial','poisson','multinomial','bernoulli'] = 'negbinomial',
|
|
104
101
|
inverse_dispersion: float = 10.0,
|
|
105
102
|
use_zeroinflate: bool = True,
|
|
106
103
|
hidden_layers: list = [500],
|
|
@@ -368,9 +365,12 @@ class SURE(nn.Module):
|
|
|
368
365
|
|
|
369
366
|
zs = zns
|
|
370
367
|
concentrate = self.decoder_concentrate(zs)
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
368
|
+
if self.loss_func == 'bernoulli':
|
|
369
|
+
log_theta = concentrate
|
|
370
|
+
else:
|
|
371
|
+
rate = concentrate.exp()
|
|
372
|
+
if self.loss_func == 'negbinomial':
|
|
373
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
374
374
|
|
|
375
375
|
if self.loss_func == 'negbinomial':
|
|
376
376
|
if self.use_zeroinflate:
|
|
@@ -384,6 +384,11 @@ class SURE(nn.Module):
|
|
|
384
384
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
385
385
|
elif self.loss_func == 'multinomial':
|
|
386
386
|
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
387
|
+
elif self.loss_func == 'bernoulli':
|
|
388
|
+
if self.use_zeroinflate:
|
|
389
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
390
|
+
else:
|
|
391
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
387
392
|
|
|
388
393
|
def guide1(self, xs):
|
|
389
394
|
with pyro.plate('data'):
|
|
@@ -442,9 +447,12 @@ class SURE(nn.Module):
|
|
|
442
447
|
zs = zns
|
|
443
448
|
|
|
444
449
|
concentrate = self.decoder_concentrate(zs)
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
450
|
+
if self.loss_func == 'bernoulli':
|
|
451
|
+
log_theta = concentrate
|
|
452
|
+
else:
|
|
453
|
+
rate = concentrate.exp()
|
|
454
|
+
if self.loss_func == 'negbinomial':
|
|
455
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
448
456
|
|
|
449
457
|
if self.loss_func == 'negbinomial':
|
|
450
458
|
if self.use_zeroinflate:
|
|
@@ -458,6 +466,11 @@ class SURE(nn.Module):
|
|
|
458
466
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
459
467
|
elif self.loss_func == 'multinomial':
|
|
460
468
|
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
469
|
+
elif self.loss_func == 'bernoulli':
|
|
470
|
+
if self.use_zeroinflate:
|
|
471
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
472
|
+
else:
|
|
473
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
461
474
|
|
|
462
475
|
def guide2(self, xs, us=None):
|
|
463
476
|
with pyro.plate('data'):
|
|
@@ -528,9 +541,12 @@ class SURE(nn.Module):
|
|
|
528
541
|
zs = zns
|
|
529
542
|
|
|
530
543
|
concentrate = self.decoder_concentrate(zs)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
544
|
+
if self.loss_func == 'bernoulli':
|
|
545
|
+
log_theta = concentrate
|
|
546
|
+
else:
|
|
547
|
+
rate = concentrate.exp()
|
|
548
|
+
if self.loss_func == 'negbinomial':
|
|
549
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
534
550
|
|
|
535
551
|
if self.loss_func == 'negbinomial':
|
|
536
552
|
if self.use_zeroinflate:
|
|
@@ -544,6 +560,11 @@ class SURE(nn.Module):
|
|
|
544
560
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
545
561
|
elif self.loss_func == 'multinomial':
|
|
546
562
|
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
563
|
+
elif self.loss_func == 'bernoulli':
|
|
564
|
+
if self.use_zeroinflate:
|
|
565
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
566
|
+
else:
|
|
567
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
547
568
|
|
|
548
569
|
def guide3(self, xs, ys, embeds=None):
|
|
549
570
|
with pyro.plate('data'):
|
|
@@ -616,9 +637,12 @@ class SURE(nn.Module):
|
|
|
616
637
|
zs = zns
|
|
617
638
|
|
|
618
639
|
concentrate = self.decoder_concentrate(zs)
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
640
|
+
if self.loss_func == 'bernoulli':
|
|
641
|
+
log_theta = concentrate
|
|
642
|
+
else:
|
|
643
|
+
rate = concentrate.exp()
|
|
644
|
+
if self.loss_func == 'negbinomial':
|
|
645
|
+
theta = dist.DirichletMultinomial(total_count=1, concentration=rate).mean
|
|
622
646
|
|
|
623
647
|
if self.loss_func == 'negbinomial':
|
|
624
648
|
if self.use_zeroinflate:
|
|
@@ -632,6 +656,11 @@ class SURE(nn.Module):
|
|
|
632
656
|
pyro.sample('x', dist.Poisson(rate=rate).to_event(1), obs=xs.round())
|
|
633
657
|
elif self.loss_func == 'multinomial':
|
|
634
658
|
pyro.sample('x', dist.Multinomial(total_count=int(1e8), probs=theta), obs=xs)
|
|
659
|
+
elif self.loss_func == 'bernoulli':
|
|
660
|
+
if self.use_zeroinflate:
|
|
661
|
+
pyro.sample('x', dist.ZeroInflatedDistribution(dist.Bernoulli(logits=log_theta),gate_logits=gate_logits).to_event(1), obs=xs)
|
|
662
|
+
else:
|
|
663
|
+
pyro.sample('x', dist.Bernoulli(logits=log_theta).to_event(1), obs=xs)
|
|
635
664
|
|
|
636
665
|
def guide4(self, xs, us, ys, embeds=None):
|
|
637
666
|
with pyro.plate('data'):
|
|
@@ -764,7 +793,14 @@ class SURE(nn.Module):
|
|
|
764
793
|
A = np.concatenate(A)
|
|
765
794
|
return A
|
|
766
795
|
|
|
767
|
-
def preprocess(self, xs):
|
|
796
|
+
def preprocess(self, xs, threshold=0):
|
|
797
|
+
if self.loss_func == 'bernoulli':
|
|
798
|
+
ad = sc.AnnData(xs)
|
|
799
|
+
binarize(ad, threshold=threshold)
|
|
800
|
+
xs = ad.X.copy()
|
|
801
|
+
else:
|
|
802
|
+
xs = np.round(xs)
|
|
803
|
+
|
|
768
804
|
if sparse.issparse(xs):
|
|
769
805
|
xs = xs.toarray()
|
|
770
806
|
return xs
|
|
@@ -781,6 +817,7 @@ class SURE(nn.Module):
|
|
|
781
817
|
weight_decay: float = 0.005,
|
|
782
818
|
decay_rate: float = 0.9,
|
|
783
819
|
config_enum: str = 'parallel',
|
|
820
|
+
threshold: int = 0,
|
|
784
821
|
use_jax: bool = False):
|
|
785
822
|
"""
|
|
786
823
|
Train the SURE model.
|
|
@@ -790,7 +827,7 @@ class SURE(nn.Module):
|
|
|
790
827
|
xs
|
|
791
828
|
Single-cell experssion matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are features.
|
|
792
829
|
us
|
|
793
|
-
|
|
830
|
+
cell-level factor matrix.
|
|
794
831
|
ys
|
|
795
832
|
Desired factor matrix. It should be a Numpy array or a Pytorch Tensor. Rows are cells and columns are desired factors.
|
|
796
833
|
num_epochs
|
|
@@ -811,7 +848,7 @@ class SURE(nn.Module):
|
|
|
811
848
|
If toggled on, Jax will be used for speeding up. CAUTION: This will raise errors because of unknown reasons when it is called in
|
|
812
849
|
the Python script or Jupyter notebook. It is OK if it is used when runing SURE in the shell command.
|
|
813
850
|
"""
|
|
814
|
-
xs = self.preprocess(xs)
|
|
851
|
+
xs = self.preprocess(xs, threshold=threshold)
|
|
815
852
|
xs = convert_to_tensor(xs, dtype=self.dtype, device=self.get_device())
|
|
816
853
|
if us is not None:
|
|
817
854
|
us = convert_to_tensor(us, dtype=self.dtype, device=self.get_device())
|
|
@@ -964,11 +1001,11 @@ def parse_args():
|
|
|
964
1001
|
help="the data file",
|
|
965
1002
|
)
|
|
966
1003
|
parser.add_argument(
|
|
967
|
-
"-
|
|
968
|
-
"--
|
|
1004
|
+
"-cf",
|
|
1005
|
+
"--cell-factor-file",
|
|
969
1006
|
default=None,
|
|
970
1007
|
type=str,
|
|
971
|
-
help="the file for the record of
|
|
1008
|
+
help="the file for the record of cell-level factors",
|
|
972
1009
|
)
|
|
973
1010
|
parser.add_argument(
|
|
974
1011
|
"-delta",
|
|
@@ -1148,18 +1185,18 @@ def main():
|
|
|
1148
1185
|
|
|
1149
1186
|
xs = dt.fread(file=args.data_file, header=True).to_numpy()
|
|
1150
1187
|
us = None
|
|
1151
|
-
if args.
|
|
1152
|
-
us = dt.fread(file=args.
|
|
1188
|
+
if args.cell_factor_file is not None:
|
|
1189
|
+
us = dt.fread(file=args.cell_factor_file, header=True).to_numpy()
|
|
1153
1190
|
|
|
1154
1191
|
input_size = xs.shape[1]
|
|
1155
|
-
|
|
1192
|
+
cell_factor_size = 0 if us is None else us.shape[1]
|
|
1156
1193
|
|
|
1157
1194
|
latent_dist = args.z_dist
|
|
1158
1195
|
|
|
1159
1196
|
###########################################
|
|
1160
1197
|
sure = SURE(
|
|
1161
1198
|
input_size=input_size,
|
|
1162
|
-
|
|
1199
|
+
cell_factor_size=cell_factor_size,
|
|
1163
1200
|
inverse_dispersion=args.inverse_dispersion,
|
|
1164
1201
|
latent_dim=args.latent_dim,
|
|
1165
1202
|
hidden_layers=args.hidden_layers,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: SURE-tools
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.4
|
|
4
4
|
Summary: Succinct Representation of Single Cells
|
|
5
5
|
Home-page: https://github.com/ZengFLab/SURE
|
|
6
6
|
Author: Feng Zeng
|
|
@@ -26,6 +26,16 @@ Requires-Dist: networkx
|
|
|
26
26
|
Requires-Dist: matplotlib
|
|
27
27
|
Requires-Dist: seaborn
|
|
28
28
|
Requires-Dist: fa2-modified
|
|
29
|
+
Dynamic: author
|
|
30
|
+
Dynamic: author-email
|
|
31
|
+
Dynamic: classifier
|
|
32
|
+
Dynamic: description
|
|
33
|
+
Dynamic: description-content-type
|
|
34
|
+
Dynamic: home-page
|
|
35
|
+
Dynamic: license-file
|
|
36
|
+
Dynamic: requires-dist
|
|
37
|
+
Dynamic: requires-python
|
|
38
|
+
Dynamic: summary
|
|
29
39
|
|
|
30
40
|
# SURE: SUccinct REpresentation of cells
|
|
31
41
|
SURE introduces a vector quantization-based probabilistic generative model for calling metacells and use them as landmarks that form a coordinate system for cell ID. Analyzing single-cell omics data in a manner analogous to reference genome-based genomic analysis.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
SURE/SURE.py,sha256=
|
|
1
|
+
SURE/SURE.py,sha256=YhsWt3ndKpiIngKTjKOt58_jNyaBDF0wsoeIVNg2Di0,48758
|
|
2
2
|
SURE/__init__.py,sha256=SbIRwAVBnNhza9vbsUH4N04atr0q_Abp04pCUTBhNio,127
|
|
3
3
|
SURE/assembly/__init__.py,sha256=jxZLURXKPzXe21LhrZ09LgZr33iqdjlQy4oSEj5gR2Q,172
|
|
4
4
|
SURE/assembly/assembly.py,sha256=6IMdelPOiRO4mUb4dC7gVCoF1Uvfw86-Map8P_jnUag,21477
|
|
@@ -9,9 +9,9 @@ SURE/utils/__init__.py,sha256=Htqv4KqVKcRiaaTBsR-6yZ4LSlbhbzutjNKXGD9-uds,660
|
|
|
9
9
|
SURE/utils/custom_mlp.py,sha256=07TYX1HgxfEjb_3i5MpiZfNhOhx3dKntuwGkrpteWiM,7036
|
|
10
10
|
SURE/utils/queue.py,sha256=E_5PA5EWcBoGAZj8BkKQnkCK0p4C-4-xcTPqdIXaPXU,1892
|
|
11
11
|
SURE/utils/utils.py,sha256=IUHjDDtYaAYllCWsZyIzqQwaLul6fJRvHRH4vIYcR-c,8462
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
12
|
+
sure_tools-1.0.4.dist-info/licenses/LICENSE,sha256=TFHKwmrAViXQbSX5W-NDItkWFjm45HWOeUniDrqmnu0,1065
|
|
13
|
+
sure_tools-1.0.4.dist-info/METADATA,sha256=2dPXR-pUr_8fNXewgDEo4oueSuRE0If05GQU94wypEo,2650
|
|
14
|
+
sure_tools-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
sure_tools-1.0.4.dist-info/entry_points.txt,sha256=u12payZYgCBy5FCwRHP6AlSQhKCiWSEDwj68r1DVdn8,40
|
|
16
|
+
sure_tools-1.0.4.dist-info/top_level.txt,sha256=BtFTebdiJeqra4r6mm-uEtwVRFLZ_IjYsQ7OnalrOvY,5
|
|
17
|
+
sure_tools-1.0.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|