SURE-tools 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/SURE.py +1203 -0
- SURE/__init__.py +7 -0
- SURE/assembly/__init__.py +3 -0
- SURE/assembly/assembly.py +511 -0
- SURE/assembly/atlas.py +575 -0
- SURE/codebook/__init__.py +4 -0
- SURE/codebook/codebook.py +472 -0
- SURE/utils/__init__.py +19 -0
- SURE/utils/custom_mlp.py +209 -0
- SURE/utils/queue.py +50 -0
- SURE/utils/utils.py +308 -0
- SURE_tools-1.0.1.dist-info/LICENSE +21 -0
- SURE_tools-1.0.1.dist-info/METADATA +68 -0
- SURE_tools-1.0.1.dist-info/RECORD +17 -0
- SURE_tools-1.0.1.dist-info/WHEEL +5 -0
- SURE_tools-1.0.1.dist-info/entry_points.txt +2 -0
- SURE_tools-1.0.1.dist-info/top_level.txt +1 -0
SURE/assembly/atlas.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import subprocess
|
|
4
|
+
import scanpy as sc
|
|
5
|
+
import numpy as np
|
|
6
|
+
import scipy as sp
|
|
7
|
+
from scipy import sparse
|
|
8
|
+
from scipy.stats import gaussian_kde
|
|
9
|
+
|
|
10
|
+
if sp.__version__ < '1.14.0':
|
|
11
|
+
from scipy.integrate import cumtrapz
|
|
12
|
+
else:
|
|
13
|
+
from scipy.integrate import cumulative_trapezoid as cumtrapz
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from functools import reduce
|
|
17
|
+
|
|
18
|
+
import datatable as dt
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
import umap
|
|
21
|
+
import faiss
|
|
22
|
+
from sklearn.neighbors import NearestNeighbors
|
|
23
|
+
#from cuml.linear_model import LogisticRegression
|
|
24
|
+
from sklearn.preprocessing import LabelEncoder,LabelBinarizer
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
from torch.utils.data import DataLoader
|
|
29
|
+
import pyro.distributions as dist
|
|
30
|
+
|
|
31
|
+
from ..SURE import SURE
|
|
32
|
+
from .assembly import assembly,get_data,get_subdata,batch_encoding,get_uns
|
|
33
|
+
from ..codebook import codebook_summarize_,codebook_generate,codebook_sketch
|
|
34
|
+
from ..utils import convert_to_tensor, tensor_to_numpy
|
|
35
|
+
from ..utils import CustomDataset
|
|
36
|
+
from ..utils import pretty_print, Colors
|
|
37
|
+
from ..utils import PriorityQueue
|
|
38
|
+
|
|
39
|
+
import networkx as nx
|
|
40
|
+
import matplotlib.pyplot as plt
|
|
41
|
+
import matplotlib.patches as mpatches
|
|
42
|
+
from matplotlib.colors import ListedColormap
|
|
43
|
+
|
|
44
|
+
import dill as pickle
|
|
45
|
+
import gzip
|
|
46
|
+
from packaging.version import Version
|
|
47
|
+
torch_version = torch.__version__
|
|
48
|
+
|
|
49
|
+
from typing import Literal
|
|
50
|
+
|
|
51
|
+
import warnings
|
|
52
|
+
warnings.filterwarnings("ignore")
|
|
53
|
+
|
|
54
|
+
class SingleOmicsAtlas(nn.Module):
|
|
55
|
+
"""
|
|
56
|
+
Compressed Cell Atlas
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
atlas_name
|
|
61
|
+
Name of the built atlas.
|
|
62
|
+
hvgs
|
|
63
|
+
Highly variable genes.
|
|
64
|
+
eps
|
|
65
|
+
Low bound.
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self,
|
|
68
|
+
atlas_name: str = 'Atlas',
|
|
69
|
+
hvgs: list = None,
|
|
70
|
+
eps: float = 1e-12):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.atlas_name = atlas_name
|
|
73
|
+
self.model = None
|
|
74
|
+
self.sure_models_list = None
|
|
75
|
+
self.hvgs = hvgs
|
|
76
|
+
self.adata = None
|
|
77
|
+
self.sample_adata = None
|
|
78
|
+
self.layer = None
|
|
79
|
+
self.n_sure_models = None
|
|
80
|
+
self.umap_metric='euclidean'
|
|
81
|
+
self.umap = None
|
|
82
|
+
self.adj = None
|
|
83
|
+
self.subatlas_list = None
|
|
84
|
+
self.n_subatlas = 0
|
|
85
|
+
self.pheno_keys = None
|
|
86
|
+
self.nearest_neighbor_engine = None
|
|
87
|
+
self.knn_k = 5
|
|
88
|
+
self.network = None
|
|
89
|
+
self.network_pos = None
|
|
90
|
+
self.sample_network = None
|
|
91
|
+
self.sample_network_pos = None
|
|
92
|
+
self.eps=eps
|
|
93
|
+
|
|
94
|
+
def fit(self, adata_list_,
|
|
95
|
+
batch_key: str = None,
|
|
96
|
+
pheno_keys: list = None,
|
|
97
|
+
preprocessing: bool = True,
|
|
98
|
+
hvgs: list = None,
|
|
99
|
+
n_top_genes: int = 5000,
|
|
100
|
+
hvg_method: Literal['seurat','seurat_v3','cell_ranger'] ='seurat',
|
|
101
|
+
layer: str = 'counts',
|
|
102
|
+
cuda_id: int = 0,
|
|
103
|
+
use_jax: bool = False,
|
|
104
|
+
codebook_size: int = 800,
|
|
105
|
+
codebook_size_per_adata: int = 500,
|
|
106
|
+
min_gamma=20,
|
|
107
|
+
min_cell=1000,
|
|
108
|
+
learning_rate: float = 0.0001,
|
|
109
|
+
batch_size: int = 200,
|
|
110
|
+
batch_size_per_adata: int = 100,
|
|
111
|
+
n_epochs: int = 200,
|
|
112
|
+
latent_dist: Literal['normal','laplacian','studentt','cauchy'] = 'studentt',
|
|
113
|
+
use_dirichlet: bool = False,
|
|
114
|
+
use_dirichlet_per_adata: bool = False,
|
|
115
|
+
zero_inflation: bool = False,
|
|
116
|
+
zero_inflation_per_adata: bool = False,
|
|
117
|
+
likelihood: Literal['negbinomial','poisson','multinomial','gaussian'] = 'multinomial',
|
|
118
|
+
likelihood_per_adata: Literal['negbinomial','poisson','multinomial','gaussian'] = 'multinomial',
|
|
119
|
+
#n_samples_per_adata: int = 10000,
|
|
120
|
+
#total_samples: int = 500000,
|
|
121
|
+
#summarize: bool = True,
|
|
122
|
+
sketch_func: Literal['mean','sum','simulate','sample','bootstrap_sample'] = 'mean',
|
|
123
|
+
n_bootstrap_neighbors: int = 8,
|
|
124
|
+
#sketching: bool = True,
|
|
125
|
+
#even_sketch: bool = True,
|
|
126
|
+
#bootstrap_sketch: bool = False,
|
|
127
|
+
#n_sketch_neighbors: int = 10,
|
|
128
|
+
n_workers: int = 1,
|
|
129
|
+
mute: bool = True,
|
|
130
|
+
edge_thresh: float = 0.001,
|
|
131
|
+
metric: Literal['euclidean','correlation','cosine'] ='euclidean',
|
|
132
|
+
knn_k: int = 5):
|
|
133
|
+
"""
|
|
134
|
+
Fit the input list of AnnData datasets.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
adata_list
|
|
139
|
+
A list of AnnData datasets.
|
|
140
|
+
batch_key
|
|
141
|
+
Undesired factor.
|
|
142
|
+
pheno_keys
|
|
143
|
+
A list of phenotype factors, of which the information should be retained in the built atlas.
|
|
144
|
+
preprocessing
|
|
145
|
+
If toggled on, the input datasets will go through the standard Scanpy proprocessing steps including normalization and log1p transformation.
|
|
146
|
+
hvgs
|
|
147
|
+
If a list of highly variable genes is given, the subsequent steps will rely on these genes.
|
|
148
|
+
n_top_genes
|
|
149
|
+
Parameter for Scanpy's highly_variable_genes
|
|
150
|
+
hvg_method
|
|
151
|
+
Parameter for Scanpy's highly_variable_genes
|
|
152
|
+
layer
|
|
153
|
+
Data used for building the atlas.
|
|
154
|
+
cuda_id
|
|
155
|
+
Cuda device.
|
|
156
|
+
use_jax
|
|
157
|
+
If toggled on, Jax will be used for speeding.
|
|
158
|
+
codebook_size
|
|
159
|
+
Size of metacells in the built atlas.
|
|
160
|
+
codebook_size_per_adata
|
|
161
|
+
Size of metacells for each adata.
|
|
162
|
+
learning_rate
|
|
163
|
+
Parameter for optimization.
|
|
164
|
+
batch_size
|
|
165
|
+
Parameter for building the atlas.
|
|
166
|
+
batch_size_per_adata
|
|
167
|
+
Parameter for calling metacells within each adata.
|
|
168
|
+
n_epochs
|
|
169
|
+
Number of epochs.
|
|
170
|
+
latent_dist
|
|
171
|
+
Distribution for latent representations.
|
|
172
|
+
use_dirichlet
|
|
173
|
+
Use Dirichlet model for building the atlas.
|
|
174
|
+
use_dirichlet_per_adata
|
|
175
|
+
Use Dirichlet model for calling metacells within each adata.
|
|
176
|
+
zero_inflation
|
|
177
|
+
Use zero-inflated model for building the atlas.
|
|
178
|
+
zero_inflation_per_adata
|
|
179
|
+
Use zero-inflated model for calling metacells within each adata.
|
|
180
|
+
likelihood
|
|
181
|
+
Data generation model for building the atlas.
|
|
182
|
+
likelihood_per_adata
|
|
183
|
+
Data generation model for calling metacells within each adata.
|
|
184
|
+
n_samples_per_adata
|
|
185
|
+
Number of samples drawn from each adata for building the atlas.
|
|
186
|
+
total_samples
|
|
187
|
+
Total number of samples for building the atlas.
|
|
188
|
+
sketching
|
|
189
|
+
If toggled on, sketched cells will be used for building the atlas.
|
|
190
|
+
bootstrap_sketch
|
|
191
|
+
If toggled on, bootstraped sketching will be used instead of simple sketching.
|
|
192
|
+
n_sketch_neighbors
|
|
193
|
+
Parameter for bootstraped sketching.
|
|
194
|
+
edge_thresh
|
|
195
|
+
Parameter for building network.
|
|
196
|
+
metric
|
|
197
|
+
Parameter for UMAP.
|
|
198
|
+
knn_k
|
|
199
|
+
Parameter for K-nearest-neighbor machine.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
print(Colors.YELLOW + 'Create A Distribution-Preserved Single-Cell Omics Atlas' + Colors.RESET)
|
|
203
|
+
adata_list = [ad for ad in adata_list_ if ad.shape[0]>min_cell]
|
|
204
|
+
n_adatas = len(adata_list)
|
|
205
|
+
self.layer = layer
|
|
206
|
+
self.n_sure_models = n_adatas
|
|
207
|
+
self.umap_metric = metric
|
|
208
|
+
self.pheno_keys = pheno_keys
|
|
209
|
+
#zero_inflation = True if sketching else zero_inflation
|
|
210
|
+
|
|
211
|
+
# assembly
|
|
212
|
+
print(f'{n_adatas} adata datasets are given')
|
|
213
|
+
self.model,self.submodels,self.hvgs = assembly(adata_list, batch_key,
|
|
214
|
+
preprocessing, hvgs, n_top_genes, hvg_method, layer, cuda_id, use_jax,
|
|
215
|
+
codebook_size, codebook_size_per_adata, min_gamma, learning_rate,
|
|
216
|
+
batch_size, batch_size_per_adata, n_epochs, latent_dist,
|
|
217
|
+
use_dirichlet, use_dirichlet_per_adata,
|
|
218
|
+
zero_inflation, zero_inflation_per_adata,
|
|
219
|
+
likelihood, likelihood_per_adata,
|
|
220
|
+
#n_samples_per_adata, total_samples, summarize,
|
|
221
|
+
sketch_func,n_bootstrap_neighbors,
|
|
222
|
+
#sketching, even_sketch, bootstrap_sketch, n_sketch_neighbors,
|
|
223
|
+
n_workers, mute)
|
|
224
|
+
|
|
225
|
+
# summarize expression
|
|
226
|
+
X,W,adj = None,None,None
|
|
227
|
+
with tqdm(total=n_adatas, desc=f'Summarize data in {layer}', unit='adata') as pbar:
|
|
228
|
+
for i in np.arange(n_adatas):
|
|
229
|
+
#print(f'Adata {i+1} / {n_adatas}: Summarize data in {layer}')
|
|
230
|
+
adata_i = adata_list[i][:,self.hvgs].copy()
|
|
231
|
+
adata_i_ = adata_list[i].copy()
|
|
232
|
+
|
|
233
|
+
xs_i = get_data(adata_i, layer).values
|
|
234
|
+
xs_i_ = get_data(adata_i_, layer).values
|
|
235
|
+
ws_i_sup = self.model.soft_assignments(xs_i)
|
|
236
|
+
xs_i_sup = codebook_summarize_(ws_i_sup, xs_i_)
|
|
237
|
+
|
|
238
|
+
if X is None:
|
|
239
|
+
X = xs_i_sup
|
|
240
|
+
W = np.sum(ws_i_sup.T, axis=1, keepdims=True)
|
|
241
|
+
|
|
242
|
+
a = convert_to_tensor(ws_i_sup)
|
|
243
|
+
a_t = a.T / torch.sum(a.T, dim=1, keepdim=True)
|
|
244
|
+
adj = torch.matmul(a_t, a)
|
|
245
|
+
else:
|
|
246
|
+
X += xs_i_sup
|
|
247
|
+
W += np.sum(ws_i_sup.T, axis=1, keepdims=True)
|
|
248
|
+
|
|
249
|
+
a = convert_to_tensor(ws_i_sup)
|
|
250
|
+
a_t = a.T / torch.sum(a.T, dim=1, keepdim=True)
|
|
251
|
+
adj += torch.matmul(a_t, a)
|
|
252
|
+
|
|
253
|
+
pbar.update(1)
|
|
254
|
+
X = X / W
|
|
255
|
+
self.adata = sc.AnnData(X)
|
|
256
|
+
self.adata.obs_names = [f'MC{x}' for x in self.adata.obs_names]
|
|
257
|
+
self.adata.var_names = adata_i_.var_names
|
|
258
|
+
|
|
259
|
+
adj = tensor_to_numpy(adj) / self.n_sure_models
|
|
260
|
+
self.adj = (adj + adj.T) / 2
|
|
261
|
+
n_nodes = adj.shape[0]
|
|
262
|
+
self.adj[np.arange(n_nodes), np.arange(n_nodes)] = 0
|
|
263
|
+
|
|
264
|
+
# summarize phenotypes
|
|
265
|
+
if pheno_keys is not None:
|
|
266
|
+
self._summarize_phenotypes_from_adatas(adata_list, pheno_keys)
|
|
267
|
+
|
|
268
|
+
# COMMENT OUT 2025.7.4
|
|
269
|
+
# compute visualization position for the atlas
|
|
270
|
+
#print('Compute the reference position of the atlas')
|
|
271
|
+
#n_samples = np.max([n_samples_per_adata * self.n_sure_models, 50000])
|
|
272
|
+
#n_samples = np.min([n_samples, total_samples])
|
|
273
|
+
#self.instantiation(n_samples)
|
|
274
|
+
#
|
|
275
|
+
# create nearest neighbor indexing
|
|
276
|
+
#self.build_nearest_neighbor_engine(knn_k)
|
|
277
|
+
#self.knn_k = knn_k
|
|
278
|
+
#
|
|
279
|
+
#self.build_network(edge_thresh=edge_thresh)
|
|
280
|
+
# END OF COMMENT OUT 2025.7.4
|
|
281
|
+
|
|
282
|
+
# the distribution of cell-to-metacell distances
|
|
283
|
+
metacells = self.model.get_metacell_coordinates()
|
|
284
|
+
self.nearest_metacell_engine = NearestNeighbors(n_neighbors=knn_k,n_jobs=-1)
|
|
285
|
+
self.nearest_metacell_engine.fit(metacells)
|
|
286
|
+
|
|
287
|
+
cell2metacell_distances = []
|
|
288
|
+
with tqdm(total=n_adatas, desc='Build cell-to-metacell distance distribution', unit='adata') as pbar:
|
|
289
|
+
for i in np.arange(n_adatas):
|
|
290
|
+
#print(f'Adata {i+1} / {n_adatas}: Build cell-to-metacell distance distribution')
|
|
291
|
+
adata_i = adata_list[i][:,self.hvgs].copy()
|
|
292
|
+
|
|
293
|
+
xs_i = get_data(adata_i, layer).values
|
|
294
|
+
zs_i = self.model.get_cell_coordinates(xs_i)
|
|
295
|
+
|
|
296
|
+
dd_i,_ = self.nearest_metacell_engine.kneighbors(zs_i, n_neighbors=1)
|
|
297
|
+
cell2metacell_distances.extend(dd_i.flatten())
|
|
298
|
+
|
|
299
|
+
pbar.update(1)
|
|
300
|
+
|
|
301
|
+
self.cell2metacell_dist = gaussian_kde(cell2metacell_distances)
|
|
302
|
+
|
|
303
|
+
print(Colors.YELLOW + f'A distribution-preserved atlas has been built from {n_adatas} adata datasets.' + Colors.RESET)
|
|
304
|
+
|
|
305
|
+
def detect_outlier(self, adata_query, thresh:float = 1e-2, batch_size:int = 1024):
|
|
306
|
+
|
|
307
|
+
def batch_p_values_greater(kde, x_vector):
|
|
308
|
+
"""
|
|
309
|
+
批量计算向量x中每个元素的P(X > x)
|
|
310
|
+
|
|
311
|
+
参数:
|
|
312
|
+
kde: 已拟合的gaussian_kde对象
|
|
313
|
+
x_vector: 包含多个x值的数组
|
|
314
|
+
|
|
315
|
+
返回:
|
|
316
|
+
与x_vector形状相同的P(X > x)数组
|
|
317
|
+
"""
|
|
318
|
+
# 创建高分辨率的CDF查找表
|
|
319
|
+
x_min = min(kde.dataset.min(), x_vector.min()) - 3 * np.std(kde.dataset)
|
|
320
|
+
x_max = max(kde.dataset.max(), x_vector.max()) + 3 * np.std(kde.dataset)
|
|
321
|
+
grid_points = max(20000, len(kde.dataset)) # 确保足够密集
|
|
322
|
+
x_grid = np.linspace(x_min, x_max, grid_points)
|
|
323
|
+
|
|
324
|
+
# 计算PDF和CDF
|
|
325
|
+
pdf = kde(x_grid)
|
|
326
|
+
cdf = np.cumsum(pdf)
|
|
327
|
+
cdf /= cdf[-1] # 归一化
|
|
328
|
+
|
|
329
|
+
# 使用插值查找每个x对应的P(X > x) = 1 - CDF(x)
|
|
330
|
+
p_values = 1 - np.interp(x_vector, x_grid, cdf)
|
|
331
|
+
|
|
332
|
+
# 处理边界外的值
|
|
333
|
+
p_values[x_vector < x_min] = 1.0
|
|
334
|
+
p_values[x_vector > x_max] = 0.0
|
|
335
|
+
|
|
336
|
+
return p_values
|
|
337
|
+
|
|
338
|
+
adata_query = adata_query.copy()
|
|
339
|
+
X_query = get_subdata(adata_query, self.hvgs, self.layer).values
|
|
340
|
+
Z_map = self.model.get_cell_coordinates(X_query, batch_size=batch_size)
|
|
341
|
+
|
|
342
|
+
dd,_ = self.nearest_metacell_engine.kneighbors(Z_map, n_neighbors=1)
|
|
343
|
+
#pp = self.cell2metacell_dist(dd.flatten())
|
|
344
|
+
pp = batch_p_values_greater(self.cell2metacell_dist,dd)
|
|
345
|
+
outliers = pp < thresh
|
|
346
|
+
|
|
347
|
+
print(f'{np.sum(outliers)} outliers found.')
|
|
348
|
+
|
|
349
|
+
return outliers
|
|
350
|
+
|
|
351
|
+
def map(self, adata_query,
|
|
352
|
+
batch_size: int = 1024):
|
|
353
|
+
"""
|
|
354
|
+
Map query data to the atlas.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
adata_query
|
|
359
|
+
Query data. It should be an AnnData object.
|
|
360
|
+
batch_size
|
|
361
|
+
Size of batch processing.
|
|
362
|
+
"""
|
|
363
|
+
adata_query = adata_query.copy()
|
|
364
|
+
X_query = get_subdata(adata_query, self.hvgs, self.layer).values
|
|
365
|
+
|
|
366
|
+
Z_map = self.model.get_cell_coordinates(X_query, batch_size=batch_size)
|
|
367
|
+
A_map = self.model.soft_assignments(X_query)
|
|
368
|
+
|
|
369
|
+
return Z_map, A_map
|
|
370
|
+
|
|
371
|
+
def phenotype_density_estimation(self, adata_list, pheno_key):
|
|
372
|
+
n_adatas = len(adata_list)
|
|
373
|
+
PH_MC_list = []
|
|
374
|
+
with tqdm(total=n_adatas, desc=f'Estimate {pheno_key} density', unit='adata') as pbar:
|
|
375
|
+
for i in np.arange(n_adatas):
|
|
376
|
+
ad = adata_list[i][:,self.hvgs].copy()
|
|
377
|
+
xs = get_data(ad, layer=self.layer).values
|
|
378
|
+
X_MC_mat = self.model.soft_assignments(xs)
|
|
379
|
+
lb = LabelBinarizer().fit(ad.obs[pheno_key].astype(str))
|
|
380
|
+
X_PH_mat = lb.transform(ad.obs[pheno_key].astype(str))
|
|
381
|
+
PH_MC_mat = np.matmul(X_PH_mat.T, X_MC_mat)
|
|
382
|
+
|
|
383
|
+
p_PH = np.sum(PH_MC_mat, axis=1, keepdims=True)
|
|
384
|
+
p_PH[p_PH==0] = 1
|
|
385
|
+
p_MC_PH = PH_MC_mat / p_PH
|
|
386
|
+
PH_MC_mat = p_MC_PH * p_PH
|
|
387
|
+
p_MC = np.sum(PH_MC_mat, axis=0, keepdims=True)
|
|
388
|
+
p_MC[p_MC==0] = 1
|
|
389
|
+
p_PH_MC = PH_MC_mat / p_MC
|
|
390
|
+
|
|
391
|
+
PH_MC_df = pd.DataFrame(p_PH_MC,
|
|
392
|
+
columns = [f'MC{x}' for x in np.arange(X_MC_mat.shape[1])],
|
|
393
|
+
index = lb.classes_)
|
|
394
|
+
PH_MC_list.append(PH_MC_df)
|
|
395
|
+
pbar.update(1)
|
|
396
|
+
|
|
397
|
+
if n_adatas>1:
|
|
398
|
+
PH_MC_DF = aggregate_dataframes(PH_MC_list)
|
|
399
|
+
PH_MC_DF = PH_MC_DF.div(n_adatas)
|
|
400
|
+
else:
|
|
401
|
+
PH_MC_DF = PH_MC_list[0]
|
|
402
|
+
|
|
403
|
+
self.adata.uns[f'{pheno_key}_density'] = PH_MC_DF
|
|
404
|
+
|
|
405
|
+
def summarize_phenotypes(self, adata_list=None, pheno_keys=None):
|
|
406
|
+
self._summarize_phenotypes_from_adatas(adata_list, pheno_keys)
|
|
407
|
+
|
|
408
|
+
def _summarize_phenotypes_from_adatas(self, adata_list, pheno_keys):
|
|
409
|
+
n_adatas = len(adata_list)
|
|
410
|
+
for pheno in pheno_keys:
|
|
411
|
+
Y = list()
|
|
412
|
+
with tqdm(total=n_adatas, desc=f'Summarize data in {pheno}', unit='adata') as pbar:
|
|
413
|
+
for i in np.arange(n_adatas):
|
|
414
|
+
if pheno in adata_list[i].obs.columns:
|
|
415
|
+
#print(f'Adata {i+1} / {n_adatas}: Summarize data in {pheno}')
|
|
416
|
+
adata_i = adata_list[i][:,self.hvgs].copy()
|
|
417
|
+
|
|
418
|
+
xs_i = get_data(adata_i, self.layer).values
|
|
419
|
+
ws_i_sup = self.model.soft_assignments(xs_i)
|
|
420
|
+
#ys_i = batch_encoding(adata_i, pheno)
|
|
421
|
+
#columns_i = ys_i.columns.tolist()
|
|
422
|
+
#ys_i = codebook_summarize_(ws_i_sup, ys_i.values)
|
|
423
|
+
|
|
424
|
+
ws_i = np.argmax(ws_i_sup, axis=1)
|
|
425
|
+
adata_i.obs['metacell'] = [f'MC{x}' for x in ws_i]
|
|
426
|
+
df = adata_i.obs[['metacell',pheno]].value_counts().unstack(fill_value=0)
|
|
427
|
+
|
|
428
|
+
Y.append(df)
|
|
429
|
+
pbar.update(1)
|
|
430
|
+
|
|
431
|
+
Y_df_ = aggregate_dataframes(Y)
|
|
432
|
+
Y_df = pd.DataFrame(0, index=[f'MC{x}' for x in np.arange(self.adata.shape[0])], columns=Y_df_.columns)
|
|
433
|
+
Y_df = Y_df.add(Y_df_, fill_value=0)
|
|
434
|
+
|
|
435
|
+
#Y = Y_df.values
|
|
436
|
+
#Y[Y<self.eps] = 0
|
|
437
|
+
#Y = Y / Y.sum(axis=1, keepdims=True)
|
|
438
|
+
|
|
439
|
+
self.adata.uns[pheno] = Y_df
|
|
440
|
+
#self.adata.uns[f'{pheno}_columns'] = Y_df.columns
|
|
441
|
+
Y_hat_ = Y_df.idxmax(axis=1)
|
|
442
|
+
Y_hat = pd.DataFrame(pd.NA, index=[f'MC{x}' for x in np.arange(self.adata.shape[0])], columns=['id'])
|
|
443
|
+
Y_hat.loc[Y_hat_.index.tolist(),'id'] = Y_hat_.tolist()
|
|
444
|
+
self.adata.obs[pheno] = Y_hat['id'].tolist()
|
|
445
|
+
|
|
446
|
+
def phenotype_predict(self, adata_query, pheno_key, batch_size=1024):
|
|
447
|
+
_,ws = self.map(adata_query, batch_size)
|
|
448
|
+
A = matrix_dotprod(ws, self.adata.uns[f'{pheno_key}_density'].values)
|
|
449
|
+
A = pd.DataFrame(A, columns=self.adata.uns[f'{pheno_key}_density'].columns)
|
|
450
|
+
return A.idxmax(axis=1).tolist()
|
|
451
|
+
|
|
452
|
+
@classmethod
|
|
453
|
+
def save_model(cls, atlas, file_path, compression=False):
|
|
454
|
+
"""Save the model to the specified file path."""
|
|
455
|
+
file_path = os.path.abspath(file_path)
|
|
456
|
+
|
|
457
|
+
atlas.sample_adata = None
|
|
458
|
+
atlas.eval()
|
|
459
|
+
|
|
460
|
+
if compression:
|
|
461
|
+
with gzip.open(file_path, 'wb') as pickle_file:
|
|
462
|
+
pickle.dump(atlas, pickle_file)
|
|
463
|
+
else:
|
|
464
|
+
with open(file_path, 'wb') as pickle_file:
|
|
465
|
+
pickle.dump(atlas, pickle_file)
|
|
466
|
+
|
|
467
|
+
print(f'Model saved to {file_path}')
|
|
468
|
+
|
|
469
|
+
@classmethod
|
|
470
|
+
def load_model(cls, file_path, n_samples=10000):
|
|
471
|
+
"""Load the model from the specified file path and return an instance."""
|
|
472
|
+
print(f'Model loaded from {file_path}')
|
|
473
|
+
|
|
474
|
+
file_path = os.path.abspath(file_path)
|
|
475
|
+
if file_path.endswith('gz'):
|
|
476
|
+
with gzip.open(file_path, 'rb') as pickle_file:
|
|
477
|
+
atlas = pickle.load(pickle_file)
|
|
478
|
+
else:
|
|
479
|
+
with open(file_path, 'rb') as pickle_file:
|
|
480
|
+
atlas = pickle.load(pickle_file)
|
|
481
|
+
|
|
482
|
+
#xs = atlas.sample(n_samples)
|
|
483
|
+
#atlas.sample_adata = sc.AnnData(xs)
|
|
484
|
+
#atlas.sample_adata.var_names = atlas.hvgs
|
|
485
|
+
#
|
|
486
|
+
#zs = atlas.model.get_cell_coordinates(xs)
|
|
487
|
+
#ws = atlas.model.soft_assignments(xs)
|
|
488
|
+
#atlas.sample_adata.obsm['X_umap'] = atlas.umap.transform(zs)
|
|
489
|
+
#atlas.sample_adata.obsm['X_sure'] = zs
|
|
490
|
+
#atlas.sample_adata.obsm['weight'] = ws
|
|
491
|
+
#
|
|
492
|
+
return atlas
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
#def aggregate_dataframes(df_list):
|
|
498
|
+
# n_dfs = len(df_list)
|
|
499
|
+
# all_columns = set(df_list[0].columns)
|
|
500
|
+
# for i in np.arange(n_dfs-1):
|
|
501
|
+
# all_columns = all_columns.union(set(df_list[i+1].columns))
|
|
502
|
+
#
|
|
503
|
+
# all_indexs = set(df_list[0].index.tolist())
|
|
504
|
+
# for i in np.arange(n_dfs-1):
|
|
505
|
+
# all_indexs = all_indexs.union(set(df_list[i+1].index.tolist()))
|
|
506
|
+
#
|
|
507
|
+
# for col in all_columns:
|
|
508
|
+
# for i in np.arange(n_dfs):
|
|
509
|
+
# if col not in df_list[i]:
|
|
510
|
+
# df_list[i][col] = 0
|
|
511
|
+
#
|
|
512
|
+
# df = pd.DataFrame(0, index=all_indexs, columns=all_columns)
|
|
513
|
+
# df = df_list[0]
|
|
514
|
+
# for i in np.arange(n_dfs-1):
|
|
515
|
+
# df += df_list[i+1]
|
|
516
|
+
#
|
|
517
|
+
# #df /= n_dfs
|
|
518
|
+
# return df
|
|
519
|
+
|
|
520
|
+
def aggregate_dataframes(df_list):
|
|
521
|
+
all_index = reduce(lambda x, y: x.union(y), [df.index for df in df_list], pd.Index([]))
|
|
522
|
+
all_index_sorted = all_index.sort_values()
|
|
523
|
+
all_columns = reduce(lambda x, y: x.union(y), [df.columns for df in df_list], pd.Index([]))
|
|
524
|
+
result = pd.DataFrame(0, index=all_index_sorted, columns=all_columns)
|
|
525
|
+
|
|
526
|
+
for df in df_list:
|
|
527
|
+
result = result.add(df, fill_value=0)
|
|
528
|
+
|
|
529
|
+
return result
|
|
530
|
+
|
|
531
|
+
def smooth_y_over_x(xs, ys, knn_k):
|
|
532
|
+
n = xs.shape[0]
|
|
533
|
+
nbrs = NearestNeighbors(n_neighbors=knn_k, n_jobs=-1)
|
|
534
|
+
nbrs.fit(xs)
|
|
535
|
+
ids = nbrs.kneighbors(xs, return_distance=False)
|
|
536
|
+
ys_smooth = np.zeros_like(ys)
|
|
537
|
+
for i in np.arange(knn_k):
|
|
538
|
+
ys_smooth += ys[ids[:,i]]
|
|
539
|
+
ys_smooth -= ys
|
|
540
|
+
ys_smooth /= knn_k-1
|
|
541
|
+
return ys_smooth
|
|
542
|
+
|
|
543
|
+
def matrix_dotprod(A, B, dtype=torch.float32):
|
|
544
|
+
A = convert_to_tensor(A, dtype=dtype)
|
|
545
|
+
B = convert_to_tensor(B, dtype=dtype)
|
|
546
|
+
AB = torch.matmul(A, B)
|
|
547
|
+
return tensor_to_numpy(AB)
|
|
548
|
+
|
|
549
|
+
def matrix_elemprod(A, B):
|
|
550
|
+
A = convert_to_tensor(A)
|
|
551
|
+
B = convert_to_tensor(B)
|
|
552
|
+
AB = A * B
|
|
553
|
+
return tensor_to_numpy(AB)
|
|
554
|
+
|
|
555
|
+
def cdf(density, xs, initial=0):
|
|
556
|
+
CDF = cumtrapz(density, xs, initial=initial)
|
|
557
|
+
CDF /= CDF[-1]
|
|
558
|
+
return CDF
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
class FaissKNeighbors:
|
|
563
|
+
def __init__(self, n_neighbors=5):
|
|
564
|
+
self.index = None
|
|
565
|
+
self.k = n_neighbors
|
|
566
|
+
|
|
567
|
+
def fit(self, X):
|
|
568
|
+
self.index = faiss.IndexFlatL2(X.shape[1])
|
|
569
|
+
self.index.add(X.astype(np.float32))
|
|
570
|
+
|
|
571
|
+
def kneighbors(self, X):
|
|
572
|
+
distances, indices = self.index.search(X.astype(np.float32), k=self.k)
|
|
573
|
+
return distances, indices
|
|
574
|
+
|
|
575
|
+
|