py2ls 0.1.10.12__py3-none-any.whl → 0.2.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of py2ls might be problematic. Click here for more details.
- py2ls/.DS_Store +0 -0
- py2ls/.git/.DS_Store +0 -0
- py2ls/.git/index +0 -0
- py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
- py2ls/.git/objects/.DS_Store +0 -0
- py2ls/.git/refs/.DS_Store +0 -0
- py2ls/ImageLoader.py +621 -0
- py2ls/__init__.py +7 -5
- py2ls/apptainer2ls.py +3940 -0
- py2ls/batman.py +164 -42
- py2ls/bio.py +2595 -0
- py2ls/cell_image_clf.py +1632 -0
- py2ls/container2ls.py +4635 -0
- py2ls/corr.py +475 -0
- py2ls/data/.DS_Store +0 -0
- py2ls/data/email/email_html_template.html +88 -0
- py2ls/data/hyper_param_autogluon_zeroshot2024.json +2383 -0
- py2ls/data/hyper_param_tabrepo_2024.py +1753 -0
- py2ls/data/mygenes_fields_241022.txt +355 -0
- py2ls/data/re_common_pattern.json +173 -0
- py2ls/data/sns_info.json +74 -0
- py2ls/data/styles/.DS_Store +0 -0
- py2ls/data/styles/example/.DS_Store +0 -0
- py2ls/data/styles/stylelib/.DS_Store +0 -0
- py2ls/data/styles/stylelib/grid.mplstyle +15 -0
- py2ls/data/styles/stylelib/high-contrast.mplstyle +6 -0
- py2ls/data/styles/stylelib/high-vis.mplstyle +4 -0
- py2ls/data/styles/stylelib/ieee.mplstyle +15 -0
- py2ls/data/styles/stylelib/light.mplstyl +6 -0
- py2ls/data/styles/stylelib/muted.mplstyle +6 -0
- py2ls/data/styles/stylelib/nature-reviews-latex.mplstyle +616 -0
- py2ls/data/styles/stylelib/nature-reviews.mplstyle +616 -0
- py2ls/data/styles/stylelib/nature.mplstyle +31 -0
- py2ls/data/styles/stylelib/no-latex.mplstyle +10 -0
- py2ls/data/styles/stylelib/notebook.mplstyle +36 -0
- py2ls/data/styles/stylelib/paper.mplstyle +290 -0
- py2ls/data/styles/stylelib/paper2.mplstyle +305 -0
- py2ls/data/styles/stylelib/retro.mplstyle +4 -0
- py2ls/data/styles/stylelib/sans.mplstyle +10 -0
- py2ls/data/styles/stylelib/scatter.mplstyle +7 -0
- py2ls/data/styles/stylelib/science.mplstyle +48 -0
- py2ls/data/styles/stylelib/std-colors.mplstyle +4 -0
- py2ls/data/styles/stylelib/vibrant.mplstyle +6 -0
- py2ls/data/tiles.csv +146 -0
- py2ls/data/usages_pd.json +1417 -0
- py2ls/data/usages_sns.json +31 -0
- py2ls/docker2ls.py +5446 -0
- py2ls/ec2ls.py +61 -0
- py2ls/fetch_update.py +145 -0
- py2ls/ich2ls.py +1955 -296
- py2ls/im2.py +8242 -0
- py2ls/image_ml2ls.py +2100 -0
- py2ls/ips.py +33909 -3418
- py2ls/ml2ls.py +7700 -0
- py2ls/mol.py +289 -0
- py2ls/mount2ls.py +1307 -0
- py2ls/netfinder.py +873 -351
- py2ls/nl2ls.py +283 -0
- py2ls/ocr.py +1581 -458
- py2ls/plot.py +10394 -314
- py2ls/rna2ls.py +311 -0
- py2ls/ssh2ls.md +456 -0
- py2ls/ssh2ls.py +5933 -0
- py2ls/ssh2ls_v01.py +2204 -0
- py2ls/stats.py +66 -172
- py2ls/temp20251124.py +509 -0
- py2ls/translator.py +2 -0
- py2ls/utils/decorators.py +3564 -0
- py2ls/utils_bio.py +3453 -0
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/METADATA +113 -224
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/RECORD +72 -16
- {py2ls-0.1.10.12.dist-info → py2ls-0.2.7.10.dist-info}/WHEEL +0 -0
py2ls/bio.py
ADDED
|
@@ -0,0 +1,2595 @@
|
|
|
1
|
+
# #======== 1. GEO data Processing Pipeline======
|
|
2
|
+
# # Load and integrate multiple datasets
|
|
3
|
+
# geo_data = load_geo(datasets, dir_save)
|
|
4
|
+
# complete_data = get_data(geo_data, dataset)
|
|
5
|
+
|
|
6
|
+
# # Quality control and normalization
|
|
7
|
+
# data_type = get_data_type(complete_data)
|
|
8
|
+
# if data_type == "read counts":
|
|
9
|
+
# normalized_data = counts2expression(complete_data, method='TMM')
|
|
10
|
+
|
|
11
|
+
# # Batch correction for multiple datasets
|
|
12
|
+
# corrected_data = batch_effect([data1, data2, data3], datasets)
|
|
13
|
+
|
|
14
|
+
# #======== 2. Differential Expression + Enrichment Pipeline======
|
|
15
|
+
# # DESeq2 analysis
|
|
16
|
+
# dds, diff_results, stats, norm_counts = counts_deseq(counts, metadata)
|
|
17
|
+
|
|
18
|
+
# # Enrichment analysis on significant genes
|
|
19
|
+
# sig_genes = diff_results[diff_results.padj < 0.05].gene.tolist()
|
|
20
|
+
# enrichment_results = get_enrichr(sig_genes, 'KEGG_2021_Human')
|
|
21
|
+
|
|
22
|
+
# # Visualization
|
|
23
|
+
# plot_enrichr(enrichment_results, kind='dotplot')
|
|
24
|
+
|
|
25
|
+
# #======== 3. Network Analysis Pipeline======
|
|
26
|
+
# # PPI network construction
|
|
27
|
+
# interactions = get_ppi(target_genes, species=9606, ci=0.7)
|
|
28
|
+
|
|
29
|
+
# # Network visualization and analysis
|
|
30
|
+
# G, ax = plot_ppi(interactions, layout='degree')
|
|
31
|
+
# key_proteins = top_ppi(interactions, n_top=10)
|
|
32
|
+
|
|
33
|
+
# #======== Dependencies ======
|
|
34
|
+
# GEOparse: GEO data access
|
|
35
|
+
# gseapy: Enrichment analysis
|
|
36
|
+
# pydeseq2: Differential expression
|
|
37
|
+
# rnanorm: Count normalization
|
|
38
|
+
# mygene: Gene identifier conversion
|
|
39
|
+
# networkx: Network analysis
|
|
40
|
+
# pyvis: Interactive network visualization
|
|
41
|
+
|
|
42
|
+
# This toolbox provides end-to-end capabilities for genomics data analysis from raw
|
|
43
|
+
# data loading through advanced network biology, with particular strengths in multi-
|
|
44
|
+
# dataset integration and interactive visualization.
|
|
45
|
+
|
|
46
|
+
import GEOparse
|
|
47
|
+
import gseapy as gp
|
|
48
|
+
from typing import Union
|
|
49
|
+
import pandas as pd
|
|
50
|
+
import numpy as np
|
|
51
|
+
import os
|
|
52
|
+
import logging
|
|
53
|
+
|
|
54
|
+
from sympy import use
|
|
55
|
+
from . import ips
|
|
56
|
+
from . import plot
|
|
57
|
+
import matplotlib.pyplot as plt
|
|
58
|
+
|
|
59
|
+
def load_geo(
|
|
60
|
+
datasets: Union[list, str] = ["GSE00000", "GSE00001"],
|
|
61
|
+
dir_save: str = "./datasets",
|
|
62
|
+
verbose=False,
|
|
63
|
+
) -> dict:
|
|
64
|
+
"""
|
|
65
|
+
Purpose: Downloads and loads GEO datasets from NCBI database
|
|
66
|
+
Principle: Uses GEOparse library to fetch and parse GEO SOFT files. Checks local cache first to avoid redundant downloads.
|
|
67
|
+
Key Operations:
|
|
68
|
+
* Verifies if datasets exist locally in specified directory
|
|
69
|
+
* Downloads missing datasets using GEOparse API
|
|
70
|
+
* Returns dictionary of GEO objects for further processing
|
|
71
|
+
|
|
72
|
+
Parameters:
|
|
73
|
+
datasets (list): List of GEO dataset IDs to download.
|
|
74
|
+
dir_save (str): Directory where datasets will be stored.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: A dictionary containing the GEO objects for each dataset.
|
|
78
|
+
"""
|
|
79
|
+
use_str = """
|
|
80
|
+
get_meta(geo: dict, dataset: str = "GSE25097")
|
|
81
|
+
get_expression_data(geo: dict, dataset: str = "GSE25097")
|
|
82
|
+
get_probe(geo: dict, dataset: str = "GSE25097", platform_id: str = "GPL10687")
|
|
83
|
+
get_data(geo: dict, dataset: str = "GSE25097")
|
|
84
|
+
"""
|
|
85
|
+
print(f"you could do further: \n{use_str}")
|
|
86
|
+
if not verbose:
|
|
87
|
+
logging.getLogger("GEOparse").setLevel(logging.WARNING)
|
|
88
|
+
else:
|
|
89
|
+
logging.getLogger("GEOparse").setLevel(logging.DEBUG)
|
|
90
|
+
# Create the directory if it doesn't exist
|
|
91
|
+
if not os.path.exists(dir_save):
|
|
92
|
+
os.makedirs(dir_save)
|
|
93
|
+
print(f"Created directory: {dir_save}")
|
|
94
|
+
if isinstance(datasets, str):
|
|
95
|
+
datasets = [datasets]
|
|
96
|
+
geo_data = {}
|
|
97
|
+
for dataset in datasets:
|
|
98
|
+
# Check if the dataset file already exists in the directory
|
|
99
|
+
dataset_file = os.path.join(dir_save, f"{dataset}_family.soft.gz")
|
|
100
|
+
|
|
101
|
+
if not os.path.isfile(dataset_file):
|
|
102
|
+
print(f"\n\nDataset {dataset} not found locally. Downloading...")
|
|
103
|
+
geo = GEOparse.get_GEO(geo=dataset, destdir=dir_save)
|
|
104
|
+
else:
|
|
105
|
+
print(f"\n\nDataset {dataset} already exists locally. Loading...")
|
|
106
|
+
geo = GEOparse.get_GEO(filepath=dataset_file)
|
|
107
|
+
|
|
108
|
+
geo_data[dataset] = geo
|
|
109
|
+
|
|
110
|
+
return geo_data
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def get_meta(geo: dict, dataset: str = "GSE25097", verbose=True) -> pd.DataFrame:
|
|
114
|
+
"""
|
|
115
|
+
Purpose: Extracts comprehensive metadata from GEO datasets
|
|
116
|
+
Principle: Parses hierarchical structure of GEO objects (study, platform, sample metadata) and flattens into DataFrame
|
|
117
|
+
Key Operations:
|
|
118
|
+
Combines study-level, platform-level, and sample-level metadata
|
|
119
|
+
Handles list-type metadata values by concatenation
|
|
120
|
+
Removes irrelevant columns (contact info, technical details)
|
|
121
|
+
Output: DataFrame with samples as rows and all available metadata as columns
|
|
122
|
+
|
|
123
|
+
df_meta = get_meta(geo, dataset="GSE25097")
|
|
124
|
+
Extracts metadata from a specific GEO dataset and returns it as a DataFrame.
|
|
125
|
+
The function dynamically extracts all available metadata fields from the given dataset.
|
|
126
|
+
|
|
127
|
+
Parameters:
|
|
128
|
+
geo (dict): A dictionary containing the GEO objects for different datasets.
|
|
129
|
+
dataset (str): The name of the dataset to extract metadata from (default is "GSE25097").
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
pd.DataFrame: A DataFrame containing structured metadata from the specified GEO dataset.
|
|
133
|
+
"""
|
|
134
|
+
# Check if the dataset is available in the provided GEO dictionary
|
|
135
|
+
if dataset not in geo:
|
|
136
|
+
raise ValueError(f"Dataset '{dataset}' not found in the provided GEO data.")
|
|
137
|
+
|
|
138
|
+
# List to store metadata dictionaries
|
|
139
|
+
meta_list = []
|
|
140
|
+
|
|
141
|
+
# Extract the GEO object for the specified dataset
|
|
142
|
+
geo_obj = geo[dataset]
|
|
143
|
+
|
|
144
|
+
# Overall Study Metadata
|
|
145
|
+
study_meta = geo_obj.metadata
|
|
146
|
+
study_metadata = {key: study_meta[key] for key in study_meta.keys()}
|
|
147
|
+
|
|
148
|
+
# Platform Metadata
|
|
149
|
+
for platform_id, platform in geo_obj.gpls.items():
|
|
150
|
+
platform_metadata = {
|
|
151
|
+
key: platform.metadata[key] for key in platform.metadata.keys()
|
|
152
|
+
}
|
|
153
|
+
platform_metadata["platform_id"] = platform_id # Include platform ID
|
|
154
|
+
|
|
155
|
+
# Sample Metadata
|
|
156
|
+
for sample_id, sample in geo_obj.gsms.items():
|
|
157
|
+
sample_metadata = {
|
|
158
|
+
key: sample.metadata[key] for key in sample.metadata.keys()
|
|
159
|
+
}
|
|
160
|
+
sample_metadata["sample_id"] = sample_id # Include sample ID
|
|
161
|
+
# Combine all metadata into a single dictionary
|
|
162
|
+
combined_meta = {
|
|
163
|
+
"dataset": dataset,
|
|
164
|
+
**{
|
|
165
|
+
k: (
|
|
166
|
+
v[0]
|
|
167
|
+
if isinstance(v, list) and len(v) == 1
|
|
168
|
+
else ", ".join(map(str, v))
|
|
169
|
+
)
|
|
170
|
+
for k, v in study_metadata.items()
|
|
171
|
+
}, # Flatten study metadata
|
|
172
|
+
**platform_metadata, # Unpack platform metadata
|
|
173
|
+
**{
|
|
174
|
+
k: (
|
|
175
|
+
v[0]
|
|
176
|
+
if isinstance(v, list) and len(v) == 1
|
|
177
|
+
else "".join(map(str, v))
|
|
178
|
+
)
|
|
179
|
+
for k, v in sample_metadata.items()
|
|
180
|
+
}, # Flatten sample metadata
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
# Append the combined metadata to the list
|
|
184
|
+
meta_list.append(combined_meta)
|
|
185
|
+
|
|
186
|
+
# Convert the list of dictionaries to a DataFrame
|
|
187
|
+
meta_df = pd.DataFrame(meta_list)
|
|
188
|
+
col_rm = [
|
|
189
|
+
"channel_count",
|
|
190
|
+
"contact_web_link",
|
|
191
|
+
"contact_address",
|
|
192
|
+
"contact_city",
|
|
193
|
+
"contact_country",
|
|
194
|
+
"contact_department",
|
|
195
|
+
"contact_email",
|
|
196
|
+
"contact_institute",
|
|
197
|
+
"contact_laboratory",
|
|
198
|
+
"contact_name",
|
|
199
|
+
"contact_phone",
|
|
200
|
+
"contact_state",
|
|
201
|
+
"contact_zip/postal_code",
|
|
202
|
+
"contributor",
|
|
203
|
+
"manufacture_protocol",
|
|
204
|
+
"taxid",
|
|
205
|
+
"web_link",
|
|
206
|
+
]
|
|
207
|
+
# rm unrelavent columns
|
|
208
|
+
meta_df = meta_df.drop(columns=[col for col in col_rm if col in meta_df.columns])
|
|
209
|
+
if verbose:
|
|
210
|
+
print(
|
|
211
|
+
f"Meta info columns for dataset '{dataset}': \n{sorted(meta_df.columns.tolist())}"
|
|
212
|
+
)
|
|
213
|
+
display(meta_df[:1].T)
|
|
214
|
+
return meta_df
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def get_probe(
|
|
218
|
+
geo: dict, dataset: str = "GSE25097", platform_id: str = None, verbose=True
|
|
219
|
+
):
|
|
220
|
+
"""
|
|
221
|
+
Purpose: Retrieves probe annotation information from GEO platforms
|
|
222
|
+
Principle: Accesses platform annotation tables containing gene symbols, IDs, and probe information
|
|
223
|
+
Key Operations:
|
|
224
|
+
Automatically detects platform IDs from metadata
|
|
225
|
+
Handles multiple platforms within single dataset
|
|
226
|
+
Provides direct links to NCBI platform pages for manual verification
|
|
227
|
+
|
|
228
|
+
df_probe = get_probe(geo, dataset="GSE25097", platform_id: str = "GPL10687")
|
|
229
|
+
"""
|
|
230
|
+
# try to find the platform_id from meta
|
|
231
|
+
if platform_id is None:
|
|
232
|
+
df_meta = get_meta(geo=geo, dataset=dataset, verbose=False)
|
|
233
|
+
platform_id = df_meta["platform_id"].unique().tolist()
|
|
234
|
+
print(f"Platform: {platform_id}")
|
|
235
|
+
if len(platform_id) > 1:
|
|
236
|
+
df_probe= geo[dataset].gpls[platform_id[0]].table
|
|
237
|
+
# df_probe=pd.DataFrame()
|
|
238
|
+
# # Iterate over each platform ID and collect the probe tables
|
|
239
|
+
# for platform_id_ in platform_id:
|
|
240
|
+
# if platform_id_ in geo[dataset].gpls:
|
|
241
|
+
# df_probe_ = geo[dataset].gpls[platform_id_].table
|
|
242
|
+
# if not df_probe_.empty:
|
|
243
|
+
# df_probe=pd.concat([df_probe, df_probe_])
|
|
244
|
+
# else:
|
|
245
|
+
# print(f"Warning: Probe table for platform {platform_id_} is empty.")
|
|
246
|
+
# else:
|
|
247
|
+
# print(f"Warning: Platform ID {platform_id_} not found in dataset {dataset}.")
|
|
248
|
+
else:
|
|
249
|
+
df_probe= geo[dataset].gpls[platform_id[0]].table
|
|
250
|
+
|
|
251
|
+
if df_probe.empty:
|
|
252
|
+
print(
|
|
253
|
+
f"Warning: cannot find the probe info. 看一下是不是在单独的文件中包含了probe信息"
|
|
254
|
+
)
|
|
255
|
+
display(f"🔗: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc={platform_id}")
|
|
256
|
+
return get_meta(geo, dataset, verbose=verbose)
|
|
257
|
+
if verbose:
|
|
258
|
+
print(f"columns in the probe table: \n{sorted(df_probe.columns.tolist())}")
|
|
259
|
+
return df_probe
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_expression_data(geo: dict, dataset: str = "GSE25097") -> pd.DataFrame:
|
|
263
|
+
"""
|
|
264
|
+
Purpose: Extracts expression matrix from GEO datasets
|
|
265
|
+
Principle: Pivots sample tables to create gene expression matrix
|
|
266
|
+
Key Operations:
|
|
267
|
+
Handles both pre-pivoted data and individual sample tables
|
|
268
|
+
Maintains sample IDs as columns/rows appropriately
|
|
269
|
+
Output: DataFrame with expression values
|
|
270
|
+
|
|
271
|
+
df_expression = get_expression_data(geo,dataset="GSE25097")
|
|
272
|
+
只包含表达量数据,并没有考虑它的probe和其它的meta
|
|
273
|
+
|
|
274
|
+
Extracts expression values from GEO datasets and returns it as a DataFrame.
|
|
275
|
+
|
|
276
|
+
Parameters:
|
|
277
|
+
geo (dict): A dictionary containing GEO objects for each dataset.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
pd.DataFrame: A DataFrame containing expression data from the GEO datasets.
|
|
281
|
+
"""
|
|
282
|
+
expression_dataframes = []
|
|
283
|
+
try:
|
|
284
|
+
expression_values = geo[dataset].pivot_samples("VALUE")
|
|
285
|
+
except:
|
|
286
|
+
for sample_id, sample in geo[dataset].gsms.items():
|
|
287
|
+
if hasattr(sample, "table"):
|
|
288
|
+
expression_values = (
|
|
289
|
+
sample.table.T
|
|
290
|
+
) # Transpose for easier DataFrame creation
|
|
291
|
+
expression_values["dataset"] = dataset
|
|
292
|
+
expression_values["sample_id"] = sample_id
|
|
293
|
+
return expression_values
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def get_data(geo: dict, dataset: str = "GSE25097", verbose=False):
|
|
297
|
+
"""
|
|
298
|
+
Purpose: Comprehensive data integration - merges expression data with probe annotations and metadata
|
|
299
|
+
Principle: Performs multi-level data integration using pandas merge operations
|
|
300
|
+
Key Operations:
|
|
301
|
+
• Merges probe annotations with expression data
|
|
302
|
+
• Transposes expression matrix to samples-as-rows format
|
|
303
|
+
• Integrates metadata using sample IDs
|
|
304
|
+
• Automatically detects and normalizes raw counts dataOutput: Complete dataset ready for analysis
|
|
305
|
+
"""
|
|
306
|
+
print(f"\n\ndataset: {dataset}\n")
|
|
307
|
+
# get probe info
|
|
308
|
+
df_probe = get_probe(geo, dataset=dataset, verbose=False)
|
|
309
|
+
# get expression values
|
|
310
|
+
df_expression = get_expression_data(geo, dataset=dataset)
|
|
311
|
+
if not df_expression.select_dtypes(include=["number"]).empty:
|
|
312
|
+
# 如果数据全部是counts类型的话, 则使用TMM进行normalize
|
|
313
|
+
if 'counts' in get_data_type(df_expression):
|
|
314
|
+
try:
|
|
315
|
+
df_expression=counts2expression(df_expression.T).T
|
|
316
|
+
print(f"{dataset}'s type is raw read counts, nomalized(transformed) via 'TMM'")
|
|
317
|
+
except Exception as e:
|
|
318
|
+
print("raw counts data")
|
|
319
|
+
if any([df_probe.empty, df_expression.empty]):
|
|
320
|
+
print(
|
|
321
|
+
f"got empty values, check the probe info. 看一下是不是在单独的文件中包含了probe信息"
|
|
322
|
+
)
|
|
323
|
+
return get_meta(geo, dataset, verbose=True)
|
|
324
|
+
print(
|
|
325
|
+
f"\n\tdf_expression.shape: {df_expression.shape} \n\tdf_probe.shape: {df_probe.shape}"
|
|
326
|
+
)
|
|
327
|
+
df_exp = pd.merge(
|
|
328
|
+
df_probe,
|
|
329
|
+
df_expression,
|
|
330
|
+
left_on=df_probe.columns.tolist()[0],
|
|
331
|
+
right_index=True,
|
|
332
|
+
how="outer",
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# get meta info
|
|
336
|
+
df_meta = get_meta(geo, dataset=dataset, verbose=False)
|
|
337
|
+
col_rm = [
|
|
338
|
+
"channel_count","contact_web_link","contact_address","contact_city","contact_country","contact_department",
|
|
339
|
+
"contact_email","contact_institute","contact_laboratory","contact_name","contact_phone","contact_state",
|
|
340
|
+
"contact_zip/postal_code","contributor","manufacture_protocol","taxid","web_link",
|
|
341
|
+
]
|
|
342
|
+
# rm unrelavent columns
|
|
343
|
+
df_meta = df_meta.drop(columns=[col for col in col_rm if col in df_meta.columns])
|
|
344
|
+
# sorte columns
|
|
345
|
+
df_meta = df_meta.reindex(sorted(df_meta.columns), axis=1)
|
|
346
|
+
# find a proper column
|
|
347
|
+
col_sample_id = ips.strcmp("sample_id", df_meta.columns.tolist())[0]
|
|
348
|
+
df_meta.set_index(col_sample_id, inplace=True) # set gene symbol as index
|
|
349
|
+
|
|
350
|
+
col_gene_symbol = ips.strcmp("GeneSymbol", df_exp.columns.tolist())[0]
|
|
351
|
+
# select the 'GSM' columns
|
|
352
|
+
col_gsm = df_exp.columns[df_exp.columns.str.startswith("GSM")].tolist()
|
|
353
|
+
df_exp.set_index(col_gene_symbol, inplace=True)
|
|
354
|
+
df_exp = df_exp[col_gsm].T # transpose, so that could add meta info
|
|
355
|
+
|
|
356
|
+
df_merged = ips.df_merge(df_meta, df_exp,use_index=True)
|
|
357
|
+
|
|
358
|
+
print(
|
|
359
|
+
f"\ndataset:'{dataset}' n_sample = {df_merged.shape[0]}, n_gene={df_exp.shape[1]}"
|
|
360
|
+
)
|
|
361
|
+
if verbose:
|
|
362
|
+
display(df_merged.sample(5))
|
|
363
|
+
return df_merged
|
|
364
|
+
|
|
365
|
+
def get_data_type(data: pd.DataFrame) -> str:
|
|
366
|
+
"""
|
|
367
|
+
Purpose: Automatically determines data type (raw counts vs normalized expression)
|
|
368
|
+
Principle: Analyzes numerical characteristics of expression data
|
|
369
|
+
Key Operations:
|
|
370
|
+
Checks data types (integers vs floats)
|
|
371
|
+
Examines value ranges and distributions
|
|
372
|
+
Uses thresholds to classify as counts (>10,000 max) or normalized (<1,000 max)
|
|
373
|
+
|
|
374
|
+
Determine the type of data: 'read counts' or 'normalized expression data'.
|
|
375
|
+
usage:
|
|
376
|
+
get_data_type(df_counts)
|
|
377
|
+
"""
|
|
378
|
+
numeric_data = data.select_dtypes(include=["number"])
|
|
379
|
+
if numeric_data.empty:
|
|
380
|
+
raise ValueError(f"找不到数字格式的数据, 请先进行转换")
|
|
381
|
+
# Check if the data contains only integers
|
|
382
|
+
if numeric_data.apply(lambda x: x.dtype == "int").all():
|
|
383
|
+
# Check for values typically found in raw read counts (large integers)
|
|
384
|
+
if numeric_data.max().max() > 10000: # Threshold for raw counts
|
|
385
|
+
return "read counts"
|
|
386
|
+
# Check if all values are floats
|
|
387
|
+
if numeric_data.apply(lambda x: x.dtype == "float").all():
|
|
388
|
+
# If values are small, it's likely normalized data
|
|
389
|
+
if numeric_data.max().max() < 1000: # Threshold for normalized expression
|
|
390
|
+
return "normalized expression data"
|
|
391
|
+
else:
|
|
392
|
+
print(f"the max value: {numeric_data.max().max()}, it could be a raw read counts data. but needs you to double check it")
|
|
393
|
+
return "read counts"
|
|
394
|
+
# If mixed data types or unexpected values
|
|
395
|
+
return "mixed or unknown"
|
|
396
|
+
|
|
397
|
+
def split_at_lower_upper(lst):
|
|
398
|
+
"""
|
|
399
|
+
将一串list,从全是lowercase,然后就是大写或者nan的地方分隔成两个list
|
|
400
|
+
"""
|
|
401
|
+
for i in range(len(lst) - 1):
|
|
402
|
+
if isinstance(lst[i], str) and lst[i].islower():
|
|
403
|
+
next_item = lst[i + 1]
|
|
404
|
+
if isinstance(next_item, str) and next_item.isupper():
|
|
405
|
+
# Found the split point: lowercase followed by uppercase
|
|
406
|
+
return lst[: i + 1], lst[i + 1 :]
|
|
407
|
+
elif pd.isna(next_item):
|
|
408
|
+
# NaN case after a lowercase string
|
|
409
|
+
return lst[: i + 1], lst[i + 1 :]
|
|
410
|
+
return lst, []
|
|
411
|
+
|
|
412
|
+
def find_condition(data:pd.DataFrame, columns=["characteristics_ch1","title"]):
|
|
413
|
+
if data.shape[1]>=data.shape[0]:
|
|
414
|
+
display(data.iloc[:1,:40].T)
|
|
415
|
+
# 详细看看每个信息的有哪些类, 其中有数字的, 要去除
|
|
416
|
+
for col in columns:
|
|
417
|
+
print(f"{"="*10} {col} {"="*10}")
|
|
418
|
+
display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]],verbose=False))
|
|
419
|
+
|
|
420
|
+
def add_condition(
|
|
421
|
+
data: pd.DataFrame,
|
|
422
|
+
column: str = "characteristics_ch1", # 在哪一行进行分类
|
|
423
|
+
column_new: str = "condition", # 新col的命名
|
|
424
|
+
by: str = "tissue: tumor liver", # 通过by来命名
|
|
425
|
+
by_not: str = ": tumor", # 健康的选择条件
|
|
426
|
+
by_name: str = "non-tumor", # 健康的命名
|
|
427
|
+
by_not_name: str = "tumor", # 不健康的命名
|
|
428
|
+
inplace: bool = True, # replace the data
|
|
429
|
+
verbose: bool = True,
|
|
430
|
+
):
|
|
431
|
+
"""
|
|
432
|
+
Purpose: Automated sample grouping based on metadata patterns
|
|
433
|
+
Principle: String matching and pattern extraction from metadata columns
|
|
434
|
+
Usage: Rapid experimental design setup for differential analysis
|
|
435
|
+
|
|
436
|
+
Add a new column to the DataFrame based on the presence of a specific substring in another column.
|
|
437
|
+
|
|
438
|
+
Parameters
|
|
439
|
+
----------
|
|
440
|
+
data : pd.DataFrame
|
|
441
|
+
The input DataFrame containing the data.
|
|
442
|
+
column : str, optional
|
|
443
|
+
The name of the column in which to search for the substring (default is 'characteristics_ch1').
|
|
444
|
+
column_new : str, optional
|
|
445
|
+
The name of the new column to be created (default is 'condition').
|
|
446
|
+
by : str, optional
|
|
447
|
+
The substring to search for in the specified column (default is 'heal').
|
|
448
|
+
|
|
449
|
+
"""
|
|
450
|
+
# first check the content in column
|
|
451
|
+
content = data[column].unique().tolist()
|
|
452
|
+
if verbose:
|
|
453
|
+
if len(content) > 10:
|
|
454
|
+
display(content[:10])
|
|
455
|
+
else:
|
|
456
|
+
display(content)
|
|
457
|
+
# 优先by
|
|
458
|
+
if by:
|
|
459
|
+
data[column_new] = data[column].apply(
|
|
460
|
+
lambda x: by_name if by in x else by_not_name
|
|
461
|
+
)
|
|
462
|
+
elif by_not:
|
|
463
|
+
data[column_new] = data[column].apply(
|
|
464
|
+
lambda x: by_not_name if not by_not in x else by_name
|
|
465
|
+
)
|
|
466
|
+
if verbose:
|
|
467
|
+
display(data.sample(5))
|
|
468
|
+
if not inplace:
|
|
469
|
+
return data
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def add_condition_multi(
|
|
473
|
+
data: pd.DataFrame,
|
|
474
|
+
column: str = "characteristics_ch1", # Column to classify
|
|
475
|
+
column_new: str = "condition", # New column name
|
|
476
|
+
conditions: dict = {
|
|
477
|
+
"low": "low",
|
|
478
|
+
"high": "high",
|
|
479
|
+
"intermediate": "intermediate",
|
|
480
|
+
}, # A dictionary where keys are substrings and values are condition names
|
|
481
|
+
default_name: str = "unknown", # Default name if no condition matches
|
|
482
|
+
inplace: bool = True, # Whether to replace the data
|
|
483
|
+
verbose: bool = True,
|
|
484
|
+
):
|
|
485
|
+
"""
|
|
486
|
+
Add a new column to the DataFrame based on the presence of specific substrings in another column.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
data : pd.DataFrame
|
|
491
|
+
The input DataFrame containing the data.
|
|
492
|
+
column : str, optional
|
|
493
|
+
The name of the column in which to search for the substrings (default is 'characteristics_ch1').
|
|
494
|
+
column_new : str, optional
|
|
495
|
+
The name of the new column to be created (default is 'condition').
|
|
496
|
+
conditions : dict, optional
|
|
497
|
+
A dictionary where keys are substrings to search for and values are the corresponding labels.
|
|
498
|
+
default_name : str, optional
|
|
499
|
+
The name to assign if no condition matches (default is 'unknown').
|
|
500
|
+
inplace : bool, optional
|
|
501
|
+
Whether to modify the original DataFrame (default is True).
|
|
502
|
+
verbose : bool, optional
|
|
503
|
+
Whether to display the unique values and final DataFrame (default is True).
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
# Display the unique values in the column
|
|
507
|
+
content = data[column].unique().tolist()
|
|
508
|
+
if verbose:
|
|
509
|
+
if len(content) > 10:
|
|
510
|
+
display(content[:10])
|
|
511
|
+
else:
|
|
512
|
+
display(content)
|
|
513
|
+
|
|
514
|
+
# Check if conditions are provided
|
|
515
|
+
if conditions is None:
|
|
516
|
+
raise ValueError(
|
|
517
|
+
"Conditions must be provided as a dictionary with substrings and corresponding labels."
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Define a helper function to map the conditions
|
|
521
|
+
def map_condition(value):
|
|
522
|
+
for substring, label in conditions.items():
|
|
523
|
+
if substring in value:
|
|
524
|
+
return label
|
|
525
|
+
return default_name # If no condition matches, return the default name
|
|
526
|
+
|
|
527
|
+
# Apply the mapping function to create the new column
|
|
528
|
+
data[column_new] = data[column].apply(map_condition)
|
|
529
|
+
|
|
530
|
+
# Display the updated DataFrame if verbose is True
|
|
531
|
+
if verbose:
|
|
532
|
+
display(data.sample(5))
|
|
533
|
+
|
|
534
|
+
if not inplace:
|
|
535
|
+
return data
|
|
536
|
+
|
|
537
|
+
def clean_dataset(
|
|
538
|
+
data: pd.DataFrame, dataset: str = None, condition: str = "condition",sep="///"
|
|
539
|
+
):
|
|
540
|
+
"""
|
|
541
|
+
Purpose: Standardizes and cleans integrated datasets for analysis
|
|
542
|
+
Principle: Handles multi-mapping genes and data formatting issues
|
|
543
|
+
Key Operations:
|
|
544
|
+
Extends genes with multiple symbols (e.g., "///" separated)
|
|
545
|
+
Removes duplicates and missing values
|
|
546
|
+
Formats sample names with dataset and condition information
|
|
547
|
+
Sets genes as index for downstream analysis
|
|
548
|
+
|
|
549
|
+
#* it has been involved in bio.batch_effects(), but default: False
|
|
550
|
+
1. clean data set and prepare super_datasets
|
|
551
|
+
2. if "///" in index, then extend it, or others.
|
|
552
|
+
3. drop duplicates and dropna()
|
|
553
|
+
4. add the 'condition' and 'dataset info' to the columns
|
|
554
|
+
5. set genes as index
|
|
555
|
+
"""
|
|
556
|
+
usage_str="""clean_dataset(data: pd.DataFrame, dataset: str = None, condition: str = "condition",sep="///")
|
|
557
|
+
"""
|
|
558
|
+
if dataset is None:
|
|
559
|
+
try:
|
|
560
|
+
dataset=data["dataset"][0]
|
|
561
|
+
except:
|
|
562
|
+
print("cannot find 'dataset' name")
|
|
563
|
+
print(f"example\n {usage_str}")
|
|
564
|
+
#! (4.1) clean data set and prepare super_datasets
|
|
565
|
+
# df_data_2, 左边的列是meta,右边的列是gene_symbol
|
|
566
|
+
col_gene = split_at_lower_upper(data.columns.tolist())[1][0]
|
|
567
|
+
idx = ips.strcmp(col_gene, data.columns.tolist())[1]
|
|
568
|
+
df_gene = data.iloc[:, idx:].T # keep the last 'condition'
|
|
569
|
+
|
|
570
|
+
#! if "///" in index, then extend it, or others.
|
|
571
|
+
print(f"before extend shape: {df_gene.shape}")
|
|
572
|
+
df = df_gene.reset_index()
|
|
573
|
+
df_gene = ips.df_extend(df, column="index", sep=sep)
|
|
574
|
+
# reset 'index' column as index
|
|
575
|
+
# df_gene = df_gene.set_index("index")
|
|
576
|
+
print(f"after extended by '{sep}' shape: {df_gene.shape}")
|
|
577
|
+
|
|
578
|
+
# *alternative:
|
|
579
|
+
# df_unique = df.reset_index().drop_duplicates(subset="index").set_index("index")
|
|
580
|
+
#! 4.2 drop duplicates and dropna()
|
|
581
|
+
df_gene = df_gene.drop_duplicates(subset=["index"]).dropna()
|
|
582
|
+
print(f"drop duplicates and dropna: shape: {df_gene.shape}")
|
|
583
|
+
|
|
584
|
+
#! add the 'condition' and 'dataset info' to the columns
|
|
585
|
+
ds = [data["dataset"][0]] * len(df_gene.columns[1:])
|
|
586
|
+
samp = df_gene.columns.tolist()[1:]
|
|
587
|
+
cond = df_gene[df_gene["index"] == condition].values.tolist()[0][1:]
|
|
588
|
+
df_gene.columns = ["index"] + [
|
|
589
|
+
f"{ds}_{sam}_{cond}" for (ds, sam, cond) in zip(ds, samp, cond)
|
|
590
|
+
]
|
|
591
|
+
df_gene.drop(df_gene[df_gene["index"] == condition].index, inplace=True)
|
|
592
|
+
#! set genes as index
|
|
593
|
+
df_gene.set_index("index",inplace=True)
|
|
594
|
+
display(df_gene.head())
|
|
595
|
+
return df_gene
|
|
596
|
+
|
|
597
|
+
def batch_effect(
|
|
598
|
+
data: list = "[df_gene_1, df_gene_2, df_gene_3]", # index (genes),columns(samples)
|
|
599
|
+
datasets: list = ["GSE25097", "GSE62232", "GSE65372"],
|
|
600
|
+
clean_data:bool=False, # default, not do data cleaning
|
|
601
|
+
top_genes:int=10,# only for plotting
|
|
602
|
+
plot_=True,
|
|
603
|
+
dir_save="./res/",
|
|
604
|
+
kws_clean_dataset:dict={},
|
|
605
|
+
**kwargs
|
|
606
|
+
):
|
|
607
|
+
"""
|
|
608
|
+
Purpose: Corrects batch effects across multiple datasets using combat algorithm
|
|
609
|
+
Principle: Empirical Bayes framework to adjust for technical variations
|
|
610
|
+
Key Operations:
|
|
611
|
+
Identifies common genes across datasets
|
|
612
|
+
Applies pyComBat normalization
|
|
613
|
+
Provides before/after visualization
|
|
614
|
+
Dependencies: combat.pycombat
|
|
615
|
+
usage 1:
|
|
616
|
+
bio.batch_effect(
|
|
617
|
+
data=[df_gene_1, df_gene_2, df_gene_3],
|
|
618
|
+
datasets=["GSE25097", "GSE62232", "GSE65372"],
|
|
619
|
+
clean_data=False,
|
|
620
|
+
dir_save="./res/")
|
|
621
|
+
|
|
622
|
+
#! # or conbine clean_dataset and batch_effect together
|
|
623
|
+
# # data = [bio.clean_dataset(data=dt, dataset=ds) for (dt, ds) in zip(data, datasets)]
|
|
624
|
+
data_common = bio.batch_effect(
|
|
625
|
+
data=[df_data_1, df_data_2, df_data_3],
|
|
626
|
+
datasets=["GSE25097", "GSE62232", "GSE65372"], clean_data=True
|
|
627
|
+
)
|
|
628
|
+
"""
|
|
629
|
+
# data = [df_gene_1, df_gene_2, df_gene_3]
|
|
630
|
+
# datasets = ["GSE25097", "GSE62232", "GSE65372"]
|
|
631
|
+
# top_genes = 10 # show top 10 genes
|
|
632
|
+
# plot_ = True
|
|
633
|
+
from combat.pycombat import pycombat
|
|
634
|
+
if clean_data:
|
|
635
|
+
data=[clean_dataset(data=dt,dataset=ds,**kws_clean_dataset) for (dt,ds) in zip(data,datasets)]
|
|
636
|
+
#! prepare data
|
|
637
|
+
# the datasets are dataframes where:
|
|
638
|
+
# the indexes correspond to the gene names
|
|
639
|
+
# the column names correspond to the sample names
|
|
640
|
+
#! merge batchs
|
|
641
|
+
# https://epigenelabs.github.io/pyComBat/
|
|
642
|
+
# we merge all the datasets into one, by keeping the common genes only
|
|
643
|
+
df_expression_common_genes = pd.concat(data, join="inner", axis=1)
|
|
644
|
+
#! convert to float
|
|
645
|
+
ips.df_astype(df_expression_common_genes, astype="float", inplace=True)
|
|
646
|
+
|
|
647
|
+
#!to visualise results, use Mini datasets, only take the first 10 samples of each batch(dataset)
|
|
648
|
+
if plot_:
|
|
649
|
+
col2plot = []
|
|
650
|
+
for ds in datasets:
|
|
651
|
+
# select the first 10 samples to plot, to see the diff
|
|
652
|
+
dat_tmp = df_expression_common_genes.columns[
|
|
653
|
+
df_expression_common_genes.columns.str.startswith(ds)
|
|
654
|
+
][:top_genes].tolist()
|
|
655
|
+
col2plot.extend(dat_tmp)
|
|
656
|
+
# visualise results
|
|
657
|
+
_, axs = plt.subplots(2, 1, figsize=(15, 10))
|
|
658
|
+
plot.plotxy(
|
|
659
|
+
ax=axs[0],
|
|
660
|
+
data=df_expression_common_genes.loc[:, col2plot],
|
|
661
|
+
kind_="bar",
|
|
662
|
+
figsets=dict(
|
|
663
|
+
title="Samples expression distribution (non-correction)",
|
|
664
|
+
ylabel="Observations",
|
|
665
|
+
xangle=90,
|
|
666
|
+
),
|
|
667
|
+
)
|
|
668
|
+
# prepare batch list
|
|
669
|
+
batch = [
|
|
670
|
+
ips.ssplit(i, by="_")[0] for i in df_expression_common_genes.columns.tolist()
|
|
671
|
+
]
|
|
672
|
+
# run pyComBat
|
|
673
|
+
df_corrected = pycombat(df_expression_common_genes, batch, **kwargs)
|
|
674
|
+
print(f"df_corrected.shape: {df_corrected.shape}")
|
|
675
|
+
display(df_corrected.head())
|
|
676
|
+
# visualise results again
|
|
677
|
+
if plot_:
|
|
678
|
+
|
|
679
|
+
plot.plotxy(
|
|
680
|
+
ax=axs[1],
|
|
681
|
+
data=df_corrected.loc[:, col2plot],
|
|
682
|
+
kind_="bar",
|
|
683
|
+
figsets=dict(
|
|
684
|
+
title="Samples expression distribution (corrected)",
|
|
685
|
+
ylabel="Observations",
|
|
686
|
+
xangle=90,
|
|
687
|
+
),
|
|
688
|
+
)
|
|
689
|
+
if dir_save is not None:
|
|
690
|
+
ips.figsave(dir_save + "batch_sample_exp_distri.pdf")
|
|
691
|
+
return df_corrected
|
|
692
|
+
|
|
693
|
+
def get_common_genes(elment1, elment2):
|
|
694
|
+
"""
|
|
695
|
+
Purpose: Identifies shared genes between datasets or gene lists
|
|
696
|
+
Principle: Set intersection operation with informative output
|
|
697
|
+
Usage: Essential for cross-dataset integration and comparison
|
|
698
|
+
"""
|
|
699
|
+
common_genes=ips.shared(elment1, elment2,verbose=False)
|
|
700
|
+
return common_genes
|
|
701
|
+
|
|
702
|
+
def counts2expression(
|
|
703
|
+
counts: pd.DataFrame,# index(samples); columns(genes)
|
|
704
|
+
method: str = "TMM", # 'CPM', 'FPKM', 'TPM', 'UQ', 'TMM', 'CUF', 'CTF'
|
|
705
|
+
length: list = None,
|
|
706
|
+
uq_factors: pd.Series = None,
|
|
707
|
+
verbose: bool = False,
|
|
708
|
+
) -> pd.DataFrame:
|
|
709
|
+
"""
|
|
710
|
+
Purpose: Converts raw RNA-seq counts to normalized expression values
|
|
711
|
+
Principle: Implements multiple normalization methods for cross-dataset compatibility
|
|
712
|
+
Supported Methods:
|
|
713
|
+
TMM: Trimmed Mean of M-values - robust against compositional biases
|
|
714
|
+
TPM: Transcripts Per Million - length-normalized for cross-comparison
|
|
715
|
+
CPM: Counts Per Million - simple library size normalization
|
|
716
|
+
FPKM: Fragments Per Kilobase Million - length and library size normalized
|
|
717
|
+
UQ: Upper Quartile - uses 75th percentile for scaling
|
|
718
|
+
Recommendations: TMM for cross-datasets, TPM for single datasets
|
|
719
|
+
|
|
720
|
+
https://www.linkedin.com/pulse/snippet-corner-raw-read-count-normalization-python-mazzalab-gzzyf?trk=public_post
|
|
721
|
+
Convert raw RNA-seq read counts to expression values
|
|
722
|
+
counts: pd.DataFrame
|
|
723
|
+
index: samples
|
|
724
|
+
columns: genes
|
|
725
|
+
usage:
|
|
726
|
+
df_normalized = counts2expression(df_counts, method='TMM', verbose=True)
|
|
727
|
+
recommend cross datasets:
|
|
728
|
+
cross-datasets:
|
|
729
|
+
TMM (Trimmed Mean of M-values); Very suitable for merging datasets, especially
|
|
730
|
+
for cross-sample and cross-dataset comparisons; commonly used in
|
|
731
|
+
differential expression analysis
|
|
732
|
+
CTF (Counts adjusted with TMM factors); Suitable for merging datasets, as
|
|
733
|
+
TMM-based normalization. Typically used as input for downstream analyses
|
|
734
|
+
like differential expression
|
|
735
|
+
TPM (Transcripts Per Million); Good for merging datasets. TPM is often more
|
|
736
|
+
suitable for cross-dataset comparisons because it adjusts for gene length
|
|
737
|
+
and ensures that the expression levels sum to the same total in each sample
|
|
738
|
+
UQ (Upper Quartile); less commonly used than TPM or TMM
|
|
739
|
+
CUF (Counts adjusted with UQ factors); Can be used, but UQ normalization is
|
|
740
|
+
generally not as standardized as TPM or TMM for merging datasets.
|
|
741
|
+
within-datasets:
|
|
742
|
+
CPM(Counts Per Million); it doesn’t adjust for gene length or other
|
|
743
|
+
variables that could vary across datasets
|
|
744
|
+
FPKM(Fragments Per Kilobase Million); FPKM has been known to be inconsistent
|
|
745
|
+
across different experiments
|
|
746
|
+
Parameters:
|
|
747
|
+
- counts: pd.DataFrame
|
|
748
|
+
Raw read counts with genes as rows and samples as columns.
|
|
749
|
+
- method: str, default='TMM'
|
|
750
|
+
CPM (Counts per Million): Scales counts by total library size.
|
|
751
|
+
FPKM (Fragments per Kilobase Million): Requires gene length; scales by both library size and gene length.
|
|
752
|
+
TPM (Transcripts per Million): Scales by gene length and total transcript abundance.
|
|
753
|
+
UQ (Upper Quartile): Normalizes based on the upper quartile of the counts.
|
|
754
|
+
TMM (Trimmed Mean of M-values): Adjusts for compositional biases.
|
|
755
|
+
CUF (Counts adjusted with Upper Quartile factors): Counts adjusted based on UQ factors.
|
|
756
|
+
CTF (Counts adjusted with TMM factors): Counts adjusted based on TMM factors.
|
|
757
|
+
- gene_lengths: pd.Series, optional
|
|
758
|
+
Gene lengths (e.g., in kilobases) for FPKM/TPM normalization. Required for FPKM/TPM.
|
|
759
|
+
- verbose: bool, default=False
|
|
760
|
+
If True, provides detailed logging information.
|
|
761
|
+
- uq_factors: pd.Series, optional
|
|
762
|
+
Precomputed Upper Quartile factors, required for UQ and CUF normalization.
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
Returns:
|
|
766
|
+
- normalized_counts: pd.DataFrame
|
|
767
|
+
Normalized expression values.
|
|
768
|
+
"""
|
|
769
|
+
import rnanorm
|
|
770
|
+
print(f"INFO: 'counts' data shoule be: index(samples); columns(genes)")
|
|
771
|
+
if "length" in method: # 有时候记不住这么多不同的名字
|
|
772
|
+
method="FPKM"
|
|
773
|
+
methods = ["CPM", "FPKM", "TPM", "UQ", "TMM", "CUF", "CTF"]
|
|
774
|
+
method = ips.strcmp(method, methods)[0]
|
|
775
|
+
if verbose:
|
|
776
|
+
print(
|
|
777
|
+
f"Starting normalization using method: {method},supported methods: {methods}"
|
|
778
|
+
)
|
|
779
|
+
columns_org = counts.columns.tolist()
|
|
780
|
+
# Check if gene lengths are provided when necessary
|
|
781
|
+
if method in ["FPKM", "TPM"]:
|
|
782
|
+
if length is None:
|
|
783
|
+
raise ValueError(f"Gene lengths must be provided for {method} normalization.")
|
|
784
|
+
if isinstance(length, list):
|
|
785
|
+
df_genelength = pd.DataFrame({"gene_length": length})
|
|
786
|
+
df_genelength.index = counts.columns # set gene_id as index
|
|
787
|
+
df_genelength.index = df_genelength.index.astype(str).str.strip()
|
|
788
|
+
# length = np.array(df_genelength["gene_length"]).reshape(1,-1)
|
|
789
|
+
length = df_genelength["gene_length"]
|
|
790
|
+
counts.index = counts.index.astype(str).str.strip()
|
|
791
|
+
elif isinstance(length, pd.Series):
|
|
792
|
+
|
|
793
|
+
length.index=length.index.astype(str).str.strip()
|
|
794
|
+
counts.columns = counts.columns.astype(str).str.strip()
|
|
795
|
+
shared_genes=ips.shared(length.index, counts.columns,verbose=False)
|
|
796
|
+
length=length.loc[shared_genes]
|
|
797
|
+
counts=counts.loc[:,shared_genes]
|
|
798
|
+
columns_org = counts.columns.tolist()
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
# # Ensure gene lengths are aligned with counts if provided
|
|
802
|
+
# if length is not None:
|
|
803
|
+
# length = length[counts.index]
|
|
804
|
+
|
|
805
|
+
# Start the normalization based on the chosen method
|
|
806
|
+
if method == "CPM":
|
|
807
|
+
normalized_counts = (
|
|
808
|
+
rnanorm.CPM().set_output(transform="pandas").fit_transform(counts)
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
elif method == "FPKM":
|
|
812
|
+
if verbose:
|
|
813
|
+
print("Performing FPKM normalization using gene lengths.")
|
|
814
|
+
normalized_counts = (
|
|
815
|
+
rnanorm.CPM().set_output(transform="pandas").fit_transform(counts)
|
|
816
|
+
)
|
|
817
|
+
# convert it to FPKM by, {FPKM= gene length /read counts ×1000} is applied using row-wise division and multiplication.
|
|
818
|
+
normalized_counts=normalized_counts.div(length.values,axis=1)*1e3
|
|
819
|
+
|
|
820
|
+
elif method == "TPM":
|
|
821
|
+
if verbose:
|
|
822
|
+
print("Performing TPM normalization using gene lengths.")
|
|
823
|
+
normalized_counts = (
|
|
824
|
+
rnanorm.TPM(gene_lengths=length)
|
|
825
|
+
.set_output(transform="pandas")
|
|
826
|
+
.fit_transform(counts)
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
elif method == "UQ":
|
|
830
|
+
if verbose:
|
|
831
|
+
print("Performing Upper Quartile (UQ) normalization.")
|
|
832
|
+
if uq_factors is None:
|
|
833
|
+
uq_factors = rnanorm.upper_quartile_factors(counts)
|
|
834
|
+
normalized_counts = (
|
|
835
|
+
rnanorm.UQ(factors=uq_factors)()
|
|
836
|
+
.set_output(transform="pandas")
|
|
837
|
+
.fit_transform(counts)
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
elif method == "TMM":
|
|
841
|
+
if verbose:
|
|
842
|
+
print("Performing TMM normalization (Trimmed Mean of M-values).")
|
|
843
|
+
normalized_counts = (
|
|
844
|
+
rnanorm.TMM().set_output(transform="pandas").fit_transform(counts)
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
elif method == "CUF":
|
|
848
|
+
if verbose:
|
|
849
|
+
print("Performing Counts adjusted with UQ factors (CUF).")
|
|
850
|
+
if uq_factors is None:
|
|
851
|
+
uq_factors = rnanorm.upper_quartile_factors(counts)
|
|
852
|
+
normalized_counts = (
|
|
853
|
+
rnanorm.CUF(factors=uq_factors)()
|
|
854
|
+
.set_output(transform="pandas")
|
|
855
|
+
.fit_transform(counts)
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
elif method == "CTF":
|
|
859
|
+
if verbose:
|
|
860
|
+
print("Performing Counts adjusted with TMM factors (CTF).")
|
|
861
|
+
normalized_counts = (rnanorm.CTF().set_output(transform="pandas").fit_transform(counts))
|
|
862
|
+
|
|
863
|
+
else:
|
|
864
|
+
raise ValueError(f"Unknown normalization method: {method}")
|
|
865
|
+
normalized_counts.columns=columns_org
|
|
866
|
+
if verbose:
|
|
867
|
+
print(f"Normalization complete using method: {method}")
|
|
868
|
+
|
|
869
|
+
return normalized_counts
|
|
870
|
+
|
|
871
|
+
def counts_deseq(counts_sam_gene: pd.DataFrame,
|
|
872
|
+
meta_sam_cond: pd.DataFrame,
|
|
873
|
+
design_factors:list=None,
|
|
874
|
+
kws_DeseqDataSet:dict={},
|
|
875
|
+
kws_DeseqStats:dict={}):
|
|
876
|
+
"""
|
|
877
|
+
Purpose: Performs differential expression analysis using DESeq2 methodology
|
|
878
|
+
Principle: Negative binomial distribution modeling with shrinkage estimation
|
|
879
|
+
Key Operations:
|
|
880
|
+
Creates DeseqDataSet object with design formula
|
|
881
|
+
Estimates size factors and dispersions
|
|
882
|
+
Fits negative binomial models
|
|
883
|
+
Performs Wald tests for significance
|
|
884
|
+
Applies multiple testing correction (Benjamini-Hochberg)
|
|
885
|
+
Output Components:
|
|
886
|
+
dds: Complete DESeq2 dataset object
|
|
887
|
+
diff: Results dataframe with log2FC, p-values, FDR
|
|
888
|
+
stat_res: Statistical results object
|
|
889
|
+
df_norm: Normalized count data
|
|
890
|
+
|
|
891
|
+
https://pydeseq2.readthedocs.io/en/latest/api/docstrings/pydeseq2.ds.DeseqStats.html
|
|
892
|
+
Note: Using normalized expression data in a DeseqDataSet object is generally not recommended
|
|
893
|
+
because the DESeq2 framework is designed to work with raw count data.
|
|
894
|
+
baseMean:
|
|
895
|
+
- This value represents the average normalized count (or expression level) of a
|
|
896
|
+
gene across all samples in dataset.
|
|
897
|
+
- For example, a baseMean of 0.287 for 4933401J01Rik indicates that this gene has
|
|
898
|
+
low expression levels in the samples compared to others with higher baseMean
|
|
899
|
+
values like Xkr4 (591.015).
|
|
900
|
+
log2FoldChange: the magnitude and direction of change in expression between conditions.
|
|
901
|
+
lfcSE (Log Fold Change Standard Error): standard error of the log2FoldChange. It
|
|
902
|
+
indicates the uncertainty in the estimate of the fold change.A lower value indicates
|
|
903
|
+
more confidence in the fold change estimate.
|
|
904
|
+
padj: This value accounts for multiple testing corrections (e.g., Benjamini-Hochberg).
|
|
905
|
+
Log10transforming: The columns -log10(pvalue) and -log10(FDR) are transformations of
|
|
906
|
+
the p-values and adjusted p-values, respectively
|
|
907
|
+
"""
|
|
908
|
+
from pydeseq2.dds import DeseqDataSet
|
|
909
|
+
from pydeseq2.ds import DeseqStats
|
|
910
|
+
from pydeseq2.default_inference import DefaultInference
|
|
911
|
+
|
|
912
|
+
# data filtering
|
|
913
|
+
# counts_sam_gene = counts_sam_gene.loc[:, ~(counts_sam_gene.sum(axis=0) < 10)]
|
|
914
|
+
if design_factors is None:
|
|
915
|
+
design_factors=meta_sam_cond.columns.tolist()
|
|
916
|
+
|
|
917
|
+
kws_DeseqDataSet.pop("design_factors",{})
|
|
918
|
+
refit_cooks=kws_DeseqDataSet.pop("refit_cooks",True)
|
|
919
|
+
|
|
920
|
+
#! DeseqDataSet
|
|
921
|
+
inference = DefaultInference(n_cpus=8)
|
|
922
|
+
dds = DeseqDataSet(
|
|
923
|
+
counts=counts_sam_gene,
|
|
924
|
+
metadata=meta_sam_cond,
|
|
925
|
+
design_factors=meta_sam_cond.columns.tolist(),
|
|
926
|
+
refit_cooks=refit_cooks,
|
|
927
|
+
inference=inference,
|
|
928
|
+
**kws_DeseqDataSet
|
|
929
|
+
)
|
|
930
|
+
dds.deseq2()
|
|
931
|
+
#* results
|
|
932
|
+
dds_explain="""
|
|
933
|
+
res[0]:
|
|
934
|
+
# X stores the count data,
|
|
935
|
+
# obs stores design factors,
|
|
936
|
+
# obsm stores sample-level data, such as "design_matrix" and "size_factors",
|
|
937
|
+
# varm stores gene-level data, such as "dispersions" and "LFC"."""
|
|
938
|
+
print(dds_explain)
|
|
939
|
+
#! DeseqStats
|
|
940
|
+
stat_res = DeseqStats(dds,**kws_DeseqStats)
|
|
941
|
+
stat_res.summary()
|
|
942
|
+
diff = stat_res.results_df.assign(padj=lambda x: x.padj.fillna(1))
|
|
943
|
+
|
|
944
|
+
# handle '0' issue, which will case inf when the later cal (e.g., log10)
|
|
945
|
+
diff["padj"] = diff["padj"].replace(0, 1e-10)
|
|
946
|
+
diff["pvalue"] = diff["pvalue"].replace(0, 1e-10)
|
|
947
|
+
|
|
948
|
+
diff["-log10(pvalue)"] = diff["pvalue"].apply(lambda x: -np.log10(x))
|
|
949
|
+
diff["-log10(FDR)"] = diff["padj"].apply(lambda x: -np.log10(x))
|
|
950
|
+
diff=diff.reset_index().rename(columns={"index": "gene"})
|
|
951
|
+
# sig_diff = (
|
|
952
|
+
# diff.query("log2FoldChange.abs()>0.585 & padj<0.05")
|
|
953
|
+
# .reset_index()
|
|
954
|
+
# .rename(columns={"index": "gene"})
|
|
955
|
+
# )
|
|
956
|
+
df_norm=pd.DataFrame(dds.layers['normed_counts'])
|
|
957
|
+
df_norm.index=counts_sam_gene.index
|
|
958
|
+
df_norm.columns=counts_sam_gene.columns
|
|
959
|
+
print("res[0]: dds\nres[1]:diff\nres[2]:stat_res\nres[3]:df_normalized")
|
|
960
|
+
return dds, diff, stat_res,df_norm
|
|
961
|
+
|
|
962
|
+
def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
|
|
963
|
+
"""
|
|
964
|
+
Purpose: Converts gene identifiers using MyGene.info service
|
|
965
|
+
Principle: Batch query to MyGene.info API for ID conversion and annotation
|
|
966
|
+
Supported: 30+ identifier types and multiple species
|
|
967
|
+
|
|
968
|
+
usage:
|
|
969
|
+
scope_genes(df_counts.columns.tolist()[:1000], species="mouse")
|
|
970
|
+
"""
|
|
971
|
+
import mygene
|
|
972
|
+
|
|
973
|
+
if scopes is None:
|
|
974
|
+
# copy from: https://docs.mygene.info/en/latest/doc/query_service.html#scopes
|
|
975
|
+
scopes = ips.fload(
|
|
976
|
+
"/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/mygenes_fields_241022.txt",
|
|
977
|
+
kind="csv",
|
|
978
|
+
verbose=False,
|
|
979
|
+
)
|
|
980
|
+
scopes = ",".join([i.strip() for i in scopes.iloc[:, 0]])
|
|
981
|
+
mg = mygene.MyGeneInfo()
|
|
982
|
+
results = mg.querymany(
|
|
983
|
+
gene_list,
|
|
984
|
+
scopes=scopes,
|
|
985
|
+
fields=fields,
|
|
986
|
+
species=species,
|
|
987
|
+
)
|
|
988
|
+
return pd.DataFrame(results)
|
|
989
|
+
|
|
990
|
+
def get_enrichr(gene_symbol_list,
|
|
991
|
+
gene_sets:str,
|
|
992
|
+
download:bool = False,
|
|
993
|
+
species='Human',
|
|
994
|
+
dir_save="./",
|
|
995
|
+
plot_=False,
|
|
996
|
+
n_top=30,
|
|
997
|
+
palette=None,
|
|
998
|
+
check_shared=True,
|
|
999
|
+
figsize=(5,8),
|
|
1000
|
+
show_ring=False,
|
|
1001
|
+
xticklabels_rot=0,
|
|
1002
|
+
title=None,# 'KEGG'
|
|
1003
|
+
cutoff=0.05,
|
|
1004
|
+
cmap="coolwarm",
|
|
1005
|
+
size=5,
|
|
1006
|
+
**kwargs):
|
|
1007
|
+
"""
|
|
1008
|
+
Purpose: Performs over-representation analysis using Enrichr database
|
|
1009
|
+
Principle: Hypergeometric test for gene set enrichment
|
|
1010
|
+
Key Operations:
|
|
1011
|
+
Interfaces with gseapy Enrichr API
|
|
1012
|
+
Supports 180+ predefined gene sets
|
|
1013
|
+
Provides multiple visualization options (barplot, dotplot)
|
|
1014
|
+
Handles species-specific gene symbols
|
|
1015
|
+
Visualization: Ranked bar plots and dot plots showing significance and effect size
|
|
1016
|
+
|
|
1017
|
+
Note: Enrichr uses a list of Entrez gene symbols as input.
|
|
1018
|
+
|
|
1019
|
+
"""
|
|
1020
|
+
kws_figsets = {}
|
|
1021
|
+
for k_arg, v_arg in kwargs.items():
|
|
1022
|
+
if "figset" in k_arg:
|
|
1023
|
+
kws_figsets = v_arg
|
|
1024
|
+
kwargs.pop(k_arg, None)
|
|
1025
|
+
break
|
|
1026
|
+
species_org=species
|
|
1027
|
+
# organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
|
|
1028
|
+
organisms=['Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm']
|
|
1029
|
+
species=ips.strcmp(species,organisms)[0]
|
|
1030
|
+
if species_org.lower()!= species.lower():
|
|
1031
|
+
print(f"species was corrected to {species}, becasue only support {organisms}")
|
|
1032
|
+
if os.path.isfile(gene_sets):
|
|
1033
|
+
gene_sets_name=os.path.basename(gene_sets)
|
|
1034
|
+
gene_sets = ips.fload(gene_sets)
|
|
1035
|
+
else:
|
|
1036
|
+
lib_support_names = gp.get_library_name()
|
|
1037
|
+
# correct input gene_set name
|
|
1038
|
+
gene_sets_name=ips.strcmp(gene_sets,lib_support_names)[0]
|
|
1039
|
+
|
|
1040
|
+
# download it
|
|
1041
|
+
if download:
|
|
1042
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
|
1043
|
+
else:
|
|
1044
|
+
gene_sets = gene_sets_name # 避免重复下载
|
|
1045
|
+
print(f"\ngene_sets get ready: {gene_sets_name}")
|
|
1046
|
+
|
|
1047
|
+
# gene symbols are uppercase
|
|
1048
|
+
gene_symbol_list=[str(i).upper() for i in gene_symbol_list]
|
|
1049
|
+
|
|
1050
|
+
# # check how shared genes
|
|
1051
|
+
if check_shared and isinstance(gene_sets, dict):
|
|
1052
|
+
shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False),
|
|
1053
|
+
ips.flatten(gene_sets,verbose=False),
|
|
1054
|
+
verbose=False)
|
|
1055
|
+
|
|
1056
|
+
#! enrichr
|
|
1057
|
+
try:
|
|
1058
|
+
enr = gp.enrichr(
|
|
1059
|
+
gene_list=gene_symbol_list,
|
|
1060
|
+
gene_sets=gene_sets,
|
|
1061
|
+
organism=species,
|
|
1062
|
+
outdir=None, # don't write to disk
|
|
1063
|
+
**kwargs
|
|
1064
|
+
)
|
|
1065
|
+
except ValueError as e:
|
|
1066
|
+
print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
|
|
1067
|
+
return None
|
|
1068
|
+
|
|
1069
|
+
results_df = enr.results
|
|
1070
|
+
print(f"got enrichr reslutls; shape: {results_df.shape}\n")
|
|
1071
|
+
results_df["-log10(Adjusted P-value)"] = -np.log10(results_df["Adjusted P-value"])
|
|
1072
|
+
results_df.sort_values("-log10(Adjusted P-value)", inplace=True, ascending=False)
|
|
1073
|
+
|
|
1074
|
+
if plot_:
|
|
1075
|
+
if palette is None:
|
|
1076
|
+
palette=plot.get_color(n_top, cmap=cmap)[::-1]
|
|
1077
|
+
#! barplot
|
|
1078
|
+
if n_top<5:
|
|
1079
|
+
height_=4
|
|
1080
|
+
elif 5<=n_top<10:
|
|
1081
|
+
height_=5
|
|
1082
|
+
elif 5<=n_top<10:
|
|
1083
|
+
height_=6
|
|
1084
|
+
elif 10<=n_top<15:
|
|
1085
|
+
height_=7
|
|
1086
|
+
elif 15<=n_top<20:
|
|
1087
|
+
height_=8
|
|
1088
|
+
elif 20<=n_top<30:
|
|
1089
|
+
height_=9
|
|
1090
|
+
else:
|
|
1091
|
+
height_=int(n_top/3)
|
|
1092
|
+
plt.figure(figsize=[10, height_])
|
|
1093
|
+
|
|
1094
|
+
ax1=plot.plotxy(
|
|
1095
|
+
data=results_df.head(n_top),
|
|
1096
|
+
kind_="barplot",
|
|
1097
|
+
x="-log10(Adjusted P-value)",
|
|
1098
|
+
y="Term",
|
|
1099
|
+
hue="Term",
|
|
1100
|
+
palette=palette,
|
|
1101
|
+
legend=None,
|
|
1102
|
+
)
|
|
1103
|
+
plot.figsets(ax=ax1, **kws_figsets)
|
|
1104
|
+
if dir_save:
|
|
1105
|
+
ips.figsave(f"{dir_save} enr_barplot.pdf")
|
|
1106
|
+
plt.show()
|
|
1107
|
+
|
|
1108
|
+
#! dotplot
|
|
1109
|
+
cutoff_curr = cutoff
|
|
1110
|
+
step=0.05
|
|
1111
|
+
cutoff_stop = 0.5
|
|
1112
|
+
while cutoff_curr <= cutoff_stop:
|
|
1113
|
+
try:
|
|
1114
|
+
if cutoff_curr!=cutoff:
|
|
1115
|
+
plt.clf()
|
|
1116
|
+
ax2 = gp.dotplot(enr.res2d,
|
|
1117
|
+
column="Adjusted P-value",
|
|
1118
|
+
show_ring=show_ring,
|
|
1119
|
+
xticklabels_rot=xticklabels_rot,
|
|
1120
|
+
title=title,
|
|
1121
|
+
cmap=cmap,
|
|
1122
|
+
cutoff=cutoff_curr,
|
|
1123
|
+
top_term=n_top,
|
|
1124
|
+
size=size,
|
|
1125
|
+
figsize=[10, height_])
|
|
1126
|
+
if len(ax2.collections)>=n_top:
|
|
1127
|
+
print(f"cutoff={cutoff_curr} done! ")
|
|
1128
|
+
break
|
|
1129
|
+
if cutoff_curr==cutoff_stop:
|
|
1130
|
+
break
|
|
1131
|
+
cutoff_curr+=step
|
|
1132
|
+
except Exception as e:
|
|
1133
|
+
cutoff_curr+=step
|
|
1134
|
+
print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
|
|
1135
|
+
ax = plt.gca()
|
|
1136
|
+
plot.figsets(ax=ax,**kws_figsets)
|
|
1137
|
+
|
|
1138
|
+
if dir_save:
|
|
1139
|
+
ips.figsave(f"{dir_save}enr_dotplot.pdf")
|
|
1140
|
+
|
|
1141
|
+
return results_df
|
|
1142
|
+
|
|
1143
|
+
def plot_enrichr(results_df,
|
|
1144
|
+
kind="bar",# 'barplot', 'dotplot'
|
|
1145
|
+
cutoff=0.05,
|
|
1146
|
+
show_ring=False,
|
|
1147
|
+
xticklabels_rot=0,
|
|
1148
|
+
title=None,# 'KEGG'
|
|
1149
|
+
cmap="coolwarm",
|
|
1150
|
+
n_top=10,
|
|
1151
|
+
size=5,
|
|
1152
|
+
ax=None,
|
|
1153
|
+
**kwargs):
|
|
1154
|
+
"""
|
|
1155
|
+
Purpose: Flexible visualization of enrichment results
|
|
1156
|
+
Plot Types:
|
|
1157
|
+
Bar plots: -log10(p-value) for top terms
|
|
1158
|
+
Dot plots: Combined visualization of p-value and gene ratio
|
|
1159
|
+
Count plots: Number of overlapping genes
|
|
1160
|
+
Customization: Color schemes, term number, significance thresholds
|
|
1161
|
+
"""
|
|
1162
|
+
kws_figsets = {}
|
|
1163
|
+
for k_arg, v_arg in kwargs.items():
|
|
1164
|
+
if "figset" in k_arg:
|
|
1165
|
+
kws_figsets = v_arg
|
|
1166
|
+
kwargs.pop(k_arg, None)
|
|
1167
|
+
break
|
|
1168
|
+
if isinstance(cmap,str):
|
|
1169
|
+
palette = plot.get_color(n_top, cmap=cmap)[::-1]
|
|
1170
|
+
elif isinstance(cmap,list):
|
|
1171
|
+
palette=cmap
|
|
1172
|
+
if n_top < 5:
|
|
1173
|
+
height_ = 3
|
|
1174
|
+
elif 5 <= n_top < 10:
|
|
1175
|
+
height_ = 3
|
|
1176
|
+
elif 10 <= n_top < 15:
|
|
1177
|
+
height_ = 3
|
|
1178
|
+
elif 15 <= n_top < 20:
|
|
1179
|
+
height_ =4
|
|
1180
|
+
elif 20 <= n_top < 30:
|
|
1181
|
+
height_ = 5
|
|
1182
|
+
elif 30 <= n_top < 40:
|
|
1183
|
+
height_ = int(n_top / 6)
|
|
1184
|
+
else:
|
|
1185
|
+
height_ = int(n_top / 8)
|
|
1186
|
+
|
|
1187
|
+
#! barplot
|
|
1188
|
+
if 'bar' in kind.lower():
|
|
1189
|
+
if ax is None:
|
|
1190
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
|
1191
|
+
ax=plot.plotxy(
|
|
1192
|
+
data=results_df.head(n_top),
|
|
1193
|
+
kind_="barplot",
|
|
1194
|
+
x="-log10(Adjusted P-value)",
|
|
1195
|
+
y="Term",
|
|
1196
|
+
hue="Term",
|
|
1197
|
+
palette=palette,
|
|
1198
|
+
legend=None,
|
|
1199
|
+
)
|
|
1200
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1201
|
+
return ax,results_df
|
|
1202
|
+
|
|
1203
|
+
#! dotplot
|
|
1204
|
+
elif 'dot' in kind.lower():
|
|
1205
|
+
#! dotplot
|
|
1206
|
+
cutoff_curr = cutoff
|
|
1207
|
+
step=0.05
|
|
1208
|
+
cutoff_stop = 0.5
|
|
1209
|
+
while cutoff_curr <= cutoff_stop:
|
|
1210
|
+
try:
|
|
1211
|
+
if cutoff_curr!=cutoff:
|
|
1212
|
+
plt.clf()
|
|
1213
|
+
ax = gp.dotplot(results_df,
|
|
1214
|
+
column="Adjusted P-value",
|
|
1215
|
+
show_ring=show_ring,
|
|
1216
|
+
xticklabels_rot=xticklabels_rot,
|
|
1217
|
+
title=title,
|
|
1218
|
+
cmap=cmap,
|
|
1219
|
+
cutoff=cutoff_curr,
|
|
1220
|
+
top_term=n_top,
|
|
1221
|
+
size=size,
|
|
1222
|
+
figsize=[10, height_])
|
|
1223
|
+
if len(ax.collections)>=n_top:
|
|
1224
|
+
print(f"cutoff={cutoff_curr} done! ")
|
|
1225
|
+
break
|
|
1226
|
+
if cutoff_curr==cutoff_stop:
|
|
1227
|
+
break
|
|
1228
|
+
cutoff_curr+=step
|
|
1229
|
+
except Exception as e:
|
|
1230
|
+
cutoff_curr+=step
|
|
1231
|
+
print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
|
|
1232
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1233
|
+
return ax,results_df
|
|
1234
|
+
|
|
1235
|
+
#! barplot with counts
|
|
1236
|
+
elif 'count' in kind.lower():
|
|
1237
|
+
if ax is None:
|
|
1238
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
|
1239
|
+
# 从overlap中提取出个数
|
|
1240
|
+
results_df["Count"] = results_df["Overlap"].apply(
|
|
1241
|
+
lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
|
|
1242
|
+
df_=results_df.sort_values(by="Count", ascending=False)
|
|
1243
|
+
|
|
1244
|
+
ax=plot.plotxy(
|
|
1245
|
+
data=df_.head(n_top),
|
|
1246
|
+
kind_="barplot",
|
|
1247
|
+
x="Count",
|
|
1248
|
+
y="Term",
|
|
1249
|
+
hue="Term",
|
|
1250
|
+
palette=palette,
|
|
1251
|
+
legend=None,
|
|
1252
|
+
ax=ax
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1256
|
+
return ax,df_
|
|
1257
|
+
|
|
1258
|
+
def plot_bp_cc_mf(
|
|
1259
|
+
deg_gene_list,
|
|
1260
|
+
gene_sets=[
|
|
1261
|
+
"GO_Biological_Process_2023",
|
|
1262
|
+
"GO_Cellular_Component_2023",
|
|
1263
|
+
"GO_Molecular_Function_2023",
|
|
1264
|
+
],
|
|
1265
|
+
species="human",
|
|
1266
|
+
download=False,
|
|
1267
|
+
n_top=10,
|
|
1268
|
+
plot_=True,
|
|
1269
|
+
ax=None,
|
|
1270
|
+
palette=plot.get_color(3,"colorblind6"),
|
|
1271
|
+
**kwargs,
|
|
1272
|
+
):
|
|
1273
|
+
"""
|
|
1274
|
+
Purpose: Integrated visualization of Gene Ontology (BP, CC, MF) enrichment
|
|
1275
|
+
Principle: Combines results from three GO domains into unified plot
|
|
1276
|
+
Usage: Comprehensive functional profiling of gene lists
|
|
1277
|
+
"""
|
|
1278
|
+
def res_enrichr_2_count(res_enrichr, n_top=10):
|
|
1279
|
+
"""把enrich resulst 提取出count,并排序"""
|
|
1280
|
+
res_enrichr["Count"] = res_enrichr["Overlap"].apply(
|
|
1281
|
+
lambda x: int(x.split("/")[0]) if isinstance(x, str) else x
|
|
1282
|
+
)
|
|
1283
|
+
res_enrichr.sort_values(by="Count", ascending=False, inplace=True)
|
|
1284
|
+
|
|
1285
|
+
return res_enrichr.head(n_top)#[["Term", "Count"]]
|
|
1286
|
+
|
|
1287
|
+
res_enrichr_BP = get_enrichr(
|
|
1288
|
+
deg_gene_list, gene_sets[0], species=species, plot_=False,download=download
|
|
1289
|
+
)
|
|
1290
|
+
res_enrichr_CC = get_enrichr(
|
|
1291
|
+
deg_gene_list, gene_sets[1], species=species, plot_=False,download=download
|
|
1292
|
+
)
|
|
1293
|
+
res_enrichr_MF = get_enrichr(
|
|
1294
|
+
deg_gene_list, gene_sets[2], species=species, plot_=False,download=download
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
|
|
1298
|
+
df_BP["Ontology"] = ["Biological Process"] * n_top
|
|
1299
|
+
|
|
1300
|
+
df_CC = res_enrichr_2_count(res_enrichr_CC, n_top=n_top)
|
|
1301
|
+
df_CC["Ontology"] = ["Cellular Component"] * n_top
|
|
1302
|
+
|
|
1303
|
+
df_MF = res_enrichr_2_count(res_enrichr_MF, n_top=n_top)
|
|
1304
|
+
df_MF["Ontology"] = ["Molecular Function"] * n_top
|
|
1305
|
+
|
|
1306
|
+
# 合并
|
|
1307
|
+
df2plot = pd.concat([df_BP, df_CC, df_MF])
|
|
1308
|
+
n_top=n_top*3
|
|
1309
|
+
if n_top < 5:
|
|
1310
|
+
height_ = 4
|
|
1311
|
+
elif 5 <= n_top < 10:
|
|
1312
|
+
height_ = 5
|
|
1313
|
+
elif 10 <= n_top < 15:
|
|
1314
|
+
height_ = 6
|
|
1315
|
+
elif 15 <= n_top < 20:
|
|
1316
|
+
height_ = 7
|
|
1317
|
+
elif 20 <= n_top < 30:
|
|
1318
|
+
height_ = 8
|
|
1319
|
+
elif 30 <= n_top < 40:
|
|
1320
|
+
height_ = int(n_top / 4)
|
|
1321
|
+
else:
|
|
1322
|
+
height_ = int(n_top / 5)
|
|
1323
|
+
if ax is None:
|
|
1324
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
|
1325
|
+
# 作图
|
|
1326
|
+
display(df2plot)
|
|
1327
|
+
if df2plot["Term"].tolist()[0].endswith(")"):
|
|
1328
|
+
df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
|
|
1329
|
+
if plot_:
|
|
1330
|
+
ax = plot.plotxy(
|
|
1331
|
+
data=df2plot,
|
|
1332
|
+
x="Count",
|
|
1333
|
+
y="Term",
|
|
1334
|
+
hue="Ontology",
|
|
1335
|
+
kind_="bar",
|
|
1336
|
+
palette=palette,
|
|
1337
|
+
ax=ax,
|
|
1338
|
+
**kwargs
|
|
1339
|
+
)
|
|
1340
|
+
return ax, df2plot
|
|
1341
|
+
|
|
1342
|
+
def get_library_name(by=None, verbose=False):
|
|
1343
|
+
"""
|
|
1344
|
+
Purpose: Retrieves available gene set libraries from Enrichr
|
|
1345
|
+
Principle: Queries gseapy for current library availability
|
|
1346
|
+
Usage: Discovery of available pathway databases
|
|
1347
|
+
"""
|
|
1348
|
+
lib_names=gp.get_library_name()
|
|
1349
|
+
if by is None:
|
|
1350
|
+
if verbose:
|
|
1351
|
+
[print(i) for i in lib_names]
|
|
1352
|
+
return lib_names
|
|
1353
|
+
else:
|
|
1354
|
+
return ips.flatten(ips.strcmp(by, lib_names, get_rank=True,verbose=verbose),verbose=verbose)
|
|
1355
|
+
|
|
1356
|
+
|
|
1357
|
+
def get_gsva(
|
|
1358
|
+
data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
|
|
1359
|
+
gene_sets: str,
|
|
1360
|
+
species:str="Human",
|
|
1361
|
+
dir_save:str="./",
|
|
1362
|
+
plot_:bool=False,
|
|
1363
|
+
n_top:int=30,
|
|
1364
|
+
check_shared:bool=True,
|
|
1365
|
+
cmap="coolwarm",
|
|
1366
|
+
min_size=1,
|
|
1367
|
+
max_size=1000,
|
|
1368
|
+
kcdf="Gaussian",# 'Gaussian' for continuous data
|
|
1369
|
+
method='gsva',
|
|
1370
|
+
seed=1,
|
|
1371
|
+
**kwargs,
|
|
1372
|
+
):
|
|
1373
|
+
"""
|
|
1374
|
+
Purpose: Gene Set Variation Analysis - estimates pathway activity per samplePrinciple: Non-parametric unsupervised method for estimating pathway enrichmentKey Operations:
|
|
1375
|
+
• Calculates enrichment scores for each sample
|
|
1376
|
+
• Handles continuous expression data (Gaussian kernel)
|
|
1377
|
+
• Supports custom and predefined gene setsOutput: Sample-by-pathway activity matrix
|
|
1378
|
+
"""
|
|
1379
|
+
kws_figsets = {}
|
|
1380
|
+
for k_arg, v_arg in kwargs.items():
|
|
1381
|
+
if "figset" in k_arg:
|
|
1382
|
+
kws_figsets = v_arg
|
|
1383
|
+
kwargs.pop(k_arg, None)
|
|
1384
|
+
break
|
|
1385
|
+
species_org = species
|
|
1386
|
+
# organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
|
|
1387
|
+
organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
|
|
1388
|
+
species = ips.strcmp(species, organisms)[0]
|
|
1389
|
+
if species_org.lower() != species.lower():
|
|
1390
|
+
print(f"species was corrected to {species}, becasue only support {organisms}")
|
|
1391
|
+
if os.path.isfile(gene_sets):
|
|
1392
|
+
gene_sets_name = os.path.basename(gene_sets)
|
|
1393
|
+
gene_sets = ips.fload(gene_sets)
|
|
1394
|
+
else:
|
|
1395
|
+
lib_support_names = gp.get_library_name()
|
|
1396
|
+
# correct input gene_set name
|
|
1397
|
+
gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
|
|
1398
|
+
# download it
|
|
1399
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
|
1400
|
+
print(f"gene_sets get ready: {gene_sets_name}")
|
|
1401
|
+
|
|
1402
|
+
# gene symbols are uppercase
|
|
1403
|
+
gene_symbol_list = [str(i).upper() for i in data_gene_samples.index]
|
|
1404
|
+
data_gene_samples.index=gene_symbol_list
|
|
1405
|
+
# display(data_gene_samples.head(3))
|
|
1406
|
+
# # check how shared genes
|
|
1407
|
+
if check_shared:
|
|
1408
|
+
ips.shared(
|
|
1409
|
+
ips.flatten(gene_symbol_list, verbose=False),
|
|
1410
|
+
ips.flatten(gene_sets, verbose=False),
|
|
1411
|
+
verbose=False
|
|
1412
|
+
)
|
|
1413
|
+
gsva_results = gp.gsva(
|
|
1414
|
+
data=data_gene_samples, # matrix should have genes as rows and samples as columns
|
|
1415
|
+
gene_sets=gene_sets,
|
|
1416
|
+
outdir=None,
|
|
1417
|
+
kcdf=kcdf, # 'Gaussian' for continuous data
|
|
1418
|
+
min_size=min_size,
|
|
1419
|
+
method=method,
|
|
1420
|
+
max_size=max_size,
|
|
1421
|
+
verbose=True,
|
|
1422
|
+
seed=seed,
|
|
1423
|
+
# no_plot=False,
|
|
1424
|
+
)
|
|
1425
|
+
gsva_res = gsva_results.res2d.copy()
|
|
1426
|
+
gsva_res["ES_abs"] = gsva_res["ES"].apply(np.abs)
|
|
1427
|
+
gsva_res = gsva_res.sort_values(by="ES_abs", ascending=False)
|
|
1428
|
+
gsva_res = (
|
|
1429
|
+
gsva_res.drop_duplicates(subset="Term").drop(columns="ES_abs")
|
|
1430
|
+
# .iloc[:80, :]
|
|
1431
|
+
.reset_index(drop=True)
|
|
1432
|
+
)
|
|
1433
|
+
gsva_res = gsva_res.sort_values(by="ES", ascending=False)
|
|
1434
|
+
if plot_:
|
|
1435
|
+
if gsva_res.shape[0]>=2*n_top:
|
|
1436
|
+
gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
|
|
1437
|
+
else:
|
|
1438
|
+
gsva_res_plot = gsva_res
|
|
1439
|
+
if isinstance(cmap,str):
|
|
1440
|
+
palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
|
|
1441
|
+
elif isinstance(cmap,list):
|
|
1442
|
+
if len(cmap)==2:
|
|
1443
|
+
palette = [cmap[0]]*n_top+[cmap[1]]*n_top
|
|
1444
|
+
else:
|
|
1445
|
+
palette=cmap
|
|
1446
|
+
# ! barplot
|
|
1447
|
+
if n_top < 5:
|
|
1448
|
+
height_ = 3
|
|
1449
|
+
elif 5 <= n_top < 10:
|
|
1450
|
+
height_ = 4
|
|
1451
|
+
elif 10 <= n_top < 15:
|
|
1452
|
+
height_ = 5
|
|
1453
|
+
elif 15 <= n_top < 20:
|
|
1454
|
+
height_ = 6
|
|
1455
|
+
elif 20 <= n_top < 30:
|
|
1456
|
+
height_ = 7
|
|
1457
|
+
elif 30 <= n_top < 40:
|
|
1458
|
+
height_ = int(n_top / 3.5)
|
|
1459
|
+
else:
|
|
1460
|
+
height_ = int(n_top / 3)
|
|
1461
|
+
plt.figure(figsize=[10, height_])
|
|
1462
|
+
ax2 = plot.plotxy(
|
|
1463
|
+
data=gsva_res_plot,
|
|
1464
|
+
x="ES",
|
|
1465
|
+
y="Term",
|
|
1466
|
+
hue="Term",
|
|
1467
|
+
palette=palette,
|
|
1468
|
+
kind_=["bar"],
|
|
1469
|
+
figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
|
|
1470
|
+
)
|
|
1471
|
+
# 改变labels的位置
|
|
1472
|
+
for i, bar in enumerate(ax2.patches):
|
|
1473
|
+
term = gsva_res_plot.iloc[i]["Term"]
|
|
1474
|
+
es_value = gsva_res_plot.iloc[i]["ES"]
|
|
1475
|
+
|
|
1476
|
+
# Positive ES values: Align y-labels to the left
|
|
1477
|
+
if es_value > 0:
|
|
1478
|
+
ax2.annotate(
|
|
1479
|
+
term,
|
|
1480
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
|
1481
|
+
xytext=(-5, 0), # Move to the left
|
|
1482
|
+
textcoords="offset points",
|
|
1483
|
+
ha="right",
|
|
1484
|
+
va="center", # Align labels to the right
|
|
1485
|
+
fontsize=10,
|
|
1486
|
+
color="black",
|
|
1487
|
+
)
|
|
1488
|
+
# Negative ES values: Align y-labels to the right
|
|
1489
|
+
else:
|
|
1490
|
+
ax2.annotate(
|
|
1491
|
+
term,
|
|
1492
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
|
1493
|
+
xytext=(5, 0), # Move to the right
|
|
1494
|
+
textcoords="offset points",
|
|
1495
|
+
ha="left",
|
|
1496
|
+
va="center", # Align labels to the left
|
|
1497
|
+
fontsize=10,
|
|
1498
|
+
color="black",
|
|
1499
|
+
)
|
|
1500
|
+
plot.figsets(ax=ax2, **kws_figsets)
|
|
1501
|
+
if dir_save:
|
|
1502
|
+
ips.figsave(dir_save + f"GSVA_{gene_sets_name}.pdf")
|
|
1503
|
+
plt.show()
|
|
1504
|
+
return gsva_res.reset_index(drop=True)
|
|
1505
|
+
|
|
1506
|
+
def plot_gsva(gsva_res, # output from bio.get_gsva()
|
|
1507
|
+
n_top=10,
|
|
1508
|
+
ax=None,
|
|
1509
|
+
x="ES",
|
|
1510
|
+
y="Term",
|
|
1511
|
+
hue="Term",
|
|
1512
|
+
cmap="coolwarm",
|
|
1513
|
+
**kwargs
|
|
1514
|
+
):
|
|
1515
|
+
kws_figsets = {}
|
|
1516
|
+
for k_arg, v_arg in kwargs.items():
|
|
1517
|
+
if "figset" in k_arg:
|
|
1518
|
+
kws_figsets = v_arg
|
|
1519
|
+
kwargs.pop(k_arg, None)
|
|
1520
|
+
break
|
|
1521
|
+
# ! barplot
|
|
1522
|
+
if n_top < 5:
|
|
1523
|
+
height_ = 4
|
|
1524
|
+
elif 5 <= n_top < 10:
|
|
1525
|
+
height_ = 5
|
|
1526
|
+
elif 10 <= n_top < 15:
|
|
1527
|
+
height_ = 6
|
|
1528
|
+
elif 15 <= n_top < 20:
|
|
1529
|
+
height_ = 7
|
|
1530
|
+
elif 20 <= n_top < 30:
|
|
1531
|
+
height_ = 8
|
|
1532
|
+
elif 30 <= n_top < 40:
|
|
1533
|
+
height_ = int(n_top / 3.5)
|
|
1534
|
+
else:
|
|
1535
|
+
height_ = int(n_top / 3)
|
|
1536
|
+
if ax is None:
|
|
1537
|
+
_,ax=plt.subplots(1,1,figsize=[10, height_])
|
|
1538
|
+
gsva_res = gsva_res.sort_values(by=x, ascending=False)
|
|
1539
|
+
|
|
1540
|
+
if gsva_res.shape[0]>=2*n_top:
|
|
1541
|
+
gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
|
|
1542
|
+
else:
|
|
1543
|
+
gsva_res_plot = gsva_res
|
|
1544
|
+
if isinstance(cmap,str):
|
|
1545
|
+
palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
|
|
1546
|
+
elif isinstance(cmap,list):
|
|
1547
|
+
if len(cmap)==2:
|
|
1548
|
+
palette = [cmap[0]]*n_top+[cmap[1]]*n_top
|
|
1549
|
+
else:
|
|
1550
|
+
palette=cmap
|
|
1551
|
+
|
|
1552
|
+
ax = plot.plotxy(
|
|
1553
|
+
ax=ax,
|
|
1554
|
+
data=gsva_res_plot,
|
|
1555
|
+
x=x,
|
|
1556
|
+
y=y,
|
|
1557
|
+
hue=hue,
|
|
1558
|
+
palette=palette,
|
|
1559
|
+
kind_=["bar"],
|
|
1560
|
+
figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
|
|
1561
|
+
)
|
|
1562
|
+
# 改变labels的位置
|
|
1563
|
+
for i, bar in enumerate(ax.patches):
|
|
1564
|
+
term = gsva_res_plot.iloc[i]["Term"]
|
|
1565
|
+
es_value = gsva_res_plot.iloc[i]["ES"]
|
|
1566
|
+
|
|
1567
|
+
# Positive ES values: Align y-labels to the left
|
|
1568
|
+
if es_value > 0:
|
|
1569
|
+
ax.annotate(
|
|
1570
|
+
term,
|
|
1571
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
|
1572
|
+
xytext=(-5, 0), # Move to the left
|
|
1573
|
+
textcoords="offset points",
|
|
1574
|
+
ha="right",
|
|
1575
|
+
va="center", # Align labels to the right
|
|
1576
|
+
fontsize=10,
|
|
1577
|
+
color="black",
|
|
1578
|
+
)
|
|
1579
|
+
# Negative ES values: Align y-labels to the right
|
|
1580
|
+
else:
|
|
1581
|
+
ax.annotate(
|
|
1582
|
+
term,
|
|
1583
|
+
xy=(0, bar.get_y() + bar.get_height() / 2),
|
|
1584
|
+
xytext=(5, 0), # Move to the right
|
|
1585
|
+
textcoords="offset points",
|
|
1586
|
+
ha="left",
|
|
1587
|
+
va="center", # Align labels to the left
|
|
1588
|
+
fontsize=10,
|
|
1589
|
+
color="black",
|
|
1590
|
+
)
|
|
1591
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1592
|
+
return ax
|
|
1593
|
+
|
|
1594
|
+
def get_prerank(
|
|
1595
|
+
rnk: pd.DataFrame,
|
|
1596
|
+
gene_sets: str,
|
|
1597
|
+
download: bool = False,
|
|
1598
|
+
species="Human",
|
|
1599
|
+
threads=8, # Number of CPU cores to use
|
|
1600
|
+
permutation_num=1000, # Number of permutations for significance
|
|
1601
|
+
min_size=15, # Minimum allowed number of genes from gene set also the data set. Default: 15
|
|
1602
|
+
max_size=500, # Maximum allowed number of genes from gene set also the data set. Defaults: 500.
|
|
1603
|
+
weight=1.0,# – Refer to algorithm.enrichment_score(). Default:1.
|
|
1604
|
+
ascending=False, #Sorting order of rankings. Default: False for descending. If None, do not sort the ranking.
|
|
1605
|
+
seed=1, # Seed for reproducibility
|
|
1606
|
+
verbose=True, # Verbosity
|
|
1607
|
+
dir_save="./",
|
|
1608
|
+
plot_=False,
|
|
1609
|
+
n_top=7,# only for plot
|
|
1610
|
+
size=5,
|
|
1611
|
+
figsize=(3,4),
|
|
1612
|
+
cutoff=0.25,
|
|
1613
|
+
show_ring=False,
|
|
1614
|
+
cmap="coolwarm",
|
|
1615
|
+
check_shared=True,
|
|
1616
|
+
**kwargs,
|
|
1617
|
+
):
|
|
1618
|
+
"""
|
|
1619
|
+
Purpose: Pre-ranked Gene Set Enrichment Analysis (GSEA)
|
|
1620
|
+
Principle: Kolmogorov-Smirnov like statistic applied to pre-ranked gene list
|
|
1621
|
+
Key Operations:
|
|
1622
|
+
Uses precomputed rankings (e.g., from DESeq2)
|
|
1623
|
+
Permutation testing for significance
|
|
1624
|
+
Identifies enriched gene sets at top and bottom of ranking
|
|
1625
|
+
Visualization: Enrichment plots, network diagrams, dot plots
|
|
1626
|
+
|
|
1627
|
+
Note: Enrichr uses a list of Entrez gene symbols as input.
|
|
1628
|
+
"""
|
|
1629
|
+
kws_figsets = {}
|
|
1630
|
+
for k_arg, v_arg in kwargs.items():
|
|
1631
|
+
if "figset" in k_arg:
|
|
1632
|
+
kws_figsets = kwargs.pop(k_arg)
|
|
1633
|
+
break
|
|
1634
|
+
species_org = species
|
|
1635
|
+
# organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
|
|
1636
|
+
organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
|
|
1637
|
+
species = ips.strcmp(species, organisms)[0]
|
|
1638
|
+
print(f"Please confirm sample species = '{species}', if not, select one from {organisms}")
|
|
1639
|
+
if isinstance(gene_sets, str) and os.path.isfile(gene_sets) :
|
|
1640
|
+
gene_sets_name = os.path.basename(gene_sets)
|
|
1641
|
+
gene_sets = ips.fload(gene_sets)
|
|
1642
|
+
else:
|
|
1643
|
+
lib_support_names = gp.get_library_name()
|
|
1644
|
+
# correct input gene_set name
|
|
1645
|
+
gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
|
|
1646
|
+
|
|
1647
|
+
# download it
|
|
1648
|
+
if download:
|
|
1649
|
+
gene_sets = gp.get_library(name=gene_sets_name, organism=species)
|
|
1650
|
+
else:
|
|
1651
|
+
gene_sets = gene_sets_name # 避免重复下载
|
|
1652
|
+
print(f"\ngene_sets get ready: {gene_sets_name}")
|
|
1653
|
+
|
|
1654
|
+
#! prerank
|
|
1655
|
+
try:
|
|
1656
|
+
# https://gseapy.readthedocs.io/en/latest/_modules/gseapy.html#prerank
|
|
1657
|
+
pre_res = gp.prerank(
|
|
1658
|
+
rnk=rnk,
|
|
1659
|
+
gene_sets=gene_sets,
|
|
1660
|
+
threads=threads, # Number of CPU cores to use
|
|
1661
|
+
permutation_num=permutation_num, # Number of permutations for significance
|
|
1662
|
+
min_size=min_size, # Minimum gene set size
|
|
1663
|
+
max_size=max_size, # Maximum gene set size
|
|
1664
|
+
weight=weight,# – Refer to algorithm.enrichment_score(). Default:1.
|
|
1665
|
+
ascending=ascending, #Sorting order of rankings. Default: False for descending. If None, do not sort the ranking.
|
|
1666
|
+
seed=seed, # Seed for reproducibility
|
|
1667
|
+
verbose=verbose, # Verbosity
|
|
1668
|
+
)
|
|
1669
|
+
except ValueError as e:
|
|
1670
|
+
print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}Jeff,check the rnk format; set 'gene name' as index, and only keep the 'score' column. This is the error message: \n{e}\n{'!'*10} Error {'!'*10}")
|
|
1671
|
+
return None
|
|
1672
|
+
df_prerank = pre_res.res2d
|
|
1673
|
+
if plot_:
|
|
1674
|
+
#! gseaplot
|
|
1675
|
+
# # (1) easy way
|
|
1676
|
+
# terms = df_prerank.Term
|
|
1677
|
+
# axs = pre_res.plot(terms=terms[0])
|
|
1678
|
+
# (2) # to make more control on the plot, use
|
|
1679
|
+
terms = df_prerank.Term
|
|
1680
|
+
axs = pre_res.plot(
|
|
1681
|
+
terms=terms[:n_top],
|
|
1682
|
+
# legend_kws={"loc": (1.2, 0)}, # set the legend loc
|
|
1683
|
+
# show_ranking=True, # whether to show the second yaxis
|
|
1684
|
+
figsize=(min(figsize),max(figsize)),
|
|
1685
|
+
)
|
|
1686
|
+
f_name_tmp=str(gene_sets)[:20] if len(str(gene_sets))>=20 else str(gene_sets)
|
|
1687
|
+
ips.figsave(dir_save + f"prerank_gseaplot_{f_name_tmp}.pdf")
|
|
1688
|
+
|
|
1689
|
+
## plot single prerank
|
|
1690
|
+
terms_ = pre_res.res2d.Term
|
|
1691
|
+
try:
|
|
1692
|
+
for i in range(n_top*2):
|
|
1693
|
+
axs_ = pre_res.plot(terms=terms_[i])
|
|
1694
|
+
ips.figsave(os.path.join(ips.mkdir(dir_save, "fig_prerank_single"),f"Top_{str(i+1)}_{terms_[i].replace("/","_")}.pdf"))
|
|
1695
|
+
except Exception as e:
|
|
1696
|
+
print(e)
|
|
1697
|
+
|
|
1698
|
+
#!dotplot
|
|
1699
|
+
from gseapy import dotplot
|
|
1700
|
+
|
|
1701
|
+
# to save figure, make sure that ``ofname`` is not None
|
|
1702
|
+
ax = dotplot(
|
|
1703
|
+
df_prerank,
|
|
1704
|
+
column="NOM p-val", # FDR q-val",
|
|
1705
|
+
cmap=cmap,
|
|
1706
|
+
size=size,
|
|
1707
|
+
figsize=(max(figsize),min(figsize)),
|
|
1708
|
+
cutoff=cutoff,
|
|
1709
|
+
show_ring=show_ring,
|
|
1710
|
+
)
|
|
1711
|
+
ips.figsave(dir_save + f"prerank_dotplot_{f_name_tmp}.pdf")
|
|
1712
|
+
|
|
1713
|
+
#! network plot
|
|
1714
|
+
from gseapy import enrichment_map
|
|
1715
|
+
import networkx as nx
|
|
1716
|
+
|
|
1717
|
+
for top_term in range(5, 50):
|
|
1718
|
+
try:
|
|
1719
|
+
# return two dataframe
|
|
1720
|
+
nodes, edges = enrichment_map(
|
|
1721
|
+
df=df_prerank,
|
|
1722
|
+
columns="FDR q-val",
|
|
1723
|
+
cutoff=0.25, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
|
1724
|
+
top_term=top_term,
|
|
1725
|
+
)
|
|
1726
|
+
# build graph
|
|
1727
|
+
G = nx.from_pandas_edgelist(
|
|
1728
|
+
edges,
|
|
1729
|
+
source="src_idx",
|
|
1730
|
+
target="targ_idx",
|
|
1731
|
+
edge_attr=["jaccard_coef", "overlap_coef", "overlap_genes"],
|
|
1732
|
+
)
|
|
1733
|
+
# to check if nodes.Hits_ratio or nodes.NES doesn’t match the number of nodes
|
|
1734
|
+
if len(list(nodes.Hits_ratio)) == len(G.nodes):
|
|
1735
|
+
node_sizes = list(nodes.Hits_ratio * 1000)
|
|
1736
|
+
else:
|
|
1737
|
+
raise ValueError(
|
|
1738
|
+
"The size of node_size list does not match the number of nodes in the graph."
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
layout = "circular"
|
|
1742
|
+
fig, ax = plt.subplots(figsize=(max(figsize),max(figsize)))
|
|
1743
|
+
if layout == "spring":
|
|
1744
|
+
pos = nx.layout.spring_layout(G)
|
|
1745
|
+
elif layout == "circular":
|
|
1746
|
+
pos = nx.layout.circular_layout(G)
|
|
1747
|
+
elif layout == "shell":
|
|
1748
|
+
pos = nx.layout.shell_layout(G)
|
|
1749
|
+
elif layout == "spectral":
|
|
1750
|
+
pos = nx.layout.spectral_layout(G)
|
|
1751
|
+
|
|
1752
|
+
# node_size = nx.get_node_attributes()
|
|
1753
|
+
# draw node
|
|
1754
|
+
nx.draw_networkx_nodes(
|
|
1755
|
+
G,
|
|
1756
|
+
pos=pos,
|
|
1757
|
+
cmap=plt.cm.RdYlBu,
|
|
1758
|
+
node_color=list(nodes.NES),
|
|
1759
|
+
node_size=list(nodes.Hits_ratio * 1000),
|
|
1760
|
+
)
|
|
1761
|
+
# draw node label
|
|
1762
|
+
nx.draw_networkx_labels(
|
|
1763
|
+
G,
|
|
1764
|
+
pos=pos,
|
|
1765
|
+
labels=nodes.Term.to_dict(),
|
|
1766
|
+
font_size=8,
|
|
1767
|
+
verticalalignment="bottom",
|
|
1768
|
+
)
|
|
1769
|
+
# draw edge
|
|
1770
|
+
edge_weight = nx.get_edge_attributes(G, "jaccard_coef").values()
|
|
1771
|
+
nx.draw_networkx_edges(
|
|
1772
|
+
G,
|
|
1773
|
+
pos=pos,
|
|
1774
|
+
width=list(map(lambda x: x * 10, edge_weight)),
|
|
1775
|
+
edge_color="#CDDBD4",
|
|
1776
|
+
)
|
|
1777
|
+
ax.set_axis_off()
|
|
1778
|
+
print(f"{gene_sets}(top_term={top_term})")
|
|
1779
|
+
plot.figsets(title=f"{gene_sets}(top_term={top_term})")
|
|
1780
|
+
ips.figsave(dir_save + f"prerank_network_{gene_sets}.pdf")
|
|
1781
|
+
break
|
|
1782
|
+
except:
|
|
1783
|
+
print(f"not work {top_term}")
|
|
1784
|
+
return df_prerank, pre_res
|
|
1785
|
+
def plot_prerank(
|
|
1786
|
+
results_df,
|
|
1787
|
+
kind="bar", # 'barplot', 'dotplot'
|
|
1788
|
+
cutoff=0.25,
|
|
1789
|
+
show_ring=False,
|
|
1790
|
+
xticklabels_rot=0,
|
|
1791
|
+
title=None, # 'KEGG'
|
|
1792
|
+
cmap="coolwarm",
|
|
1793
|
+
n_top=10,
|
|
1794
|
+
size=5, # when size is None in network, by "NES"
|
|
1795
|
+
facecolor=None,# default by "NES"
|
|
1796
|
+
linewidth=None,# default by "NES"
|
|
1797
|
+
linecolor=None,# default by "NES"
|
|
1798
|
+
linealpha=None, # default by "NES"
|
|
1799
|
+
alpha=None,# default by "NES"
|
|
1800
|
+
ax=None,
|
|
1801
|
+
**kwargs,
|
|
1802
|
+
):
|
|
1803
|
+
kws_figsets = {}
|
|
1804
|
+
for k_arg, v_arg in kwargs.items():
|
|
1805
|
+
if "figset" in k_arg:
|
|
1806
|
+
kws_figsets = v_arg
|
|
1807
|
+
kwargs.pop(k_arg, None)
|
|
1808
|
+
break
|
|
1809
|
+
if isinstance(cmap, str):
|
|
1810
|
+
palette = plot.get_color(n_top, cmap=cmap)[::-1]
|
|
1811
|
+
elif isinstance(cmap, list):
|
|
1812
|
+
palette = cmap
|
|
1813
|
+
if n_top < 5:
|
|
1814
|
+
height_ = 4
|
|
1815
|
+
elif 5 <= n_top < 10:
|
|
1816
|
+
height_ = 5
|
|
1817
|
+
elif 10 <= n_top < 15:
|
|
1818
|
+
height_ = 6
|
|
1819
|
+
elif 15 <= n_top < 20:
|
|
1820
|
+
height_ = 7
|
|
1821
|
+
elif 20 <= n_top < 30:
|
|
1822
|
+
height_ = 8
|
|
1823
|
+
elif 30 <= n_top < 40:
|
|
1824
|
+
height_ = int(n_top / 5)
|
|
1825
|
+
else:
|
|
1826
|
+
height_ = int(n_top / 6)
|
|
1827
|
+
results_df["-log10(Adjusted P-value)"]=results_df["FDR q-val"].apply(lambda x : -np.log10(x))
|
|
1828
|
+
results_df["Count"] = results_df["Lead_genes"].apply(lambda x: len(x.split(";")))
|
|
1829
|
+
#! barplot
|
|
1830
|
+
if "bar" in kind.lower():
|
|
1831
|
+
df_=results_df.sort_values(by="-log10(Adjusted P-value)",ascending=False)
|
|
1832
|
+
if ax is None:
|
|
1833
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
|
1834
|
+
ax = plot.plotxy(
|
|
1835
|
+
data=df_.head(n_top),
|
|
1836
|
+
kind_="barplot",
|
|
1837
|
+
x="-log10(Adjusted P-value)",
|
|
1838
|
+
y="Term",
|
|
1839
|
+
hue="Term",
|
|
1840
|
+
palette=palette,
|
|
1841
|
+
legend=None,
|
|
1842
|
+
)
|
|
1843
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1844
|
+
return ax, df_
|
|
1845
|
+
|
|
1846
|
+
#! dotplot
|
|
1847
|
+
elif "dot" in kind.lower():
|
|
1848
|
+
#! dotplot
|
|
1849
|
+
cutoff_curr = cutoff
|
|
1850
|
+
step = 0.05
|
|
1851
|
+
cutoff_stop = 0.5
|
|
1852
|
+
while cutoff_curr <= cutoff_stop:
|
|
1853
|
+
try:
|
|
1854
|
+
if cutoff_curr != cutoff:
|
|
1855
|
+
plt.clf()
|
|
1856
|
+
ax = gp.dotplot(
|
|
1857
|
+
results_df,
|
|
1858
|
+
column="NOM p-val",
|
|
1859
|
+
show_ring=show_ring,
|
|
1860
|
+
xticklabels_rot=xticklabels_rot,
|
|
1861
|
+
title=title,
|
|
1862
|
+
cmap=cmap,
|
|
1863
|
+
cutoff=cutoff_curr,
|
|
1864
|
+
top_term=n_top,
|
|
1865
|
+
size=size,
|
|
1866
|
+
figsize=[10, height_],
|
|
1867
|
+
)
|
|
1868
|
+
if len(ax.collections) >= n_top:
|
|
1869
|
+
print(f"cutoff={cutoff_curr} done! ")
|
|
1870
|
+
break
|
|
1871
|
+
if cutoff_curr == cutoff_stop:
|
|
1872
|
+
break
|
|
1873
|
+
cutoff_curr += step
|
|
1874
|
+
except Exception as e:
|
|
1875
|
+
cutoff_curr += step
|
|
1876
|
+
print(
|
|
1877
|
+
f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} "
|
|
1878
|
+
)
|
|
1879
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1880
|
+
return ax, results_df
|
|
1881
|
+
|
|
1882
|
+
#! barplot with counts
|
|
1883
|
+
elif "co" in kind.lower():
|
|
1884
|
+
if ax is None:
|
|
1885
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
|
1886
|
+
# 从overlap中提取出个数
|
|
1887
|
+
df_ = results_df.sort_values(by="Count", ascending=False)
|
|
1888
|
+
ax = plot.plotxy(
|
|
1889
|
+
data=df_.head(n_top),
|
|
1890
|
+
kind_="barplot",
|
|
1891
|
+
x="Count",
|
|
1892
|
+
y="Term",
|
|
1893
|
+
hue="Term",
|
|
1894
|
+
palette=palette,
|
|
1895
|
+
legend=None,
|
|
1896
|
+
ax=ax,
|
|
1897
|
+
**kwargs,
|
|
1898
|
+
)
|
|
1899
|
+
|
|
1900
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1901
|
+
return ax, df_
|
|
1902
|
+
#! scatter with counts
|
|
1903
|
+
elif "sca" in kind.lower():
|
|
1904
|
+
if isinstance(cmap, str):
|
|
1905
|
+
palette = plot.get_color(n_top, cmap=cmap)
|
|
1906
|
+
elif isinstance(cmap, list):
|
|
1907
|
+
palette = cmap
|
|
1908
|
+
if ax is None:
|
|
1909
|
+
_, ax = plt.subplots(1, 1, figsize=[10, height_])
|
|
1910
|
+
# 从overlap中提取出个数
|
|
1911
|
+
df_ = results_df.sort_values(by="Count", ascending=False)
|
|
1912
|
+
ax = plot.plotxy(
|
|
1913
|
+
data=df_.head(n_top),
|
|
1914
|
+
kind_="scatter",
|
|
1915
|
+
x="Count",
|
|
1916
|
+
y="Term",
|
|
1917
|
+
hue="Count",
|
|
1918
|
+
size="Count",
|
|
1919
|
+
sizes=[10,50],
|
|
1920
|
+
palette=palette,
|
|
1921
|
+
legend=None,
|
|
1922
|
+
ax=ax,
|
|
1923
|
+
**kwargs,
|
|
1924
|
+
)
|
|
1925
|
+
|
|
1926
|
+
plot.figsets(ax=ax, **kws_figsets)
|
|
1927
|
+
return ax, df_
|
|
1928
|
+
elif "net" in kind.lower():
|
|
1929
|
+
#! network plot
|
|
1930
|
+
from gseapy import enrichment_map
|
|
1931
|
+
import networkx as nx
|
|
1932
|
+
from matplotlib import cm
|
|
1933
|
+
# try:
|
|
1934
|
+
if cutoff>=1 or cutoff is None:
|
|
1935
|
+
print(f"cutoff is {cutoff} => Without applying filter")
|
|
1936
|
+
nodes, edges = enrichment_map(
|
|
1937
|
+
df=results_df,
|
|
1938
|
+
columns="NOM p-val",
|
|
1939
|
+
cutoff=1.1, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
|
1940
|
+
top_term=n_top,
|
|
1941
|
+
)
|
|
1942
|
+
else:
|
|
1943
|
+
cutoff_curr = cutoff
|
|
1944
|
+
step = 0.05
|
|
1945
|
+
cutoff_stop = 1.0
|
|
1946
|
+
while cutoff_curr <= cutoff_stop:
|
|
1947
|
+
try:
|
|
1948
|
+
# return two dataframe
|
|
1949
|
+
nodes, edges = enrichment_map(
|
|
1950
|
+
df=results_df,
|
|
1951
|
+
columns="NOM p-val",
|
|
1952
|
+
cutoff=cutoff_curr, # 0.25 when "FDR q-val"; 0.05 when "Nom p-value"
|
|
1953
|
+
top_term=n_top,
|
|
1954
|
+
)
|
|
1955
|
+
|
|
1956
|
+
if nodes.shape[0] >= n_top:
|
|
1957
|
+
print(f"cutoff={cutoff_curr} done! ")
|
|
1958
|
+
break
|
|
1959
|
+
if cutoff_curr == cutoff_stop:
|
|
1960
|
+
break
|
|
1961
|
+
cutoff_curr += step
|
|
1962
|
+
except Exception as e:
|
|
1963
|
+
cutoff_curr += step
|
|
1964
|
+
print(
|
|
1965
|
+
f"{e}: trying cutoff={cutoff_curr}"
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
print("size: by 'NES'") if size is None else print("")
|
|
1969
|
+
print("linewidth: by 'NES'") if linewidth is None else print("")
|
|
1970
|
+
print("linecolor: by 'NES'") if linecolor is None else print("")
|
|
1971
|
+
print("linealpha: by 'NES'") if linealpha is None else print("")
|
|
1972
|
+
print("facecolor: by 'NES'") if facecolor is None else print("")
|
|
1973
|
+
print("alpha: by '-log10(Adjusted P-value)'") if alpha is None else print("")
|
|
1974
|
+
edges.sort_values(by="jaccard_coef", ascending=False,inplace=True)
|
|
1975
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
|
1976
|
+
G,ax=plot_ppi(
|
|
1977
|
+
interactions=edges,
|
|
1978
|
+
player1="src_name",
|
|
1979
|
+
player2="targ_name",
|
|
1980
|
+
weight="jaccard_coef",
|
|
1981
|
+
size=[
|
|
1982
|
+
node["NES"] * 300 for _, node in nodes.iterrows()
|
|
1983
|
+
] if size is None else size, # size nodes by NES
|
|
1984
|
+
facecolor=[colormap(node["NES"]) for _, node in nodes.iterrows()] if facecolor is None else facecolor, # Color by FDR q-val
|
|
1985
|
+
linewidth=[node["NES"] * 300 for _, node in nodes.iterrows()] if linewidth is None else linewidth,
|
|
1986
|
+
linecolor=[node["NES"] * 300 for _, node in nodes.iterrows()] if linecolor is None else linecolor,
|
|
1987
|
+
linealpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if linealpha is None else linealpha,
|
|
1988
|
+
alpha=[node["NES"] * 300 for _, node in nodes.iterrows()] if alpha is None else alpha,
|
|
1989
|
+
**kwargs
|
|
1990
|
+
)
|
|
1991
|
+
# except Exception as e:
|
|
1992
|
+
# print(f"not work {n_top},{e}")
|
|
1993
|
+
return ax, G, nodes, edges
|
|
1994
|
+
|
|
1995
|
+
|
|
1996
|
+
#! https://string-db.org/help/api/
|
|
1997
|
+
|
|
1998
|
+
import pandas as pd
|
|
1999
|
+
import requests
|
|
2000
|
+
import networkx as nx
|
|
2001
|
+
import matplotlib.pyplot as plt
|
|
2002
|
+
from io import StringIO
|
|
2003
|
+
from py2ls import ips
|
|
2004
|
+
|
|
2005
|
+
|
|
2006
|
+
def get_ppi(
|
|
2007
|
+
target_genes:list,
|
|
2008
|
+
species:int=9606, # "human"
|
|
2009
|
+
ci:float=0.1, # int 1~1000
|
|
2010
|
+
max_nodes:int=50,
|
|
2011
|
+
base_url:str="https://string-db.org",
|
|
2012
|
+
gene_mapping_api:str="/api/json/get_string_ids?",
|
|
2013
|
+
interaction_api:str="/api/tsv/network?",
|
|
2014
|
+
):
|
|
2015
|
+
"""
|
|
2016
|
+
Purpose: Retrieves protein-protein interaction data from STRING databasePrinciple: API-based query to STRINGdb for experimentally validated and predicted interactionsKey Operations:
|
|
2017
|
+
• Maps gene symbols to STRING identifiers
|
|
2018
|
+
• Filters by confidence score and species
|
|
2019
|
+
• Returns comprehensive interaction data with multiple evidence typesEvidence Scores: Neighborhood, fusion, coexpression, experimental, database, textmining
|
|
2020
|
+
|
|
2021
|
+
Generate a Protein-Protein Interaction (PPI) network using STRINGdb data.
|
|
2022
|
+
|
|
2023
|
+
return:
|
|
2024
|
+
the STRING protein-protein interaction (PPI) data, which contains information about
|
|
2025
|
+
predicted and experimentally validated associations between proteins.
|
|
2026
|
+
|
|
2027
|
+
stringId_A and stringId_B: Unique identifiers for the interacting proteins based on the
|
|
2028
|
+
STRING database.
|
|
2029
|
+
preferredName_A and preferredName_B: Standard gene names for the interacting proteins.
|
|
2030
|
+
ncbiTaxonId: The taxon ID (9606 for humans).
|
|
2031
|
+
score: A combined score reflecting the overall confidence of the interaction, which aggregates different sources of evidence.
|
|
2032
|
+
|
|
2033
|
+
nscore, fscore, pscore, ascore, escore, dscore, tscore: These are sub-scores representing the confidence in the interaction based on various evidence types:
|
|
2034
|
+
- nscore: Neighborhood score, based on genes located near each other in the genome.
|
|
2035
|
+
- fscore: Fusion score, based on gene fusions in other genomes.
|
|
2036
|
+
- pscore: Phylogenetic profile score, based on co-occurrence across different species.
|
|
2037
|
+
- ascore: Coexpression score, reflecting the likelihood of coexpression.
|
|
2038
|
+
- escore: Experimental score, based on experimental evidence.
|
|
2039
|
+
- dscore: Database score, from curated databases.
|
|
2040
|
+
- tscore: Text-mining score, from literature co-occurrence.
|
|
2041
|
+
|
|
2042
|
+
Higher score values (closer to 1) indicate stronger evidence for an interaction.
|
|
2043
|
+
- Combined score: Useful for ranking interactions based on overall confidence. A score >0.7 is typically considered high-confidence.
|
|
2044
|
+
- Sub-scores: Interpret the types of evidence supporting the interaction. For instance:
|
|
2045
|
+
- High ascore indicates strong evidence of coexpression.
|
|
2046
|
+
- High escore suggests experimental validation.
|
|
2047
|
+
|
|
2048
|
+
"""
|
|
2049
|
+
print("check api: https://string-db.org/help/api/")
|
|
2050
|
+
|
|
2051
|
+
# 将species转化为taxon_id
|
|
2052
|
+
if isinstance(species,str):
|
|
2053
|
+
print(species)
|
|
2054
|
+
species=list(get_taxon_id(species).values())[0]
|
|
2055
|
+
print(species)
|
|
2056
|
+
|
|
2057
|
+
|
|
2058
|
+
string_api_url = base_url + gene_mapping_api
|
|
2059
|
+
interaction_api_url = base_url + interaction_api
|
|
2060
|
+
# Map gene symbols to STRING IDs
|
|
2061
|
+
mapped_genes = {}
|
|
2062
|
+
for gene in target_genes:
|
|
2063
|
+
params = {"identifiers": gene, "species": species, "limit": 1}
|
|
2064
|
+
response = requests.get(string_api_url, params=params)
|
|
2065
|
+
if response.status_code == 200:
|
|
2066
|
+
try:
|
|
2067
|
+
json_data = response.json()
|
|
2068
|
+
if json_data:
|
|
2069
|
+
mapped_genes[gene] = json_data[0]["stringId"]
|
|
2070
|
+
except ValueError:
|
|
2071
|
+
print(
|
|
2072
|
+
f"Failed to decode JSON for gene {gene}. Response: {response.text}"
|
|
2073
|
+
)
|
|
2074
|
+
else:
|
|
2075
|
+
print(
|
|
2076
|
+
f"Failed to fetch data for gene {gene}. Status code: {response.status_code}"
|
|
2077
|
+
)
|
|
2078
|
+
if not mapped_genes:
|
|
2079
|
+
print("No mapped genes found in STRING database.")
|
|
2080
|
+
return None
|
|
2081
|
+
|
|
2082
|
+
# Retrieve PPI data from STRING API
|
|
2083
|
+
string_ids = "%0d".join(mapped_genes.values())
|
|
2084
|
+
params = {
|
|
2085
|
+
"identifiers": string_ids,
|
|
2086
|
+
"species": species,
|
|
2087
|
+
"required_score": int(ci * 1000),
|
|
2088
|
+
"limit": max_nodes,
|
|
2089
|
+
}
|
|
2090
|
+
response = requests.get(interaction_api_url, params=params)
|
|
2091
|
+
|
|
2092
|
+
if response.status_code == 200:
|
|
2093
|
+
try:
|
|
2094
|
+
interactions = pd.read_csv(StringIO(response.text), sep="\t")
|
|
2095
|
+
except Exception as e:
|
|
2096
|
+
print("Error reading the interaction data:", e)
|
|
2097
|
+
print("Response content:", response.text)
|
|
2098
|
+
return None
|
|
2099
|
+
else:
|
|
2100
|
+
print(
|
|
2101
|
+
f"Failed to retrieve interaction data. Status code: {response.status_code}"
|
|
2102
|
+
)
|
|
2103
|
+
print("Response content:", response.text)
|
|
2104
|
+
return None
|
|
2105
|
+
display(interactions.head())
|
|
2106
|
+
# Filter interactions by ci score
|
|
2107
|
+
if "score" in interactions.columns:
|
|
2108
|
+
interactions = interactions[interactions["score"] >= ci]
|
|
2109
|
+
if interactions.empty:
|
|
2110
|
+
print("No interactions found with the specified confidence.")
|
|
2111
|
+
return None
|
|
2112
|
+
else:
|
|
2113
|
+
print("The 'score' column is missing from the retrieved data. Unable to filter by confidence interval.")
|
|
2114
|
+
if "fdr" in interactions.columns:
|
|
2115
|
+
interactions=interactions.sort_values(by="fdr",ascending=False)
|
|
2116
|
+
return interactions
|
|
2117
|
+
# * usage
|
|
2118
|
+
# interactions = get_ppi(target_genes, ci=0.0001)
|
|
2119
|
+
|
|
2120
|
+
def plot_ppi(
|
|
2121
|
+
interactions,
|
|
2122
|
+
player1="preferredName_A",
|
|
2123
|
+
player2="preferredName_B",
|
|
2124
|
+
weight="score",
|
|
2125
|
+
n_layers=None, # Number of concentric layers
|
|
2126
|
+
n_rank=[5, 10], # Nodes in each rank for the concentric layout
|
|
2127
|
+
dist_node = 10, # Distance between each rank of circles
|
|
2128
|
+
layout="degree",
|
|
2129
|
+
size=None,#700,
|
|
2130
|
+
sizes=(50,500),# min and max of size
|
|
2131
|
+
facecolor="skyblue",
|
|
2132
|
+
cmap='coolwarm',
|
|
2133
|
+
edgecolor="k",
|
|
2134
|
+
edgelinewidth=1.5,
|
|
2135
|
+
alpha=.5,
|
|
2136
|
+
alphas=(0.1, 1.0),# min and max of alpha
|
|
2137
|
+
marker="o",
|
|
2138
|
+
node_hideticks=True,
|
|
2139
|
+
linecolor="gray",
|
|
2140
|
+
line_cmap='coolwarm',
|
|
2141
|
+
linewidth=1.5,
|
|
2142
|
+
linewidths=(0.5,5),# min and max of linewidth
|
|
2143
|
+
linealpha=1.0,
|
|
2144
|
+
linealphas=(0.1,1.0),# min and max of linealpha
|
|
2145
|
+
linestyle="-",
|
|
2146
|
+
line_arrowstyle='-',
|
|
2147
|
+
fontsize=10,
|
|
2148
|
+
fontcolor="k",
|
|
2149
|
+
ha:str="center",
|
|
2150
|
+
va:str="center",
|
|
2151
|
+
figsize=(12, 10),
|
|
2152
|
+
k_value=0.3,
|
|
2153
|
+
bgcolor="w",
|
|
2154
|
+
dir_save="./ppi_network.html",
|
|
2155
|
+
physics=True,
|
|
2156
|
+
notebook=False,
|
|
2157
|
+
scale=1,
|
|
2158
|
+
ax=None,
|
|
2159
|
+
**kwargs
|
|
2160
|
+
):
|
|
2161
|
+
"""
|
|
2162
|
+
Purpose: Network visualization of protein-protein interactions
|
|
2163
|
+
Principle: NetworkX and PyVis for interactive and static network visualization
|
|
2164
|
+
Layout Options:
|
|
2165
|
+
|
|
2166
|
+
Spring: Force-directed layout
|
|
2167
|
+
Circular: Concentric circles
|
|
2168
|
+
Degree-based: Nodes positioned by connectivity
|
|
2169
|
+
Customization: Node size/color by centrality, edge thickness by confidence
|
|
2170
|
+
|
|
2171
|
+
Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
|
|
2172
|
+
"""
|
|
2173
|
+
from pyvis.network import Network
|
|
2174
|
+
import networkx as nx
|
|
2175
|
+
from IPython.display import IFrame
|
|
2176
|
+
from matplotlib.colors import Normalize
|
|
2177
|
+
from matplotlib import cm
|
|
2178
|
+
# Check for required columns in the DataFrame
|
|
2179
|
+
for col in [player1, player2, weight]:
|
|
2180
|
+
if col not in interactions.columns:
|
|
2181
|
+
raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
|
|
2182
|
+
interactions.sort_values(by=[weight], inplace=True)
|
|
2183
|
+
# Initialize Pyvis network
|
|
2184
|
+
net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
|
|
2185
|
+
net.force_atlas_2based(
|
|
2186
|
+
gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
|
|
2187
|
+
)
|
|
2188
|
+
net.toggle_physics(physics)
|
|
2189
|
+
|
|
2190
|
+
kws_figsets = {}
|
|
2191
|
+
for k_arg, v_arg in kwargs.items():
|
|
2192
|
+
if "figset" in k_arg:
|
|
2193
|
+
kws_figsets = v_arg
|
|
2194
|
+
kwargs.pop(k_arg, None)
|
|
2195
|
+
break
|
|
2196
|
+
|
|
2197
|
+
# Create a NetworkX graph from the interaction data
|
|
2198
|
+
G = nx.Graph()
|
|
2199
|
+
for _, row in interactions.iterrows():
|
|
2200
|
+
G.add_edge(row[player1], row[player2], weight=row[weight])
|
|
2201
|
+
# G = nx.from_pandas_edgelist(interactions, source=player1, target=player2, edge_attr=weight)
|
|
2202
|
+
|
|
2203
|
+
|
|
2204
|
+
# Calculate node degrees
|
|
2205
|
+
degrees = dict(G.degree())
|
|
2206
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
|
2207
|
+
colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
|
|
2208
|
+
|
|
2209
|
+
if not ips.isa(facecolor, 'color'):
|
|
2210
|
+
print("facecolor: based on degrees")
|
|
2211
|
+
facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
|
|
2212
|
+
num_nodes = G.number_of_nodes()
|
|
2213
|
+
#* size
|
|
2214
|
+
# Set properties based on degrees
|
|
2215
|
+
if not isinstance(size, (int,float,list)):
|
|
2216
|
+
print("size: based on degrees")
|
|
2217
|
+
size = [deg * 50 for deg in degrees.values()] # Scale sizes
|
|
2218
|
+
size = (size[:num_nodes] if len(size) > num_nodes else size) if isinstance(size, list) else [size] * num_nodes
|
|
2219
|
+
if isinstance(size, list) and len(ips.flatten(size,verbose=False))!=1:
|
|
2220
|
+
# Normalize sizes
|
|
2221
|
+
min_size, max_size = sizes # Use sizes tuple for min and max values
|
|
2222
|
+
min_degree, max_degree = min(size), max(size)
|
|
2223
|
+
if max_degree > min_degree: # Avoid division by zero
|
|
2224
|
+
size = [
|
|
2225
|
+
min_size + (max_size - min_size) * (sz - min_degree) / (max_degree - min_degree)
|
|
2226
|
+
for sz in size
|
|
2227
|
+
]
|
|
2228
|
+
else:
|
|
2229
|
+
# If all values are the same, set them to a default of the midpoint
|
|
2230
|
+
size = [(min_size + max_size) / 2] * len(size)
|
|
2231
|
+
|
|
2232
|
+
#* facecolor
|
|
2233
|
+
facecolor = (facecolor[:num_nodes] if len(facecolor) > num_nodes else facecolor) if isinstance(facecolor, list) else [facecolor] * num_nodes
|
|
2234
|
+
# * facealpha
|
|
2235
|
+
if isinstance(alpha, list):
|
|
2236
|
+
alpha = (alpha[:num_nodes] if len(alpha) > num_nodes else alpha + [alpha[-1]] * (num_nodes - len(alpha)))
|
|
2237
|
+
min_alphas, max_alphas = alphas # Use alphas tuple for min and max values
|
|
2238
|
+
if len(alpha) > 0:
|
|
2239
|
+
# Normalize alpha based on the specified min and max
|
|
2240
|
+
min_alpha, max_alpha = min(alpha), max(alpha)
|
|
2241
|
+
if max_alpha > min_alpha: # Avoid division by zero
|
|
2242
|
+
alpha = [
|
|
2243
|
+
min_alphas + (max_alphas - min_alphas) * (ea - min_alpha) / (max_alpha - min_alpha)
|
|
2244
|
+
for ea in alpha
|
|
2245
|
+
]
|
|
2246
|
+
else:
|
|
2247
|
+
# If all alpha values are the same, set them to the average of min and max
|
|
2248
|
+
alpha = [(min_alphas + max_alphas) / 2] * len(alpha)
|
|
2249
|
+
else:
|
|
2250
|
+
# Default to a full opacity if no edges are provided
|
|
2251
|
+
alpha = [1.0] * num_nodes
|
|
2252
|
+
else:
|
|
2253
|
+
# If alpha is a single value, convert it to a list and normalize it
|
|
2254
|
+
alpha = [alpha] * num_nodes # Adjust based on alphas
|
|
2255
|
+
|
|
2256
|
+
for i, node in enumerate(G.nodes()):
|
|
2257
|
+
net.add_node(
|
|
2258
|
+
node,
|
|
2259
|
+
label=node,
|
|
2260
|
+
size=size[i],
|
|
2261
|
+
color=facecolor[i],
|
|
2262
|
+
alpha=alpha[i],
|
|
2263
|
+
font={"size": fontsize, "color": fontcolor},
|
|
2264
|
+
)
|
|
2265
|
+
print(f'nodes number: {i+1}')
|
|
2266
|
+
|
|
2267
|
+
for edge in G.edges(data=True):
|
|
2268
|
+
net.add_edge(
|
|
2269
|
+
edge[0],
|
|
2270
|
+
edge[1],
|
|
2271
|
+
weight=edge[2]["weight"],
|
|
2272
|
+
color=edgecolor,
|
|
2273
|
+
width=edgelinewidth * edge[2]["weight"],
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
layouts = [
|
|
2277
|
+
"spring",
|
|
2278
|
+
"circular",
|
|
2279
|
+
"kamada_kawai",
|
|
2280
|
+
"random",
|
|
2281
|
+
"shell",
|
|
2282
|
+
"planar",
|
|
2283
|
+
"spiral",
|
|
2284
|
+
"degree"
|
|
2285
|
+
]
|
|
2286
|
+
layout = ips.strcmp(layout, layouts)[0]
|
|
2287
|
+
print(f"layout:{layout}, or select one in {layouts}")
|
|
2288
|
+
|
|
2289
|
+
# Choose layout
|
|
2290
|
+
if layout == "spring":
|
|
2291
|
+
pos = nx.spring_layout(G, k=k_value)
|
|
2292
|
+
elif layout == "circular":
|
|
2293
|
+
pos = nx.circular_layout(G)
|
|
2294
|
+
elif layout == "kamada_kawai":
|
|
2295
|
+
pos = nx.kamada_kawai_layout(G)
|
|
2296
|
+
elif layout == "spectral":
|
|
2297
|
+
pos = nx.spectral_layout(G)
|
|
2298
|
+
elif layout == "random":
|
|
2299
|
+
pos = nx.random_layout(G)
|
|
2300
|
+
elif layout == "shell":
|
|
2301
|
+
pos = nx.shell_layout(G)
|
|
2302
|
+
elif layout == "planar":
|
|
2303
|
+
if nx.check_planarity(G)[0]:
|
|
2304
|
+
pos = nx.planar_layout(G)
|
|
2305
|
+
else:
|
|
2306
|
+
print("Graph is not planar; switching to spring layout.")
|
|
2307
|
+
pos = nx.spring_layout(G, k=k_value)
|
|
2308
|
+
elif layout == "spiral":
|
|
2309
|
+
pos = nx.spiral_layout(G)
|
|
2310
|
+
elif layout=='degree':
|
|
2311
|
+
# Calculate node degrees and sort nodes by degree
|
|
2312
|
+
degrees = dict(G.degree())
|
|
2313
|
+
sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
|
2314
|
+
norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
|
|
2315
|
+
colormap = cm.get_cmap(cmap)
|
|
2316
|
+
|
|
2317
|
+
# Create positions for concentric circles based on n_layers and n_rank
|
|
2318
|
+
pos = {}
|
|
2319
|
+
n_layers=len(n_rank)+1 if n_layers is None else n_layers
|
|
2320
|
+
for rank_index in range(n_layers):
|
|
2321
|
+
if rank_index < len(n_rank):
|
|
2322
|
+
nodes_per_rank = n_rank[rank_index]
|
|
2323
|
+
rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
|
|
2324
|
+
else:
|
|
2325
|
+
# 随机打乱剩余节点的顺序
|
|
2326
|
+
remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
|
|
2327
|
+
random_indices = np.random.permutation(len(remaining_nodes))
|
|
2328
|
+
rank_nodes = [remaining_nodes[i] for i in random_indices]
|
|
2329
|
+
|
|
2330
|
+
radius = (rank_index + 1) * dist_node # Radius for this rank
|
|
2331
|
+
|
|
2332
|
+
# Arrange nodes in a circle for the current rank
|
|
2333
|
+
for i, (node, degree) in enumerate(rank_nodes):
|
|
2334
|
+
angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
|
|
2335
|
+
pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
|
|
2336
|
+
|
|
2337
|
+
else:
|
|
2338
|
+
print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
|
|
2339
|
+
pos = nx.spring_layout(G, k=k_value)
|
|
2340
|
+
|
|
2341
|
+
for node, (x, y) in pos.items():
|
|
2342
|
+
net.get_node(node)["x"] = x * scale
|
|
2343
|
+
net.get_node(node)["y"] = y * scale
|
|
2344
|
+
|
|
2345
|
+
# If ax is None, use plt.gca()
|
|
2346
|
+
if ax is None:
|
|
2347
|
+
fig, ax = plt.subplots(1,1,figsize=figsize)
|
|
2348
|
+
|
|
2349
|
+
# Draw nodes, edges, and labels with customization options
|
|
2350
|
+
nx.draw_networkx_nodes(
|
|
2351
|
+
G,
|
|
2352
|
+
pos,
|
|
2353
|
+
ax=ax,
|
|
2354
|
+
node_size=size,
|
|
2355
|
+
node_color=facecolor,
|
|
2356
|
+
linewidths=edgelinewidth,
|
|
2357
|
+
edgecolors=edgecolor,
|
|
2358
|
+
alpha=alpha,
|
|
2359
|
+
hide_ticks=node_hideticks,
|
|
2360
|
+
node_shape=marker
|
|
2361
|
+
)
|
|
2362
|
+
|
|
2363
|
+
#* linewidth
|
|
2364
|
+
if not isinstance(linewidth, list):
|
|
2365
|
+
linewidth = [linewidth] * G.number_of_edges()
|
|
2366
|
+
else:
|
|
2367
|
+
linewidth = (linewidth[:G.number_of_edges()] if len(linewidth) > G.number_of_edges() else linewidth + [linewidth[-1]] * (G.number_of_edges() - len(linewidth)))
|
|
2368
|
+
# Normalize linewidth if it is a list
|
|
2369
|
+
if isinstance(linewidth, list):
|
|
2370
|
+
min_linewidth, max_linewidth = min(linewidth), max(linewidth)
|
|
2371
|
+
vmin, vmax = linewidths # Use linewidths tuple for min and max values
|
|
2372
|
+
if max_linewidth > min_linewidth: # Avoid division by zero
|
|
2373
|
+
# Scale between vmin and vmax
|
|
2374
|
+
linewidth = [
|
|
2375
|
+
vmin + (vmax - vmin) * (lw - min_linewidth) / (max_linewidth - min_linewidth)
|
|
2376
|
+
for lw in linewidth
|
|
2377
|
+
]
|
|
2378
|
+
else:
|
|
2379
|
+
# If all values are the same, set them to a default of the midpoint
|
|
2380
|
+
linewidth = [(vmin + vmax) / 2] * len(linewidth)
|
|
2381
|
+
else:
|
|
2382
|
+
# If linewidth is a single value, convert it to a list of that value
|
|
2383
|
+
linewidth = [linewidth] * G.number_of_edges()
|
|
2384
|
+
#* linecolor
|
|
2385
|
+
if not isinstance(linecolor, str):
|
|
2386
|
+
weights = [G[u][v]["weight"] for u, v in G.edges()]
|
|
2387
|
+
norm = Normalize(vmin=min(weights), vmax=max(weights))
|
|
2388
|
+
colormap = cm.get_cmap(line_cmap)
|
|
2389
|
+
linecolor = [colormap(norm(weight)) for weight in weights]
|
|
2390
|
+
else:
|
|
2391
|
+
linecolor = [linecolor] * G.number_of_edges()
|
|
2392
|
+
|
|
2393
|
+
# * linealpha
|
|
2394
|
+
if isinstance(linealpha, list):
|
|
2395
|
+
linealpha = (linealpha[:G.number_of_edges()] if len(linealpha) > G.number_of_edges() else linealpha + [linealpha[-1]] * (G.number_of_edges() - len(linealpha)))
|
|
2396
|
+
min_alpha, max_alpha = linealphas # Use linealphas tuple for min and max values
|
|
2397
|
+
if len(linealpha) > 0:
|
|
2398
|
+
min_linealpha, max_linealpha = min(linealpha), max(linealpha)
|
|
2399
|
+
if max_linealpha > min_linealpha: # Avoid division by zero
|
|
2400
|
+
linealpha = [
|
|
2401
|
+
min_alpha + (max_alpha - min_alpha) * (ea - min_linealpha) / (max_linealpha - min_linealpha)
|
|
2402
|
+
for ea in linealpha
|
|
2403
|
+
]
|
|
2404
|
+
else:
|
|
2405
|
+
linealpha = [(min_alpha + max_alpha) / 2] * len(linealpha)
|
|
2406
|
+
else:
|
|
2407
|
+
linealpha = [1.0] * G.number_of_edges() # 如果设置有误,则将它设置成1.0
|
|
2408
|
+
else:
|
|
2409
|
+
linealpha = [linealpha] * G.number_of_edges() # Convert to list if single value
|
|
2410
|
+
nx.draw_networkx_edges(
|
|
2411
|
+
G,
|
|
2412
|
+
pos,
|
|
2413
|
+
ax=ax,
|
|
2414
|
+
edge_color=linecolor,
|
|
2415
|
+
width=linewidth,
|
|
2416
|
+
style=linestyle,
|
|
2417
|
+
arrowstyle=line_arrowstyle,
|
|
2418
|
+
alpha=linealpha
|
|
2419
|
+
)
|
|
2420
|
+
|
|
2421
|
+
nx.draw_networkx_labels(
|
|
2422
|
+
G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
|
|
2423
|
+
)
|
|
2424
|
+
plot.figsets(ax=ax,**kws_figsets)
|
|
2425
|
+
ax.axis("off")
|
|
2426
|
+
if dir_save:
|
|
2427
|
+
if not os.path.basename(dir_save):
|
|
2428
|
+
dir_save="_.html"
|
|
2429
|
+
net.write_html(dir_save)
|
|
2430
|
+
nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
|
|
2431
|
+
print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
|
|
2432
|
+
ips.figsave(dir_save.replace(".html",".pdf"))
|
|
2433
|
+
return G,ax
|
|
2434
|
+
|
|
2435
|
+
|
|
2436
|
+
# * usage:
|
|
2437
|
+
# G, ax = bio.plot_ppi(
|
|
2438
|
+
# interactions,
|
|
2439
|
+
# player1="preferredName_A",
|
|
2440
|
+
# player2="preferredName_B",
|
|
2441
|
+
# weight="score",
|
|
2442
|
+
# # size="auto",
|
|
2443
|
+
# # size=interactions["score"].tolist(),
|
|
2444
|
+
# # layout="circ",
|
|
2445
|
+
# n_rank=[5, 10, 15],
|
|
2446
|
+
# dist_node=100,
|
|
2447
|
+
# alpha=0.6,
|
|
2448
|
+
# linecolor="0.8",
|
|
2449
|
+
# linewidth=1,
|
|
2450
|
+
# figsize=(8, 8.5),
|
|
2451
|
+
# cmap="jet",
|
|
2452
|
+
# edgelinewidth=0.5,
|
|
2453
|
+
# # facecolor="#FF5F57",
|
|
2454
|
+
# fontsize=10,
|
|
2455
|
+
# # fontcolor="b",
|
|
2456
|
+
# # edgecolor="r",
|
|
2457
|
+
# # scale=100,
|
|
2458
|
+
# # physics=False,
|
|
2459
|
+
# figsets=dict(title="ppi networks"),
|
|
2460
|
+
# )
|
|
2461
|
+
# figsave("./ppi_network.pdf")
|
|
2462
|
+
|
|
2463
|
+
def top_ppi(interactions, n_top=10):
|
|
2464
|
+
"""
|
|
2465
|
+
Purpose: Identifies key proteins in interaction networks using centrality measures
|
|
2466
|
+
Centrality Metrics:
|
|
2467
|
+
|
|
2468
|
+
Degree: Number of direct connections
|
|
2469
|
+
Betweenness: Bridge positions in network
|
|
2470
|
+
Usage: Prioritization of biologically important proteins
|
|
2471
|
+
|
|
2472
|
+
Analyzes protein-protein interactions (PPIs) to identify key proteins based on
|
|
2473
|
+
degree and betweenness centrality.
|
|
2474
|
+
|
|
2475
|
+
Parameters:
|
|
2476
|
+
interactions (pd.DataFrame): DataFrame containing PPI data with columns
|
|
2477
|
+
['preferredName_A', 'preferredName_B', 'score'].
|
|
2478
|
+
|
|
2479
|
+
Returns:
|
|
2480
|
+
dict: A dictionary containing the top key proteins by degree and betweenness centrality.
|
|
2481
|
+
"""
|
|
2482
|
+
|
|
2483
|
+
# Create a NetworkX graph from the interaction data
|
|
2484
|
+
G = nx.Graph()
|
|
2485
|
+
for _, row in interactions.iterrows():
|
|
2486
|
+
G.add_edge(row["preferredName_A"], row["preferredName_B"], weight=row["score"])
|
|
2487
|
+
|
|
2488
|
+
# Calculate Degree Centrality
|
|
2489
|
+
degree_centrality = G.degree()
|
|
2490
|
+
key_proteins_degree = sorted(degree_centrality, key=lambda x: x[1], reverse=True)
|
|
2491
|
+
|
|
2492
|
+
# Calculate Betweenness Centrality
|
|
2493
|
+
betweenness_centrality = nx.betweenness_centrality(G)
|
|
2494
|
+
key_proteins_betweenness = sorted(
|
|
2495
|
+
betweenness_centrality.items(), key=lambda x: x[1], reverse=True
|
|
2496
|
+
)
|
|
2497
|
+
print(
|
|
2498
|
+
{
|
|
2499
|
+
"Top 10 Key Proteins by Degree Centrality": key_proteins_degree[:10],
|
|
2500
|
+
"Top 10 Key Proteins by Betweenness Centrality": key_proteins_betweenness[
|
|
2501
|
+
:10
|
|
2502
|
+
],
|
|
2503
|
+
}
|
|
2504
|
+
)
|
|
2505
|
+
# Return the top n_top key proteins
|
|
2506
|
+
if n_top == "all":
|
|
2507
|
+
return key_proteins_degree, key_proteins_betweenness
|
|
2508
|
+
else:
|
|
2509
|
+
return key_proteins_degree[:n_top], key_proteins_betweenness[:n_top]
|
|
2510
|
+
|
|
2511
|
+
|
|
2512
|
+
# * usage: top_ppi(interactions)
|
|
2513
|
+
# top_ppi(interactions, n_top="all")
|
|
2514
|
+
# top_ppi(interactions, n_top=10)
|
|
2515
|
+
|
|
2516
|
+
|
|
2517
|
+
|
|
2518
|
+
species_dict = {
|
|
2519
|
+
"Human": "Homo sapiens",
|
|
2520
|
+
"House mouse": "Mus musculus",
|
|
2521
|
+
"Zebrafish": "Danio rerio",
|
|
2522
|
+
"Norway rat": "Rattus norvegicus",
|
|
2523
|
+
"Fruit fly": "Drosophila melanogaster",
|
|
2524
|
+
"Baker's yeast": "Saccharomyces cerevisiae",
|
|
2525
|
+
"Nematode": "Caenorhabditis elegans",
|
|
2526
|
+
"Chicken": "Gallus gallus",
|
|
2527
|
+
"Cattle": "Bos taurus",
|
|
2528
|
+
"Rice": "Oryza sativa",
|
|
2529
|
+
"Thale cress": "Arabidopsis thaliana",
|
|
2530
|
+
"Guinea pig": "Cavia porcellus",
|
|
2531
|
+
"Domestic dog": "Canis lupus familiaris",
|
|
2532
|
+
"Domestic cat": "Felis catus",
|
|
2533
|
+
"Horse": "Equus caballus",
|
|
2534
|
+
"Domestic pig": "Sus scrofa",
|
|
2535
|
+
"African clawed frog": "Xenopus laevis",
|
|
2536
|
+
"Great white shark": "Carcharodon carcharias",
|
|
2537
|
+
"Common chimpanzee": "Pan troglodytes",
|
|
2538
|
+
"Rhesus macaque": "Macaca mulatta",
|
|
2539
|
+
"Water buffalo": "Bubalus bubalis",
|
|
2540
|
+
"Lettuce": "Lactuca sativa",
|
|
2541
|
+
"Tomato": "Solanum lycopersicum",
|
|
2542
|
+
"Maize": "Zea mays",
|
|
2543
|
+
"Cucumber": "Cucumis sativus",
|
|
2544
|
+
"Common grape vine": "Vitis vinifera",
|
|
2545
|
+
"Scots pine": "Pinus sylvestris",
|
|
2546
|
+
}
|
|
2547
|
+
|
|
2548
|
+
|
|
2549
|
+
def get_taxon_id(species_list):
|
|
2550
|
+
"""
|
|
2551
|
+
Purpose: Converts species names to NCBI taxonomy IDs
|
|
2552
|
+
Principle: BioPython Entrez queries to taxonomy database
|
|
2553
|
+
Supported: 25+ common model organisms
|
|
2554
|
+
Convert species names to their corresponding taxon ID codes.
|
|
2555
|
+
|
|
2556
|
+
Parameters:
|
|
2557
|
+
- species_list: List of species names (strings).
|
|
2558
|
+
|
|
2559
|
+
Returns:
|
|
2560
|
+
- dict: A dictionary with species names as keys and their taxon IDs as values.
|
|
2561
|
+
"""
|
|
2562
|
+
from Bio import Entrez
|
|
2563
|
+
|
|
2564
|
+
if not isinstance(species_list, list):
|
|
2565
|
+
species_list = [species_list]
|
|
2566
|
+
species_list = [
|
|
2567
|
+
species_dict[ips.strcmp(i, ips.flatten(list(species_dict.keys())))[0]]
|
|
2568
|
+
for i in species_list
|
|
2569
|
+
]
|
|
2570
|
+
taxon_dict = {}
|
|
2571
|
+
|
|
2572
|
+
for species in species_list:
|
|
2573
|
+
try:
|
|
2574
|
+
search_handle = Entrez.esearch(db="taxonomy", term=species)
|
|
2575
|
+
search_results = Entrez.read(search_handle)
|
|
2576
|
+
search_handle.close()
|
|
2577
|
+
|
|
2578
|
+
# Get the taxon ID
|
|
2579
|
+
if search_results["IdList"]:
|
|
2580
|
+
taxon_id = search_results["IdList"][0]
|
|
2581
|
+
taxon_dict[species] = int(taxon_id)
|
|
2582
|
+
else:
|
|
2583
|
+
taxon_dict[species] = None # Not found
|
|
2584
|
+
except Exception as e:
|
|
2585
|
+
print(f"Error occurred for species '{species}': {e}")
|
|
2586
|
+
taxon_dict[species] = None # Error in processing
|
|
2587
|
+
return taxon_dict
|
|
2588
|
+
|
|
2589
|
+
|
|
2590
|
+
# # * usage: get_taxon_id("human")
|
|
2591
|
+
# species_names = ["human", "nouse", "rat"]
|
|
2592
|
+
# taxon_ids = get_taxon_id(species_names)
|
|
2593
|
+
# print(taxon_ids)
|
|
2594
|
+
|
|
2595
|
+
|