py2ls 0.1.10.12__py3-none-any.whl → 0.2.7.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of py2ls might be problematic. Click here for more details.

Files changed (72) hide show
  1. py2ls/.DS_Store +0 -0
  2. py2ls/.git/.DS_Store +0 -0
  3. py2ls/.git/index +0 -0
  4. py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
  5. py2ls/.git/objects/.DS_Store +0 -0
  6. py2ls/.git/refs/.DS_Store +0 -0
  7. py2ls/ImageLoader.py +621 -0
  8. py2ls/__init__.py +7 -5
  9. py2ls/apptainer2ls.py +3940 -0
  10. py2ls/batman.py +164 -42
  11. py2ls/bio.py +2595 -0
  12. py2ls/cell_image_clf.py +1632 -0
  13. py2ls/container2ls.py +4635 -0
  14. py2ls/corr.py +475 -0
  15. py2ls/data/.DS_Store +0 -0
  16. py2ls/data/email/email_html_template.html +88 -0
  17. py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
  18. py2ls/data/hyper_param_tabrepo_2024.py +1753 -0
  19. py2ls/data/mygenes_fields_241022.txt +355 -0
  20. py2ls/data/re_common_pattern.json +173 -0
  21. py2ls/data/sns_info.json +74 -0
  22. py2ls/data/styles/.DS_Store +0 -0
  23. py2ls/data/styles/example/.DS_Store +0 -0
  24. py2ls/data/styles/stylelib/.DS_Store +0 -0
  25. py2ls/data/styles/stylelib/grid.mplstyle +15 -0
  26. py2ls/data/styles/stylelib/high-contrast.mplstyle +6 -0
  27. py2ls/data/styles/stylelib/high-vis.mplstyle +4 -0
  28. py2ls/data/styles/stylelib/ieee.mplstyle +15 -0
  29. py2ls/data/styles/stylelib/light.mplstyl +6 -0
  30. py2ls/data/styles/stylelib/muted.mplstyle +6 -0
  31. py2ls/data/styles/stylelib/nature-reviews-latex.mplstyle +616 -0
  32. py2ls/data/styles/stylelib/nature-reviews.mplstyle +616 -0
  33. py2ls/data/styles/stylelib/nature.mplstyle +31 -0
  34. py2ls/data/styles/stylelib/no-latex.mplstyle +10 -0
  35. py2ls/data/styles/stylelib/notebook.mplstyle +36 -0
  36. py2ls/data/styles/stylelib/paper.mplstyle +290 -0
  37. py2ls/data/styles/stylelib/paper2.mplstyle +305 -0
  38. py2ls/data/styles/stylelib/retro.mplstyle +4 -0
  39. py2ls/data/styles/stylelib/sans.mplstyle +10 -0
  40. py2ls/data/styles/stylelib/scatter.mplstyle +7 -0
  41. py2ls/data/styles/stylelib/science.mplstyle +48 -0
  42. py2ls/data/styles/stylelib/std-colors.mplstyle +4 -0
  43. py2ls/data/styles/stylelib/vibrant.mplstyle +6 -0
  44. py2ls/data/tiles.csv +146 -0
  45. py2ls/data/usages_pd.json +1417 -0
  46. py2ls/data/usages_sns.json +31 -0
  47. py2ls/docker2ls.py +5446 -0
  48. py2ls/ec2ls.py +61 -0
  49. py2ls/fetch_update.py +145 -0
  50. py2ls/ich2ls.py +1955 -296
  51. py2ls/im2.py +8242 -0
  52. py2ls/image_ml2ls.py +2100 -0
  53. py2ls/ips.py +33909 -3418
  54. py2ls/ml2ls.py +7700 -0
  55. py2ls/mol.py +289 -0
  56. py2ls/mount2ls.py +1307 -0
  57. py2ls/netfinder.py +873 -351
  58. py2ls/nl2ls.py +283 -0
  59. py2ls/ocr.py +1581 -458
  60. py2ls/plot.py +10394 -314
  61. py2ls/rna2ls.py +311 -0
  62. py2ls/ssh2ls.md +456 -0
  63. py2ls/ssh2ls.py +5933 -0
  64. py2ls/ssh2ls_v01.py +2204 -0
  65. py2ls/stats.py +66 -172
  66. py2ls/temp20251124.py +509 -0
  67. py2ls/translator.py +2 -0
  68. py2ls/utils/decorators.py +3564 -0
  69. py2ls/utils_bio.py +3453 -0
  70. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/METADATA +113 -224
  71. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/RECORD +72 -16
  72. {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/WHEEL +0 -0
py2ls/bio.py ADDED
@@ -0,0 +1,2595 @@
1
+ # #======== 1. GEO data Processing Pipeline======
2
+ # # Load and integrate multiple datasets
3
+ # geo_data = load_geo(datasets, dir_save)
4
+ # complete_data = get_data(geo_data, dataset)
5
+
6
+ # # Quality control and normalization
7
+ # data_type = get_data_type(complete_data)
8
+ # if data_type == "read counts":
9
+ # normalized_data = counts2expression(complete_data, method='TMM')
10
+
11
+ # # Batch correction for multiple datasets
12
+ # corrected_data = batch_effect([data1, data2, data3], datasets)
13
+
14
+ # #======== 2. Differential Expression + Enrichment Pipeline======
15
+ # # DESeq2 analysis
16
+ # dds, diff_results, stats, norm_counts = counts_deseq(counts, metadata)
17
+
18
+ # # Enrichment analysis on significant genes
19
+ # sig_genes = diff_results[diff_results.padj < 0.05].gene.tolist()
20
+ # enrichment_results = get_enrichr(sig_genes, 'KEGG_2021_Human')
21
+
22
+ # # Visualization
23
+ # plot_enrichr(enrichment_results, kind='dotplot')
24
+
25
+ # #======== 3. Network Analysis Pipeline======
26
+ # # PPI network construction
27
+ # interactions = get_ppi(target_genes, species=9606, ci=0.7)
28
+
29
+ # # Network visualization and analysis
30
+ # G, ax = plot_ppi(interactions, layout='degree')
31
+ # key_proteins = top_ppi(interactions, n_top=10)
32
+
33
+ # #======== Dependencies ======
34
+ # GEOparse: GEO data access
35
+ # gseapy: Enrichment analysis
36
+ # pydeseq2: Differential expression
37
+ # rnanorm: Count normalization
38
+ # mygene: Gene identifier conversion
39
+ # networkx: Network analysis
40
+ # pyvis: Interactive network visualization
41
+
42
+ # This toolbox provides end-to-end capabilities for genomics data analysis from raw
43
+ # data loading through advanced network biology, with particular strengths in multi-
44
+ # dataset integration and interactive visualization.
45
+
46
+ import GEOparse
47
+ import gseapy as gp
48
+ from typing import Union
49
+ import pandas as pd
50
+ import numpy as np
51
+ import os
52
+ import logging
53
+
54
+ from sympy import use
55
+ from . import ips
56
+ from . import plot
57
+ import matplotlib.pyplot as plt
58
+
59
+ def load_geo(
60
+ datasets: Union[list, str] = ["GSE00000", "GSE00001"],
61
+ dir_save: str = "./datasets",
62
+ verbose=False,
63
+ ) -> dict:
64
+ """
65
+ Purpose: Downloads and loads GEO datasets from NCBI database
66
+ Principle: Uses GEOparse library to fetch and parse GEO SOFT files. Checks local cache first to avoid redundant downloads.
67
+ Key Operations:
68
+ * Verifies if datasets exist locally in specified directory
69
+ * Downloads missing datasets using GEOparse API
70
+ * Returns dictionary of GEO objects for further processing
71
+
72
+ Parameters:
73
+ datasets (list): List of GEO dataset IDs to download.
74
+ dir_save (str): Directory where datasets will be stored.
75
+
76
+ Returns:
77
+ dict: A dictionary containing the GEO objects for each dataset.
78
+ """
79
+ use_str = """
80
+ get_meta(geo: dict, dataset: str = "GSE25097")
81
+ get_expression_data(geo: dict, dataset: str = "GSE25097")
82
+ get_probe(geo: dict, dataset: str = "GSE25097", platform_id: str = "GPL10687")
83
+ get_data(geo: dict, dataset: str = "GSE25097")
84
+ """
85
+ print(f"you could do further: \n{use_str}")
86
+ if not verbose:
87
+ logging.getLogger("GEOparse").setLevel(logging.WARNING)
88
+ else:
89
+ logging.getLogger("GEOparse").setLevel(logging.DEBUG)
90
+ # Create the directory if it doesn't exist
91
+ if not os.path.exists(dir_save):
92
+ os.makedirs(dir_save)
93
+ print(f"Created directory: {dir_save}")
94
+ if isinstance(datasets, str):
95
+ datasets = [datasets]
96
+ geo_data = {}
97
+ for dataset in datasets:
98
+ # Check if the dataset file already exists in the directory
99
+ dataset_file = os.path.join(dir_save, f"{dataset}_family.soft.gz")
100
+
101
+ if not os.path.isfile(dataset_file):
102
+ print(f"\n\nDataset {dataset} not found locally. Downloading...")
103
+ geo = GEOparse.get_GEO(geo=dataset, destdir=dir_save)
104
+ else:
105
+ print(f"\n\nDataset {dataset} already exists locally. Loading...")
106
+ geo = GEOparse.get_GEO(filepath=dataset_file)
107
+
108
+ geo_data[dataset] = geo
109
+
110
+ return geo_data
111
+
112
+
113
+ def get_meta(geo: dict, dataset: str = "GSE25097", verbose=True) -> pd.DataFrame:
114
+ """
115
+ Purpose: Extracts comprehensive metadata from GEO datasets
116
+ Principle: Parses hierarchical structure of GEO objects (study, platform, sample metadata) and flattens into DataFrame
117
+ Key Operations:
118
+ Combines study-level, platform-level, and sample-level metadata
119
+ Handles list-type metadata values by concatenation
120
+ Removes irrelevant columns (contact info, technical details)
121
+ Output: DataFrame with samples as rows and all available metadata as columns
122
+
123
+ df_meta = get_meta(geo, dataset="GSE25097")
124
+ Extracts metadata from a specific GEO dataset and returns it as a DataFrame.
125
+ The function dynamically extracts all available metadata fields from the given dataset.
126
+
127
+ Parameters:
128
+ geo (dict): A dictionary containing the GEO objects for different datasets.
129
+ dataset (str): The name of the dataset to extract metadata from (default is "GSE25097").
130
+
131
+ Returns:
132
+ pd.DataFrame: A DataFrame containing structured metadata from the specified GEO dataset.
133
+ """
134
+ # Check if the dataset is available in the provided GEO dictionary
135
+ if dataset not in geo:
136
+ raise ValueError(f"Dataset '{dataset}' not found in the provided GEO data.")
137
+
138
+ # List to store metadata dictionaries
139
+ meta_list = []
140
+
141
+ # Extract the GEO object for the specified dataset
142
+ geo_obj = geo[dataset]
143
+
144
+ # Overall Study Metadata
145
+ study_meta = geo_obj.metadata
146
+ study_metadata = {key: study_meta[key] for key in study_meta.keys()}
147
+
148
+ # Platform Metadata
149
+ for platform_id, platform in geo_obj.gpls.items():
150
+ platform_metadata = {
151
+ key: platform.metadata[key] for key in platform.metadata.keys()
152
+ }
153
+ platform_metadata["platform_id"] = platform_id # Include platform ID
154
+
155
+ # Sample Metadata
156
+ for sample_id, sample in geo_obj.gsms.items():
157
+ sample_metadata = {
158
+ key: sample.metadata[key] for key in sample.metadata.keys()
159
+ }
160
+ sample_metadata["sample_id"] = sample_id # Include sample ID
161
+ # Combine all metadata into a single dictionary
162
+ combined_meta = {
163
+ "dataset": dataset,
164
+ **{
165
+ k: (
166
+ v[0]
167
+ if isinstance(v, list) and len(v) == 1
168
+ else ", ".join(map(str, v))
169
+ )
170
+ for k, v in study_metadata.items()
171
+ }, # Flatten study metadata
172
+ **platform_metadata, # Unpack platform metadata
173
+ **{
174
+ k: (
175
+ v[0]
176
+ if isinstance(v, list) and len(v) == 1
177
+ else "".join(map(str, v))
178
+ )
179
+ for k, v in sample_metadata.items()
180
+ }, # Flatten sample metadata
181
+ }
182
+
183
+ # Append the combined metadata to the list
184
+ meta_list.append(combined_meta)
185
+
186
+ # Convert the list of dictionaries to a DataFrame
187
+ meta_df = pd.DataFrame(meta_list)
188
+ col_rm = [
189
+ "channel_count",
190
+ "contact_web_link",
191
+ "contact_address",
192
+ "contact_city",
193
+ "contact_country",
194
+ "contact_department",
195
+ "contact_email",
196
+ "contact_institute",
197
+ "contact_laboratory",
198
+ "contact_name",
199
+ "contact_phone",
200
+ "contact_state",
201
+ "contact_zip/postal_code",
202
+ "contributor",
203
+ "manufacture_protocol",
204
+ "taxid",
205
+ "web_link",
206
+ ]
207
+ # rm unrelavent columns
208
+ meta_df = meta_df.drop(columns=[col for col in col_rm if col in meta_df.columns])
209
+ if verbose:
210
+ print(
211
+ f"Meta info columns for dataset '{dataset}': \n{sorted(meta_df.columns.tolist())}"
212
+ )
213
+ display(meta_df[:1].T)
214
+ return meta_df
215
+
216
+
217
+ def get_probe(
218
+ geo: dict, dataset: str = "GSE25097", platform_id: str = None, verbose=True
219
+ ):
220
+ """
221
+ Purpose: Retrieves probe annotation information from GEO platforms
222
+ Principle: Accesses platform annotation tables containing gene symbols, IDs, and probe information
223
+ Key Operations:
224
+ Automatically detects platform IDs from metadata
225
+ Handles multiple platforms within single dataset
226
+ Provides direct links to NCBI platform pages for manual verification
227
+
228
+ df_probe = get_probe(geo, dataset="GSE25097", platform_id: str = "GPL10687")
229
+ """
230
+ # try to find the platform_id from meta
231
+ if platform_id is None:
232
+ df_meta = get_meta(geo=geo, dataset=dataset, verbose=False)
233
+ platform_id = df_meta["platform_id"].unique().tolist()
234
+ print(f"Platform: {platform_id}")
235
+ if len(platform_id) > 1:
236
+ df_probe= geo[dataset].gpls[platform_id[0]].table
237
+ # df_probe=pd.DataFrame()
238
+ # # Iterate over each platform ID and collect the probe tables
239
+ # for platform_id_ in platform_id:
240
+ # if platform_id_ in geo[dataset].gpls:
241
+ # df_probe_ = geo[dataset].gpls[platform_id_].table
242
+ # if not df_probe_.empty:
243
+ # df_probe=pd.concat([df_probe, df_probe_])
244
+ # else:
245
+ # print(f"Warning: Probe table for platform {platform_id_} is empty.")
246
+ # else:
247
+ # print(f"Warning: Platform ID {platform_id_} not found in dataset {dataset}.")
248
+ else:
249
+ df_probe= geo[dataset].gpls[platform_id[0]].table
250
+
251
+ if df_probe.empty:
252
+ print(
253
+ f"Warning: cannot find the probe info. 看一下是不是在单独的文件中包含了probe信息"
254
+ )
255
+ display(f"🔗: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc={platform_id}")
256
+ return get_meta(geo, dataset, verbose=verbose)
257
+ if verbose:
258
+ print(f"columns in the probe table: \n{sorted(df_probe.columns.tolist())}")
259
+ return df_probe
260
+
261
+
262
+ def get_expression_data(geo: dict, dataset: str = "GSE25097") -> pd.DataFrame:
263
+ """
264
+ Purpose: Extracts expression matrix from GEO datasets
265
+ Principle: Pivots sample tables to create gene expression matrix
266
+ Key Operations:
267
+ Handles both pre-pivoted data and individual sample tables
268
+ Maintains sample IDs as columns/rows appropriately
269
+ Output: DataFrame with expression values
270
+
271
+ df_expression = get_expression_data(geo,dataset="GSE25097")
272
+ 只包含表达量数据,并没有考虑它的probe和其它的meta
273
+
274
+ Extracts expression values from GEO datasets and returns it as a DataFrame.
275
+
276
+ Parameters:
277
+ geo (dict): A dictionary containing GEO objects for each dataset.
278
+
279
+ Returns:
280
+ pd.DataFrame: A DataFrame containing expression data from the GEO datasets.
281
+ """
282
+ expression_dataframes = []
283
+ try:
284
+ expression_values = geo[dataset].pivot_samples("VALUE")
285
+ except:
286
+ for sample_id, sample in geo[dataset].gsms.items():
287
+ if hasattr(sample, "table"):
288
+ expression_values = (
289
+ sample.table.T
290
+ ) # Transpose for easier DataFrame creation
291
+ expression_values["dataset"] = dataset
292
+ expression_values["sample_id"] = sample_id
293
+ return expression_values
294
+
295
+
296
+ def get_data(geo: dict, dataset: str = "GSE25097", verbose=False):
297
+ """
298
+ Purpose: Comprehensive data integration - merges expression data with probe annotations and metadata
299
+ Principle: Performs multi-level data integration using pandas merge operations
300
+ Key Operations:
301
+ • Merges probe annotations with expression data
302
+ • Transposes expression matrix to samples-as-rows format
303
+ • Integrates metadata using sample IDs
304
+ • Automatically detects and normalizes raw counts dataOutput: Complete dataset ready for analysis
305
+ """
306
+ print(f"\n\ndataset: {dataset}\n")
307
+ # get probe info
308
+ df_probe = get_probe(geo, dataset=dataset, verbose=False)
309
+ # get expression values
310
+ df_expression = get_expression_data(geo, dataset=dataset)
311
+ if not df_expression.select_dtypes(include=["number"]).empty:
312
+ # 如果数据全部是counts类型的话, 则使用TMM进行normalize
313
+ if 'counts' in get_data_type(df_expression):
314
+ try:
315
+ df_expression=counts2expression(df_expression.T).T
316
+ print(f"{dataset}'s type is raw read counts, nomalized(transformed) via 'TMM'")
317
+ except Exception as e:
318
+ print("raw counts data")
319
+ if any([df_probe.empty, df_expression.empty]):
320
+ print(
321
+ f"got empty values, check the probe info. 看一下是不是在单独的文件中包含了probe信息"
322
+ )
323
+ return get_meta(geo, dataset, verbose=True)
324
+ print(
325
+ f"\n\tdf_expression.shape: {df_expression.shape} \n\tdf_probe.shape: {df_probe.shape}"
326
+ )
327
+ df_exp = pd.merge(
328
+ df_probe,
329
+ df_expression,
330
+ left_on=df_probe.columns.tolist()[0],
331
+ right_index=True,
332
+ how="outer",
333
+ )
334
+
335
+ # get meta info
336
+ df_meta = get_meta(geo, dataset=dataset, verbose=False)
337
+ col_rm = [
338
+ "channel_count","contact_web_link","contact_address","contact_city","contact_country","contact_department",
339
+ "contact_email","contact_institute","contact_laboratory","contact_name","contact_phone","contact_state",
340
+ "contact_zip/postal_code","contributor","manufacture_protocol","taxid","web_link",
341
+ ]
342
+ # rm unrelavent columns
343
+ df_meta = df_meta.drop(columns=[col for col in col_rm if col in df_meta.columns])
344
+ # sorte columns
345
+ df_meta = df_meta.reindex(sorted(df_meta.columns), axis=1)
346
+ # find a proper column
347
+ col_sample_id = ips.strcmp("sample_id", df_meta.columns.tolist())[0]
348
+ df_meta.set_index(col_sample_id, inplace=True) # set gene symbol as index
349
+
350
+ col_gene_symbol = ips.strcmp("GeneSymbol", df_exp.columns.tolist())[0]
351
+ # select the 'GSM' columns
352
+ col_gsm = df_exp.columns[df_exp.columns.str.startswith("GSM")].tolist()
353
+ df_exp.set_index(col_gene_symbol, inplace=True)
354
+ df_exp = df_exp[col_gsm].T # transpose, so that could add meta info
355
+
356
+ df_merged = ips.df_merge(df_meta, df_exp,use_index=True)
357
+
358
+ print(
359
+ f"\ndataset:'{dataset}' n_sample = {df_merged.shape[0]}, n_gene={df_exp.shape[1]}"
360
+ )
361
+ if verbose:
362
+ display(df_merged.sample(5))
363
+ return df_merged
364
+
365
+ def get_data_type(data: pd.DataFrame) -> str:
366
+ """
367
+ Purpose: Automatically determines data type (raw counts vs normalized expression)
368
+ Principle: Analyzes numerical characteristics of expression data
369
+ Key Operations:
370
+ Checks data types (integers vs floats)
371
+ Examines value ranges and distributions
372
+ Uses thresholds to classify as counts (>10,000 max) or normalized (<1,000 max)
373
+
374
+ Determine the type of data: 'read counts' or 'normalized expression data'.
375
+ usage:
376
+ get_data_type(df_counts)
377
+ """
378
+ numeric_data = data.select_dtypes(include=["number"])
379
+ if numeric_data.empty:
380
+ raise ValueError(f"找不到数字格式的数据, 请先进行转换")
381
+ # Check if the data contains only integers
382
+ if numeric_data.apply(lambda x: x.dtype == "int").all():
383
+ # Check for values typically found in raw read counts (large integers)
384
+ if numeric_data.max().max() > 10000: # Threshold for raw counts
385
+ return "read counts"
386
+ # Check if all values are floats
387
+ if numeric_data.apply(lambda x: x.dtype == "float").all():
388
+ # If values are small, it's likely normalized data
389
+ if numeric_data.max().max() < 1000: # Threshold for normalized expression
390
+ return "normalized expression data"
391
+ else:
392
+ print(f"the max value: {numeric_data.max().max()}, it could be a raw read counts data. but needs you to double check it")
393
+ return "read counts"
394
+ # If mixed data types or unexpected values
395
+ return "mixed or unknown"
396
+
397
+ def split_at_lower_upper(lst):
398
+ """
399
+ 将一串list,从全是lowercase,然后就是大写或者nan的地方分隔成两个list
400
+ """
401
+ for i in range(len(lst) - 1):
402
+ if isinstance(lst[i], str) and lst[i].islower():
403
+ next_item = lst[i + 1]
404
+ if isinstance(next_item, str) and next_item.isupper():
405
+ # Found the split point: lowercase followed by uppercase
406
+ return lst[: i + 1], lst[i + 1 :]
407
+ elif pd.isna(next_item):
408
+ # NaN case after a lowercase string
409
+ return lst[: i + 1], lst[i + 1 :]
410
+ return lst, []
411
+
412
+ def find_condition(data:pd.DataFrame, columns=["characteristics_ch1","title"]):
413
+ if data.shape[1]>=data.shape[0]:
414
+ display(data.iloc[:1,:40].T)
415
+ # 详细看看每个信息的有哪些类, 其中有数字的, 要去除
416
+ for col in columns:
417
+ print(f"{"="*10} {col} {"="*10}")
418
+ display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]],verbose=False))
419
+
420
+ def add_condition(
421
+ data: pd.DataFrame,
422
+ column: str = "characteristics_ch1", # 在哪一行进行分类
423
+ column_new: str = "condition", # 新col的命名
424
+ by: str = "tissue: tumor liver", # 通过by来命名
425
+ by_not: str = ": tumor", # 健康的选择条件
426
+ by_name: str = "non-tumor", # 健康的命名
427
+ by_not_name: str = "tumor", # 不健康的命名
428
+ inplace: bool = True, # replace the data
429
+ verbose: bool = True,
430
+ ):
431
+ """
432
+ Purpose: Automated sample grouping based on metadata patterns
433
+ Principle: String matching and pattern extraction from metadata columns
434
+ Usage: Rapid experimental design setup for differential analysis
435
+
436
+ Add a new column to the DataFrame based on the presence of a specific substring in another column.
437
+
438
+ Parameters
439
+ ----------
440
+ data : pd.DataFrame
441
+ The input DataFrame containing the data.
442
+ column : str, optional
443
+ The name of the column in which to search for the substring (default is 'characteristics_ch1').
444
+ column_new : str, optional
445
+ The name of the new column to be created (default is 'condition').
446
+ by : str, optional
447
+ The substring to search for in the specified column (default is 'heal').
448
+
449
+ """
450
+ # first check the content in column
451
+ content = data[column].unique().tolist()
452
+ if verbose:
453
+ if len(content) > 10:
454
+ display(content[:10])
455
+ else:
456
+ display(content)
457
+ # 优先by
458
+ if by:
459
+ data[column_new] = data[column].apply(
460
+ lambda x: by_name if by in x else by_not_name
461
+ )
462
+ elif by_not:
463
+ data[column_new] = data[column].apply(
464
+ lambda x: by_not_name if not by_not in x else by_name
465
+ )
466
+ if verbose:
467
+ display(data.sample(5))
468
+ if not inplace:
469
+ return data
470
+
471
+
472
+ def add_condition_multi(
473
+ data: pd.DataFrame,
474
+ column: str = "characteristics_ch1", # Column to classify
475
+ column_new: str = "condition", # New column name
476
+ conditions: dict = {
477
+ "low": "low",
478
+ "high": "high",
479
+ "intermediate": "intermediate",
480
+ }, # A dictionary where keys are substrings and values are condition names
481
+ default_name: str = "unknown", # Default name if no condition matches
482
+ inplace: bool = True, # Whether to replace the data
483
+ verbose: bool = True,
484
+ ):
485
+ """
486
+ Add a new column to the DataFrame based on the presence of specific substrings in another column.
487
+
488
+ Parameters
489
+ ----------
490
+ data : pd.DataFrame
491
+ The input DataFrame containing the data.
492
+ column : str, optional
493
+ The name of the column in which to search for the substrings (default is 'characteristics_ch1').
494
+ column_new : str, optional
495
+ The name of the new column to be created (default is 'condition').
496
+ conditions : dict, optional
497
+ A dictionary where keys are substrings to search for and values are the corresponding labels.
498
+ default_name : str, optional
499
+ The name to assign if no condition matches (default is 'unknown').
500
+ inplace : bool, optional
501
+ Whether to modify the original DataFrame (default is True).
502
+ verbose : bool, optional
503
+ Whether to display the unique values and final DataFrame (default is True).
504
+ """
505
+
506
+ # Display the unique values in the column
507
+ content = data[column].unique().tolist()
508
+ if verbose:
509
+ if len(content) > 10:
510
+ display(content[:10])
511
+ else:
512
+ display(content)
513
+
514
+ # Check if conditions are provided
515
+ if conditions is None:
516
+ raise ValueError(
517
+ "Conditions must be provided as a dictionary with substrings and corresponding labels."
518
+ )
519
+
520
+ # Define a helper function to map the conditions
521
+ def map_condition(value):
522
+ for substring, label in conditions.items():
523
+ if substring in value:
524
+ return label
525
+ return default_name # If no condition matches, return the default name
526
+
527
+ # Apply the mapping function to create the new column
528
+ data[column_new] = data[column].apply(map_condition)
529
+
530
+ # Display the updated DataFrame if verbose is True
531
+ if verbose:
532
+ display(data.sample(5))
533
+
534
+ if not inplace:
535
+ return data
536
+
537
+ def clean_dataset(
538
+ data: pd.DataFrame, dataset: str = None, condition: str = "condition",sep="///"
539
+ ):
540
+ """
541
+ Purpose: Standardizes and cleans integrated datasets for analysis
542
+ Principle: Handles multi-mapping genes and data formatting issues
543
+ Key Operations:
544
+ Extends genes with multiple symbols (e.g., "///" separated)
545
+ Removes duplicates and missing values
546
+ Formats sample names with dataset and condition information
547
+ Sets genes as index for downstream analysis
548
+
549
+ #* it has been involved in bio.batch_effects(), but default: False
550
+ 1. clean data set and prepare super_datasets
551
+ 2. if "///" in index, then extend it, or others.
552
+ 3. drop duplicates and dropna()
553
+ 4. add the 'condition' and 'dataset info' to the columns
554
+ 5. set genes as index
555
+ """
556
+ usage_str="""clean_dataset(data: pd.DataFrame, dataset: str = None, condition: str = "condition",sep="///")
557
+ """
558
+ if dataset is None:
559
+ try:
560
+ dataset=data["dataset"][0]
561
+ except:
562
+ print("cannot find 'dataset' name")
563
+ print(f"example\n {usage_str}")
564
+ #! (4.1) clean data set and prepare super_datasets
565
+ # df_data_2, 左边的列是meta,右边的列是gene_symbol
566
+ col_gene = split_at_lower_upper(data.columns.tolist())[1][0]
567
+ idx = ips.strcmp(col_gene, data.columns.tolist())[1]
568
+ df_gene = data.iloc[:, idx:].T # keep the last 'condition'
569
+
570
+ #! if "///" in index, then extend it, or others.
571
+ print(f"before extend shape: {df_gene.shape}")
572
+ df = df_gene.reset_index()
573
+ df_gene = ips.df_extend(df, column="index", sep=sep)
574
+ # reset 'index' column as index
575
+ # df_gene = df_gene.set_index("index")
576
+ print(f"after extended by '{sep}' shape: {df_gene.shape}")
577
+
578
+ # *alternative:
579
+ # df_unique = df.reset_index().drop_duplicates(subset="index").set_index("index")
580
+ #! 4.2 drop duplicates and dropna()
581
+ df_gene = df_gene.drop_duplicates(subset=["index"]).dropna()
582
+ print(f"drop duplicates and dropna: shape: {df_gene.shape}")
583
+
584
+ #! add the 'condition' and 'dataset info' to the columns
585
+ ds = [data["dataset"][0]] * len(df_gene.columns[1:])
586
+ samp = df_gene.columns.tolist()[1:]
587
+ cond = df_gene[df_gene["index"] == condition].values.tolist()[0][1:]
588
+ df_gene.columns = ["index"] + [
589
+ f"{ds}_{sam}_{cond}" for (ds, sam, cond) in zip(ds, samp, cond)
590
+ ]
591
+ df_gene.drop(df_gene[df_gene["index"] == condition].index, inplace=True)
592
+ #! set genes as index
593
+ df_gene.set_index("index",inplace=True)
594
+ display(df_gene.head())
595
+ return df_gene
596
+
597
+ def batch_effect(
598
+ data: list = "[df_gene_1, df_gene_2, df_gene_3]", # index (genes),columns(samples)
599
+ datasets: list = ["GSE25097", "GSE62232", "GSE65372"],
600
+ clean_data:bool=False, # default, not do data cleaning
601
+ top_genes:int=10,# only for plotting
602
+ plot_=True,
603
+ dir_save="./res/",
604
+ kws_clean_dataset:dict={},
605
+ **kwargs
606
+ ):
607
+ """
608
+ Purpose: Corrects batch effects across multiple datasets using combat algorithm
609
+ Principle: Empirical Bayes framework to adjust for technical variations
610
+ Key Operations:
611
+ Identifies common genes across datasets
612
+ Applies pyComBat normalization
613
+ Provides before/after visualization
614
+ Dependencies: combat.pycombat
615
+ usage 1:
616
+ bio.batch_effect(
617
+ data=[df_gene_1, df_gene_2, df_gene_3],
618
+ datasets=["GSE25097", "GSE62232", "GSE65372"],
619
+ clean_data=False,
620
+ dir_save="./res/")
621
+
622
+ #! # or conbine clean_dataset and batch_effect together
623
+ # # data = [bio.clean_dataset(data=dt, dataset=ds) for (dt, ds) in zip(data, datasets)]
624
+ data_common = bio.batch_effect(
625
+ data=[df_data_1, df_data_2, df_data_3],
626
+ datasets=["GSE25097", "GSE62232", "GSE65372"], clean_data=True
627
+ )
628
+ """
629
+ # data = [df_gene_1, df_gene_2, df_gene_3]
630
+ # datasets = ["GSE25097", "GSE62232", "GSE65372"]
631
+ # top_genes = 10 # show top 10 genes
632
+ # plot_ = True
633
+ from combat.pycombat import pycombat
634
+ if clean_data:
635
+ data=[clean_dataset(data=dt,dataset=ds,**kws_clean_dataset) for (dt,ds) in zip(data,datasets)]
636
+ #! prepare data
637
+ # the datasets are dataframes where:
638
+ # the indexes correspond to the gene names
639
+ # the column names correspond to the sample names
640
+ #! merge batchs
641
+ # https://epigenelabs.github.io/pyComBat/
642
+ # we merge all the datasets into one, by keeping the common genes only
643
+ df_expression_common_genes = pd.concat(data, join="inner", axis=1)
644
+ #! convert to float
645
+ ips.df_astype(df_expression_common_genes, astype="float", inplace=True)
646
+
647
+ #!to visualise results, use Mini datasets, only take the first 10 samples of each batch(dataset)
648
+ if plot_:
649
+ col2plot = []
650
+ for ds in datasets:
651
+ # select the first 10 samples to plot, to see the diff
652
+ dat_tmp = df_expression_common_genes.columns[
653
+ df_expression_common_genes.columns.str.startswith(ds)
654
+ ][:top_genes].tolist()
655
+ col2plot.extend(dat_tmp)
656
+ # visualise results
657
+ _, axs = plt.subplots(2, 1, figsize=(15, 10))
658
+ plot.plotxy(
659
+ ax=axs[0],
660
+ data=df_expression_common_genes.loc[:, col2plot],
661
+ kind_="bar",
662
+ figsets=dict(
663
+ title="Samples expression distribution (non-correction)",
664
+ ylabel="Observations",
665
+ xangle=90,
666
+ ),
667
+ )
668
+ # prepare batch list
669
+ batch = [
670
+ ips.ssplit(i, by="_")[0] for i in df_expression_common_genes.columns.tolist()
671
+ ]
672
+ # run pyComBat
673
+ df_corrected = pycombat(df_expression_common_genes, batch, **kwargs)
674
+ print(f"df_corrected.shape: {df_corrected.shape}")
675
+ display(df_corrected.head())
676
+ # visualise results again
677
+ if plot_:
678
+
679
+ plot.plotxy(
680
+ ax=axs[1],
681
+ data=df_corrected.loc[:, col2plot],
682
+ kind_="bar",
683
+ figsets=dict(
684
+ title="Samples expression distribution (corrected)",
685
+ ylabel="Observations",
686
+ xangle=90,
687
+ ),
688
+ )
689
+ if dir_save is not None:
690
+ ips.figsave(dir_save + "batch_sample_exp_distri.pdf")
691
+ return df_corrected
692
+
693
+ def get_common_genes(elment1, elment2):
694
+ """
695
+ Purpose: Identifies shared genes between datasets or gene lists
696
+ Principle: Set intersection operation with informative output
697
+ Usage: Essential for cross-dataset integration and comparison
698
+ """
699
+ common_genes=ips.shared(elment1, elment2,verbose=False)
700
+ return common_genes
701
+
702
+ def counts2expression(
703
+ counts: pd.DataFrame,# index(samples); columns(genes)
704
+ method: str = "TMM", # 'CPM', 'FPKM', 'TPM', 'UQ', 'TMM', 'CUF', 'CTF'
705
+ length: list = None,
706
+ uq_factors: pd.Series = None,
707
+ verbose: bool = False,
708
+ ) -> pd.DataFrame:
709
+ """
710
+ Purpose: Converts raw RNA-seq counts to normalized expression values
711
+ Principle: Implements multiple normalization methods for cross-dataset compatibility
712
+ Supported Methods:
713
+ TMM: Trimmed Mean of M-values - robust against compositional biases
714
+ TPM: Transcripts Per Million - length-normalized for cross-comparison
715
+ CPM: Counts Per Million - simple library size normalization
716
+ FPKM: Fragments Per Kilobase Million - length and library size normalized
717
+ UQ: Upper Quartile - uses 75th percentile for scaling
718
+ Recommendations: TMM for cross-datasets, TPM for single datasets
719
+
720
+ https://www.linkedin.com/pulse/snippet-corner-raw-read-count-normalization-python-mazzalab-gzzyf?trk=public_post
721
+ Convert raw RNA-seq read counts to expression values
722
+ counts: pd.DataFrame
723
+ index: samples
724
+ columns: genes
725
+ usage:
726
+ df_normalized = counts2expression(df_counts, method='TMM', verbose=True)
727
+ recommend cross datasets:
728
+ cross-datasets:
729
+ TMM (Trimmed Mean of M-values); Very suitable for merging datasets, especially
730
+ for cross-sample and cross-dataset comparisons; commonly used in
731
+ differential expression analysis
732
+ CTF (Counts adjusted with TMM factors); Suitable for merging datasets, as
733
+ TMM-based normalization. Typically used as input for downstream analyses
734
+ like differential expression
735
+ TPM (Transcripts Per Million); Good for merging datasets. TPM is often more
736
+ suitable for cross-dataset comparisons because it adjusts for gene length
737
+ and ensures that the expression levels sum to the same total in each sample
738
+ UQ (Upper Quartile); less commonly used than TPM or TMM
739
+ CUF (Counts adjusted with UQ factors); Can be used, but UQ normalization is
740
+ generally not as standardized as TPM or TMM for merging datasets.
741
+ within-datasets:
742
+ CPM(Counts Per Million); it doesn’t adjust for gene length or other
743
+ variables that could vary across datasets
744
+ FPKM(Fragments Per Kilobase Million); FPKM has been known to be inconsistent
745
+ across different experiments
746
+ Parameters:
747
+ - counts: pd.DataFrame
748
+ Raw read counts with genes as rows and samples as columns.
749
+ - method: str, default='TMM'
750
+ CPM (Counts per Million): Scales counts by total library size.
751
+ FPKM (Fragments per Kilobase Million): Requires gene length; scales by both library size and gene length.
752
+ TPM (Transcripts per Million): Scales by gene length and total transcript abundance.
753
+ UQ (Upper Quartile): Normalizes based on the upper quartile of the counts.
754
+ TMM (Trimmed Mean of M-values): Adjusts for compositional biases.
755
+ CUF (Counts adjusted with Upper Quartile factors): Counts adjusted based on UQ factors.
756
+ CTF (Counts adjusted with TMM factors): Counts adjusted based on TMM factors.
757
+ - gene_lengths: pd.Series, optional
758
+ Gene lengths (e.g., in kilobases) for FPKM/TPM normalization. Required for FPKM/TPM.
759
+ - verbose: bool, default=False
760
+ If True, provides detailed logging information.
761
+ - uq_factors: pd.Series, optional
762
+ Precomputed Upper Quartile factors, required for UQ and CUF normalization.
763
+
764
+
765
+ Returns:
766
+ - normalized_counts: pd.DataFrame
767
+ Normalized expression values.
768
+ """
769
+ import rnanorm
770
+ print(f"INFO: 'counts' data shoule be: index(samples); columns(genes)")
771
+ if "length" in method: # 有时候记不住这么多不同的名字
772
+ method="FPKM"
773
+ methods = ["CPM", "FPKM", "TPM", "UQ", "TMM", "CUF", "CTF"]
774
+ method = ips.strcmp(method, methods)[0]
775
+ if verbose:
776
+ print(
777
+ f"Starting normalization using method: {method},supported methods: {methods}"
778
+ )
779
+ columns_org = counts.columns.tolist()
780
+ # Check if gene lengths are provided when necessary
781
+ if method in ["FPKM", "TPM"]:
782
+ if length is None:
783
+ raise ValueError(f"Gene lengths must be provided for {method} normalization.")
784
+ if isinstance(length, list):
785
+ df_genelength = pd.DataFrame({"gene_length": length})
786
+ df_genelength.index = counts.columns # set gene_id as index
787
+ df_genelength.index = df_genelength.index.astype(str).str.strip()
788
+ # length = np.array(df_genelength["gene_length"]).reshape(1,-1)
789
+ length = df_genelength["gene_length"]
790
+ counts.index = counts.index.astype(str).str.strip()
791
+ elif isinstance(length, pd.Series):
792
+
793
+ length.index=length.index.astype(str).str.strip()
794
+ counts.columns = counts.columns.astype(str).str.strip()
795
+ shared_genes=ips.shared(length.index, counts.columns,verbose=False)
796
+ length=length.loc[shared_genes]
797
+ counts=counts.loc[:,shared_genes]
798
+ columns_org = counts.columns.tolist()
799
+
800
+
801
+ # # Ensure gene lengths are aligned with counts if provided
802
+ # if length is not None:
803
+ # length = length[counts.index]
804
+
805
+ # Start the normalization based on the chosen method
806
+ if method == "CPM":
807
+ normalized_counts = (
808
+ rnanorm.CPM().set_output(transform="pandas").fit_transform(counts)
809
+ )
810
+
811
+ elif method == "FPKM":
812
+ if verbose:
813
+ print("Performing FPKM normalization using gene lengths.")
814
+ normalized_counts = (
815
+ rnanorm.CPM().set_output(transform="pandas").fit_transform(counts)
816
+ )
817
+ # convert it to FPKM by, {FPKM= gene length /read counts ×1000} is applied using row-wise division and multiplication.
818
+ normalized_counts=normalized_counts.div(length.values,axis=1)*1e3
819
+
820
+ elif method == "TPM":
821
+ if verbose:
822
+ print("Performing TPM normalization using gene lengths.")
823
+ normalized_counts = (
824
+ rnanorm.TPM(gene_lengths=length)
825
+ .set_output(transform="pandas")
826
+ .fit_transform(counts)
827
+ )
828
+
829
+ elif method == "UQ":
830
+ if verbose:
831
+ print("Performing Upper Quartile (UQ) normalization.")
832
+ if uq_factors is None:
833
+ uq_factors = rnanorm.upper_quartile_factors(counts)
834
+ normalized_counts = (
835
+ rnanorm.UQ(factors=uq_factors)()
836
+ .set_output(transform="pandas")
837
+ .fit_transform(counts)
838
+ )
839
+
840
+ elif method == "TMM":
841
+ if verbose:
842
+ print("Performing TMM normalization (Trimmed Mean of M-values).")
843
+ normalized_counts = (
844
+ rnanorm.TMM().set_output(transform="pandas").fit_transform(counts)
845
+ )
846
+
847
+ elif method == "CUF":
848
+ if verbose:
849
+ print("Performing Counts adjusted with UQ factors (CUF).")
850
+ if uq_factors is None:
851
+ uq_factors = rnanorm.upper_quartile_factors(counts)
852
+ normalized_counts = (
853
+ rnanorm.CUF(factors=uq_factors)()
854
+ .set_output(transform="pandas")
855
+ .fit_transform(counts)
856
+ )
857
+
858
+ elif method == "CTF":
859
+ if verbose:
860
+ print("Performing Counts adjusted with TMM factors (CTF).")
861
+ normalized_counts = (rnanorm.CTF().set_output(transform="pandas").fit_transform(counts))
862
+
863
+ else:
864
+ raise ValueError(f"Unknown normalization method: {method}")
865
+ normalized_counts.columns=columns_org
866
+ if verbose:
867
+ print(f"Normalization complete using method: {method}")
868
+
869
+ return normalized_counts
870
+
871
+ def counts_deseq(counts_sam_gene: pd.DataFrame,
872
+ meta_sam_cond: pd.DataFrame,
873
+ design_factors:list=None,
874
+ kws_DeseqDataSet:dict={},
875
+ kws_DeseqStats:dict={}):
876
+ """
877
+ Purpose: Performs differential expression analysis using DESeq2 methodology
878
+ Principle: Negative binomial distribution modeling with shrinkage estimation
879
+ Key Operations:
880
+ Creates DeseqDataSet object with design formula
881
+ Estimates size factors and dispersions
882
+ Fits negative binomial models
883
+ Performs Wald tests for significance
884
+ Applies multiple testing correction (Benjamini-Hochberg)
885
+ Output Components:
886
+ dds: Complete DESeq2 dataset object
887
+ diff: Results dataframe with log2FC, p-values, FDR
888
+ stat_res: Statistical results object
889
+ df_norm: Normalized count data
890
+
891
+ https://pydeseq2.readthedocs.io/en/latest/api/docstrings/pydeseq2.ds.DeseqStats.html
892
+ Note: Using normalized expression data in a DeseqDataSet object is generally not recommended
893
+ because the DESeq2 framework is designed to work with raw count data.
894
+ baseMean:
895
+ - This value represents the average normalized count (or expression level) of a
896
+ gene across all samples in dataset.
897
+ - For example, a baseMean of 0.287 for 4933401J01Rik indicates that this gene has
898
+ low expression levels in the samples compared to others with higher baseMean
899
+ values like Xkr4 (591.015).
900
+ log2FoldChange: the magnitude and direction of change in expression between conditions.
901
+ lfcSE (Log Fold Change Standard Error): standard error of the log2FoldChange. It
902
+ indicates the uncertainty in the estimate of the fold change.A lower value indicates
903
+ more confidence in the fold change estimate.
904
+ padj: This value accounts for multiple testing corrections (e.g., Benjamini-Hochberg).
905
+ Log10transforming: The columns -log10(pvalue) and -log10(FDR) are transformations of
906
+ the p-values and adjusted p-values, respectively
907
+ """
908
+ from pydeseq2.dds import DeseqDataSet
909
+ from pydeseq2.ds import DeseqStats
910
+ from pydeseq2.default_inference import DefaultInference
911
+
912
+ # data filtering
913
+ # counts_sam_gene = counts_sam_gene.loc[:, ~(counts_sam_gene.sum(axis=0) < 10)]
914
+ if design_factors is None:
915
+ design_factors=meta_sam_cond.columns.tolist()
916
+
917
+ kws_DeseqDataSet.pop("design_factors",{})
918
+ refit_cooks=kws_DeseqDataSet.pop("refit_cooks",True)
919
+
920
+ #! DeseqDataSet
921
+ inference = DefaultInference(n_cpus=8)
922
+ dds = DeseqDataSet(
923
+ counts=counts_sam_gene,
924
+ metadata=meta_sam_cond,
925
+ design_factors=meta_sam_cond.columns.tolist(),
926
+ refit_cooks=refit_cooks,
927
+ inference=inference,
928
+ **kws_DeseqDataSet
929
+ )
930
+ dds.deseq2()
931
+ #* results
932
+ dds_explain="""
933
+ res[0]:
934
+ # X stores the count data,
935
+ # obs stores design factors,
936
+ # obsm stores sample-level data, such as "design_matrix" and "size_factors",
937
+ # varm stores gene-level data, such as "dispersions" and "LFC"."""
938
+ print(dds_explain)
939
+ #! DeseqStats
940
+ stat_res = DeseqStats(dds,**kws_DeseqStats)
941
+ stat_res.summary()
942
+ diff = stat_res.results_df.assign(padj=lambda x: x.padj.fillna(1))
943
+
944
+ # handle '0' issue, which will case inf when the later cal (e.g., log10)
945
+ diff["padj"] = diff["padj"].replace(0, 1e-10)
946
+ diff["pvalue"] = diff["pvalue"].replace(0, 1e-10)
947
+
948
+ diff["-log10(pvalue)"] = diff["pvalue"].apply(lambda x: -np.log10(x))
949
+ diff["-log10(FDR)"] = diff["padj"].apply(lambda x: -np.log10(x))
950
+ diff=diff.reset_index().rename(columns={"index": "gene"})
951
+ # sig_diff = (
952
+ # diff.query("log2FoldChange.abs()>0.585 & padj<0.05")
953
+ # .reset_index()
954
+ # .rename(columns={"index": "gene"})
955
+ # )
956
+ df_norm=pd.DataFrame(dds.layers['normed_counts'])
957
+ df_norm.index=counts_sam_gene.index
958
+ df_norm.columns=counts_sam_gene.columns
959
+ print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
960
+ return dds, diff, stat_res,df_norm
961
+
962
+ def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
963
+ """
964
+ Purpose: Converts gene identifiers using MyGene.info service
965
+ Principle: Batch query to MyGene.info API for ID conversion and annotation
966
+ Supported: 30+ identifier types and multiple species
967
+
968
+ usage:
969
+ scope_genes(df_counts.columns.tolist()[:1000], species="mouse")
970
+ """
971
+ import mygene
972
+
973
+ if scopes is None:
974
+ # copy from: https://docs.mygene.info/en/latest/doc/query_service.html#scopes
975
+ scopes = ips.fload(
976
+ "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/mygenes_fields_241022.txt",
977
+ kind="csv",
978
+ verbose=False,
979
+ )
980
+ scopes = ",".join([i.strip() for i in scopes.iloc[:, 0]])
981
+ mg = mygene.MyGeneInfo()
982
+ results = mg.querymany(
983
+ gene_list,
984
+ scopes=scopes,
985
+ fields=fields,
986
+ species=species,
987
+ )
988
+ return pd.DataFrame(results)
989
+
990
+ def get_enrichr(gene_symbol_list,
991
+ gene_sets:str,
992
+ download:bool = False,
993
+ species='Human',
994
+ dir_save="./",
995
+ plot_=False,
996
+ n_top=30,
997
+ palette=None,
998
+ check_shared=True,
999
+ figsize=(5,8),
1000
+ show_ring=False,
1001
+ xticklabels_rot=0,
1002
+ title=None,# 'KEGG'
1003
+ cutoff=0.05,
1004
+ cmap="coolwarm",
1005
+ size=5,
1006
+ **kwargs):
1007
+ """
1008
+ Purpose: Performs over-representation analysis using Enrichr database
1009
+ Principle: Hypergeometric test for gene set enrichment
1010
+ Key Operations:
1011
+ Interfaces with gseapy Enrichr API
1012
+ Supports 180+ predefined gene sets
1013
+ Provides multiple visualization options (barplot, dotplot)
1014
+ Handles species-specific gene symbols
1015
+ Visualization: Ranked bar plots and dot plots showing significance and effect size
1016
+
1017
+ Note: Enrichr uses a list of Entrez gene symbols as input.
1018
+
1019
+ """
1020
+ kws_figsets = {}
1021
+ for k_arg, v_arg in kwargs.items():
1022
+ if "figset" in k_arg:
1023
+ kws_figsets = v_arg
1024
+ kwargs.pop(k_arg, None)
1025
+ break
1026
+ species_org=species
1027
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1028
+ organisms=['Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm']
1029
+ species=ips.strcmp(species,organisms)[0]
1030
+ if species_org.lower()!= species.lower():
1031
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1032
+ if os.path.isfile(gene_sets):
1033
+ gene_sets_name=os.path.basename(gene_sets)
1034
+ gene_sets = ips.fload(gene_sets)
1035
+ else:
1036
+ lib_support_names = gp.get_library_name()
1037
+ # correct input gene_set name
1038
+ gene_sets_name=ips.strcmp(gene_sets,lib_support_names)[0]
1039
+
1040
+ # download it
1041
+ if download:
1042
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1043
+ else:
1044
+ gene_sets = gene_sets_name # 避免重复下载
1045
+ print(f"\ngene_sets get ready: {gene_sets_name}")
1046
+
1047
+ # gene symbols are uppercase
1048
+ gene_symbol_list=[str(i).upper() for i in gene_symbol_list]
1049
+
1050
+ # # check how shared genes
1051
+ if check_shared and isinstance(gene_sets, dict):
1052
+ shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False),
1053
+ ips.flatten(gene_sets,verbose=False),
1054
+ verbose=False)
1055
+
1056
+ #! enrichr
1057
+ try:
1058
+ enr = gp.enrichr(
1059
+ gene_list=gene_symbol_list,
1060
+ gene_sets=gene_sets,
1061
+ organism=species,
1062
+ outdir=None, # don't write to disk
1063
+ **kwargs
1064
+ )
1065
+ except ValueError as e:
1066
+ print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
1067
+ return None
1068
+
1069
+ results_df = enr.results
1070
+ print(f"got enrichr reslutls; shape: {results_df.shape}\n")
1071
+ results_df["-log10(Adjusted P-value)"] = -np.log10(results_df["Adjusted P-value"])
1072
+ results_df.sort_values("-log10(Adjusted P-value)", inplace=True, ascending=False)
1073
+
1074
+ if plot_:
1075
+ if palette is None:
1076
+ palette=plot.get_color(n_top, cmap=cmap)[::-1]
1077
+ #! barplot
1078
+ if n_top<5:
1079
+ height_=4
1080
+ elif 5<=n_top<10:
1081
+ height_=5
1082
+ elif 5<=n_top<10:
1083
+ height_=6
1084
+ elif 10<=n_top<15:
1085
+ height_=7
1086
+ elif 15<=n_top<20:
1087
+ height_=8
1088
+ elif 20<=n_top<30:
1089
+ height_=9
1090
+ else:
1091
+ height_=int(n_top/3)
1092
+ plt.figure(figsize=[10, height_])
1093
+
1094
+ ax1=plot.plotxy(
1095
+ data=results_df.head(n_top),
1096
+ kind_="barplot",
1097
+ x="-log10(Adjusted P-value)",
1098
+ y="Term",
1099
+ hue="Term",
1100
+ palette=palette,
1101
+ legend=None,
1102
+ )
1103
+ plot.figsets(ax=ax1, **kws_figsets)
1104
+ if dir_save:
1105
+ ips.figsave(f"{dir_save} enr_barplot.pdf")
1106
+ plt.show()
1107
+
1108
+ #! dotplot
1109
+ cutoff_curr = cutoff
1110
+ step=0.05
1111
+ cutoff_stop = 0.5
1112
+ while cutoff_curr <= cutoff_stop:
1113
+ try:
1114
+ if cutoff_curr!=cutoff:
1115
+ plt.clf()
1116
+ ax2 = gp.dotplot(enr.res2d,
1117
+ column="Adjusted P-value",
1118
+ show_ring=show_ring,
1119
+ xticklabels_rot=xticklabels_rot,
1120
+ title=title,
1121
+ cmap=cmap,
1122
+ cutoff=cutoff_curr,
1123
+ top_term=n_top,
1124
+ size=size,
1125
+ figsize=[10, height_])
1126
+ if len(ax2.collections)>=n_top:
1127
+ print(f"cutoff={cutoff_curr} done! ")
1128
+ break
1129
+ if cutoff_curr==cutoff_stop:
1130
+ break
1131
+ cutoff_curr+=step
1132
+ except Exception as e:
1133
+ cutoff_curr+=step
1134
+ print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
1135
+ ax = plt.gca()
1136
+ plot.figsets(ax=ax,**kws_figsets)
1137
+
1138
+ if dir_save:
1139
+ ips.figsave(f"{dir_save}enr_dotplot.pdf")
1140
+
1141
+ return results_df
1142
+
1143
+ def plot_enrichr(results_df,
1144
+ kind="bar",# 'barplot', 'dotplot'
1145
+ cutoff=0.05,
1146
+ show_ring=False,
1147
+ xticklabels_rot=0,
1148
+ title=None,# 'KEGG'
1149
+ cmap="coolwarm",
1150
+ n_top=10,
1151
+ size=5,
1152
+ ax=None,
1153
+ **kwargs):
1154
+ """
1155
+ Purpose: Flexible visualization of enrichment results
1156
+ Plot Types:
1157
+ Bar plots: -log10(p-value) for top terms
1158
+ Dot plots: Combined visualization of p-value and gene ratio
1159
+ Count plots: Number of overlapping genes
1160
+ Customization: Color schemes, term number, significance thresholds
1161
+ """
1162
+ kws_figsets = {}
1163
+ for k_arg, v_arg in kwargs.items():
1164
+ if "figset" in k_arg:
1165
+ kws_figsets = v_arg
1166
+ kwargs.pop(k_arg, None)
1167
+ break
1168
+ if isinstance(cmap,str):
1169
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1170
+ elif isinstance(cmap,list):
1171
+ palette=cmap
1172
+ if n_top < 5:
1173
+ height_ = 3
1174
+ elif 5 <= n_top < 10:
1175
+ height_ = 3
1176
+ elif 10 <= n_top < 15:
1177
+ height_ = 3
1178
+ elif 15 <= n_top < 20:
1179
+ height_ =4
1180
+ elif 20 <= n_top < 30:
1181
+ height_ = 5
1182
+ elif 30 <= n_top < 40:
1183
+ height_ = int(n_top / 6)
1184
+ else:
1185
+ height_ = int(n_top / 8)
1186
+
1187
+ #! barplot
1188
+ if 'bar' in kind.lower():
1189
+ if ax is None:
1190
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1191
+ ax=plot.plotxy(
1192
+ data=results_df.head(n_top),
1193
+ kind_="barplot",
1194
+ x="-log10(Adjusted P-value)",
1195
+ y="Term",
1196
+ hue="Term",
1197
+ palette=palette,
1198
+ legend=None,
1199
+ )
1200
+ plot.figsets(ax=ax, **kws_figsets)
1201
+ return ax,results_df
1202
+
1203
+ #! dotplot
1204
+ elif 'dot' in kind.lower():
1205
+ #! dotplot
1206
+ cutoff_curr = cutoff
1207
+ step=0.05
1208
+ cutoff_stop = 0.5
1209
+ while cutoff_curr <= cutoff_stop:
1210
+ try:
1211
+ if cutoff_curr!=cutoff:
1212
+ plt.clf()
1213
+ ax = gp.dotplot(results_df,
1214
+ column="Adjusted P-value",
1215
+ show_ring=show_ring,
1216
+ xticklabels_rot=xticklabels_rot,
1217
+ title=title,
1218
+ cmap=cmap,
1219
+ cutoff=cutoff_curr,
1220
+ top_term=n_top,
1221
+ size=size,
1222
+ figsize=[10, height_])
1223
+ if len(ax.collections)>=n_top:
1224
+ print(f"cutoff={cutoff_curr} done! ")
1225
+ break
1226
+ if cutoff_curr==cutoff_stop:
1227
+ break
1228
+ cutoff_curr+=step
1229
+ except Exception as e:
1230
+ cutoff_curr+=step
1231
+ print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
1232
+ plot.figsets(ax=ax, **kws_figsets)
1233
+ return ax,results_df
1234
+
1235
+ #! barplot with counts
1236
+ elif 'count' in kind.lower():
1237
+ if ax is None:
1238
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1239
+ # 从overlap中提取出个数
1240
+ results_df["Count"] = results_df["Overlap"].apply(
1241
+ lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
1242
+ df_=results_df.sort_values(by="Count", ascending=False)
1243
+
1244
+ ax=plot.plotxy(
1245
+ data=df_.head(n_top),
1246
+ kind_="barplot",
1247
+ x="Count",
1248
+ y="Term",
1249
+ hue="Term",
1250
+ palette=palette,
1251
+ legend=None,
1252
+ ax=ax
1253
+ )
1254
+
1255
+ plot.figsets(ax=ax, **kws_figsets)
1256
+ return ax,df_
1257
+
1258
+ def plot_bp_cc_mf(
1259
+ deg_gene_list,
1260
+ gene_sets=[
1261
+ "GO_Biological_Process_2023",
1262
+ "GO_Cellular_Component_2023",
1263
+ "GO_Molecular_Function_2023",
1264
+ ],
1265
+ species="human",
1266
+ download=False,
1267
+ n_top=10,
1268
+ plot_=True,
1269
+ ax=None,
1270
+ palette=plot.get_color(3,"colorblind6"),
1271
+ **kwargs,
1272
+ ):
1273
+ """
1274
+ Purpose: Integrated visualization of Gene Ontology (BP, CC, MF) enrichment
1275
+ Principle: Combines results from three GO domains into unified plot
1276
+ Usage: Comprehensive functional profiling of gene lists
1277
+ """
1278
+ def res_enrichr_2_count(res_enrichr, n_top=10):
1279
+ """把enrich resulst 提取出count,并排序"""
1280
+ res_enrichr["Count"] = res_enrichr["Overlap"].apply(
1281
+ lambda x: int(x.split("/")[0]) if isinstance(x, str) else x
1282
+ )
1283
+ res_enrichr.sort_values(by="Count", ascending=False, inplace=True)
1284
+
1285
+ return res_enrichr.head(n_top)#[["Term", "Count"]]
1286
+
1287
+ res_enrichr_BP = get_enrichr(
1288
+ deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
1289
+ )
1290
+ res_enrichr_CC = get_enrichr(
1291
+ deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
1292
+ )
1293
+ res_enrichr_MF = get_enrichr(
1294
+ deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
1295
+ )
1296
+
1297
+ df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
1298
+ df_BP["Ontology"] = ["Biological Process"] * n_top
1299
+
1300
+ df_CC = res_enrichr_2_count(res_enrichr_CC, n_top=n_top)
1301
+ df_CC["Ontology"] = ["Cellular Component"] * n_top
1302
+
1303
+ df_MF = res_enrichr_2_count(res_enrichr_MF, n_top=n_top)
1304
+ df_MF["Ontology"] = ["Molecular Function"] * n_top
1305
+
1306
+ # 合并
1307
+ df2plot = pd.concat([df_BP, df_CC, df_MF])
1308
+ n_top=n_top*3
1309
+ if n_top < 5:
1310
+ height_ = 4
1311
+ elif 5 <= n_top < 10:
1312
+ height_ = 5
1313
+ elif 10 <= n_top < 15:
1314
+ height_ = 6
1315
+ elif 15 <= n_top < 20:
1316
+ height_ = 7
1317
+ elif 20 <= n_top < 30:
1318
+ height_ = 8
1319
+ elif 30 <= n_top < 40:
1320
+ height_ = int(n_top / 4)
1321
+ else:
1322
+ height_ = int(n_top / 5)
1323
+ if ax is None:
1324
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1325
+ # 作图
1326
+ display(df2plot)
1327
+ if df2plot["Term"].tolist()[0].endswith(")"):
1328
+ df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
1329
+ if plot_:
1330
+ ax = plot.plotxy(
1331
+ data=df2plot,
1332
+ x="Count",
1333
+ y="Term",
1334
+ hue="Ontology",
1335
+ kind_="bar",
1336
+ palette=palette,
1337
+ ax=ax,
1338
+ **kwargs
1339
+ )
1340
+ return ax, df2plot
1341
+
1342
+ def get_library_name(by=None, verbose=False):
1343
+ """
1344
+ Purpose: Retrieves available gene set libraries from Enrichr
1345
+ Principle: Queries gseapy for current library availability
1346
+ Usage: Discovery of available pathway databases
1347
+ """
1348
+ lib_names=gp.get_library_name()
1349
+ if by is None:
1350
+ if verbose:
1351
+ [print(i) for i in lib_names]
1352
+ return lib_names
1353
+ else:
1354
+ return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
1355
+
1356
+
1357
+ def get_gsva(
1358
+ data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
1359
+ gene_sets: str,
1360
+ species:str="Human",
1361
+ dir_save:str="./",
1362
+ plot_:bool=False,
1363
+ n_top:int=30,
1364
+ check_shared:bool=True,
1365
+ cmap="coolwarm",
1366
+ min_size=1,
1367
+ max_size=1000,
1368
+ kcdf="Gaussian",# 'Gaussian' for continuous data
1369
+ method='gsva',
1370
+ seed=1,
1371
+ **kwargs,
1372
+ ):
1373
+ """
1374
+ Purpose: Gene Set Variation Analysis - estimates pathway activity per samplePrinciple: Non-parametric unsupervised method for estimating pathway enrichmentKey Operations:
1375
+ • Calculates enrichment scores for each sample
1376
+ • Handles continuous expression data (Gaussian kernel)
1377
+ • Supports custom and predefined gene setsOutput: Sample-by-pathway activity matrix
1378
+ """
1379
+ kws_figsets = {}
1380
+ for k_arg, v_arg in kwargs.items():
1381
+ if "figset" in k_arg:
1382
+ kws_figsets = v_arg
1383
+ kwargs.pop(k_arg, None)
1384
+ break
1385
+ species_org = species
1386
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1387
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1388
+ species = ips.strcmp(species, organisms)[0]
1389
+ if species_org.lower() != species.lower():
1390
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1391
+ if os.path.isfile(gene_sets):
1392
+ gene_sets_name = os.path.basename(gene_sets)
1393
+ gene_sets = ips.fload(gene_sets)
1394
+ else:
1395
+ lib_support_names = gp.get_library_name()
1396
+ # correct input gene_set name
1397
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1398
+ # download it
1399
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1400
+ print(f"gene_sets get ready: {gene_sets_name}")
1401
+
1402
+ # gene symbols are uppercase
1403
+ gene_symbol_list = [str(i).upper() for i in data_gene_samples.index]
1404
+ data_gene_samples.index=gene_symbol_list
1405
+ # display(data_gene_samples.head(3))
1406
+ # # check how shared genes
1407
+ if check_shared:
1408
+ ips.shared(
1409
+ ips.flatten(gene_symbol_list, verbose=False),
1410
+ ips.flatten(gene_sets, verbose=False),
1411
+ verbose=False
1412
+ )
1413
+ gsva_results = gp.gsva(
1414
+ data=data_gene_samples, # matrix should have genes as rows and samples as columns
1415
+ gene_sets=gene_sets,
1416
+ outdir=None,
1417
+ kcdf=kcdf, # 'Gaussian' for continuous data
1418
+ min_size=min_size,
1419
+ method=method,
1420
+ max_size=max_size,
1421
+ verbose=True,
1422
+ seed=seed,
1423
+ # no_plot=False,
1424
+ )
1425
+ gsva_res = gsva_results.res2d.copy()
1426
+ gsva_res["ES_abs"] = gsva_res["ES"].apply(np.abs)
1427
+ gsva_res = gsva_res.sort_values(by="ES_abs", ascending=False)
1428
+ gsva_res = (
1429
+ gsva_res.drop_duplicates(subset="Term").drop(columns="ES_abs")
1430
+ # .iloc[:80, :]
1431
+ .reset_index(drop=True)
1432
+ )
1433
+ gsva_res = gsva_res.sort_values(by="ES", ascending=False)
1434
+ if plot_:
1435
+ if gsva_res.shape[0]>=2*n_top:
1436
+ gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
1437
+ else:
1438
+ gsva_res_plot = gsva_res
1439
+ if isinstance(cmap,str):
1440
+ palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
1441
+ elif isinstance(cmap,list):
1442
+ if len(cmap)==2:
1443
+ palette = [cmap[0]]*n_top+[cmap[1]]*n_top
1444
+ else:
1445
+ palette=cmap
1446
+ # ! barplot
1447
+ if n_top < 5:
1448
+ height_ = 3
1449
+ elif 5 <= n_top < 10:
1450
+ height_ = 4
1451
+ elif 10 <= n_top < 15:
1452
+ height_ = 5
1453
+ elif 15 <= n_top < 20:
1454
+ height_ = 6
1455
+ elif 20 <= n_top < 30:
1456
+ height_ = 7
1457
+ elif 30 <= n_top < 40:
1458
+ height_ = int(n_top / 3.5)
1459
+ else:
1460
+ height_ = int(n_top / 3)
1461
+ plt.figure(figsize=[10, height_])
1462
+ ax2 = plot.plotxy(
1463
+ data=gsva_res_plot,
1464
+ x="ES",
1465
+ y="Term",
1466
+ hue="Term",
1467
+ palette=palette,
1468
+ kind_=["bar"],
1469
+ figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
1470
+ )
1471
+ # 改变labels的位置
1472
+ for i, bar in enumerate(ax2.patches):
1473
+ term = gsva_res_plot.iloc[i]["Term"]
1474
+ es_value = gsva_res_plot.iloc[i]["ES"]
1475
+
1476
+ # Positive ES values: Align y-labels to the left
1477
+ if es_value > 0:
1478
+ ax2.annotate(
1479
+ term,
1480
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1481
+ xytext=(-5, 0), # Move to the left
1482
+ textcoords="offset points",
1483
+ ha="right",
1484
+ va="center", # Align labels to the right
1485
+ fontsize=10,
1486
+ color="black",
1487
+ )
1488
+ # Negative ES values: Align y-labels to the right
1489
+ else:
1490
+ ax2.annotate(
1491
+ term,
1492
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1493
+ xytext=(5, 0), # Move to the right
1494
+ textcoords="offset points",
1495
+ ha="left",
1496
+ va="center", # Align labels to the left
1497
+ fontsize=10,
1498
+ color="black",
1499
+ )
1500
+ plot.figsets(ax=ax2, **kws_figsets)
1501
+ if dir_save:
1502
+ ips.figsave(dir_save + f"GSVA_{gene_sets_name}.pdf")
1503
+ plt.show()
1504
+ return gsva_res.reset_index(drop=True)
1505
+
1506
+ def plot_gsva(gsva_res, # output from bio.get_gsva()
1507
+ n_top=10,
1508
+ ax=None,
1509
+ x="ES",
1510
+ y="Term",
1511
+ hue="Term",
1512
+ cmap="coolwarm",
1513
+ **kwargs
1514
+ ):
1515
+ kws_figsets = {}
1516
+ for k_arg, v_arg in kwargs.items():
1517
+ if "figset" in k_arg:
1518
+ kws_figsets = v_arg
1519
+ kwargs.pop(k_arg, None)
1520
+ break
1521
+ # ! barplot
1522
+ if n_top < 5:
1523
+ height_ = 4
1524
+ elif 5 <= n_top < 10:
1525
+ height_ = 5
1526
+ elif 10 <= n_top < 15:
1527
+ height_ = 6
1528
+ elif 15 <= n_top < 20:
1529
+ height_ = 7
1530
+ elif 20 <= n_top < 30:
1531
+ height_ = 8
1532
+ elif 30 <= n_top < 40:
1533
+ height_ = int(n_top / 3.5)
1534
+ else:
1535
+ height_ = int(n_top / 3)
1536
+ if ax is None:
1537
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1538
+ gsva_res = gsva_res.sort_values(by=x, ascending=False)
1539
+
1540
+ if gsva_res.shape[0]>=2*n_top:
1541
+ gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
1542
+ else:
1543
+ gsva_res_plot = gsva_res
1544
+ if isinstance(cmap,str):
1545
+ palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
1546
+ elif isinstance(cmap,list):
1547
+ if len(cmap)==2:
1548
+ palette = [cmap[0]]*n_top+[cmap[1]]*n_top
1549
+ else:
1550
+ palette=cmap
1551
+
1552
+ ax = plot.plotxy(
1553
+ ax=ax,
1554
+ data=gsva_res_plot,
1555
+ x=x,
1556
+ y=y,
1557
+ hue=hue,
1558
+ palette=palette,
1559
+ kind_=["bar"],
1560
+ figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
1561
+ )
1562
+ # 改变labels的位置
1563
+ for i, bar in enumerate(ax.patches):
1564
+ term = gsva_res_plot.iloc[i]["Term"]
1565
+ es_value = gsva_res_plot.iloc[i]["ES"]
1566
+
1567
+ # Positive ES values: Align y-labels to the left
1568
+ if es_value > 0:
1569
+ ax.annotate(
1570
+ term,
1571
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1572
+ xytext=(-5, 0), # Move to the left
1573
+ textcoords="offset points",
1574
+ ha="right",
1575
+ va="center", # Align labels to the right
1576
+ fontsize=10,
1577
+ color="black",
1578
+ )
1579
+ # Negative ES values: Align y-labels to the right
1580
+ else:
1581
+ ax.annotate(
1582
+ term,
1583
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1584
+ xytext=(5, 0), # Move to the right
1585
+ textcoords="offset points",
1586
+ ha="left",
1587
+ va="center", # Align labels to the left
1588
+ fontsize=10,
1589
+ color="black",
1590
+ )
1591
+ plot.figsets(ax=ax, **kws_figsets)
1592
+ return ax
1593
+
1594
+ def get_prerank(
1595
+ rnk: pd.DataFrame,
1596
+ gene_sets: str,
1597
+ download: bool = False,
1598
+ species="Human",
1599
+ threads=8, # Number of CPU cores to use
1600
+ permutation_num=1000, # Number of permutations for significance
1601
+ min_size=15, # Minimum allowed number of genes from gene set also the data set. Default: 15
1602
+ max_size=500, # Maximum allowed number of genes from gene set also the data set. Defaults: 500.
1603
+ weight=1.0,# – Refer to algorithm.enrichment_score(). Default:1.
1604
+ ascending=False, #Sorting order of rankings. Default: False for descending. If None, do not sort the ranking.
1605
+ seed=1, # Seed for reproducibility
1606
+ verbose=True, # Verbosity
1607
+ dir_save="./",
1608
+ plot_=False,
1609
+ n_top=7,# only for plot
1610
+ size=5,
1611
+ figsize=(3,4),
1612
+ cutoff=0.25,
1613
+ show_ring=False,
1614
+ cmap="coolwarm",
1615
+ check_shared=True,
1616
+ **kwargs,
1617
+ ):
1618
+ """
1619
+ Purpose: Pre-ranked Gene Set Enrichment Analysis (GSEA)
1620
+ Principle: Kolmogorov-Smirnov like statistic applied to pre-ranked gene list
1621
+ Key Operations:
1622
+ Uses precomputed rankings (e.g., from DESeq2)
1623
+ Permutation testing for significance
1624
+ Identifies enriched gene sets at top and bottom of ranking
1625
+ Visualization: Enrichment plots, network diagrams, dot plots
1626
+
1627
+ Note: Enrichr uses a list of Entrez gene symbols as input.
1628
+ """
1629
+ kws_figsets = {}
1630
+ for k_arg, v_arg in kwargs.items():
1631
+ if "figset" in k_arg:
1632
+ kws_figsets = kwargs.pop(k_arg)
1633
+ break
1634
+ species_org = species
1635
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1636
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1637
+ species = ips.strcmp(species, organisms)[0]
1638
+ print(f"Please confirm sample species = '{species}', if not, select one from {organisms}")
1639
+ if isinstance(gene_sets, str) and os.path.isfile(gene_sets) :
1640
+ gene_sets_name = os.path.basename(gene_sets)
1641
+ gene_sets = ips.fload(gene_sets)
1642
+ else:
1643
+ lib_support_names = gp.get_library_name()
1644
+ # correct input gene_set name
1645
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1646
+
1647
+ # download it
1648
+ if download:
1649
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1650
+ else:
1651
+ gene_sets = gene_sets_name # 避免重复下载
1652
+ print(f"\ngene_sets get ready: {gene_sets_name}")
1653
+
1654
+ #! prerank
1655
+ try:
1656
+ # https://gseapy.readthedocs.io/en/latest/_modules/gseapy.html#prerank
1657
+ pre_res = gp.prerank(
1658
+ rnk=rnk,
1659
+ gene_sets=gene_sets,
1660
+ threads=threads, # Number of CPU cores to use
1661
+ permutation_num=permutation_num, # Number of permutations for significance
1662
+ min_size=min_size, # Minimum gene set size
1663
+ max_size=max_size, # Maximum gene set size
1664
+ weight=weight,# – Refer to algorithm.enrichment_score(). Default:1.
1665
+ ascending=ascending, #Sorting order of rankings. Default: False for descending. If None, do not sort the ranking.
1666
+ seed=seed, # Seed for reproducibility
1667
+ verbose=verbose, # Verbosity
1668
+ )
1669
+ except ValueError as e:
1670
+ print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}Jeff,check the rnk format; set 'gene name' as index, and only keep the 'score' column. This is the error message: \n{e}\n{'!'*10} Error {'!'*10}")
1671
+ return None
1672
+ df_prerank = pre_res.res2d
1673
+ if plot_:
1674
+ #! gseaplot
1675
+ # # (1) easy way
1676
+ # terms = df_prerank.Term
1677
+ # axs = pre_res.plot(terms=terms[0])
1678
+ # (2) # to make more control on the plot, use
1679
+ terms = df_prerank.Term
1680
+ axs = pre_res.plot(
1681
+ terms=terms[:n_top],
1682
+ # legend_kws={"loc": (1.2, 0)}, # set the legend loc
1683
+ # show_ranking=True, # whether to show the second yaxis
1684
+ figsize=(min(figsize),max(figsize)),
1685
+ )
1686
+ f_name_tmp=str(gene_sets)[:20] if len(str(gene_sets))>=20 else str(gene_sets)
1687
+ ips.figsave(dir_save + f"prerank_gseaplot_{f_name_tmp}.pdf")
1688
+
1689
+ ## plot single prerank
1690
+ terms_ = pre_res.res2d.Term
1691
+ try:
1692
+ for i in range(n_top*2):
1693
+ axs_ = pre_res.plot(terms=terms_[i])
1694
+ ips.figsave(os.path.join(ips.mkdir(dir_save, "fig_prerank_single"),f"Top_{str(i+1)}_{terms_[i].replace("/","_")}.pdf"))
1695
+ except Exception as e:
1696
+ print(e)
1697
+
1698
+ #!dotplot
1699
+ from gseapy import dotplot
1700
+
1701
+ # to save figure, make sure that ``ofname`` is not None
1702
+ ax = dotplot(
1703
+ df_prerank,
1704
+ column="NOM p-val", # FDR q-val",
1705
+ cmap=cmap,
1706
+ size=size,
1707
+ figsize=(max(figsize),min(figsize)),
1708
+ cutoff=cutoff,
1709
+ show_ring=show_ring,
1710
+ )
1711
+ ips.figsave(dir_save + f"prerank_dotplot_{f_name_tmp}.pdf")
1712
+
1713
+ #! network plot
1714
+ from gseapy import enrichment_map
1715
+ import networkx as nx
1716
+
1717
+ for top_term in range(5, 50):
1718
+ try:
1719
+ # return two dataframe
1720
+ nodes, edges = enrichment_map(
1721
+ df=df_prerank,
1722
+ columns="FDR q-val",
1723
+ cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1724
+ top_term=top_term,
1725
+ )
1726
+ # build graph
1727
+ G = nx.from_pandas_edgelist(
1728
+ edges,
1729
+ source="src_idx",
1730
+ target="targ_idx",
1731
+ edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
1732
+ )
1733
+ # to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
1734
+ if len(list(nodes.Hits_ratio)) == len(G.nodes):
1735
+ node_sizes = list(nodes.Hits_ratio * 1000)
1736
+ else:
1737
+ raise ValueError(
1738
+ "The size of node_size list does not match the number of nodes in the graph."
1739
+ )
1740
+
1741
+ layout = "circular"
1742
+ fig, ax = plt.subplots(figsize=(max(figsize),max(figsize)))
1743
+ if layout == "spring":
1744
+ pos = nx.layout.spring_layout(G)
1745
+ elif layout == "circular":
1746
+ pos = nx.layout.circular_layout(G)
1747
+ elif layout == "shell":
1748
+ pos = nx.layout.shell_layout(G)
1749
+ elif layout == "spectral":
1750
+ pos = nx.layout.spectral_layout(G)
1751
+
1752
+ # node_size = nx.get_node_attributes()
1753
+ # draw node
1754
+ nx.draw_networkx_nodes(
1755
+ G,
1756
+ pos=pos,
1757
+ cmap=plt.cm.RdYlBu,
1758
+ node_color=list(nodes.NES),
1759
+ node_size=list(nodes.Hits_ratio * 1000),
1760
+ )
1761
+ # draw node label
1762
+ nx.draw_networkx_labels(
1763
+ G,
1764
+ pos=pos,
1765
+ labels=nodes.Term.to_dict(),
1766
+ font_size=8,
1767
+ verticalalignment="bottom",
1768
+ )
1769
+ # draw edge
1770
+ edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
1771
+ nx.draw_networkx_edges(
1772
+ G,
1773
+ pos=pos,
1774
+ width=list(map(lambda x: x * 10, edge_weight)),
1775
+ edge_color="#CDDBD4",
1776
+ )
1777
+ ax.set_axis_off()
1778
+ print(f"{gene_sets}(top_term={top_term})")
1779
+ plot.figsets(title=f"{gene_sets}(top_term={top_term})")
1780
+ ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
1781
+ break
1782
+ except:
1783
+ print(f"not work {top_term}")
1784
+ return df_prerank, pre_res
1785
+ def plot_prerank(
1786
+ results_df,
1787
+ kind="bar", # 'barplot', 'dotplot'
1788
+ cutoff=0.25,
1789
+ show_ring=False,
1790
+ xticklabels_rot=0,
1791
+ title=None, # 'KEGG'
1792
+ cmap="coolwarm",
1793
+ n_top=10,
1794
+ size=5, # when size is None in network, by "NES"
1795
+ facecolor=None,# default by "NES"
1796
+ linewidth=None,# default by "NES"
1797
+ linecolor=None,# default by "NES"
1798
+ linealpha=None, # default by "NES"
1799
+ alpha=None,# default by "NES"
1800
+ ax=None,
1801
+ **kwargs,
1802
+ ):
1803
+ kws_figsets = {}
1804
+ for k_arg, v_arg in kwargs.items():
1805
+ if "figset" in k_arg:
1806
+ kws_figsets = v_arg
1807
+ kwargs.pop(k_arg, None)
1808
+ break
1809
+ if isinstance(cmap, str):
1810
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1811
+ elif isinstance(cmap, list):
1812
+ palette = cmap
1813
+ if n_top < 5:
1814
+ height_ = 4
1815
+ elif 5 <= n_top < 10:
1816
+ height_ = 5
1817
+ elif 10 <= n_top < 15:
1818
+ height_ = 6
1819
+ elif 15 <= n_top < 20:
1820
+ height_ = 7
1821
+ elif 20 <= n_top < 30:
1822
+ height_ = 8
1823
+ elif 30 <= n_top < 40:
1824
+ height_ = int(n_top / 5)
1825
+ else:
1826
+ height_ = int(n_top / 6)
1827
+ results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
1828
+ results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
1829
+ #! barplot
1830
+ if "bar" in kind.lower():
1831
+ df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
1832
+ if ax is None:
1833
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1834
+ ax = plot.plotxy(
1835
+ data=df_.head(n_top),
1836
+ kind_="barplot",
1837
+ x="-log10(Adjusted P-value)",
1838
+ y="Term",
1839
+ hue="Term",
1840
+ palette=palette,
1841
+ legend=None,
1842
+ )
1843
+ plot.figsets(ax=ax, **kws_figsets)
1844
+ return ax, df_
1845
+
1846
+ #! dotplot
1847
+ elif "dot" in kind.lower():
1848
+ #! dotplot
1849
+ cutoff_curr = cutoff
1850
+ step = 0.05
1851
+ cutoff_stop = 0.5
1852
+ while cutoff_curr <= cutoff_stop:
1853
+ try:
1854
+ if cutoff_curr != cutoff:
1855
+ plt.clf()
1856
+ ax = gp.dotplot(
1857
+ results_df,
1858
+ column="NOM p-val",
1859
+ show_ring=show_ring,
1860
+ xticklabels_rot=xticklabels_rot,
1861
+ title=title,
1862
+ cmap=cmap,
1863
+ cutoff=cutoff_curr,
1864
+ top_term=n_top,
1865
+ size=size,
1866
+ figsize=[10, height_],
1867
+ )
1868
+ if len(ax.collections) >= n_top:
1869
+ print(f"cutoff={cutoff_curr} done! ")
1870
+ break
1871
+ if cutoff_curr == cutoff_stop:
1872
+ break
1873
+ cutoff_curr += step
1874
+ except Exception as e:
1875
+ cutoff_curr += step
1876
+ print(
1877
+ f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
1878
+ )
1879
+ plot.figsets(ax=ax, **kws_figsets)
1880
+ return ax, results_df
1881
+
1882
+ #! barplot with counts
1883
+ elif "co" in kind.lower():
1884
+ if ax is None:
1885
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1886
+ # 从overlap中提取出个数
1887
+ df_ = results_df.sort_values(by="Count", ascending=False)
1888
+ ax = plot.plotxy(
1889
+ data=df_.head(n_top),
1890
+ kind_="barplot",
1891
+ x="Count",
1892
+ y="Term",
1893
+ hue="Term",
1894
+ palette=palette,
1895
+ legend=None,
1896
+ ax=ax,
1897
+ **kwargs,
1898
+ )
1899
+
1900
+ plot.figsets(ax=ax, **kws_figsets)
1901
+ return ax, df_
1902
+ #! scatter with counts
1903
+ elif "sca" in kind.lower():
1904
+ if isinstance(cmap, str):
1905
+ palette = plot.get_color(n_top, cmap=cmap)
1906
+ elif isinstance(cmap, list):
1907
+ palette = cmap
1908
+ if ax is None:
1909
+ _, ax = plt.subplots(1, 1, figsize=[10, height_])
1910
+ # 从overlap中提取出个数
1911
+ df_ = results_df.sort_values(by="Count", ascending=False)
1912
+ ax = plot.plotxy(
1913
+ data=df_.head(n_top),
1914
+ kind_="scatter",
1915
+ x="Count",
1916
+ y="Term",
1917
+ hue="Count",
1918
+ size="Count",
1919
+ sizes=[10,50],
1920
+ palette=palette,
1921
+ legend=None,
1922
+ ax=ax,
1923
+ **kwargs,
1924
+ )
1925
+
1926
+ plot.figsets(ax=ax, **kws_figsets)
1927
+ return ax, df_
1928
+ elif "net" in kind.lower():
1929
+ #! network plot
1930
+ from gseapy import enrichment_map
1931
+ import networkx as nx
1932
+ from matplotlib import cm
1933
+ # try:
1934
+ if cutoff>=1 or cutoff is None:
1935
+ print(f"cutoff is {cutoff} => Without applying filter")
1936
+ nodes, edges = enrichment_map(
1937
+ df=results_df,
1938
+ columns="NOM p-val",
1939
+ cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1940
+ top_term=n_top,
1941
+ )
1942
+ else:
1943
+ cutoff_curr = cutoff
1944
+ step = 0.05
1945
+ cutoff_stop = 1.0
1946
+ while cutoff_curr <= cutoff_stop:
1947
+ try:
1948
+ # return two dataframe
1949
+ nodes, edges = enrichment_map(
1950
+ df=results_df,
1951
+ columns="NOM p-val",
1952
+ cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
1953
+ top_term=n_top,
1954
+ )
1955
+
1956
+ if nodes.shape[0] >= n_top:
1957
+ print(f"cutoff={cutoff_curr} done! ")
1958
+ break
1959
+ if cutoff_curr == cutoff_stop:
1960
+ break
1961
+ cutoff_curr += step
1962
+ except Exception as e:
1963
+ cutoff_curr += step
1964
+ print(
1965
+ f"{e}: trying cutoff={cutoff_curr}"
1966
+ )
1967
+
1968
+ print("size: by 'NES'") if size is None else print("")
1969
+ print("linewidth: by 'NES'") if linewidth is None else print("")
1970
+ print("linecolor: by 'NES'") if linecolor is None else print("")
1971
+ print("linealpha: by 'NES'") if linealpha is None else print("")
1972
+ print("facecolor: by 'NES'") if facecolor is None else print("")
1973
+ print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
1974
+ edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
1975
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1976
+ G,ax=plot_ppi(
1977
+ interactions=edges,
1978
+ player1="src_name",
1979
+ player2="targ_name",
1980
+ weight="jaccard_coef",
1981
+ size=[
1982
+ node["NES"] * 300 for _, node in nodes.iterrows()
1983
+ ] if size is None else size, # size nodes by NES
1984
+ facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
1985
+ linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
1986
+ linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
1987
+ linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
1988
+ alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
1989
+ **kwargs
1990
+ )
1991
+ # except Exception as e:
1992
+ # print(f"not work {n_top},{e}")
1993
+ return ax, G, nodes, edges
1994
+
1995
+
1996
+ #! https://string-db.org/help/api/
1997
+
1998
+ import pandas as pd
1999
+ import requests
2000
+ import networkx as nx
2001
+ import matplotlib.pyplot as plt
2002
+ from io import StringIO
2003
+ from py2ls import ips
2004
+
2005
+
2006
+ def get_ppi(
2007
+ target_genes:list,
2008
+ species:int=9606, # "human"
2009
+ ci:float=0.1, # int 1~1000
2010
+ max_nodes:int=50,
2011
+ base_url:str="https://string-db.org",
2012
+ gene_mapping_api:str="/api/json/get_string_ids?",
2013
+ interaction_api:str="/api/tsv/network?",
2014
+ ):
2015
+ """
2016
+ Purpose: Retrieves protein-protein interaction data from STRING databasePrinciple: API-based query to STRINGdb for experimentally validated and predicted interactionsKey Operations:
2017
+ • Maps gene symbols to STRING identifiers
2018
+ • Filters by confidence score and species
2019
+ • Returns comprehensive interaction data with multiple evidence typesEvidence Scores: Neighborhood, fusion, coexpression, experimental, database, textmining
2020
+
2021
+ Generate a Protein-Protein Interaction (PPI) network using STRINGdb data.
2022
+
2023
+ return:
2024
+ the STRING protein-protein interaction (PPI) data, which contains information about
2025
+ predicted and experimentally validated associations between proteins.
2026
+
2027
+ stringId_A and stringId_B: Unique identifiers for the interacting proteins based on the
2028
+ STRING database.
2029
+ preferredName_A and preferredName_B: Standard gene names for the interacting proteins.
2030
+ ncbiTaxonId: The taxon ID (9606 for humans).
2031
+ score: A combined score reflecting the overall confidence of the interaction, which aggregates different sources of evidence.
2032
+
2033
+ nscore, fscore, pscore, ascore, escore, dscore, tscore: These are sub-scores representing the confidence in the interaction based on various evidence types:
2034
+ - nscore: Neighborhood score, based on genes located near each other in the genome.
2035
+ - fscore: Fusion score, based on gene fusions in other genomes.
2036
+ - pscore: Phylogenetic profile score, based on co-occurrence across different species.
2037
+ - ascore: Coexpression score, reflecting the likelihood of coexpression.
2038
+ - escore: Experimental score, based on experimental evidence.
2039
+ - dscore: Database score, from curated databases.
2040
+ - tscore: Text-mining score, from literature co-occurrence.
2041
+
2042
+ Higher score values (closer to 1) indicate stronger evidence for an interaction.
2043
+ - Combined score: Useful for ranking interactions based on overall confidence. A score >0.7 is typically considered high-confidence.
2044
+ - Sub-scores: Interpret the types of evidence supporting the interaction. For instance:
2045
+ - High ascore indicates strong evidence of coexpression.
2046
+ - High escore suggests experimental validation.
2047
+
2048
+ """
2049
+ print("check api: https://string-db.org/help/api/")
2050
+
2051
+ # 将species转化为taxon_id
2052
+ if isinstance(species,str):
2053
+ print(species)
2054
+ species=list(get_taxon_id(species).values())[0]
2055
+ print(species)
2056
+
2057
+
2058
+ string_api_url = base_url + gene_mapping_api
2059
+ interaction_api_url = base_url + interaction_api
2060
+ # Map gene symbols to STRING IDs
2061
+ mapped_genes = {}
2062
+ for gene in target_genes:
2063
+ params = {"identifiers": gene, "species": species, "limit": 1}
2064
+ response = requests.get(string_api_url, params=params)
2065
+ if response.status_code == 200:
2066
+ try:
2067
+ json_data = response.json()
2068
+ if json_data:
2069
+ mapped_genes[gene] = json_data[0]["stringId"]
2070
+ except ValueError:
2071
+ print(
2072
+ f"Failed to decode JSON for gene {gene}. Response: {response.text}"
2073
+ )
2074
+ else:
2075
+ print(
2076
+ f"Failed to fetch data for gene {gene}. Status code: {response.status_code}"
2077
+ )
2078
+ if not mapped_genes:
2079
+ print("No mapped genes found in STRING database.")
2080
+ return None
2081
+
2082
+ # Retrieve PPI data from STRING API
2083
+ string_ids = "%0d".join(mapped_genes.values())
2084
+ params = {
2085
+ "identifiers": string_ids,
2086
+ "species": species,
2087
+ "required_score": int(ci * 1000),
2088
+ "limit": max_nodes,
2089
+ }
2090
+ response = requests.get(interaction_api_url, params=params)
2091
+
2092
+ if response.status_code == 200:
2093
+ try:
2094
+ interactions = pd.read_csv(StringIO(response.text), sep="\t")
2095
+ except Exception as e:
2096
+ print("Error reading the interaction data:", e)
2097
+ print("Response content:", response.text)
2098
+ return None
2099
+ else:
2100
+ print(
2101
+ f"Failed to retrieve interaction data. Status code: {response.status_code}"
2102
+ )
2103
+ print("Response content:", response.text)
2104
+ return None
2105
+ display(interactions.head())
2106
+ # Filter interactions by ci score
2107
+ if "score" in interactions.columns:
2108
+ interactions = interactions[interactions["score"] >= ci]
2109
+ if interactions.empty:
2110
+ print("No interactions found with the specified confidence.")
2111
+ return None
2112
+ else:
2113
+ print("The 'score' column is missing from the retrieved data. Unable to filter by confidence interval.")
2114
+ if "fdr" in interactions.columns:
2115
+ interactions=interactions.sort_values(by="fdr",ascending=False)
2116
+ return interactions
2117
+ # * usage
2118
+ # interactions = get_ppi(target_genes, ci=0.0001)
2119
+
2120
+ def plot_ppi(
2121
+ interactions,
2122
+ player1="preferredName_A",
2123
+ player2="preferredName_B",
2124
+ weight="score",
2125
+ n_layers=None, # Number of concentric layers
2126
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
2127
+ dist_node = 10, # Distance between each rank of circles
2128
+ layout="degree",
2129
+ size=None,#700,
2130
+ sizes=(50,500),# min and max of size
2131
+ facecolor="skyblue",
2132
+ cmap='coolwarm',
2133
+ edgecolor="k",
2134
+ edgelinewidth=1.5,
2135
+ alpha=.5,
2136
+ alphas=(0.1, 1.0),# min and max of alpha
2137
+ marker="o",
2138
+ node_hideticks=True,
2139
+ linecolor="gray",
2140
+ line_cmap='coolwarm',
2141
+ linewidth=1.5,
2142
+ linewidths=(0.5,5),# min and max of linewidth
2143
+ linealpha=1.0,
2144
+ linealphas=(0.1,1.0),# min and max of linealpha
2145
+ linestyle="-",
2146
+ line_arrowstyle='-',
2147
+ fontsize=10,
2148
+ fontcolor="k",
2149
+ ha:str="center",
2150
+ va:str="center",
2151
+ figsize=(12, 10),
2152
+ k_value=0.3,
2153
+ bgcolor="w",
2154
+ dir_save="./ppi_network.html",
2155
+ physics=True,
2156
+ notebook=False,
2157
+ scale=1,
2158
+ ax=None,
2159
+ **kwargs
2160
+ ):
2161
+ """
2162
+ Purpose: Network visualization of protein-protein interactions
2163
+ Principle: NetworkX and PyVis for interactive and static network visualization
2164
+ Layout Options:
2165
+
2166
+ Spring: Force-directed layout
2167
+ Circular: Concentric circles
2168
+ Degree-based: Nodes positioned by connectivity
2169
+ Customization: Node size/color by centrality, edge thickness by confidence
2170
+
2171
+ Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
2172
+ """
2173
+ from pyvis.network import Network
2174
+ import networkx as nx
2175
+ from IPython.display import IFrame
2176
+ from matplotlib.colors import Normalize
2177
+ from matplotlib import cm
2178
+ # Check for required columns in the DataFrame
2179
+ for col in [player1, player2, weight]:
2180
+ if col not in interactions.columns:
2181
+ raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
2182
+ interactions.sort_values(by=[weight], inplace=True)
2183
+ # Initialize Pyvis network
2184
+ net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
2185
+ net.force_atlas_2based(
2186
+ gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
2187
+ )
2188
+ net.toggle_physics(physics)
2189
+
2190
+ kws_figsets = {}
2191
+ for k_arg, v_arg in kwargs.items():
2192
+ if "figset" in k_arg:
2193
+ kws_figsets = v_arg
2194
+ kwargs.pop(k_arg, None)
2195
+ break
2196
+
2197
+ # Create a NetworkX graph from the interaction data
2198
+ G = nx.Graph()
2199
+ for _, row in interactions.iterrows():
2200
+ G.add_edge(row[player1], row[player2], weight=row[weight])
2201
+ # G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
2202
+
2203
+
2204
+ # Calculate node degrees
2205
+ degrees = dict(G.degree())
2206
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
2207
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
2208
+
2209
+ if not ips.isa(facecolor, 'color'):
2210
+ print("facecolor: based on degrees")
2211
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
2212
+ num_nodes = G.number_of_nodes()
2213
+ #* size
2214
+ # Set properties based on degrees
2215
+ if not isinstance(size, (int,float,list)):
2216
+ print("size: based on degrees")
2217
+ size = [deg * 50 for deg in degrees.values()] # Scale sizes
2218
+ size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
2219
+ if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
2220
+ # Normalize sizes
2221
+ min_size, max_size = sizes # Use sizes tuple for min and max values
2222
+ min_degree, max_degree = min(size), max(size)
2223
+ if max_degree > min_degree: # Avoid division by zero
2224
+ size = [
2225
+ min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
2226
+ for sz in size
2227
+ ]
2228
+ else:
2229
+ # If all values are the same, set them to a default of the midpoint
2230
+ size = [(min_size + max_size) / 2] * len(size)
2231
+
2232
+ #* facecolor
2233
+ facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
2234
+ # * facealpha
2235
+ if isinstance(alpha, list):
2236
+ alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
2237
+ min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
2238
+ if len(alpha) > 0:
2239
+ # Normalize alpha based on the specified min and max
2240
+ min_alpha, max_alpha = min(alpha), max(alpha)
2241
+ if max_alpha > min_alpha: # Avoid division by zero
2242
+ alpha = [
2243
+ min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
2244
+ for ea in alpha
2245
+ ]
2246
+ else:
2247
+ # If all alpha values are the same, set them to the average of min and max
2248
+ alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
2249
+ else:
2250
+ # Default to a full opacity if no edges are provided
2251
+ alpha = [1.0] * num_nodes
2252
+ else:
2253
+ # If alpha is a single value, convert it to a list and normalize it
2254
+ alpha = [alpha] * num_nodes # Adjust based on alphas
2255
+
2256
+ for i, node in enumerate(G.nodes()):
2257
+ net.add_node(
2258
+ node,
2259
+ label=node,
2260
+ size=size[i],
2261
+ color=facecolor[i],
2262
+ alpha=alpha[i],
2263
+ font={"size": fontsize, "color": fontcolor},
2264
+ )
2265
+ print(f'nodes number: {i+1}')
2266
+
2267
+ for edge in G.edges(data=True):
2268
+ net.add_edge(
2269
+ edge[0],
2270
+ edge[1],
2271
+ weight=edge[2]["weight"],
2272
+ color=edgecolor,
2273
+ width=edgelinewidth * edge[2]["weight"],
2274
+ )
2275
+
2276
+ layouts = [
2277
+ "spring",
2278
+ "circular",
2279
+ "kamada_kawai",
2280
+ "random",
2281
+ "shell",
2282
+ "planar",
2283
+ "spiral",
2284
+ "degree"
2285
+ ]
2286
+ layout = ips.strcmp(layout, layouts)[0]
2287
+ print(f"layout:{layout}, or select one in {layouts}")
2288
+
2289
+ # Choose layout
2290
+ if layout == "spring":
2291
+ pos = nx.spring_layout(G, k=k_value)
2292
+ elif layout == "circular":
2293
+ pos = nx.circular_layout(G)
2294
+ elif layout == "kamada_kawai":
2295
+ pos = nx.kamada_kawai_layout(G)
2296
+ elif layout == "spectral":
2297
+ pos = nx.spectral_layout(G)
2298
+ elif layout == "random":
2299
+ pos = nx.random_layout(G)
2300
+ elif layout == "shell":
2301
+ pos = nx.shell_layout(G)
2302
+ elif layout == "planar":
2303
+ if nx.check_planarity(G)[0]:
2304
+ pos = nx.planar_layout(G)
2305
+ else:
2306
+ print("Graph is not planar; switching to spring layout.")
2307
+ pos = nx.spring_layout(G, k=k_value)
2308
+ elif layout == "spiral":
2309
+ pos = nx.spiral_layout(G)
2310
+ elif layout=='degree':
2311
+ # Calculate node degrees and sort nodes by degree
2312
+ degrees = dict(G.degree())
2313
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
2314
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
2315
+ colormap = cm.get_cmap(cmap)
2316
+
2317
+ # Create positions for concentric circles based on n_layers and n_rank
2318
+ pos = {}
2319
+ n_layers=len(n_rank)+1 if n_layers is None else n_layers
2320
+ for rank_index in range(n_layers):
2321
+ if rank_index < len(n_rank):
2322
+ nodes_per_rank = n_rank[rank_index]
2323
+ rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
2324
+ else:
2325
+ # 随机打乱剩余节点的顺序
2326
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
2327
+ random_indices = np.random.permutation(len(remaining_nodes))
2328
+ rank_nodes = [remaining_nodes[i] for i in random_indices]
2329
+
2330
+ radius = (rank_index + 1) * dist_node # Radius for this rank
2331
+
2332
+ # Arrange nodes in a circle for the current rank
2333
+ for i, (node, degree) in enumerate(rank_nodes):
2334
+ angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
2335
+ pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
2336
+
2337
+ else:
2338
+ print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
2339
+ pos = nx.spring_layout(G, k=k_value)
2340
+
2341
+ for node, (x, y) in pos.items():
2342
+ net.get_node(node)["x"] = x * scale
2343
+ net.get_node(node)["y"] = y * scale
2344
+
2345
+ # If ax is None, use plt.gca()
2346
+ if ax is None:
2347
+ fig, ax = plt.subplots(1,1,figsize=figsize)
2348
+
2349
+ # Draw nodes, edges, and labels with customization options
2350
+ nx.draw_networkx_nodes(
2351
+ G,
2352
+ pos,
2353
+ ax=ax,
2354
+ node_size=size,
2355
+ node_color=facecolor,
2356
+ linewidths=edgelinewidth,
2357
+ edgecolors=edgecolor,
2358
+ alpha=alpha,
2359
+ hide_ticks=node_hideticks,
2360
+ node_shape=marker
2361
+ )
2362
+
2363
+ #* linewidth
2364
+ if not isinstance(linewidth, list):
2365
+ linewidth = [linewidth] * G.number_of_edges()
2366
+ else:
2367
+ linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
2368
+ # Normalize linewidth if it is a list
2369
+ if isinstance(linewidth, list):
2370
+ min_linewidth, max_linewidth = min(linewidth), max(linewidth)
2371
+ vmin, vmax = linewidths # Use linewidths tuple for min and max values
2372
+ if max_linewidth > min_linewidth: # Avoid division by zero
2373
+ # Scale between vmin and vmax
2374
+ linewidth = [
2375
+ vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
2376
+ for lw in linewidth
2377
+ ]
2378
+ else:
2379
+ # If all values are the same, set them to a default of the midpoint
2380
+ linewidth = [(vmin + vmax) / 2] * len(linewidth)
2381
+ else:
2382
+ # If linewidth is a single value, convert it to a list of that value
2383
+ linewidth = [linewidth] * G.number_of_edges()
2384
+ #* linecolor
2385
+ if not isinstance(linecolor, str):
2386
+ weights = [G[u][v]["weight"] for u, v in G.edges()]
2387
+ norm = Normalize(vmin=min(weights), vmax=max(weights))
2388
+ colormap = cm.get_cmap(line_cmap)
2389
+ linecolor = [colormap(norm(weight)) for weight in weights]
2390
+ else:
2391
+ linecolor = [linecolor] * G.number_of_edges()
2392
+
2393
+ # * linealpha
2394
+ if isinstance(linealpha, list):
2395
+ linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
2396
+ min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
2397
+ if len(linealpha) > 0:
2398
+ min_linealpha, max_linealpha = min(linealpha), max(linealpha)
2399
+ if max_linealpha > min_linealpha: # Avoid division by zero
2400
+ linealpha = [
2401
+ min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
2402
+ for ea in linealpha
2403
+ ]
2404
+ else:
2405
+ linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
2406
+ else:
2407
+ linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
2408
+ else:
2409
+ linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
2410
+ nx.draw_networkx_edges(
2411
+ G,
2412
+ pos,
2413
+ ax=ax,
2414
+ edge_color=linecolor,
2415
+ width=linewidth,
2416
+ style=linestyle,
2417
+ arrowstyle=line_arrowstyle,
2418
+ alpha=linealpha
2419
+ )
2420
+
2421
+ nx.draw_networkx_labels(
2422
+ G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
2423
+ )
2424
+ plot.figsets(ax=ax,**kws_figsets)
2425
+ ax.axis("off")
2426
+ if dir_save:
2427
+ if not os.path.basename(dir_save):
2428
+ dir_save="_.html"
2429
+ net.write_html(dir_save)
2430
+ nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
2431
+ print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
2432
+ ips.figsave(dir_save.replace(".html",".pdf"))
2433
+ return G,ax
2434
+
2435
+
2436
+ # * usage:
2437
+ # G, ax = bio.plot_ppi(
2438
+ # interactions,
2439
+ # player1="preferredName_A",
2440
+ # player2="preferredName_B",
2441
+ # weight="score",
2442
+ # # size="auto",
2443
+ # # size=interactions["score"].tolist(),
2444
+ # # layout="circ",
2445
+ # n_rank=[5, 10, 15],
2446
+ # dist_node=100,
2447
+ # alpha=0.6,
2448
+ # linecolor="0.8",
2449
+ # linewidth=1,
2450
+ # figsize=(8, 8.5),
2451
+ # cmap="jet",
2452
+ # edgelinewidth=0.5,
2453
+ # # facecolor="#FF5F57",
2454
+ # fontsize=10,
2455
+ # # fontcolor="b",
2456
+ # # edgecolor="r",
2457
+ # # scale=100,
2458
+ # # physics=False,
2459
+ # figsets=dict(title="ppi networks"),
2460
+ # )
2461
+ # figsave("./ppi_network.pdf")
2462
+
2463
+ def top_ppi(interactions, n_top=10):
2464
+ """
2465
+ Purpose: Identifies key proteins in interaction networks using centrality measures
2466
+ Centrality Metrics:
2467
+
2468
+ Degree: Number of direct connections
2469
+ Betweenness: Bridge positions in network
2470
+ Usage: Prioritization of biologically important proteins
2471
+
2472
+ Analyzes protein-protein interactions (PPIs) to identify key proteins based on
2473
+ degree and betweenness centrality.
2474
+
2475
+ Parameters:
2476
+ interactions (pd.DataFrame): DataFrame containing PPI data with columns
2477
+ ['preferredName_A', 'preferredName_B', 'score'].
2478
+
2479
+ Returns:
2480
+ dict: A dictionary containing the top key proteins by degree and betweenness centrality.
2481
+ """
2482
+
2483
+ # Create a NetworkX graph from the interaction data
2484
+ G = nx.Graph()
2485
+ for _, row in interactions.iterrows():
2486
+ G.add_edge(row["preferredName_A"], row["preferredName_B"], weight=row["score"])
2487
+
2488
+ # Calculate Degree Centrality
2489
+ degree_centrality = G.degree()
2490
+ key_proteins_degree = sorted(degree_centrality, key=lambda x: x[1], reverse=True)
2491
+
2492
+ # Calculate Betweenness Centrality
2493
+ betweenness_centrality = nx.betweenness_centrality(G)
2494
+ key_proteins_betweenness = sorted(
2495
+ betweenness_centrality.items(), key=lambda x: x[1], reverse=True
2496
+ )
2497
+ print(
2498
+ {
2499
+ "Top 10 Key Proteins by Degree Centrality": key_proteins_degree[:10],
2500
+ "Top 10 Key Proteins by Betweenness Centrality": key_proteins_betweenness[
2501
+ :10
2502
+ ],
2503
+ }
2504
+ )
2505
+ # Return the top n_top key proteins
2506
+ if n_top == "all":
2507
+ return key_proteins_degree, key_proteins_betweenness
2508
+ else:
2509
+ return key_proteins_degree[:n_top], key_proteins_betweenness[:n_top]
2510
+
2511
+
2512
+ # * usage: top_ppi(interactions)
2513
+ # top_ppi(interactions, n_top="all")
2514
+ # top_ppi(interactions, n_top=10)
2515
+
2516
+
2517
+
2518
+ species_dict = {
2519
+ "Human": "Homo sapiens",
2520
+ "House mouse": "Mus musculus",
2521
+ "Zebrafish": "Danio rerio",
2522
+ "Norway rat": "Rattus norvegicus",
2523
+ "Fruit fly": "Drosophila melanogaster",
2524
+ "Baker's yeast": "Saccharomyces cerevisiae",
2525
+ "Nematode": "Caenorhabditis elegans",
2526
+ "Chicken": "Gallus gallus",
2527
+ "Cattle": "Bos taurus",
2528
+ "Rice": "Oryza sativa",
2529
+ "Thale cress": "Arabidopsis thaliana",
2530
+ "Guinea pig": "Cavia porcellus",
2531
+ "Domestic dog": "Canis lupus familiaris",
2532
+ "Domestic cat": "Felis catus",
2533
+ "Horse": "Equus caballus",
2534
+ "Domestic pig": "Sus scrofa",
2535
+ "African clawed frog": "Xenopus laevis",
2536
+ "Great white shark": "Carcharodon carcharias",
2537
+ "Common chimpanzee": "Pan troglodytes",
2538
+ "Rhesus macaque": "Macaca mulatta",
2539
+ "Water buffalo": "Bubalus bubalis",
2540
+ "Lettuce": "Lactuca sativa",
2541
+ "Tomato": "Solanum lycopersicum",
2542
+ "Maize": "Zea mays",
2543
+ "Cucumber": "Cucumis sativus",
2544
+ "Common grape vine": "Vitis vinifera",
2545
+ "Scots pine": "Pinus sylvestris",
2546
+ }
2547
+
2548
+
2549
+ def get_taxon_id(species_list):
2550
+ """
2551
+ Purpose: Converts species names to NCBI taxonomy IDs
2552
+ Principle: BioPython Entrez queries to taxonomy database
2553
+ Supported: 25+ common model organisms
2554
+ Convert species names to their corresponding taxon ID codes.
2555
+
2556
+ Parameters:
2557
+ - species_list: List of species names (strings).
2558
+
2559
+ Returns:
2560
+ - dict: A dictionary with species names as keys and their taxon IDs as values.
2561
+ """
2562
+ from Bio import Entrez
2563
+
2564
+ if not isinstance(species_list, list):
2565
+ species_list = [species_list]
2566
+ species_list = [
2567
+ species_dict[ips.strcmp(i, ips.flatten(list(species_dict.keys())))[0]]
2568
+ for i in species_list
2569
+ ]
2570
+ taxon_dict = {}
2571
+
2572
+ for species in species_list:
2573
+ try:
2574
+ search_handle = Entrez.esearch(db="taxonomy", term=species)
2575
+ search_results = Entrez.read(search_handle)
2576
+ search_handle.close()
2577
+
2578
+ # Get the taxon ID
2579
+ if search_results["IdList"]:
2580
+ taxon_id = search_results["IdList"][0]
2581
+ taxon_dict[species] = int(taxon_id)
2582
+ else:
2583
+ taxon_dict[species] = None # Not found
2584
+ except Exception as e:
2585
+ print(f"Error occurred for species '{species}': {e}")
2586
+ taxon_dict[species] = None # Error in processing
2587
+ return taxon_dict
2588
+
2589
+
2590
+ # # * usage: get_taxon_id("human")
2591
+ # species_names = ["human", "nouse", "rat"]
2592
+ # taxon_ids = get_taxon_id(species_names)
2593
+ # print(taxon_ids)
2594
+
2595
+