SURE-tools 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of SURE-tools might be problematic. Click here for more details.

SURE/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from .SURE import SURE
2
+
3
+ from . import utils
4
+ from . import codebook
5
+ from . import SURE
6
+
7
+ __all__ = ['SURE', 'utils', 'codebook']
@@ -0,0 +1,3 @@
1
+ # Importing specific functions from modules
2
+ from .assembly import assembly, split_by, split_batch_by, split_batch, get_data, get_subdata
3
+ from .atlas import SingleOmicsAtlas
@@ -0,0 +1,511 @@
1
+ import scanpy as sc
2
+ import pandas as pd
3
+ import numpy as np
4
+ import scipy as sp
5
+ from scipy.sparse import csr_matrix
6
+ from sklearn.preprocessing import OneHotEncoder, LabelBinarizer
7
+
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+ import subprocess
12
+ import pkg_resources
13
+ import datatable as dt
14
+
15
+ from ..SURE import SURE
16
+ from ..codebook import codebook_generate, codebook_summarize, codebook_aggregate, codebook_sketch, codebook_bootstrap_sketch
17
+ from ..utils import convert_to_tensor, tensor_to_numpy
18
+ from ..utils import pretty_print
19
+ from ..utils import find_partitions_greedy
20
+
21
+ from concurrent.futures import ThreadPoolExecutor, as_completed
22
+ from tqdm import tqdm
23
+
24
+
25
+ def assembly(adata_list_, batch_key, preprocessing=True, hvgs=None,
26
+ n_top_genes=5000, hvg_method='cell_ranger', layer='counts', cuda_id=0, use_jax=False,
27
+ codebook_size=800, codebook_size_per_adata=500, min_gamma=20, learning_rate=0.0001,
28
+ batch_size=200, batch_size_per_adata=100, n_epochs=200, latent_dist='studentt',
29
+ use_dirichlet=False, use_dirichlet_per_adata=False,
30
+ zero_inflation=False, zero_inflation_per_adata=False,
31
+ likelihood='multinomial', likelihood_per_adata='multinomial',
32
+ #n_samples_per_adata=10000, total_samples=1000000, summarize=False,
33
+ sketch_func='sum', n_bootstrap_neighbors=8,
34
+ #sketching=True, even_sketch=True, bootstrap_sketch=False, n_sketch_neighbors=10,
35
+ n_workers=1, mute=True):
36
+ adata_list = [ad.copy() for ad in adata_list_]
37
+ n_adatas = len(adata_list)
38
+
39
+ jit=''
40
+ if use_jax:
41
+ jit='--jit'
42
+
43
+ # get common hvgs
44
+ if hvgs is None:
45
+ preprocessing = False if (hvg_method=='seurat_v3') else preprocessing
46
+
47
+ # preprocess
48
+ if preprocessing:
49
+ with tqdm(total=n_adatas, desc='Preprocessing', unit='adata') as pbar:
50
+ for i in np.arange(n_adatas):
51
+ #print(f'Adata {i+1} / {n_adatas}: Preprocessing')
52
+ adata_list[i] = preprocess(adata_list[i], layer)
53
+ pbar.update(1)
54
+
55
+ #for i in np.arange(n_adatas):
56
+ # print(f'Adata {i+1} / {n_adatas}: Find {n_top_genes} HVGs')
57
+ # hvgs_ = highly_variable_genes(adata_list[i], n_top_genes, hvg_method)
58
+ # if hvgs is None:
59
+ # hvgs = set(hvgs_)
60
+ # else:
61
+ # hvgs = hvgs & set(hvgs_)
62
+ #hvgs = list(hvgs)
63
+
64
+ # test
65
+ hvgs = highly_variable_genes_for_adatas(adata_list, n_top_genes, hvg_method)
66
+ print(f'{len(hvgs)} common HVGs are found')
67
+
68
+ adata_size_list = []
69
+ for i in np.arange(n_adatas):
70
+ adata_size_list.append(adata_list[i].shape[0])
71
+
72
+ for i in np.arange(n_adatas):
73
+ mask = [x in adata_list[i].var_names.tolist() for x in hvgs]
74
+ if all(mask):
75
+ adata_list[i] = adata_list[i][:,hvgs]
76
+ else:
77
+ adata_i_X = get_subdata(adata_list[i], hvgs, 'X')
78
+ adata_i_counts = get_subdata(adata_list[i], hvgs, 'counts')
79
+ adata_i_obs = adata_list[i].obs
80
+ adata_i_new = sc.AnnData(adata_i_X, obs=adata_i_obs)
81
+ adata_i_new.var_names = hvgs
82
+ adata_i_new.layers['counts'] = adata_i_counts
83
+ adata_list[i] = adata_i_new
84
+
85
+ models_list = []
86
+ models_file_list = []
87
+ model = None
88
+ # process
89
+ #with tempfile.TemporaryDirectory() as temp_dir:
90
+ try:
91
+ if not os.path.exists('./SURE_temp_dir'):
92
+ os.mkdir('./SURE_temp_dir')
93
+ else:
94
+ pass
95
+ temp_dir = tempfile.mkdtemp(dir='./SURE_temp_dir')
96
+
97
+ if latent_dist == 'lapacian':
98
+ latent_dist_param='--z-dist laplacian'
99
+ elif latent_dist == 'studentt':
100
+ latent_dist_param='--z-dist studentt'
101
+ elif latent_dist == 'cauchy':
102
+ latent_dist_param='--z-dist cauchy'
103
+ else:
104
+ latent_dist_param=''
105
+
106
+ dirichlet = '-dirichlet' if use_dirichlet else ''
107
+ dirichlet_per_adata = '-dirichlet' if use_dirichlet_per_adata else ''
108
+
109
+ zi = '-zi exact' if zero_inflation else ''
110
+ zi_per_adata = '-zi exact' if zero_inflation_per_adata else ''
111
+
112
+ def run_sure(ad, i):
113
+ X = get_data(ad, layer=layer)
114
+ if batch_key is not None:
115
+ U = batch_encoding(ad, batch_key=batch_key)
116
+ else:
117
+ U = pd.DataFrame(np.zeros((X.shape[0])), columns=['batch'])
118
+
119
+ temp_count_file = os.path.join(temp_dir, f'temp_counts_{i}.txt.gz')
120
+ temp_uwv_file = os.path.join(temp_dir, f'temp_uwv_{i}.txt.gz')
121
+ temp_model_file = os.path.join(temp_dir, f'temp_{i}.pth')
122
+ temp_log_file = os.path.join(temp_dir, f'temp_log_{i}.txt.gz')
123
+
124
+ dt.Frame(X.round()).to_csv(temp_count_file)
125
+ dt.Frame(U).to_csv(temp_uwv_file)
126
+
127
+ codebook_size_per_adata_ = np.int32(X.shape[0] / min_gamma)
128
+ codebook_size_per_adata_ = codebook_size_per_adata_ if codebook_size_per_adata > codebook_size_per_adata_ else codebook_size_per_adata
129
+
130
+ mute_cmd = ''
131
+ if mute:
132
+ mute_cmd = f'2>&1 | gzip > {temp_log_file}'
133
+
134
+ cmd = f'CUDA_VISIBLE_DEVICES={cuda_id} SURE --data-file "{temp_count_file}" \
135
+ --undesired-factor-file "{temp_uwv_file}" \
136
+ --seed 0 \
137
+ --cuda {jit} \
138
+ -lr {learning_rate} \
139
+ -n {n_epochs} \
140
+ -bs {batch_size_per_adata} \
141
+ -cs {codebook_size_per_adata_} \
142
+ -likeli {likelihood_per_adata} {latent_dist_param} {dirichlet_per_adata} {zi_per_adata} \
143
+ --save-model "{temp_model_file}" {mute_cmd}'
144
+ subprocess.call(f'{cmd}', shell=True)
145
+ return {i:temp_model_file}
146
+
147
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
148
+ #models_file_list = list(tqdm(executor.map(run_sure, adata_list, np.arange(len(adata_list))), total=len(adata_list), desc="Metacell calling"))
149
+ futures = [executor.submit(run_sure, ad, i) for i,ad in enumerate(adata_list)]
150
+ results_ = {}
151
+ for f in tqdm(as_completed(futures), total=len(adata_list), unit='adata', desc="Metacell calling"):
152
+ results_.update(f.result())
153
+ models_file_list = [results_[k] for k in sorted(results_)]
154
+
155
+ # get the distribution structure for each adata
156
+ for i in np.arange(n_adatas):
157
+ '''
158
+ X = get_data(adata_list[i], layer=layer)
159
+ if batch_key is not None:
160
+ U = batch_encoding(adata_list[i], batch_key=batch_key)
161
+ else:
162
+ U = pd.DataFrame(np.zeros((X.shape[0])), columns=['batch'])
163
+
164
+ temp_count_file = os.path.join(temp_dir, f'temp_counts_{i}.txt.gz')
165
+ temp_uwv_file = os.path.join(temp_dir, f'temp_uwv_{i}.txt.gz')
166
+ temp_model_file = os.path.join(temp_dir, f'temp_{i}.pth')
167
+
168
+ dt.Frame(X.round()).to_csv(temp_count_file)
169
+ dt.Frame(U).to_csv(temp_uwv_file)
170
+
171
+ codebook_size_per_adata_ = np.int32(X.shape[0] / min_gamma)
172
+ codebook_size_per_adata_ = codebook_size_per_adata_ if codebook_size_per_adata > codebook_size_per_adata_ else codebook_size_per_adata
173
+
174
+ print(f'Adata {i+1} / {n_adatas}: Compute distribution-preserved sketching with {codebook_size_per_adata_} metacells from {X.shape[0]} cells')
175
+
176
+ cmd = f'CUDA_VISIBLE_DEVICES={cuda_id} SURE --data-file "{temp_count_file}" \
177
+ --undesired-factor-file "{temp_uwv_file}" \
178
+ --seed 0 \
179
+ --cuda {jit} \
180
+ -lr {learning_rate} \
181
+ -n {n_epochs} \
182
+ -bs {batch_size_per_adata} \
183
+ -cs {codebook_size_per_adata_} \
184
+ -likeli {likelihood_per_adata} {latent_dist_param} {dirichlet_per_adata} {zi_per_adata} \
185
+ --save-model "{temp_model_file}" '
186
+ pretty_print(cmd)
187
+ subprocess.call(f'{cmd}', shell=True)
188
+ model_i = SURE.load_model(temp_model_file)
189
+ '''
190
+ model_i = SURE.load_model(models_file_list[i])
191
+ models_list.append(model_i)
192
+
193
+ # generate samples from the learned distributions for assembly
194
+ if n_adatas > 1:
195
+ '''
196
+ n_samples = n_samples_per_adata * n_adatas
197
+ n_samples = n_samples if n_samples<total_samples else total_samples
198
+ adata_freq = adata_size_list / np.sum(adata_size_list)
199
+ #n_samples_list = generate_equal_list(n_samples, n_adatas)
200
+ n_samples_list = n_samples * adata_freq
201
+ n_samples_list = [np.int32(x) for x in n_samples_list]
202
+
203
+ for i in np.arange(n_adatas):
204
+ n_samples_list[i] = n_samples_list[i] if n_samples_list[i] < adata_size_list[i] else adata_size_list[i]
205
+ n_samples_list[i] = n_samples_list[i] if n_samples_list[i] > min_gamma * models_list[i].code_size else min_gamma * models_list[i].code_size
206
+ '''
207
+
208
+ adatas_to_assembly=[]
209
+ with tqdm(total=n_adatas, desc='Sampling', unit='adata') as pbar:
210
+ for i in np.arange(n_adatas):
211
+ # print(f'Generate {n_samples_list[i]} samples from sketched atlas {i+1} / {n_adatas} ')
212
+
213
+ model_i = models_list[i]
214
+ if sketch_func=='sum':
215
+ xs_i_ = get_data(adata_list[i], layer=layer).values
216
+ ns = models_list[i].hard_assignments(xs_i_)
217
+ xs_i = codebook_aggregate(ns, xs_i_)
218
+ elif sketch_func=='mean':
219
+ xs_i_ = get_data(adata_list[i], layer=layer).values
220
+ ns = models_list[i].hard_assignments(xs_i_)
221
+ xs_i = codebook_summarize(ns, xs_i_)
222
+ elif sketch_func=='simulate':
223
+ zs_i,_ = codebook_generate(model_i, min_gamma * models_list[i].code_size)
224
+ xs_i = model_i.generate_count_data(zs_i)
225
+ elif sketch_func=='sample':
226
+ data_i = get_subdata(adata_list[i], hvgs=hvgs, layer=layer).values
227
+ xs_i,_,_ = codebook_sketch(model_i, data_i, min_gamma * models_list[i].code_size, even_sample=True)
228
+ elif sketch_func=='bootstrap_sample':
229
+ data_i = get_subdata(adata_list[i], hvgs=hvgs, layer=layer).values
230
+ xs_i,_,_ = codebook_bootstrap_sketch(model_i, data_i, min_gamma * models_list[i].code_size, n_bootstrap_neighbors, even_sample=True)
231
+ '''
232
+ if summarize:
233
+ xs_i_ = get_data(adata_list[i], layer=layer).values
234
+ ns = models_list[i].hard_assignments(xs_i_)
235
+ #xs_i = codebook_summarize(ns, xs_i_)
236
+ xs_i = codebook_aggregate(ns, xs_i_)
237
+ elif sketching:
238
+ data_i = get_subdata(adata_list[i], hvgs=hvgs, layer=layer).values
239
+ if not bootstrap_sketch:
240
+ xs_i,_,_ = codebook_sketch(model_i, data_i, n_samples_list[i], even_sample=even_sketch)
241
+ else:
242
+ xs_i,_,_ = codebook_bootstrap_sketch(model_i, data_i, n_samples_list[i], n_sketch_neighbors, even_sample=even_sketch)
243
+ else:
244
+ zs_i,_ = codebook_generate(model_i, n_samples_list[i])
245
+ xs_i = model_i.generate_count_data(zs_i)
246
+ '''
247
+
248
+ adata_i = sc.AnnData(xs_i)
249
+ adata_i.obs['adata_id'] = i
250
+
251
+ adatas_to_assembly.append(adata_i)
252
+ pbar.update(1)
253
+
254
+ # assembly
255
+ adata_to_assembly = sc.concat(adatas_to_assembly)
256
+ temp_count_file = os.path.join(temp_dir, f'temp_counts.txt.gz')
257
+ temp_uwv_file = os.path.join(temp_dir, f'temp_uwv.txt.gz')
258
+ temp_model_file = os.path.join(temp_dir, f'temp_model.pth')
259
+
260
+ X = get_data(adata_to_assembly, layer='X')
261
+ U = batch_encoding(adata_to_assembly, batch_key='adata_id')
262
+ dt.Frame(X.round()).to_csv(temp_count_file)
263
+ dt.Frame(U).to_csv(temp_uwv_file)
264
+
265
+ codebook_size_ = np.int32(X.shape[0] / min_gamma)
266
+ codebook_size_ = codebook_size_ if codebook_size > codebook_size_ else codebook_size
267
+
268
+ print(f'Create distribution-preserved atlas with {codebook_size_} metacells from {X.shape[0]} samples')
269
+ cmd = f'CUDA_VISIBLE_DEVICES={cuda_id} SURE --data-file "{temp_count_file}" \
270
+ --undesired-factor-file "{temp_uwv_file}" \
271
+ --seed 0 \
272
+ --cuda {jit} \
273
+ -lr {learning_rate} \
274
+ -n {n_epochs} \
275
+ -bs {batch_size} \
276
+ -cs {codebook_size_} \
277
+ -likeli {likelihood} {latent_dist_param} {dirichlet} {zi} \
278
+ --save-model "{temp_model_file}" '
279
+ pretty_print(cmd)
280
+ subprocess.call(f'{cmd}', shell=True)
281
+ model = SURE.load_model(temp_model_file)
282
+ else:
283
+ model = models_list[0]
284
+ finally:
285
+ shutil.rmtree('./SURE_temp_dir', ignore_errors=True)
286
+
287
+ return model, models_list, hvgs
288
+
289
+ def preprocess(adata, layer='counts'):
290
+ adata.X = get_data(adata, layer).values.copy()
291
+ sc.pp.normalize_total(adata, target_sum=1e4)
292
+ sc.pp.log1p(adata)
293
+ return adata
294
+
295
+ def highly_variable_genes(adata, n_top_genes, hvg_method):
296
+ sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method)
297
+ hvgs = adata.var_names[adata.var.highly_variable]
298
+ return hvgs
299
+
300
+ def highly_variable_genes_for_adatas__(adata_list, n_top_genes, hvg_method):
301
+ n_adatas = len(adata_list)
302
+
303
+ for i in np.arange(n_adatas):
304
+ adata_list[i].obs['adata_id'] = i
305
+
306
+ adata = sc.concat(adata_list)
307
+ sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method, batch_key='adata_id')
308
+ hvgs = adata.var_names[adata.var.highly_variable].tolist()
309
+ return hvgs
310
+
311
+ def highly_variable_genes_for_adatas(adata_list, n_top_genes, hvg_method):
312
+ n_adatas = len(adata_list)
313
+
314
+ dfs = []
315
+ for i in np.arange(n_adatas):
316
+ #print(f'Adata {i+1} / {n_adatas}: Find {n_top_genes} HVGs')
317
+ adata = adata_list[i]
318
+ sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method)
319
+ hvgs_i = adata.var_names[adata.var.highly_variable].tolist()
320
+ df_i = adata.var.loc[hvgs_i,:]
321
+ df_i.reset_index(drop=False, inplace=True, names=["gene"])
322
+ dfs.append(df_i)
323
+
324
+ df = pd.concat(dfs, axis=0)
325
+ df["highly_variable"] = df["highly_variable"].astype(int)
326
+ df = df.groupby("gene", observed=True).agg(
327
+ dict(
328
+ means="mean",
329
+ dispersions="mean",
330
+ dispersions_norm="mean",
331
+ highly_variable="sum",
332
+ )
333
+ )
334
+ df["highly_variable_nbatches"] = df["highly_variable"]
335
+ df["dispersions_norm"] = df["dispersions_norm"].fillna(0)
336
+
337
+ df.sort_values(
338
+ ["highly_variable_nbatches", "dispersions_norm"],
339
+ ascending=False,
340
+ na_position="last",
341
+ inplace=True,
342
+ )
343
+ df["highly_variable"] = np.arange(df.shape[0]) < n_top_genes
344
+
345
+ #df["highly_variable"] = (df["means"]>0.0125) & (df["means"]<3) & (df["dispersions_norm"]>0.5) & ((df["dispersions_norm"]<np.inf))
346
+
347
+ df = df[df['highly_variable']]
348
+ hvgs = df.index.tolist()
349
+ #hvgs = hvgs[:n_top_genes]
350
+ return hvgs
351
+
352
+ def highly_variable_genes_for_adatas_(adata_list, n_top_genes, hvg_method):
353
+ n_adatas = len(adata_list)
354
+ hvgs,hvgs_ = None,None
355
+
356
+ for i in np.arange(n_adatas):
357
+ print(f'Adata {i+1} / {n_adatas}: Find {n_top_genes} HVGs')
358
+ adata = adata_list[i]
359
+ sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=hvg_method)
360
+ if hvgs_:
361
+ hvgs_ &= set(adata.var_names[adata.var.highly_variable].tolist())
362
+ else:
363
+ hvgs_ = set(adata.var_names[adata.var.highly_variable].tolist())
364
+
365
+ if len(hvgs_) == n_top_genes:
366
+ hvgs = list(hvgs_)
367
+ else:
368
+ hvgs = list(hvgs_)
369
+
370
+ n_average = (n_top_genes - len(hvgs)) // n_adatas
371
+ for i in np.arange(n_adatas):
372
+ if i==n_adatas-1:
373
+ n_average = n_top_genes - len(hvgs)
374
+
375
+ adata = adata_list[i]
376
+ hvgs_i = list(set(adata.var_names[adata.var.highly_variable].tolist()) - set(hvgs))
377
+ df_i = adata.var.loc[hvgs_i,:].copy()
378
+ #df_i.sort_values(by='highly_variable_rank', inplace=True)
379
+ df_i.sort_values(by='dispersions_norm', ascending=False, inplace=True)
380
+ hvgs_i = df_i.index.tolist()
381
+ hvgs.extend(hvgs_i[:n_average])
382
+
383
+ return hvgs
384
+
385
+ def batch_encoding_bk(adata, batch_key):
386
+ sklearn_version = pkg_resources.get_distribution("scikit-learn").version
387
+ if pkg_resources.parse_version(sklearn_version) < pkg_resources.parse_version("1.2"):
388
+ enc = OneHotEncoder(sparse=False).fit(adata.obs[batch_key].to_numpy().reshape(-1,1))
389
+ else:
390
+ enc = OneHotEncoder(sparse_output=False).fit(adata.obs[batch_key].to_numpy().reshape(-1,1))
391
+ return pd.DataFrame(enc.transform(adata.obs[batch_key].to_numpy().reshape(-1,1)), columns=enc.categories_[0])
392
+
393
+ def batch_encoding(adata, batch_key):
394
+ enc = LabelBinarizer().fit(adata.obs[batch_key].to_numpy())
395
+ return pd.DataFrame(enc.transform(adata.obs[batch_key].to_numpy()), columns=enc.classes_)
396
+
397
+ def get_data(adata, layer='counts'):
398
+ if layer.lower()!='x':
399
+ data = adata.layers[layer]
400
+ else:
401
+ data = adata.X
402
+
403
+ if sp.sparse.issparse(data):
404
+ data = data.toarray()
405
+
406
+ data[np.isnan(data)] = 0
407
+
408
+ return pd.DataFrame(data.astype('float32'), columns=adata.var_names)
409
+
410
+ def get_subdata(adata, hvgs, layer='counts'):
411
+ #mask = [hvg in X for hvg in hvgs]
412
+ hvgs_df = pd.DataFrame({'hvgs':hvgs})
413
+ mask = hvgs_df['hvgs'].isin(adata.var_names.tolist())
414
+ if all(mask):
415
+ X = get_data(adata[:,hvgs], layer)
416
+ return X[hvgs]
417
+ else:
418
+ #X2 = np.zeros((X.shape[0], len(hvgs)))
419
+ #X2 = pd.DataFrame(X2, columns=hvgs)
420
+
421
+ #columns = [c for c in X.columns.tolist() if c in hvgs]
422
+ #X2[columns] = X[columns].copy()
423
+
424
+ # inspired by SCimilarity
425
+ shell = sc.AnnData(
426
+ X=csr_matrix((0, len(hvgs))),
427
+ var=pd.DataFrame(index=hvgs),
428
+ )
429
+ if layer.lower() != 'x':
430
+ shell.layers[layer] = shell.X.copy()
431
+ shell = sc.concat(
432
+ (shell, adata[:, adata.var.index.isin(shell.var.index)]), join="outer"
433
+ )
434
+ X2 = get_data(shell, layer)
435
+ return X2[hvgs]
436
+
437
+ def get_uns(adata, key):
438
+ data = None
439
+
440
+ if key in adata.uns:
441
+ data = adata.uns[key]
442
+ columns = adata.uns[f'{key}_columns']
443
+ if sp.sparse.issparse(data):
444
+ data = data.toarray()
445
+ data = pd.DataFrame(data.astype('float32'),
446
+ columns=columns)
447
+
448
+ return data
449
+
450
+ def split_by(adata, by:str, copy:bool=False):
451
+ adata_list = []
452
+ for id in adata.obs[by].unique():
453
+ if copy:
454
+ adata_list.append(adata[adata.obs[by].isin([id])].copy())
455
+ else:
456
+ adata_list.append(adata[adata.obs[by].isin([id])])
457
+ return adata_list
458
+
459
+ def split_batch_by_bk(adata, by, batch_size=30000, copy=False):
460
+ df = adata.obs[by].value_counts().reset_index()
461
+ df.columns = [by,'Value']
462
+
463
+ n = int(np.round(adata.shape[0] / batch_size))
464
+ n = n if n > 0 else 1
465
+
466
+ adata_list = []
467
+ parts = find_partitions_greedy(df['Value'].tolist(), n)
468
+ for _,by_ids in parts:
469
+ ids = df.iloc[by_ids,:][by].tolist()
470
+ if copy:
471
+ adata_list.append(adata[adata.obs[by].isin(ids)].copy())
472
+ else:
473
+ adata_list.append(adata[adata.obs[by].isin(ids)])
474
+
475
+ return adata_list
476
+
477
+ def split_batch_by(adata,
478
+ by:str,
479
+ batch_size: int = 30000,
480
+ copy: bool = False):
481
+ groups = adata.obs[by].unique()
482
+ adata_list = []
483
+ for grp in groups:
484
+ adata_list.extend(split_batch(adata[adata.obs[by]==grp], batch_size=batch_size, copy=copy))
485
+
486
+ return adata_list
487
+
488
+ def split_batch(adata, batch_size: int=30000, copy: bool=False):
489
+ n = int(np.round(adata.shape[0] / batch_size))
490
+ n = n if n > 0 else 1
491
+
492
+ cells = adata.obs_names.tolist()
493
+ chunks = np.array_split(cells, n)
494
+
495
+ adata_list = []
496
+ for chunk in chunks:
497
+ chunk = list(chunk)
498
+
499
+ if copy:
500
+ adata_list.append(adata[chunk].copy())
501
+ else:
502
+ adata_list.append(adata[chunk])
503
+
504
+ return adata_list
505
+
506
+
507
+ def generate_equal_list(total, n):
508
+ base_value = total // n
509
+ remainder = total - (base_value * n)
510
+ result = [base_value] * (n - 1) + [base_value + remainder]
511
+ return result