fastccc 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
fastccc/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ import sys
2
+ from loguru import logger
3
+
4
+ logger.remove()
5
+ logger.add(sys.stdout, level="INFO", format='<cyan>{time:YYYY-MM-DD HH:mm:ss}</cyan> | <level>{level: <8}</level> | <level>{message}</level>')
6
+
7
+ from .core import Cauchy_combination_of_statistical_analysis_methods
8
+ from .core import statistical_analysis_method
@@ -0,0 +1,455 @@
1
+ import scanpy as sc
2
+ import numpy as np
3
+ import pandas as pd
4
+ from scipy.sparse import issparse
5
+ from tqdm import tqdm
6
+ from loguru import logger
7
+ from .preprocess import get_interactions
8
+ from . import preproc_utils
9
+ from .core import calculate_cluster_percents, analyze_interactions_percents
10
+ from .distrib_digit import Distribution_digit, get_pmf_array_from_samples_for_digitized_bins, get_minimum_distribution_for_digit
11
+ from . import dist_complex
12
+ from . import dist_lr
13
+ from . import score
14
+ import itertools
15
+ from scipy.signal import fftconvolve
16
+ from collections import Counter
17
+ import os
18
+ import pickle
19
+
20
+ def digitize_transform(x, n_bins=50):
21
+ def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray:
22
+ assert x.ndim == 1 and bins.ndim == 1
23
+ left_digits = np.digitize(x, bins)
24
+ if side == "one":
25
+ return left_digits
26
+ right_digits = np.digitize(x, bins, right=True)
27
+ rands = np.random.rand(len(x)) # uniform random numbers
28
+ digits = rands * (right_digits - left_digits) + left_digits
29
+ digits = np.ceil(digits).astype(np.int64)
30
+ return digits
31
+
32
+ # non_zero_ids = x.nonzero()
33
+ # non_zero_row = x[non_zero_ids]
34
+ '''
35
+ input x 就是 非0的,直接针对csr,coo数据的
36
+ '''
37
+ bins = np.quantile(x, np.linspace(0, 1, n_bins - 1))
38
+ non_zero_digits = _digitize(x, bins)
39
+ return non_zero_digits
40
+
41
+
42
+ def calculate_L_R_and_IS_score(mean_counts, interactions):
43
+ p1_index = []
44
+ p2_index = []
45
+ all_index = []
46
+ for i in itertools.product(sorted(mean_counts.index), sorted(mean_counts.index)):
47
+ p1_index.append(i[0])
48
+ p2_index.append(i[1])
49
+ all_index.append('|'.join(i))
50
+ p1 = mean_counts.loc[p1_index, interactions['multidata_1_id']]
51
+ p2 = mean_counts.loc[p2_index, interactions['multidata_2_id']]
52
+ p1.columns = interactions.index
53
+ p2.columns = interactions.index
54
+ p1.index = all_index
55
+ p2.index = all_index
56
+ interactions_strength = (p1 + p2)/2 * (p1 > 0) * (p2>0)
57
+ return p1, p2, interactions_strength
58
+
59
+ def calculate_L_R_and_IS_percents(cluster_percents, interactions, threshold=0.1, sep='|'):
60
+
61
+ p1_index = []
62
+ p2_index = []
63
+ all_index = []
64
+ for i in itertools.product(sorted(cluster_percents.index), sorted(cluster_percents.index)):
65
+ p1_index.append(i[0])
66
+ p2_index.append(i[1])
67
+ all_index.append(sep.join(i))
68
+
69
+ p1 = cluster_percents.loc[p1_index, interactions['multidata_1_id']]
70
+ p2 = cluster_percents.loc[p2_index, interactions['multidata_2_id']]
71
+ p1.columns = interactions.index
72
+ p2.columns = interactions.index
73
+ p1.index = all_index
74
+ p2.index = all_index
75
+
76
+ interactions_strength = (p1>threshold) * (p2>threshold)
77
+ # print((p1>threshold) * (p2>threshold))
78
+
79
+ return p1, p2, interactions_strength
80
+
81
+ def calculate_mean_pmfs(counts_df, labels_df, complex_table, gene_pmf_dict, n_fft=100):
82
+ meta_dict = Counter(labels_df.cell_type)
83
+ ####### clusters_mean #######
84
+ clusters_mean_dict = {}
85
+ for celltype in sorted(meta_dict):
86
+ clusters_mean_dict[celltype] = {}
87
+ n_sum = meta_dict[celltype]
88
+ if n_sum < n_fft:
89
+ for gene in counts_df.columns:
90
+ if gene not in gene_pmf_dict:
91
+ continue
92
+ else:
93
+ clusters_mean_dict[celltype][gene] = gene_pmf_dict[gene][n_sum]
94
+ else:
95
+ for gene in counts_df.columns:
96
+ if gene not in gene_pmf_dict:
97
+ continue
98
+ else:
99
+ clusters_mean_dict[celltype][gene] = gene_pmf_dict[gene][1] ** n_sum / n_sum
100
+ mean_pmfs = pd.DataFrame(clusters_mean_dict).T
101
+ complex_func = get_minimum_distribution_for_digit
102
+ mean_pmfs = dist_complex.combine_complex_distribution_df(mean_pmfs, complex_table, complex_func)
103
+ return mean_pmfs
104
+
105
+
106
+ def rank_preprocess(adata):
107
+ np.random.seed(42) # add seed to ensure reproduiablity
108
+ assert issparse(adata.X), "Anndata.X should be a sparse matrix format."
109
+ if adata.shape[1] < 5000:
110
+ logger.warning("Do you use whole transcriptomes? Raw data w\o filtering genes should work better.")
111
+
112
+ for i in tqdm(range(adata.shape[0]), desc="Ranking genes for cells", unit="cell",
113
+ bar_format="{l_bar}{bar} | {n_fmt}/{total_fmt} cells completed", leave=False):
114
+ indices = slice(adata.X.indptr[i], adata.X.indptr[i+1])
115
+ x = adata.X.data[indices]
116
+ adata.X.data[indices] = digitize_transform(x)
117
+ logger.success("Rank preprocess done.")
118
+ return adata
119
+
120
+ def get_fastccc_input(adata, lrdb_file_path, convert_type = 'hgnc_symbol'):
121
+ logger.info("Loading LRIs database. hgnc_symbol as gene name is requested.")
122
+ interactions = get_interactions(lrdb_file_path)
123
+ ##### gene_table ########
124
+ gene_table = pd.read_csv(os.path.join(lrdb_file_path, 'gene_table.csv'))
125
+ protein_table = pd.read_csv(os.path.join(lrdb_file_path, 'protein_table.csv'))
126
+ gene_table = gene_table.merge(protein_table, left_on='protein_id', right_on='id_protein')
127
+ #########################
128
+
129
+
130
+ ##### complex_table ######
131
+ complex_composition = pd.read_csv(os.path.join(lrdb_file_path, 'complex_composition_table.csv'))
132
+ complex_table = pd.read_csv(os.path.join(lrdb_file_path, 'complex_table.csv'))
133
+ complex_table = complex_table.merge(complex_composition, left_on='complex_multidata_id', right_on='complex_multidata_id')
134
+
135
+ # 让我们只关注 'complex_multidata_id','protein_multidata_id'
136
+ complex_table = complex_table[['complex_multidata_id','protein_multidata_id']]
137
+ '''
138
+ complex_table(pandas.DataFrame):
139
+ =======================================================
140
+ | complex_multidata_id | protein_multidata_id
141
+ -------------------------------------------------------
142
+ 0 | 1355 | 1134
143
+ 1 | 1356 | 1175
144
+ 2 | 1357 | 1167
145
+ =======================================================
146
+ '''
147
+ ##########################
148
+
149
+ ##### feature to id conversion ######
150
+ # 不在 标准列表 里的 gene 就不要了
151
+ tmp = gene_table[[convert_type, 'protein_multidata_id']]
152
+ tmp = tmp.drop_duplicates()
153
+ tmp.set_index('protein_multidata_id', inplace=True)
154
+
155
+ select_columns = []
156
+ columns_names = []
157
+ for foo, boo in zip(tmp.index, tmp[convert_type]):
158
+ if boo in adata.var_names:#counts.columns:
159
+ select_columns.append(boo)
160
+ columns_names.append(foo)
161
+
162
+ reduced_counts = adata[:, select_columns].to_df()
163
+ reduced_counts.columns = columns_names
164
+ reduced_counts = reduced_counts.T.groupby(reduced_counts.columns).mean().T
165
+ # FutureWarning: DataFrame.groupby with axis=1 is deprecated. Do `frame.T.groupby(...)` without axis instead.
166
+ # reduced_counts = reduced_counts.groupby(reduced_counts.columns, axis=1).mean()
167
+ ######################################
168
+
169
+ ########## filter genes ############
170
+ # gene 在 所有 cell 上为 0 不要
171
+ reduced_counts = preproc_utils.filter_empty_genes(reduced_counts)
172
+ ######################################
173
+
174
+ ######################################################################
175
+ # 3.Other DF filtered #
176
+ ######################################################################
177
+ # 一个 interaction 可能只有 partA 存在,但是 partB 不存在
178
+ # 只有 如果 任意一部分不存在, 另一部分没必要参与后续计算
179
+
180
+ ##### delete item not involved interactions ####
181
+
182
+ foo_dict = complex_table.groupby('complex_multidata_id').apply(lambda x: list(x['protein_multidata_id'].values), include_groups=False).to_dict()
183
+ '''
184
+ dictionary complex_id: [protein_id_1, pid2, pid3, ...]
185
+ foo_dict = {
186
+ 1355: [1134],
187
+ 1356: [1175],
188
+ xxxx: [AAAA, BBBB, CCCC],
189
+ }
190
+ '''
191
+
192
+ def __content__(key):
193
+ if key not in foo_dict:
194
+ return [key]
195
+ else:
196
+ return foo_dict[key]
197
+
198
+ def __exist__(key, df):
199
+ # 目前的 complex 策略就是 全部都要有
200
+ # 经测试,这是cpdb用的策略
201
+ for item in __content__(key):
202
+ if item not in df.columns:
203
+ return False
204
+ return True
205
+
206
+ temp_list = []
207
+ temp_dict = {}
208
+ for item in reduced_counts.columns:
209
+ temp_dict[item] = False
210
+
211
+ # 注释是为了验证 interactions 的过滤策略, 完全一致
212
+ # print(interactions)
213
+ select_index = []
214
+ for partA, partB in zip(interactions.multidata_1_id, interactions.multidata_2_id):
215
+ if __exist__(partA, reduced_counts) and __exist__(partB, reduced_counts):
216
+ temp_list.extend([partA, partB])
217
+ select_index.append(True)
218
+ else:
219
+ select_index.append(False)
220
+ interactions_filtered = interactions[select_index]
221
+
222
+ for item in temp_list:
223
+ for subitem in __content__(item):
224
+ if subitem in temp_dict:
225
+ temp_dict[subitem] = True
226
+ select_index = [key for key in temp_dict if temp_dict[key]]
227
+ reduced_counts = reduced_counts[select_index]
228
+
229
+ counts_df = reduced_counts
230
+ temp_list = set(temp_list)
231
+ select_index = [True if item in temp_list else False for item in complex_table.complex_multidata_id]
232
+ complex_table = complex_table[select_index]
233
+ interactions = interactions_filtered
234
+ logger.success("Requested data for fastccc is prepared.")
235
+ return counts_df, complex_table, interactions
236
+
237
+
238
+ def fastccc_for_reference(reference_name, save_path, counts_df, labels_df, complex_table, interactions, min_percentile = 0.1, ref_debug_mode=False, query_debug_mode=False, for_uploading=False):
239
+ logger.info("Running FastCCC.")
240
+ mean_counts = score.calculate_cluster_mean(counts_df, labels_df)
241
+ complex_func = score.calculate_complex_min_func
242
+ mean_counts = score.combine_complex_distribution_df(mean_counts, complex_table, complex_func)
243
+ percents = calculate_cluster_percents(counts_df, labels_df, complex_table)
244
+
245
+ n_bins = 50
246
+ precision_digit = 0.01
247
+ pmf_bins_digit = np.arange(0, n_bins+precision_digit - 1e-10, precision_digit)
248
+
249
+ #######
250
+ logger.info("Calculating null distributions.")
251
+ n_fft = 100
252
+ gene_sum_pmf_dict = {}
253
+ basic_info_dict = {}
254
+ for gene in counts_df.columns:
255
+ samples = counts_df[gene].values
256
+
257
+ loc = np.mean(samples)
258
+ scale = np.std(samples)
259
+ basic_info_dict[gene] = {'loc':loc, 'scale':scale}
260
+
261
+ gene_sum_pmf_dict[gene] = {1: get_pmf_array_from_samples_for_digitized_bins(samples)}
262
+ basic_info_dict[gene]['expr_dist'] = gene_sum_pmf_dict[gene][1]
263
+
264
+ for item in range(2,n_fft):
265
+ gene_sum_pmf_dict[gene][item] = fftconvolve(gene_sum_pmf_dict[gene][item-1], gene_sum_pmf_dict[gene][1])
266
+
267
+ gene_pmf_dict = {}
268
+ for gene in counts_df.columns:
269
+ gene_pmf_dict[gene] = {}
270
+ loc = basic_info_dict[gene]['loc']
271
+ scale = basic_info_dict[gene]['scale']
272
+ for item in range(1,n_fft):
273
+ pmf = gene_sum_pmf_dict[gene][item]
274
+ cdf = np.cumsum(pmf)
275
+ pmf_array = np.diff(cdf[np.int64(pmf_bins_digit * item)],prepend=0)
276
+ if item == 1:
277
+ gene_pmf_dict[gene][item] = Distribution_digit('other', pmf_array=pmf_array, loc=loc, scale=scale, is_align=True)
278
+ else:
279
+ gene_pmf_dict[gene][item] = Distribution_digit('other', pmf_array=pmf_array, is_align=True)
280
+
281
+ if ref_debug_mode or query_debug_mode:
282
+ p1_index = []
283
+ p2_index = []
284
+ all_index = []
285
+ for i in itertools.product(sorted(mean_counts.index), sorted(mean_counts.index)):
286
+ p1_index.append(i[0])
287
+ p2_index.append(i[1])
288
+ all_index.append('|'.join(i))
289
+ p1 = mean_counts.loc[p1_index, interactions['multidata_1_id']]
290
+ p2 = mean_counts.loc[p2_index, interactions['multidata_2_id']]
291
+ p1.columns = interactions.index
292
+ p2.columns = interactions.index
293
+ p1.index = all_index
294
+ p2.index = all_index
295
+ interactions_strength = (p1 + p2)/2 * (p1 > 0) * (p2>0)
296
+ L_perc, R_perc, percents_analysis = calculate_L_R_and_IS_percents(percents, interactions, threshold=min_percentile)
297
+
298
+ meta_dict = Counter(labels_df.cell_type)
299
+ ####### clusters_mean #######
300
+ clusters_mean_dict = {}
301
+ for celltype in sorted(meta_dict):
302
+ clusters_mean_dict[celltype] = {}
303
+ n_sum = meta_dict[celltype]
304
+ if n_sum < n_fft:
305
+ for gene in counts_df.columns:
306
+ clusters_mean_dict[celltype][gene] = gene_pmf_dict[gene][n_sum]
307
+ else:
308
+ for gene in counts_df.columns:
309
+ clusters_mean_dict[celltype][gene] = gene_pmf_dict[gene][1] ** n_sum / n_sum
310
+ mean_pmfs = pd.DataFrame(clusters_mean_dict).T
311
+ complex_func = get_minimum_distribution_for_digit
312
+ mean_pmfs = dist_complex.combine_complex_distribution_df(mean_pmfs, complex_table, complex_func)
313
+
314
+ logger.info("Calculating sig. LRIs.")
315
+ pvals = dist_lr.calculate_key_interactions_pvalue(
316
+ mean_pmfs, interactions, interactions_strength, percents_analysis, method='Arithmetic'
317
+ )
318
+
319
+ if query_debug_mode:
320
+ pvals.to_csv(f'{save_path}/debug_pvals.txt', sep='\t')
321
+ return
322
+
323
+ if ref_debug_mode:
324
+ pvals.to_csv(f'{save_path}/ref_pvals.txt', sep='\t')
325
+ percents_analysis.to_csv(f'{save_path}/ref_percents_analysis.txt', sep='\t')
326
+ L_perc.to_csv(f'{save_path}/ref_percents_L.txt', sep='\t')
327
+ R_perc.to_csv(f'{save_path}/ref_percents_R.txt', sep='\t')
328
+ # interactions_strength.to_csv(f'{save_path}/ref_interactions_strength.csv')
329
+ p1.to_csv(f'{save_path}/ref_interactions_strength_L.txt', sep='\t')
330
+ p2.to_csv(f'{save_path}/ref_interactions_strength_R.txt', sep='\t')
331
+
332
+ ####### save reference results #######
333
+ logger.info("Saving reference.")
334
+ if for_uploading:
335
+ with open(f'{save_path}/basic_info_dict.pkl', 'wb') as f:
336
+ pickle.dump(basic_info_dict, f)
337
+ else:
338
+ with open(f'{save_path}/ref_gene_pmf_dict.pkl', 'wb') as f:
339
+ pickle.dump(gene_pmf_dict, f)
340
+ with open(f'{save_path}/ref_percents.pkl', 'wb') as f:
341
+ pickle.dump(percents, f)
342
+ with open(f'{save_path}/ref_mean_counts.pkl', 'wb') as f:
343
+ pickle.dump(mean_counts, f)
344
+ with open(f'{save_path}/complex_table.pkl', 'wb') as f:
345
+ pickle.dump(complex_table, f)
346
+ with open(f'{save_path}/interactions.pkl', 'wb') as f:
347
+ pickle.dump(interactions, f)
348
+
349
+
350
+ def record_hk_genes(adata):
351
+ from .hk_genes import housekeeping_genes
352
+ select_index = [item for item in adata.var_names if item in housekeeping_genes]
353
+ hk_adata = adata[:, select_index]
354
+ mean_hk_rnk = hk_adata.X.mean(axis=0)
355
+ return mean_hk_rnk, select_index
356
+
357
+ def record_adjustment_info(adata, save_path):
358
+ mean_hk_rnk, gene_index = record_hk_genes(adata)
359
+ ref_hk = pd.DataFrame(np.array(mean_hk_rnk).flatten(), index=gene_index, columns=['ref_hk'])
360
+ ref_hk.to_csv(f'{save_path}/ref_hk.txt', sep='\t')
361
+
362
+
363
+ reference_config = {}
364
+
365
+ from datetime import date
366
+
367
+ def dumps(toml_dict, table=""):
368
+ document = []
369
+ for key, value in toml_dict.items():
370
+ match value:
371
+ case dict():
372
+ table_key = f"{table}.{key}" if table else key
373
+ document.append(
374
+ f"\n[{table_key}]\n{_dumps_dict(value)}"
375
+ )
376
+ case _:
377
+ document.append(f"{key} = {_dumps_value(value)}")
378
+ return "\n".join(document)
379
+
380
+ def _dumps_dict(toml_dict):
381
+ document = []
382
+ for key, value in toml_dict.items():
383
+ key = f'"{key}"'
384
+ document.append(f"{key} = {_dumps_value(value)}")
385
+ return "\n".join(document)
386
+
387
+ def _dumps_value(value):
388
+ match value:
389
+ case bool():
390
+ return "true" if value else "false"
391
+ case float() | int():
392
+ return str(value)
393
+ case str():
394
+ return f'"{value}"'
395
+ case date():
396
+ return value.isoformat()
397
+ case list():
398
+ return f"[{', '.join(_dumps_value(v) for v in value)}]"
399
+ case _:
400
+ raise TypeError(
401
+ f"{type(value).__name__} {value!r} is not supported"
402
+ )
403
+
404
+ def save_config(save_path):
405
+ logger.info("Saving reference config.")
406
+ save_content = dumps(reference_config)
407
+ with open(f'{save_path}/config.toml', 'w') as f:
408
+ f.write(save_content)
409
+
410
+
411
+ def build_reference_workflow(database_file_path, reference_counts_file_path, celltype_file_path, reference_name, save_path, meta_key=None, min_percentile = 0.1, debug_mode=False, for_uploading=False):
412
+ logger.info(f"Start building CCC reference.")
413
+
414
+ reference_config['reference_name'] = reference_name
415
+ reference_config['min_percentile'] = min_percentile
416
+ if database_file_path.endswith('/'):
417
+ reference_config['LRI_database'] = database_file_path[:-1].split('/')[-1]
418
+ else:
419
+ reference_config['LRI_database'] = database_file_path.split('/')[-1]
420
+
421
+ logger.info(f"Reference_name = {reference_config['reference_name']}")
422
+ logger.info(f"min_percentile = {reference_config['min_percentile']}")
423
+ logger.info(f"LRI database = {reference_config['LRI_database']}")
424
+
425
+ save_path = os.path.join(save_path, reference_name)
426
+ if not os.path.exists(save_path):
427
+ os.makedirs(save_path)
428
+ logger.success(f"Reference save dir {save_path} is created.")
429
+ else:
430
+ logger.warning(f"{save_path} already exists, all files will be overwritten")
431
+
432
+ reference = sc.read_h5ad(reference_counts_file_path)
433
+ sc.pp.filter_cells(reference, min_genes=50)
434
+ logger.info(f"Reading reference adata, {reference.shape[0]} cells x {reference.shape[1]} genes.")
435
+
436
+ if meta_key is not None:
437
+ labels_df = pd.DataFrame(reference.obs[meta_key])
438
+ labels_df.columns = ['cell_type']
439
+ labels_df.index.name = 'barcode_sample'
440
+ else:
441
+ labels_df = pd.read_csv(celltype_file_path, sep='\t', index_col=0)
442
+ for barcode in reference.obs_names:
443
+ assert barcode in labels_df.index, "The index of query data doesn't match the index of labels"
444
+ labels_df = labels_df.loc[reference.obs_names, :]
445
+
446
+ ct_counter = Counter(labels_df['cell_type'])
447
+ reference_config['celltype'] = ct_counter
448
+
449
+ reference = rank_preprocess(reference)
450
+ record_adjustment_info(reference, save_path)
451
+ counts_df, complex_table, interactions = get_fastccc_input(reference, database_file_path)
452
+ fastccc_for_reference(reference_name, save_path, counts_df, labels_df, complex_table, interactions, min_percentile, ref_debug_mode=debug_mode, for_uploading=for_uploading)
453
+ save_config(save_path)
454
+ logger.success(f"Reference '{reference_name}' is built.")
455
+
@@ -0,0 +1,55 @@
1
+ import sys
2
+ import os, glob
3
+ import math
4
+ import getopt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import pickle as pkl
8
+ from loguru import logger
9
+
10
+ def cauthy_combine(fastCCC_dir, task_id=None):
11
+ if task_id is None:
12
+ logger.warning("No task_id is provided, all pvals files will be combined.")
13
+ pval_paths = glob.glob(fastCCC_dir+os.sep+'*pvals.tsv')
14
+ else:
15
+ logger.info(f"Task ID for combining is :{task_id}")
16
+ pval_paths = glob.glob(fastCCC_dir+os.sep+f'{task_id}*pvals.tsv')
17
+ logger.info(f"There are {len(pval_paths)} pval files.")
18
+ joined_path = '\n'.join(pval_paths)
19
+ logger.debug(f"\n{joined_path}")
20
+
21
+ ct_pairs, cpis = None, None
22
+ comb_dict = dict()
23
+ for pval_path in pval_paths:
24
+ comb = os.path.basename(pval_path)
25
+ comb = comb.replace('pvals.tsv', '')
26
+ pval_df = pd.read_csv(pval_path, header=0, index_col=0, sep='\t')
27
+ if ct_pairs is None:
28
+ ct_pairs = pval_df.index.tolist()
29
+ if cpis is None:
30
+ cpis = pval_df.columns.tolist()
31
+ comb_dict[comb] = pval_df.values
32
+
33
+ pval_mat = []
34
+ for comb, values in comb_dict.items():
35
+ pval_mat.append(np.expand_dims(values, axis=1))
36
+ pval_mat = np.concatenate(pval_mat, axis=1)
37
+ weight = np.ones(len(comb_dict)) / len(comb_dict)
38
+ T = pval_mat.copy()
39
+ T[np.where(pval_mat == 1)] = np.tan(-np.pi/2)
40
+ T[foo] = np.tan(np.pi*(0.5 - T[(foo:=np.where(pval_mat != 1))]))
41
+ T = weight @ T
42
+ P = 0.5 - np.arctan(T) / np.pi
43
+
44
+ T_df = pd.DataFrame(T, index=ct_pairs, columns=cpis)
45
+ P_df = pd.DataFrame(P, index=ct_pairs, columns=cpis)
46
+
47
+ if task_id is None:
48
+ T_df.to_csv(fastCCC_dir+os.sep+'Cauchy_stats.tsv', sep='\t')
49
+ P_df.to_csv(fastCCC_dir+os.sep+'Cauchy_pvals.tsv', sep='\t')
50
+ else:
51
+ T_df.to_csv(fastCCC_dir+os.sep+f'{task_id}_Cauchy_stats.tsv', sep='\t')
52
+ P_df.to_csv(fastCCC_dir+os.sep+f'{task_id}_Cauchy_pvals.tsv', sep='\t')
53
+
54
+ # if __name__ == "__main__":
55
+ # cauthy_combine(fastCCC_dir)
fastccc/ccc_utils.py ADDED
@@ -0,0 +1,75 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import os
4
+ import psutil
5
+
6
+
7
+ def create_significant_interactions_df(
8
+ pvals,
9
+ LRI_db_path,
10
+ save = False,
11
+ save_path = './temp/',
12
+ save_file = 'significant_interaction_list',
13
+ seperator = '|'
14
+ ):
15
+
16
+ from . import preprocess
17
+ interactions = preprocess.get_interactions(LRI_db_path)
18
+ id2symbol_dict = (pd.read_csv(f'{LRI_db_path}/gene_table.csv')[['protein_id', 'hgnc_symbol']]\
19
+ .set_index('protein_id')['hgnc_symbol']).to_dict()
20
+ ##### complex_table ######
21
+ complex_composition = pd.read_csv(os.path.join(LRI_db_path, 'complex_composition_table.csv'))
22
+ complex_table = pd.read_csv(os.path.join(LRI_db_path, 'complex_table.csv'))
23
+ complex_table = complex_table.merge(complex_composition, left_on='complex_multidata_id', right_on='complex_multidata_id')
24
+ foo_dict = complex_table.groupby('complex_multidata_id').apply(lambda x: list(x['protein_multidata_id'].values), include_groups=False).to_dict()
25
+ for key in foo_dict:
26
+ value = ','.join([id2symbol_dict[item] for item in foo_dict[key]])
27
+ id2symbol_dict[key] = value
28
+
29
+ x, y = np.where(pvals.values<0.05)
30
+ lines = []
31
+ for subx, suby in zip(x,y):
32
+ celltype_A, celltype_B = pvals.index[subx].split(seperator)
33
+ interaction_ID = pvals.columns[suby]
34
+ lines.append((celltype_A, celltype_B, interaction_ID, pvals.iloc[subx, suby]))
35
+
36
+ output_df = pd.DataFrame(lines, columns=['sender_celltype', 'receiver_celltype', 'LRI_ID', 'p-value'])
37
+ output_df = output_df.merge(interactions, left_on='LRI_ID', right_index=True, how='left')
38
+
39
+ output_df.multidata_1_id = [id2symbol_dict[ligand_gene_name] for ligand_gene_name in output_df.multidata_1_id.tolist()]
40
+ output_df.multidata_2_id = [id2symbol_dict[receptor_gene_name] for receptor_gene_name in output_df.multidata_2_id.tolist()]
41
+
42
+ output_df = output_df[['sender_celltype', 'receiver_celltype', 'LRI_ID', 'multidata_1_id', 'multidata_2_id', 'p-value']]
43
+ output_df = output_df.rename(columns={'multidata_1_id': 'ligand', 'multidata_2_id': 'receptor'})
44
+
45
+ save_file = os.path.join(save_path, save_file)
46
+ if save:
47
+ output_df.to_excel(f'{save_file}.xlsx')
48
+ return output_df
49
+
50
+
51
+ def create_significant_interactions_with_flag_df(pvals, significant_flag, interactions, save_path='./temp/', save_file='significant_interaction_list'):
52
+ seperator = '|'
53
+ x, y = np.where(pvals.values<0.05)
54
+ lines = []
55
+ for subx, suby in zip(x,y):
56
+ if not significant_flag.iloc[subx, suby]:
57
+ continue
58
+ celltype_A, celltype_B = pvals.index[subx].split(seperator)
59
+ interaction_ID = pvals.columns[suby]
60
+ lines.append((celltype_A, celltype_B, interaction_ID, pvals.iloc[subx, suby]))
61
+ output_df = pd.DataFrame(lines, columns=['Ligand_celltype', 'Receptor_celltype', 'Interaction_ID', 'P-val'])
62
+ output_df = output_df.merge(interactions, left_on='Interaction_ID', right_index=True, how='left')
63
+ save_file = os.path.join(save_path, save_file)
64
+ output_df.to_excel(f'{save_file}.xlsx')
65
+ return output_df
66
+
67
+ def get_current_memory():
68
+ """
69
+ 获取当前内存占用
70
+ usage:
71
+ current_memory = get_current_memory()
72
+ print("当前内存占用: {:.2f} MB".format(current_memory))
73
+ """
74
+ process = psutil.Process()
75
+ return process.memory_info().rss / (1024 * 1024) # 转换为MB