SURE-tools 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of SURE-tools might be problematic. Click here for more details.
- SURE/SURE.py +1203 -0
- SURE/__init__.py +7 -0
- SURE/assembly/__init__.py +3 -0
- SURE/assembly/assembly.py +511 -0
- SURE/assembly/atlas.py +575 -0
- SURE/codebook/__init__.py +4 -0
- SURE/codebook/codebook.py +472 -0
- SURE/utils/__init__.py +19 -0
- SURE/utils/custom_mlp.py +209 -0
- SURE/utils/queue.py +50 -0
- SURE/utils/utils.py +308 -0
- SURE_tools-1.0.1.dist-info/LICENSE +21 -0
- SURE_tools-1.0.1.dist-info/METADATA +68 -0
- SURE_tools-1.0.1.dist-info/RECORD +17 -0
- SURE_tools-1.0.1.dist-info/WHEEL +5 -0
- SURE_tools-1.0.1.dist-info/entry_points.txt +2 -0
- SURE_tools-1.0.1.dist-info/top_level.txt +1 -0
SURE/__init__.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
1
|
+
import scanpy as sc
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
import scipy as sp
|
|
5
|
+
from scipy.sparse import csr_matrix
|
|
6
|
+
from sklearn.preprocessing import OneHotEncoder, LabelBinarizer
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import shutil
|
|
10
|
+
import tempfile
|
|
11
|
+
import subprocess
|
|
12
|
+
import pkg_resources
|
|
13
|
+
import datatable as dt
|
|
14
|
+
|
|
15
|
+
from ..SURE import SURE
|
|
16
|
+
from ..codebook import codebook_generate, codebook_summarize, codebook_aggregate, codebook_sketch, codebook_bootstrap_sketch
|
|
17
|
+
from ..utils import convert_to_tensor, tensor_to_numpy
|
|
18
|
+
from ..utils import pretty_print
|
|
19
|
+
from ..utils import find_partitions_greedy
|
|
20
|
+
|
|
21
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def assembly(adata_list_, batch_key, preprocessing=True, hvgs=None,
|
|
26
|
+
n_top_genes=5000, hvg_method='cell_ranger', layer='counts', cuda_id=0, use_jax=False,
|
|
27
|
+
codebook_size=800, codebook_size_per_adata=500, min_gamma=20, learning_rate=0.0001,
|
|
28
|
+
batch_size=200, batch_size_per_adata=100, n_epochs=200, latent_dist='studentt',
|
|
29
|
+
use_dirichlet=False, use_dirichlet_per_adata=False,
|
|
30
|
+
zero_inflation=False, zero_inflation_per_adata=False,
|
|
31
|
+
likelihood='multinomial', likelihood_per_adata='multinomial',
|
|
32
|
+
#n_samples_per_adata=10000, total_samples=1000000, summarize=False,
|
|
33
|
+
sketch_func='sum', n_bootstrap_neighbors=8,
|
|
34
|
+
#sketching=True, even_sketch=True, bootstrap_sketch=False, n_sketch_neighbors=10,
|
|
35
|
+
n_workers=1, mute=True):
|
|
36
|
+
adata_list = [ad.copy() for ad in adata_list_]
|
|
37
|
+
n_adatas = len(adata_list)
|
|
38
|
+
|
|
39
|
+
jit=''
|
|
40
|
+
if use_jax:
|
|
41
|
+
jit='--jit'
|
|
42
|
+
|
|
43
|
+
# get common hvgs
|
|
44
|
+
if hvgs is None:
|
|
45
|
+
preprocessing = False if (hvg_method=='seurat_v3') else preprocessing
|
|
46
|
+
|
|
47
|
+
# preprocess
|
|
48
|
+
if preprocessing:
|
|
49
|
+
with tqdm(total=n_adatas, desc='Preprocessing', unit='adata') as pbar:
|
|
50
|
+
for i in np.arange(n_adatas):
|
|
51
|
+
#print(f'Adata {i+1} / {n_adatas}: Preprocessing')
|
|
52
|
+
adata_list[i] = preprocess(adata_list[i], layer)
|
|
53
|
+
pbar.update(1)
|
|
54
|
+
|
|
55
|
+
#for i in np.arange(n_adatas):
|
|
56
|
+
# print(f'Adata {i+1} / {n_adatas}: Find {n_top_genes} HVGs')
|
|
57
|
+
# hvgs_ = highly_variable_genes(adata_list[i], n_top_genes, hvg_method)
|
|
58
|
+
# if hvgs is None:
|
|
59
|
+
# hvgs = set(hvgs_)
|
|
60
|
+
# else:
|
|
61
|
+
# hvgs = hvgs & set(hvgs_)
|
|
62
|
+
#hvgs = list(hvgs)
|
|
63
|
+
|
|
64
|
+
# test
|
|
65
|
+
hvgs = highly_variable_genes_for_adatas(adata_list, n_top_genes, hvg_method)
|
|
66
|
+
print(f'{len(hvgs)} common HVGs are found')
|
|
67
|
+
|
|
68
|
+
adata_size_list = []
|
|
69
|
+
for i in np.arange(n_adatas):
|
|
70
|
+
adata_size_list.append(adata_list[i].shape[0])
|
|
71
|
+
|
|
72
|
+
for i in np.arange(n_adatas):
|
|
73
|
+
mask = [x in adata_list[i].var_names.tolist() for x in hvgs]
|
|
74
|
+
if all(mask):
|
|
75
|
+
adata_list[i] = adata_list[i][:,hvgs]
|
|
76
|
+
else:
|
|
77
|
+
adata_i_X = get_subdata(adata_list[i], hvgs, 'X')
|
|
78
|
+
adata_i_counts = get_subdata(adata_list[i], hvgs, 'counts')
|
|
79
|
+
adata_i_obs = adata_list[i].obs
|
|
80
|
+
adata_i_new = sc.AnnData(adata_i_X, obs=adata_i_obs)
|
|
81
|
+
adata_i_new.var_names = hvgs
|
|
82
|
+
adata_i_new.layers['counts'] = adata_i_counts
|
|
83
|
+
adata_list[i] = adata_i_new
|
|
84
|
+
|
|
85
|
+
models_list = []
|
|
86
|
+
models_file_list = []
|
|
87
|
+
model = None
|
|
88
|
+
# process
|
|
89
|
+
#with tempfile.TemporaryDirectory() as temp_dir:
|
|
90
|
+
try:
|
|
91
|
+
if not os.path.exists('./SURE_temp_dir'):
|
|
92
|
+
os.mkdir('./SURE_temp_dir')
|
|
93
|
+
else:
|
|
94
|
+
pass
|
|
95
|
+
temp_dir = tempfile.mkdtemp(dir='./SURE_temp_dir')
|
|
96
|
+
|
|
97
|
+
if latent_dist == 'lapacian':
|
|
98
|
+
latent_dist_param='--z-dist laplacian'
|
|
99
|
+
elif latent_dist == 'studentt':
|
|
100
|
+
latent_dist_param='--z-dist studentt'
|
|
101
|
+
elif latent_dist == 'cauchy':
|
|
102
|
+
latent_dist_param='--z-dist cauchy'
|
|
103
|
+
else:
|
|
104
|
+
latent_dist_param=''
|
|
105
|
+
|
|
106
|
+
dirichlet = '-dirichlet' if use_dirichlet else ''
|
|
107
|
+
dirichlet_per_adata = '-dirichlet' if use_dirichlet_per_adata else ''
|
|
108
|
+
|
|
109
|
+
zi = '-zi exact' if zero_inflation else ''
|
|
110
|
+
zi_per_adata = '-zi exact' if zero_inflation_per_adata else ''
|
|
111
|
+
|
|
112
|
+
def run_sure(ad, i):
|
|
113
|
+
X = get_data(ad, layer=layer)
|
|
114
|
+
if batch_key is not None:
|
|
115
|
+
U = batch_encoding(ad, batch_key=batch_key)
|
|
116
|
+
else:
|
|
117
|
+
U = pd.DataFrame(np.zeros((X.shape[0])), columns=['batch'])
|
|
118
|
+
|
|
119
|
+
temp_count_file = os.path.join(temp_dir, f'temp_counts_{i}.txt.gz')
|
|
120
|
+
temp_uwv_file = os.path.join(temp_dir, f'temp_uwv_{i}.txt.gz')
|
|
121
|
+
temp_model_file = os.path.join(temp_dir, f'temp_{i}.pth')
|
|
122
|
+
temp_log_file = os.path.join(temp_dir, f'temp_log_{i}.txt.gz')
|
|
123
|
+
|
|
124
|
+
dt.Frame(X.round()).to_csv(temp_count_file)
|
|
125
|
+
dt.Frame(U).to_csv(temp_uwv_file)
|
|
126
|
+
|
|
127
|
+
codebook_size_per_adata_ = np.int32(X.shape[0] / min_gamma)
|
|
128
|
+
codebook_size_per_adata_ = codebook_size_per_adata_ if codebook_size_per_adata > codebook_size_per_adata_ else codebook_size_per_adata
|
|
129
|
+
|
|
130
|
+
mute_cmd = ''
|
|
131
|
+
if mute:
|
|
132
|
+
mute_cmd = f'2>&1 | gzip > {temp_log_file}'
|
|
133
|
+
|
|
134
|
+
cmd = f'CUDA_VISIBLE_DEVICES={cuda_id} SURE --data-file "{temp_count_file}" \
|
|
135
|
+
--undesired-factor-file "{temp_uwv_file}" \
|
|
136
|
+
--seed 0 \
|
|
137
|
+
--cuda {jit} \
|
|
138
|
+
-lr {learning_rate} \
|
|
139
|
+
-n {n_epochs} \
|
|
140
|
+
-bs {batch_size_per_adata} \
|
|
141
|
+
-cs {codebook_size_per_adata_} \
|
|
142
|
+
-likeli {likelihood_per_adata} {latent_dist_param} {dirichlet_per_adata} {zi_per_adata} \
|
|
143
|
+
--save-model "{temp_model_file}" {mute_cmd}'
|
|
144
|
+
subprocess.call(f'{cmd}', shell=True)
|
|
145
|
+
return {i:temp_model_file}
|
|
146
|
+
|
|
147
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
|
148
|
+
#models_file_list = list(tqdm(executor.map(run_sure, adata_list, np.arange(len(adata_list))), total=len(adata_list), desc="Metacell calling"))
|
|
149
|
+
futures = [executor.submit(run_sure, ad, i) for i,ad in enumerate(adata_list)]
|
|
150
|
+
results_ = {}
|
|
151
|
+
for f in tqdm(as_completed(futures), total=len(adata_list), unit='adata', desc="Metacell calling"):
|
|
152
|
+
results_.update(f.result())
|
|
153
|
+
models_file_list = [results_[k] for k in sorted(results_)]
|
|
154
|
+
|
|
155
|
+
# get the distribution structure for each adata
|
|
156
|
+
for i in np.arange(n_adatas):
|
|
157
|
+
'''
|
|
158
|
+
X = get_data(adata_list[i], layer=layer)
|
|
159
|
+
if batch_key is not None:
|
|
160
|
+
U = batch_encoding(adata_list[i], batch_key=batch_key)
|
|
161
|
+
else:
|
|
162
|
+
U = pd.DataFrame(np.zeros((X.shape[0])), columns=['batch'])
|
|
163
|
+
|
|
164
|
+
temp_count_file = os.path.join(temp_dir, f'temp_counts_{i}.txt.gz')
|
|
165
|
+
temp_uwv_file = os.path.join(temp_dir, f'temp_uwv_{i}.txt.gz')
|
|
166
|
+
temp_model_file = os.path.join(temp_dir, f'temp_{i}.pth')
|
|
167
|
+
|
|
168
|
+
dt.Frame(X.round()).to_csv(temp_count_file)
|
|
169
|
+
dt.Frame(U).to_csv(temp_uwv_file)
|
|
170
|
+
|
|
171
|
+
codebook_size_per_adata_ = np.int32(X.shape[0] / min_gamma)
|
|
172
|
+
codebook_size_per_adata_ = codebook_size_per_adata_ if codebook_size_per_adata > codebook_size_per_adata_ else codebook_size_per_adata
|
|
173
|
+
|
|
174
|
+
print(f'Adata {i+1} / {n_adatas}: Compute distribution-preserved sketching with {codebook_size_per_adata_} metacells from {X.shape[0]} cells')
|
|
175
|
+
|
|
176
|
+
cmd = f'CUDA_VISIBLE_DEVICES={cuda_id} SURE --data-file "{temp_count_file}" \
|
|
177
|
+
--undesired-factor-file "{temp_uwv_file}" \
|
|
178
|
+
--seed 0 \
|
|
179
|
+
--cuda {jit} \
|
|
180
|
+
-lr {learning_rate} \
|
|
181
|
+
-n {n_epochs} \
|
|
182
|
+
-bs {batch_size_per_adata} \
|
|
183
|
+
-cs {codebook_size_per_adata_} \
|
|
184
|
+
-likeli {likelihood_per_adata} {latent_dist_param} {dirichlet_per_adata} {zi_per_adata} \
|
|
185
|
+
--save-model "{temp_model_file}" '
|
|
186
|
+
pretty_print(cmd)
|
|
187
|
+
subprocess.call(f'{cmd}', shell=True)
|
|
188
|
+
model_i = SURE.load_model(temp_model_file)
|
|
189
|
+
'''
|
|
190
|
+
model_i = SURE.load_model(models_file_list[i])
|
|
191
|
+
models_list.append(model_i)
|
|
192
|
+
|
|
193
|
+
# generate samples from the learned distributions for assembly
|
|
194
|
+
if n_adatas > 1:
|
|
195
|
+
'''
|
|
196
|
+
n_samples = n_samples_per_adata * n_adatas
|
|
197
|
+
n_samples = n_samples if n_samples<total_samples else total_samples
|
|
198
|
+
adata_freq = adata_size_list / np.sum(adata_size_list)
|
|
199
|
+
#n_samples_list = generate_equal_list(n_samples, n_adatas)
|
|
200
|
+
n_samples_list = n_samples * adata_freq
|
|
201
|
+
n_samples_list = [np.int32(x) for x in n_samples_list]
|
|
202
|
+
|
|
203
|
+
for i in np.arange(n_adatas):
|
|
204
|
+
n_samples_list[i] = n_samples_list[i] if n_samples_list[i] < adata_size_list[i] else adata_size_list[i]
|
|
205
|
+
n_samples_list[i] = n_samples_list[i] if n_samples_list[i] > min_gamma * models_list[i].code_size else min_gamma * models_list[i].code_size
|
|
206
|
+
'''
|
|
207
|
+
|
|
208
|
+
adatas_to_assembly=[]
|
|
209
|
+
with tqdm(total=n_adatas, desc='Sampling', unit='adata') as pbar:
|
|
210
|
+
for i in np.arange(n_adatas):
|
|
211
|
+
# print(f'Generate {n_samples_list[i]} samples from sketched atlas {i+1} / {n_adatas} ')
|
|
212
|
+
|
|
213
|
+
model_i = models_list[i]
|
|
214
|
+
if sketch_func=='sum':
|
|
215
|
+
xs_i_ = get_data(adata_list[i], layer=layer).values
|
|
216
|
+
ns = models_list[i].hard_assignments(xs_i_)
|
|
217
|
+
xs_i = codebook_aggregate(ns, xs_i_)
|
|
218
|
+
elif sketch_func=='mean':
|
|
219
|
+
xs_i_ = get_data(adata_list[i], layer=layer).values
|
|
220
|
+
ns = models_list[i].hard_assignments(xs_i_)
|
|
221
|
+
xs_i = codebook_summarize(ns, xs_i_)
|
|
222
|
+
elif sketch_func=='simulate':
|
|
223
|
+
zs_i,_ = codebook_generate(model_i, min_gamma * models_list[i].code_size)
|
|
224
|
+
xs_i = model_i.generate_count_data(zs_i)
|
|
225
|
+
elif sketch_func=='sample':
|
|
226
|
+
data_i = get_subdata(adata_list[i], hvgs=hvgs, layer=layer).values
|
|
227
|
+
xs_i,_,_ = codebook_sketch(model_i, data_i, min_gamma * models_list[i].code_size, even_sample=True)
|
|
228
|
+
elif sketch_func=='bootstrap_sample':
|
|
229
|
+
data_i = get_subdata(adata_list[i], hvgs=hvgs, layer=layer).values
|
|
230
|
+
xs_i,_,_ = codebook_bootstrap_sketch(model_i, data_i, min_gamma * models_list[i].code_size, n_bootstrap_neighbors, even_sample=True)
|
|
231
|
+
'''
|
|
232
|
+
if summarize:
|
|
233
|
+
xs_i_ = get_data(adata_list[i], layer=layer).values
|
|
234
|
+
ns = models_list[i].hard_assignments(xs_i_)
|
|
235
|
+
#xs_i = codebook_summarize(ns, xs_i_)
|
|
236
|
+
xs_i = codebook_aggregate(ns, xs_i_)
|
|
237
|
+
elif sketching:
|
|
238
|
+
data_i = get_subdata(adata_list[i], hvgs=hvgs, layer=layer).values
|
|
239
|
+
if not bootstrap_sketch:
|
|
240
|
+
xs_i,_,_ = codebook_sketch(model_i, data_i, n_samples_list[i], even_sample=even_sketch)
|
|
241
|
+
else:
|
|
242
|
+
xs_i,_,_ = codebook_bootstrap_sketch(model_i, data_i, n_samples_list[i], n_sketch_neighbors, even_sample=even_sketch)
|
|
243
|
+
else:
|
|
244
|
+
zs_i,_ = codebook_generate(model_i, n_samples_list[i])
|
|
245
|
+
xs_i = model_i.generate_count_data(zs_i)
|
|
246
|
+
'''
|
|
247
|
+
|
|
248
|
+
adata_i = sc.AnnData(xs_i)
|
|
249
|
+
adata_i.obs['adata_id'] = i
|
|
250
|
+
|
|
251
|
+
adatas_to_assembly.append(adata_i)
|
|
252
|
+
pbar.update(1)
|
|
253
|
+
|
|
254
|
+
# assembly
|
|
255
|
+
adata_to_assembly = sc.concat(adatas_to_assembly)
|
|
256
|
+
temp_count_file = os.path.join(temp_dir, f'temp_counts.txt.gz')
|
|
257
|
+
temp_uwv_file = os.path.join(temp_dir, f'temp_uwv.txt.gz')
|
|
258
|
+
temp_model_file = os.path.join(temp_dir, f'temp_model.pth')
|
|
259
|
+
|
|
260
|
+
X = get_data(adata_to_assembly, layer='X')
|
|
261
|
+
U = batch_encoding(adata_to_assembly, batch_key='adata_id')
|
|
262
|
+
dt.Frame(X.round()).to_csv(temp_count_file)
|
|
263
|
+
dt.Frame(U).to_csv(temp_uwv_file)
|
|
264
|
+
|
|
265
|
+
codebook_size_ = np.int32(X.shape[0] / min_gamma)
|
|
266
|
+
codebook_size_ = codebook_size_ if codebook_size > codebook_size_ else codebook_size
|
|
267
|
+
|
|
268
|
+
print(f'Create distribution-preserved atlas with {codebook_size_} metacells from {X.shape[0]} samples')
|
|
269
|
+
cmd = f'CUDA_VISIBLE_DEVICES={cuda_id} SURE --data-file "{temp_count_file}" \
|
|
270
|
+
--undesired-factor-file "{temp_uwv_file}" \
|
|
271
|
+
--seed 0 \
|
|
272
|
+
--cuda {jit} \
|
|
273
|
+
-lr {learning_rate} \
|
|
274
|
+
-n {n_epochs} \
|
|
275
|
+
-bs {batch_size} \
|
|
276
|
+
-cs {codebook_size_} \
|
|
277
|
+
-likeli {likelihood} {latent_dist_param} {dirichlet} {zi} \
|
|
278
|
+
--save-model "{temp_model_file}" '
|
|
279
|
+
pretty_print(cmd)
|
|
280
|
+
subprocess.call(f'{cmd}', shell=True)
|
|
281
|
+
model = SURE.load_model(temp_model_file)
|
|
282
|
+
else:
|
|
283
|
+
model = models_list[0]
|
|
284
|
+
finally:
|
|
285
|
+
shutil.rmtree('./SURE_temp_dir', ignore_errors=True)
|
|
286
|
+
|
|
287
|
+
return model, models_list, hvgs
|
|
288
|
+
|
|
289
|
+
def preprocess(adata, layer='counts'):
|
|
290
|
+
adata.X = get_data(adata, layer).values.copy()
|
|
291
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
|
292
|
+
sc.pp.log1p(adata)
|
|
293
|
+
return adata
|
|
294
|
+
|
|
295
|
+
def highly_variable_genes(adata, n_top_genes, hvg_method):
|
|
296
|
+
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method)
|
|
297
|
+
hvgs = adata.var_names[adata.var.highly_variable]
|
|
298
|
+
return hvgs
|
|
299
|
+
|
|
300
|
+
def highly_variable_genes_for_adatas__(adata_list, n_top_genes, hvg_method):
|
|
301
|
+
n_adatas = len(adata_list)
|
|
302
|
+
|
|
303
|
+
for i in np.arange(n_adatas):
|
|
304
|
+
adata_list[i].obs['adata_id'] = i
|
|
305
|
+
|
|
306
|
+
adata = sc.concat(adata_list)
|
|
307
|
+
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method, batch_key='adata_id')
|
|
308
|
+
hvgs = adata.var_names[adata.var.highly_variable].tolist()
|
|
309
|
+
return hvgs
|
|
310
|
+
|
|
311
|
+
def highly_variable_genes_for_adatas(adata_list, n_top_genes, hvg_method):
|
|
312
|
+
n_adatas = len(adata_list)
|
|
313
|
+
|
|
314
|
+
dfs = []
|
|
315
|
+
for i in np.arange(n_adatas):
|
|
316
|
+
#print(f'Adata {i+1} / {n_adatas}: Find {n_top_genes} HVGs')
|
|
317
|
+
adata = adata_list[i]
|
|
318
|
+
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method)
|
|
319
|
+
hvgs_i = adata.var_names[adata.var.highly_variable].tolist()
|
|
320
|
+
df_i = adata.var.loc[hvgs_i,:]
|
|
321
|
+
df_i.reset_index(drop=False, inplace=True, names=["gene"])
|
|
322
|
+
dfs.append(df_i)
|
|
323
|
+
|
|
324
|
+
df = pd.concat(dfs, axis=0)
|
|
325
|
+
df["highly_variable"] = df["highly_variable"].astype(int)
|
|
326
|
+
df = df.groupby("gene", observed=True).agg(
|
|
327
|
+
dict(
|
|
328
|
+
means="mean",
|
|
329
|
+
dispersions="mean",
|
|
330
|
+
dispersions_norm="mean",
|
|
331
|
+
highly_variable="sum",
|
|
332
|
+
)
|
|
333
|
+
)
|
|
334
|
+
df["highly_variable_nbatches"] = df["highly_variable"]
|
|
335
|
+
df["dispersions_norm"] = df["dispersions_norm"].fillna(0)
|
|
336
|
+
|
|
337
|
+
df.sort_values(
|
|
338
|
+
["highly_variable_nbatches", "dispersions_norm"],
|
|
339
|
+
ascending=False,
|
|
340
|
+
na_position="last",
|
|
341
|
+
inplace=True,
|
|
342
|
+
)
|
|
343
|
+
df["highly_variable"] = np.arange(df.shape[0]) < n_top_genes
|
|
344
|
+
|
|
345
|
+
#df["highly_variable"] = (df["means"]>0.0125) & (df["means"]<3) & (df["dispersions_norm"]>0.5) & ((df["dispersions_norm"]<np.inf))
|
|
346
|
+
|
|
347
|
+
df = df[df['highly_variable']]
|
|
348
|
+
hvgs = df.index.tolist()
|
|
349
|
+
#hvgs = hvgs[:n_top_genes]
|
|
350
|
+
return hvgs
|
|
351
|
+
|
|
352
|
+
def highly_variable_genes_for_adatas_(adata_list, n_top_genes, hvg_method):
|
|
353
|
+
n_adatas = len(adata_list)
|
|
354
|
+
hvgs,hvgs_ = None,None
|
|
355
|
+
|
|
356
|
+
for i in np.arange(n_adatas):
|
|
357
|
+
print(f'Adata {i+1} / {n_adatas}: Find {n_top_genes} HVGs')
|
|
358
|
+
adata = adata_list[i]
|
|
359
|
+
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method)
|
|
360
|
+
if hvgs_:
|
|
361
|
+
hvgs_ &= set(adata.var_names[adata.var.highly_variable].tolist())
|
|
362
|
+
else:
|
|
363
|
+
hvgs_ = set(adata.var_names[adata.var.highly_variable].tolist())
|
|
364
|
+
|
|
365
|
+
if len(hvgs_) == n_top_genes:
|
|
366
|
+
hvgs = list(hvgs_)
|
|
367
|
+
else:
|
|
368
|
+
hvgs = list(hvgs_)
|
|
369
|
+
|
|
370
|
+
n_average = (n_top_genes - len(hvgs)) // n_adatas
|
|
371
|
+
for i in np.arange(n_adatas):
|
|
372
|
+
if i==n_adatas-1:
|
|
373
|
+
n_average = n_top_genes - len(hvgs)
|
|
374
|
+
|
|
375
|
+
adata = adata_list[i]
|
|
376
|
+
hvgs_i = list(set(adata.var_names[adata.var.highly_variable].tolist()) - set(hvgs))
|
|
377
|
+
df_i = adata.var.loc[hvgs_i,:].copy()
|
|
378
|
+
#df_i.sort_values(by='highly_variable_rank', inplace=True)
|
|
379
|
+
df_i.sort_values(by='dispersions_norm', ascending=False, inplace=True)
|
|
380
|
+
hvgs_i = df_i.index.tolist()
|
|
381
|
+
hvgs.extend(hvgs_i[:n_average])
|
|
382
|
+
|
|
383
|
+
return hvgs
|
|
384
|
+
|
|
385
|
+
def batch_encoding_bk(adata, batch_key):
|
|
386
|
+
sklearn_version = pkg_resources.get_distribution("scikit-learn").version
|
|
387
|
+
if pkg_resources.parse_version(sklearn_version) < pkg_resources.parse_version("1.2"):
|
|
388
|
+
enc = OneHotEncoder(sparse=False).fit(adata.obs[batch_key].to_numpy().reshape(-1,1))
|
|
389
|
+
else:
|
|
390
|
+
enc = OneHotEncoder(sparse_output=False).fit(adata.obs[batch_key].to_numpy().reshape(-1,1))
|
|
391
|
+
return pd.DataFrame(enc.transform(adata.obs[batch_key].to_numpy().reshape(-1,1)), columns=enc.categories_[0])
|
|
392
|
+
|
|
393
|
+
def batch_encoding(adata, batch_key):
|
|
394
|
+
enc = LabelBinarizer().fit(adata.obs[batch_key].to_numpy())
|
|
395
|
+
return pd.DataFrame(enc.transform(adata.obs[batch_key].to_numpy()), columns=enc.classes_)
|
|
396
|
+
|
|
397
|
+
def get_data(adata, layer='counts'):
|
|
398
|
+
if layer.lower()!='x':
|
|
399
|
+
data = adata.layers[layer]
|
|
400
|
+
else:
|
|
401
|
+
data = adata.X
|
|
402
|
+
|
|
403
|
+
if sp.sparse.issparse(data):
|
|
404
|
+
data = data.toarray()
|
|
405
|
+
|
|
406
|
+
data[np.isnan(data)] = 0
|
|
407
|
+
|
|
408
|
+
return pd.DataFrame(data.astype('float32'), columns=adata.var_names)
|
|
409
|
+
|
|
410
|
+
def get_subdata(adata, hvgs, layer='counts'):
|
|
411
|
+
#mask = [hvg in X for hvg in hvgs]
|
|
412
|
+
hvgs_df = pd.DataFrame({'hvgs':hvgs})
|
|
413
|
+
mask = hvgs_df['hvgs'].isin(adata.var_names.tolist())
|
|
414
|
+
if all(mask):
|
|
415
|
+
X = get_data(adata[:,hvgs], layer)
|
|
416
|
+
return X[hvgs]
|
|
417
|
+
else:
|
|
418
|
+
#X2 = np.zeros((X.shape[0], len(hvgs)))
|
|
419
|
+
#X2 = pd.DataFrame(X2, columns=hvgs)
|
|
420
|
+
|
|
421
|
+
#columns = [c for c in X.columns.tolist() if c in hvgs]
|
|
422
|
+
#X2[columns] = X[columns].copy()
|
|
423
|
+
|
|
424
|
+
# inspired by SCimilarity
|
|
425
|
+
shell = sc.AnnData(
|
|
426
|
+
X=csr_matrix((0, len(hvgs))),
|
|
427
|
+
var=pd.DataFrame(index=hvgs),
|
|
428
|
+
)
|
|
429
|
+
if layer.lower() != 'x':
|
|
430
|
+
shell.layers[layer] = shell.X.copy()
|
|
431
|
+
shell = sc.concat(
|
|
432
|
+
(shell, adata[:, adata.var.index.isin(shell.var.index)]), join="outer"
|
|
433
|
+
)
|
|
434
|
+
X2 = get_data(shell, layer)
|
|
435
|
+
return X2[hvgs]
|
|
436
|
+
|
|
437
|
+
def get_uns(adata, key):
|
|
438
|
+
data = None
|
|
439
|
+
|
|
440
|
+
if key in adata.uns:
|
|
441
|
+
data = adata.uns[key]
|
|
442
|
+
columns = adata.uns[f'{key}_columns']
|
|
443
|
+
if sp.sparse.issparse(data):
|
|
444
|
+
data = data.toarray()
|
|
445
|
+
data = pd.DataFrame(data.astype('float32'),
|
|
446
|
+
columns=columns)
|
|
447
|
+
|
|
448
|
+
return data
|
|
449
|
+
|
|
450
|
+
def split_by(adata, by:str, copy:bool=False):
|
|
451
|
+
adata_list = []
|
|
452
|
+
for id in adata.obs[by].unique():
|
|
453
|
+
if copy:
|
|
454
|
+
adata_list.append(adata[adata.obs[by].isin([id])].copy())
|
|
455
|
+
else:
|
|
456
|
+
adata_list.append(adata[adata.obs[by].isin([id])])
|
|
457
|
+
return adata_list
|
|
458
|
+
|
|
459
|
+
def split_batch_by_bk(adata, by, batch_size=30000, copy=False):
|
|
460
|
+
df = adata.obs[by].value_counts().reset_index()
|
|
461
|
+
df.columns = [by,'Value']
|
|
462
|
+
|
|
463
|
+
n = int(np.round(adata.shape[0] / batch_size))
|
|
464
|
+
n = n if n > 0 else 1
|
|
465
|
+
|
|
466
|
+
adata_list = []
|
|
467
|
+
parts = find_partitions_greedy(df['Value'].tolist(), n)
|
|
468
|
+
for _,by_ids in parts:
|
|
469
|
+
ids = df.iloc[by_ids,:][by].tolist()
|
|
470
|
+
if copy:
|
|
471
|
+
adata_list.append(adata[adata.obs[by].isin(ids)].copy())
|
|
472
|
+
else:
|
|
473
|
+
adata_list.append(adata[adata.obs[by].isin(ids)])
|
|
474
|
+
|
|
475
|
+
return adata_list
|
|
476
|
+
|
|
477
|
+
def split_batch_by(adata,
|
|
478
|
+
by:str,
|
|
479
|
+
batch_size: int = 30000,
|
|
480
|
+
copy: bool = False):
|
|
481
|
+
groups = adata.obs[by].unique()
|
|
482
|
+
adata_list = []
|
|
483
|
+
for grp in groups:
|
|
484
|
+
adata_list.extend(split_batch(adata[adata.obs[by]==grp], batch_size=batch_size, copy=copy))
|
|
485
|
+
|
|
486
|
+
return adata_list
|
|
487
|
+
|
|
488
|
+
def split_batch(adata, batch_size: int=30000, copy: bool=False):
|
|
489
|
+
n = int(np.round(adata.shape[0] / batch_size))
|
|
490
|
+
n = n if n > 0 else 1
|
|
491
|
+
|
|
492
|
+
cells = adata.obs_names.tolist()
|
|
493
|
+
chunks = np.array_split(cells, n)
|
|
494
|
+
|
|
495
|
+
adata_list = []
|
|
496
|
+
for chunk in chunks:
|
|
497
|
+
chunk = list(chunk)
|
|
498
|
+
|
|
499
|
+
if copy:
|
|
500
|
+
adata_list.append(adata[chunk].copy())
|
|
501
|
+
else:
|
|
502
|
+
adata_list.append(adata[chunk])
|
|
503
|
+
|
|
504
|
+
return adata_list
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def generate_equal_list(total, n):
|
|
508
|
+
base_value = total // n
|
|
509
|
+
remainder = total - (base_value * n)
|
|
510
|
+
result = [base_value] * (n - 1) + [base_value + remainder]
|
|
511
|
+
return result
|