py2ls 0.2.4.3__py3-none-any.whl → 0.2.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
py2ls/bio.py CHANGED
@@ -1,8 +1,12 @@
1
1
  import GEOparse
2
+ import gseapy as gp
2
3
  from typing import Union
3
4
  import pandas as pd
5
+ import numpy as np
4
6
  import os
5
7
  import logging
8
+
9
+ from sympy import use
6
10
  from . import ips
7
11
  from . import plot
8
12
  import matplotlib.pyplot as plt
@@ -123,11 +127,32 @@ def get_meta(geo: dict, dataset: str = "GSE25097", verbose=True) -> pd.DataFrame
123
127
 
124
128
  # Convert the list of dictionaries to a DataFrame
125
129
  meta_df = pd.DataFrame(meta_list)
130
+ col_rm = [
131
+ "channel_count",
132
+ "contact_web_link",
133
+ "contact_address",
134
+ "contact_city",
135
+ "contact_country",
136
+ "contact_department",
137
+ "contact_email",
138
+ "contact_institute",
139
+ "contact_laboratory",
140
+ "contact_name",
141
+ "contact_phone",
142
+ "contact_state",
143
+ "contact_zip/postal_code",
144
+ "contributor",
145
+ "manufacture_protocol",
146
+ "taxid",
147
+ "web_link",
148
+ ]
149
+ # rm unrelavent columns
150
+ meta_df = meta_df.drop(columns=[col for col in col_rm if col in meta_df.columns])
126
151
  if verbose:
127
152
  print(
128
153
  f"Meta info columns for dataset '{dataset}': \n{sorted(meta_df.columns.tolist())}"
129
154
  )
130
- display(meta_df[:3].T)
155
+ display(meta_df[:1].T)
131
156
  return meta_df
132
157
 
133
158
 
@@ -142,13 +167,14 @@ def get_probe(
142
167
  df_meta = get_meta(geo=geo, dataset=dataset, verbose=False)
143
168
  platform_id = df_meta["platform_id"].unique().tolist()
144
169
  platform_id = platform_id[0] if len(platform_id) == 1 else platform_id
145
- print(platform_id)
170
+ print(f"Platform: {platform_id}")
146
171
  df_probe = geo[dataset].gpls[platform_id].table
147
172
  if df_probe.empty:
148
173
  print(
149
- f"above is meta info, failed to find the probe info. 看一下是不是在单独的文件中包含了probe信息"
174
+ f"Warning: cannot find the probe info. 看一下是不是在单独的文件中包含了probe信息"
150
175
  )
151
- return get_meta(geo, dataset, verbose=True)
176
+ display(f"🔗: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc={platform_id}")
177
+ return get_meta(geo, dataset, verbose=verbose)
152
178
  if verbose:
153
179
  print(f"columns in the probe table: \n{sorted(df_probe.columns.tolist())}")
154
180
  return df_probe
@@ -181,19 +207,25 @@ def get_expression_data(geo: dict, dataset: str = "GSE25097") -> pd.DataFrame:
181
207
  return expression_values
182
208
 
183
209
 
184
- def get_data(geo: dict, dataset: str = "GSE25097", verbose=True):
210
+ def get_data(geo: dict, dataset: str = "GSE25097", verbose=False):
211
+ print(f"\n\ndataset: {dataset}\n")
185
212
  # get probe info
186
213
  df_probe = get_probe(geo, dataset=dataset, verbose=False)
187
214
  # get expression values
188
215
  df_expression = get_expression_data(geo, dataset=dataset)
189
- print(
190
- f"df_expression.shape: {df_expression.shape} \ndf_probe.shape: {df_probe.shape}"
191
- )
216
+ if not df_expression.select_dtypes(include=["number"]).empty:
217
+ # 如果数据全部是counts类型的话, 则使用TMM进行normalize
218
+ if 'counts' in get_data_type(df_expression):
219
+ print(f"{dataset}'s type is raw read counts, nomalized(transformed) via 'TMM'")
220
+ df_expression=counts2expression(df_expression.T).T
192
221
  if any([df_probe.empty, df_expression.empty]):
193
222
  print(
194
- f"above is meta info, failed to find the probe info. 看一下是不是在单独的文件中包含了probe信息"
223
+ f"got empty values, check the probe info. 看一下是不是在单独的文件中包含了probe信息"
195
224
  )
196
225
  return get_meta(geo, dataset, verbose=True)
226
+ print(
227
+ f"\n\tdf_expression.shape: {df_expression.shape} \n\tdf_probe.shape: {df_probe.shape}"
228
+ )
197
229
  df_exp = pd.merge(
198
230
  df_probe,
199
231
  df_expression,
@@ -237,14 +269,39 @@ def get_data(geo: dict, dataset: str = "GSE25097", verbose=True):
237
269
  df_exp.set_index(col_gene_symbol, inplace=True)
238
270
  df_exp = df_exp[col_gsm].T # transpose, so that could add meta info
239
271
 
240
- df_merged = ips.df_merge(df_meta, df_exp)
272
+ df_merged = ips.df_merge(df_meta, df_exp,use_index=True)
273
+
274
+ print(
275
+ f"\ndataset:'{dataset}' n_sample = {df_merged.shape[0]}, n_gene={df_exp.shape[1]}"
276
+ )
241
277
  if verbose:
242
- print(
243
- f"\ndataset:'{dataset}' n_sample = {df_merged.shape[0]}, n_gene={df_exp.shape[1]}"
244
- )
245
278
  display(df_merged.sample(5))
246
279
  return df_merged
247
280
 
281
+ def get_data_type(data: pd.DataFrame) -> str:
282
+ """
283
+ Determine the type of data: 'read counts' or 'normalized expression data'.
284
+ usage:
285
+ get_data_type(df_counts)
286
+ """
287
+ numeric_data = data.select_dtypes(include=["number"])
288
+ if numeric_data.empty:
289
+ raise ValueError(f"找不到数字格式的数据, 请先进行转换")
290
+ # Check if the data contains only integers
291
+ if numeric_data.apply(lambda x: x.dtype == "int").all():
292
+ # Check for values typically found in raw read counts (large integers)
293
+ if numeric_data.max().max() > 10000: # Threshold for raw counts
294
+ return "read counts"
295
+ # Check if all values are floats
296
+ if numeric_data.apply(lambda x: x.dtype == "float").all():
297
+ # If values are small, it's likely normalized data
298
+ if numeric_data.max().max() < 1000: # Threshold for normalized expression
299
+ return "normalized expression data"
300
+ else:
301
+ print(f"the max value: {numeric_data.max().max()}, it could be a raw read counts data. but needs you to double check it")
302
+ return "read counts"
303
+ # If mixed data types or unexpected values
304
+ return "mixed or unknown"
248
305
 
249
306
  def split_at_lower_upper(lst):
250
307
  """
@@ -261,6 +318,13 @@ def split_at_lower_upper(lst):
261
318
  return lst[: i + 1], lst[i + 1 :]
262
319
  return lst, []
263
320
 
321
+ def find_condition(data:pd.DataFrame, columns=["characteristics_ch1","title"]):
322
+ if data.shape[1]>=data.shape[0]:
323
+ display(data.iloc[:1,:40].T)
324
+ # 详细看看每个信息的有哪些类, 其中有数字的, 要去除
325
+ for col in columns:
326
+ print(f"{"="*10} {col} {"="*10}")
327
+ display(ips.flatten([ips.ssplit(i, by="numer")[0] for i in data[col]],verbose=False))
264
328
 
265
329
  def add_condition(
266
330
  data: pd.DataFrame,
@@ -305,7 +369,7 @@ def add_condition(
305
369
  lambda x: by_not_name if not by_not in x else by_name
306
370
  )
307
371
  if verbose:
308
- display(data)
372
+ display(data.sample(5))
309
373
  if not inplace:
310
374
  return data
311
375
 
@@ -370,13 +434,13 @@ def add_condition_multi(
370
434
 
371
435
  # Display the updated DataFrame if verbose is True
372
436
  if verbose:
373
- display(data)
437
+ display(data.sample(5))
374
438
 
375
439
  if not inplace:
376
440
  return data
377
441
 
378
442
  def clean_dataset(
379
- data: pd.DataFrame, dataset: str = "GSE25097", condition: str = "condition",sep="///"
443
+ data: pd.DataFrame, dataset: str = None, condition: str = "condition",sep="///"
380
444
  ):
381
445
  """
382
446
  #* it has been involved in bio.batch_effects(), but default: False
@@ -386,6 +450,14 @@ def clean_dataset(
386
450
  4. add the 'condition' and 'dataset info' to the columns
387
451
  5. set genes as index
388
452
  """
453
+ usage_str="""clean_dataset(data: pd.DataFrame, dataset: str = None, condition: str = "condition",sep="///")
454
+ """
455
+ if dataset is None:
456
+ try:
457
+ dataset=data["dataset"][0]
458
+ except:
459
+ print("cannot find 'dataset' name")
460
+ print(f"example\n {usage_str}")
389
461
  #! (4.1) clean data set and prepare super_datasets
390
462
  # df_data_2, 左边的列是meta,右边的列是gene_symbol
391
463
  col_gene = split_at_lower_upper(data.columns.tolist())[1][0]
@@ -420,7 +492,7 @@ def clean_dataset(
420
492
  return df_gene
421
493
 
422
494
  def batch_effect(
423
- data: list = "[df_gene_1, df_gene_2, df_gene_3]",
495
+ data: list = "[df_gene_1, df_gene_2, df_gene_3]", # index (genes),columns(samples)
424
496
  datasets: list = ["GSE25097", "GSE62232", "GSE65372"],
425
497
  clean_data:bool=False, # default, not do data cleaning
426
498
  top_genes:int=10,# only for plotting
@@ -509,5 +581,1298 @@ def batch_effect(
509
581
  return df_corrected
510
582
 
511
583
  def get_common_genes(elment1, elment2):
512
- common_genes=ips.shared(elment1, elment2)
513
- return common_genes
584
+ common_genes=ips.shared(elment1, elment2,verbose=False)
585
+ return common_genes
586
+
587
+ def counts2expression(
588
+ counts: pd.DataFrame,# index(samples); columns(genes)
589
+ method: str = "TMM", # 'CPM', 'FPKM', 'TPM', 'UQ', 'TMM', 'CUF', 'CTF'
590
+ length: list = None,
591
+ uq_factors: pd.Series = None,
592
+ verbose: bool = False,
593
+ ) -> pd.DataFrame:
594
+ """
595
+ https://www.linkedin.com/pulse/snippet-corner-raw-read-count-normalization-python-mazzalab-gzzyf?trk=public_post
596
+ Convert raw RNA-seq read counts to expression values
597
+ counts: pd.DataFrame
598
+ index: samples
599
+ columns: genes
600
+ usage:
601
+ df_normalized = counts2expression(df_counts, method='TMM', verbose=True)
602
+ recommend cross datasets:
603
+ cross-datasets:
604
+ TMM (Trimmed Mean of M-values); Very suitable for merging datasets, especially
605
+ for cross-sample and cross-dataset comparisons; commonly used in
606
+ differential expression analysis
607
+ CTF (Counts adjusted with TMM factors); Suitable for merging datasets, as
608
+ TMM-based normalization. Typically used as input for downstream analyses
609
+ like differential expression
610
+ TPM (Transcripts Per Million); Good for merging datasets. TPM is often more
611
+ suitable for cross-dataset comparisons because it adjusts for gene length
612
+ and ensures that the expression levels sum to the same total in each sample
613
+ UQ (Upper Quartile); less commonly used than TPM or TMM
614
+ CUF (Counts adjusted with UQ factors); Can be used, but UQ normalization is
615
+ generally not as standardized as TPM or TMM for merging datasets.
616
+ within-datasets:
617
+ CPM(Counts Per Million); it doesn’t adjust for gene length or other
618
+ variables that could vary across datasets
619
+ FPKM(Fragments Per Kilobase Million); FPKM has been known to be inconsistent
620
+ across different experiments
621
+ Parameters:
622
+ - counts: pd.DataFrame
623
+ Raw read counts with genes as rows and samples as columns.
624
+ - method: str, default='TMM'
625
+ CPM (Counts per Million): Scales counts by total library size.
626
+ FPKM (Fragments per Kilobase Million): Requires gene length; scales by both library size and gene length.
627
+ TPM (Transcripts per Million): Scales by gene length and total transcript abundance.
628
+ UQ (Upper Quartile): Normalizes based on the upper quartile of the counts.
629
+ TMM (Trimmed Mean of M-values): Adjusts for compositional biases.
630
+ CUF (Counts adjusted with Upper Quartile factors): Counts adjusted based on UQ factors.
631
+ CTF (Counts adjusted with TMM factors): Counts adjusted based on TMM factors.
632
+ - gene_lengths: pd.Series, optional
633
+ Gene lengths (e.g., in kilobases) for FPKM/TPM normalization. Required for FPKM/TPM.
634
+ - verbose: bool, default=False
635
+ If True, provides detailed logging information.
636
+ - uq_factors: pd.Series, optional
637
+ Precomputed Upper Quartile factors, required for UQ and CUF normalization.
638
+
639
+
640
+ Returns:
641
+ - normalized_counts: pd.DataFrame
642
+ Normalized expression values.
643
+ """
644
+ import rnanorm
645
+ print(f"'counts' data shoule be: index(samples); columns(genes)")
646
+ if "length" in method: # 有时候记不住这么多不同的名字
647
+ method="FPKM"
648
+ methods = ["CPM", "FPKM", "TPM", "UQ", "TMM", "CUF", "CTF"]
649
+ method = ips.strcmp(method, methods)[0]
650
+ if verbose:
651
+ print(
652
+ f"Starting normalization using method: {method},supported methods: {methods}"
653
+ )
654
+ columns_org = counts.columns.tolist()
655
+ # Check if gene lengths are provided when necessary
656
+ if method in ["FPKM", "TPM"]:
657
+ if length is None:
658
+ raise ValueError(f"Gene lengths must be provided for {method} normalization.")
659
+ if isinstance(length, list):
660
+ df_genelength = pd.DataFrame({"gene_length": length})
661
+ df_genelength.index = counts.columns # set gene_id as index
662
+ df_genelength.index = df_genelength.index.astype(str).str.strip()
663
+ # length = np.array(df_genelength["gene_length"]).reshape(1,-1)
664
+ length = df_genelength["gene_length"]
665
+ counts.index = counts.index.astype(str).str.strip()
666
+ elif isinstance(length, pd.Series):
667
+
668
+ length.index=length.index.astype(str).str.strip()
669
+ counts.columns = counts.columns.astype(str).str.strip()
670
+ shared_genes=ips.shared(length.index, counts.columns,verbose=False)
671
+ length=length.loc[shared_genes]
672
+ counts=counts.loc[:,shared_genes]
673
+ columns_org = counts.columns.tolist()
674
+
675
+
676
+ # # Ensure gene lengths are aligned with counts if provided
677
+ # if length is not None:
678
+ # length = length[counts.index]
679
+
680
+ # Start the normalization based on the chosen method
681
+ if method == "CPM":
682
+ normalized_counts = (
683
+ rnanorm.CPM().set_output(transform="pandas").fit_transform(counts)
684
+ )
685
+
686
+ elif method == "FPKM":
687
+ if verbose:
688
+ print("Performing FPKM normalization using gene lengths.")
689
+ normalized_counts = (
690
+ rnanorm.CPM().set_output(transform="pandas").fit_transform(counts)
691
+ )
692
+ # convert it to FPKM by, {FPKM= gene length /read counts ×1000} is applied using row-wise division and multiplication.
693
+ normalized_counts=normalized_counts.div(length.values,axis=1)*1e3
694
+
695
+ elif method == "TPM":
696
+ if verbose:
697
+ print("Performing TPM normalization using gene lengths.")
698
+ normalized_counts = (
699
+ rnanorm.TPM(gene_lengths=length)
700
+ .set_output(transform="pandas")
701
+ .fit_transform(counts)
702
+ )
703
+
704
+ elif method == "UQ":
705
+ if verbose:
706
+ print("Performing Upper Quartile (UQ) normalization.")
707
+ if uq_factors is None:
708
+ uq_factors = rnanorm.upper_quartile_factors(counts)
709
+ normalized_counts = (
710
+ rnanorm.UQ(factors=uq_factors)()
711
+ .set_output(transform="pandas")
712
+ .fit_transform(counts)
713
+ )
714
+
715
+ elif method == "TMM":
716
+ if verbose:
717
+ print("Performing TMM normalization (Trimmed Mean of M-values).")
718
+ normalized_counts = (
719
+ rnanorm.TMM().set_output(transform="pandas").fit_transform(counts)
720
+ )
721
+
722
+ elif method == "CUF":
723
+ if verbose:
724
+ print("Performing Counts adjusted with UQ factors (CUF).")
725
+ if uq_factors is None:
726
+ uq_factors = rnanorm.upper_quartile_factors(counts)
727
+ normalized_counts = (
728
+ rnanorm.CUF(factors=uq_factors)()
729
+ .set_output(transform="pandas")
730
+ .fit_transform(counts)
731
+ )
732
+
733
+ elif method == "CTF":
734
+ if verbose:
735
+ print("Performing Counts adjusted with TMM factors (CTF).")
736
+ normalized_counts = (rnanorm.CTF().set_output(transform="pandas").fit_transform(counts))
737
+
738
+ else:
739
+ raise ValueError(f"Unknown normalization method: {method}")
740
+ normalized_counts.columns=columns_org
741
+ if verbose:
742
+ print(f"Normalization complete using method: {method}")
743
+
744
+ return normalized_counts
745
+
746
+ def counts_deseq(counts_sam_gene: pd.DataFrame,
747
+ meta_sam_cond: pd.DataFrame,
748
+ design_factors:list=None,
749
+ kws_DeseqDataSet:dict={},
750
+ kws_DeseqStats:dict={}):
751
+ """
752
+ https://pydeseq2.readthedocs.io/en/latest/api/docstrings/pydeseq2.ds.DeseqStats.html
753
+ Note: Using normalized expression data in a DeseqDataSet object is generally not recommended
754
+ because the DESeq2 framework is designed to work with raw count data.
755
+ baseMean:
756
+ - This value represents the average normalized count (or expression level) of a
757
+ gene across all samples in your dataset.
758
+ - For example, a baseMean of 0.287 for 4933401J01Rik indicates that this gene has
759
+ low expression levels in the samples compared to others with higher baseMean
760
+ values like Xkr4 (591.015).
761
+ log2FoldChange: the magnitude and direction of change in expression between conditions.
762
+ lfcSE (Log Fold Change Standard Error): standard error of the log2FoldChange. It
763
+ indicates the uncertainty in the estimate of the fold change.A lower value indicates
764
+ more confidence in the fold change estimate.
765
+ padj: This value accounts for multiple testing corrections (e.g., Benjamini-Hochberg).
766
+ Log10transforming: The columns -log10(pvalue) and -log10(FDR) are transformations of
767
+ the p-values and adjusted p-values, respectively
768
+ """
769
+ from pydeseq2.dds import DeseqDataSet
770
+ from pydeseq2.ds import DeseqStats
771
+ from pydeseq2.default_inference import DefaultInference
772
+
773
+ # data filtering
774
+ # counts_sam_gene = counts_sam_gene.loc[:, ~(counts_sam_gene.sum(axis=0) < 10)]
775
+ if design_factors is None:
776
+ design_factors=meta_sam_cond.columns.tolist()
777
+
778
+ kws_DeseqDataSet.pop("design_factors",{})
779
+ refit_cooks=kws_DeseqDataSet.pop("refit_cooks",True)
780
+
781
+ #! DeseqDataSet
782
+ inference = DefaultInference(n_cpus=8)
783
+ dds = DeseqDataSet(
784
+ counts=counts_sam_gene,
785
+ metadata=meta_sam_cond,
786
+ design_factors=meta_sam_cond.columns.tolist(),
787
+ refit_cooks=refit_cooks,
788
+ inference=inference,
789
+ **kws_DeseqDataSet
790
+ )
791
+ dds.deseq2()
792
+ #* results
793
+ dds_explain="""
794
+ res[0]:
795
+ # X stores the count data,
796
+ # obs stores design factors,
797
+ # obsm stores sample-level data, such as "design_matrix" and "size_factors",
798
+ # varm stores gene-level data, such as "dispersions" and "LFC"."""
799
+ print(dds_explain)
800
+ #! DeseqStats
801
+ stat_res = DeseqStats(dds,**kws_DeseqStats)
802
+ stat_res.summary()
803
+ diff = stat_res.results_df.assign(padj=lambda x: x.padj.fillna(1))
804
+
805
+ # handle '0' issue, which will case inf when the later cal (e.g., log10)
806
+ diff["padj"] = diff["padj"].replace(0, 1e-10)
807
+ diff["pvalue"] = diff["pvalue"].replace(0, 1e-10)
808
+
809
+ diff["-log10(pvalue)"] = diff["pvalue"].apply(lambda x: -np.log10(x))
810
+ diff["-log10(FDR)"] = diff["padj"].apply(lambda x: -np.log10(x))
811
+ diff=diff.reset_index().rename(columns={"index": "gene"})
812
+ # sig_diff = (
813
+ # diff.query("log2FoldChange.abs()>0.585 & padj<0.05")
814
+ # .reset_index()
815
+ # .rename(columns={"index": "gene"})
816
+ # )
817
+ return dds, diff,stat_res
818
+
819
+ def scope_genes(gene_list: list, scopes:str=None, fields: str = "symbol", species="human"):
820
+ """
821
+ usage:
822
+ scope_genes(df_counts.columns.tolist()[:1000], species="mouse")
823
+ """
824
+ import mygene
825
+
826
+ if scopes is None:
827
+ # copy from: https://docs.mygene.info/en/latest/doc/query_service.html#scopes
828
+ scopes = ips.fload(
829
+ "/Users/macjianfeng/Dropbox/github/python/py2ls/py2ls/data/mygenes_fields_241022.txt",
830
+ kind="csv",
831
+ verbose=False,
832
+ )
833
+ scopes = ",".join([i.strip() for i in scopes.iloc[:, 0]])
834
+ mg = mygene.MyGeneInfo()
835
+ results = mg.querymany(
836
+ gene_list,
837
+ scopes=scopes,
838
+ fields=fields,
839
+ species=species,
840
+ )
841
+ return pd.DataFrame(results)
842
+
843
+ def get_enrichr(gene_symbol_list,
844
+ gene_sets:str,
845
+ download:bool = False,
846
+ species='Human',
847
+ dir_save="./",
848
+ plot_=False,
849
+ n_top=30,
850
+ palette=None,
851
+ check_shared=True,
852
+ figsize=(5,8),
853
+ show_ring=False,
854
+ xticklabels_rot=0,
855
+ title=None,# 'KEGG'
856
+ cutoff=0.05,
857
+ cmap="coolwarm",
858
+ size=5,
859
+ **kwargs):
860
+ """
861
+ Note: Enrichr uses a list of Entrez gene symbols as input.
862
+
863
+ """
864
+ kws_figsets = {}
865
+ for k_arg, v_arg in kwargs.items():
866
+ if "figset" in k_arg:
867
+ kws_figsets = v_arg
868
+ kwargs.pop(k_arg, None)
869
+ break
870
+ species_org=species
871
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
872
+ organisms=['Human', 'Mouse', 'Yeast', 'Fly', 'Fish', 'Worm']
873
+ species=ips.strcmp(species,organisms)[0]
874
+ if species_org.lower()!= species.lower():
875
+ print(f"species was corrected to {species}, becasue only support {organisms}")
876
+ if os.path.isfile(gene_sets):
877
+ gene_sets_name=os.path.basename(gene_sets)
878
+ gene_sets = ips.fload(gene_sets)
879
+ else:
880
+ lib_support_names = gp.get_library_name()
881
+ # correct input gene_set name
882
+ gene_sets_name=ips.strcmp(gene_sets,lib_support_names)[0]
883
+
884
+ # download it
885
+ if download:
886
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
887
+ else:
888
+ gene_sets = gene_sets_name # 避免重复下载
889
+ print(f"\ngene_sets get ready: {gene_sets_name}")
890
+
891
+ # gene symbols are uppercase
892
+ gene_symbol_list=[str(i).upper() for i in gene_symbol_list]
893
+
894
+ # # check how shared genes
895
+ if check_shared and isinstance(gene_sets, dict):
896
+ shared_genes=ips.shared(ips.flatten(gene_symbol_list,verbose=False),
897
+ ips.flatten(gene_sets,verbose=False),
898
+ verbose=False)
899
+
900
+ #! enrichr
901
+ try:
902
+ enr = gp.enrichr(
903
+ gene_list=gene_symbol_list,
904
+ gene_sets=gene_sets,
905
+ organism=species,
906
+ outdir=None, # don't write to disk
907
+ **kwargs
908
+ )
909
+ except ValueError as e:
910
+ print(f"\n{'!'*10} Error {'!'*10}\n{' '*4}{e}\n{'!'*10} Error {'!'*10}")
911
+ return None
912
+
913
+ results_df = enr.results
914
+ print(f"got enrichr reslutls; shape: {results_df.shape}\n")
915
+ results_df["-log10(Adjusted P-value)"] = -np.log10(results_df["Adjusted P-value"])
916
+ results_df.sort_values("-log10(Adjusted P-value)", inplace=True, ascending=False)
917
+
918
+ if plot_:
919
+ if palette is None:
920
+ palette=plot.get_color(n_top, cmap=cmap)[::-1]
921
+ #! barplot
922
+ if n_top<5:
923
+ height_=4
924
+ elif 5<=n_top<10:
925
+ height_=5
926
+ elif 5<=n_top<10:
927
+ height_=6
928
+ elif 10<=n_top<15:
929
+ height_=7
930
+ elif 15<=n_top<20:
931
+ height_=8
932
+ elif 20<=n_top<30:
933
+ height_=9
934
+ else:
935
+ height_=int(n_top/3)
936
+ plt.figure(figsize=[10, height_])
937
+
938
+ ax1=plot.plotxy(
939
+ data=results_df.head(n_top),
940
+ kind="barplot",
941
+ x="-log10(Adjusted P-value)",
942
+ y="Term",
943
+ hue="Term",
944
+ palette=palette,
945
+ legend=None,
946
+ )
947
+ plot.figsets(ax=ax1, **kws_figsets)
948
+ if dir_save:
949
+ ips.figsave(f"{dir_save} enr_barplot.pdf")
950
+ plt.show()
951
+
952
+ #! dotplot
953
+ cutoff_curr = cutoff
954
+ step=0.05
955
+ cutoff_stop = 0.5
956
+ while cutoff_curr <= cutoff_stop:
957
+ try:
958
+ if cutoff_curr!=cutoff:
959
+ plt.clf()
960
+ ax2 = gp.dotplot(enr.res2d,
961
+ column="Adjusted P-value",
962
+ show_ring=show_ring,
963
+ xticklabels_rot=xticklabels_rot,
964
+ title=title,
965
+ cmap=cmap,
966
+ cutoff=cutoff_curr,
967
+ top_term=n_top,
968
+ size=size,
969
+ figsize=[10, height_])
970
+ if len(ax2.collections)>=n_top:
971
+ print(f"cutoff={cutoff_curr} done! ")
972
+ break
973
+ if cutoff_curr==cutoff_stop:
974
+ break
975
+ cutoff_curr+=step
976
+ except Exception as e:
977
+ cutoff_curr+=step
978
+ print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
979
+ ax = plt.gca()
980
+ plot.figsets(ax=ax,**kws_figsets)
981
+
982
+ if dir_save:
983
+ ips.figsave(f"{dir_save}enr_dotplot.pdf")
984
+
985
+ return results_df
986
+
987
+ def plot_enrichr(results_df,
988
+ kind="bar",# 'barplot', 'dotplot'
989
+ cutoff=0.05,
990
+ show_ring=False,
991
+ xticklabels_rot=0,
992
+ title=None,# 'KEGG'
993
+ cmap="coolwarm",
994
+ n_top=10,
995
+ size=5,
996
+ ax=None,
997
+ **kwargs):
998
+ kws_figsets = {}
999
+ for k_arg, v_arg in kwargs.items():
1000
+ if "figset" in k_arg:
1001
+ kws_figsets = v_arg
1002
+ kwargs.pop(k_arg, None)
1003
+ break
1004
+ if isinstance(cmap,str):
1005
+ palette = plot.get_color(n_top, cmap=cmap)[::-1]
1006
+ elif isinstance(cmap,list):
1007
+ palette=cmap
1008
+
1009
+ if n_top<5:
1010
+ height_=4
1011
+ elif 5<=n_top<10:
1012
+ height_=5
1013
+ elif 5<=n_top<10:
1014
+ height_=6
1015
+ elif 10<=n_top<15:
1016
+ height_=7
1017
+ elif 15<=n_top<20:
1018
+ height_=8
1019
+ elif 20<=n_top<30:
1020
+ height_=9
1021
+ else:
1022
+ height_=int(n_top/3)
1023
+ if ax is None:
1024
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1025
+ #! barplot
1026
+ if 'bar' in kind.lower():
1027
+ ax=plot.plotxy(
1028
+ data=results_df.head(n_top),
1029
+ kind="barplot",
1030
+ x="-log10(Adjusted P-value)",
1031
+ y="Term",
1032
+ hue="Term",
1033
+ palette=palette,
1034
+ legend=None,
1035
+ )
1036
+ plot.figsets(ax=ax, **kws_figsets)
1037
+ return ax,results_df
1038
+ #! dotplot
1039
+ elif 'dot' in kind.lower():
1040
+ #! dotplot
1041
+ cutoff_curr = cutoff
1042
+ step=0.05
1043
+ cutoff_stop = 0.5
1044
+ while cutoff_curr <= cutoff_stop:
1045
+ try:
1046
+ if cutoff_curr!=cutoff:
1047
+ plt.clf()
1048
+ ax = gp.dotplot(results_df,
1049
+ column="Adjusted P-value",
1050
+ show_ring=show_ring,
1051
+ xticklabels_rot=xticklabels_rot,
1052
+ title=title,
1053
+ cmap=cmap,
1054
+ cutoff=cutoff_curr,
1055
+ top_term=n_top,
1056
+ size=size,
1057
+ figsize=[10, height_])
1058
+ if len(ax.collections)>=n_top:
1059
+ print(f"cutoff={cutoff_curr} done! ")
1060
+ break
1061
+ if cutoff_curr==cutoff_stop:
1062
+ break
1063
+ cutoff_curr+=step
1064
+ except Exception as e:
1065
+ cutoff_curr+=step
1066
+ print(f"Warning: trying cutoff={cutoff_curr}, cutoff={cutoff_curr-step} failed: {e} ")
1067
+ plot.figsets(ax=ax, **kws_figsets)
1068
+ return ax,results_df
1069
+ #! barplot with counts
1070
+ elif 'count' in kind.lower():
1071
+ # 从overlap中提取出个数
1072
+ results_df["count"] = results_df["Overlap"].apply(
1073
+ lambda x: int(x.split("/")[0]) if isinstance(x, str) else x)
1074
+ df_=results_df.sort_values(by="count", ascending=False)
1075
+ ax=plot.plotxy(
1076
+ data=df_.head(n_top),
1077
+ kind="barplot",
1078
+ x="count",
1079
+ y="Term",
1080
+ hue="Term",
1081
+ palette=palette,
1082
+ legend=None,
1083
+ ax=ax
1084
+ )
1085
+
1086
+ plot.figsets(ax=ax, **kws_figsets)
1087
+ return ax,df_
1088
+
1089
+ def plot_bp_cc_mf(
1090
+ deg_gene_list,
1091
+ gene_sets=[
1092
+ "GO_Biological_Process_2023",
1093
+ "GO_Cellular_Component_2023",
1094
+ "GO_Molecular_Function_2023",
1095
+ ],
1096
+ species="human",
1097
+ n_top=10,
1098
+ plot_=True,
1099
+ ax=None,
1100
+ palette=plot.get_color(3),
1101
+ ** kwargs,
1102
+ ):
1103
+
1104
+ def res_enrichr_2_count(res_enrichr, n_top=10):
1105
+ """把enrich resulst 提取出count,并排序"""
1106
+ res_enrichr["Count"] = res_enrichr["Overlap"].apply(
1107
+ lambda x: int(x.split("/")[0]) if isinstance(x, str) else x
1108
+ )
1109
+ res_enrichr.sort_values(by="Count", ascending=False, inplace=True)
1110
+
1111
+ return res_enrichr.head(n_top)#[["Term", "Count"]]
1112
+
1113
+ res_enrichr_BP = get_enrichr(
1114
+ deg_gene_list, gene_sets[0], species=species, plot_=False
1115
+ )
1116
+ res_enrichr_CC = get_enrichr(
1117
+ deg_gene_list, gene_sets[1], species=species, plot_=False
1118
+ )
1119
+ res_enrichr_MF = get_enrichr(
1120
+ deg_gene_list, gene_sets[2], species=species, plot_=False
1121
+ )
1122
+
1123
+ df_BP = res_enrichr_2_count(res_enrichr_BP, n_top=n_top)
1124
+ df_BP["Ontology"] = ["BP"] * n_top
1125
+
1126
+ df_CC = res_enrichr_2_count(res_enrichr_CC, n_top=n_top)
1127
+ df_CC["Ontology"] = ["CC"] * n_top
1128
+
1129
+ df_MF = res_enrichr_2_count(res_enrichr_MF, n_top=n_top)
1130
+ df_MF["Ontology"] = ["MF"] * n_top
1131
+
1132
+ # 合并
1133
+ df2plot = pd.concat([df_BP, df_CC, df_MF])
1134
+ n_top=n_top*3
1135
+ if n_top < 5:
1136
+ height_ = 4
1137
+ elif 5 <= n_top < 10:
1138
+ height_ = 5
1139
+ elif 10 <= n_top < 15:
1140
+ height_ = 6
1141
+ elif 15 <= n_top < 20:
1142
+ height_ = 7
1143
+ elif 20 <= n_top < 30:
1144
+ height_ = 8
1145
+ elif 30 <= n_top < 40:
1146
+ height_ = int(n_top / 4)
1147
+ else:
1148
+ height_ = int(n_top / 5)
1149
+ if ax is None:
1150
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1151
+ # 作图
1152
+ if df2plot["Term"].tolist()[0].endswith(")"):
1153
+ df2plot["Term"] = df2plot["Term"].apply(lambda x: x.split("(")[0][:-1])
1154
+ if plot_:
1155
+ ax = plot.plotxy(
1156
+ data=df2plot,
1157
+ x="Count",
1158
+ y="Term",
1159
+ hue="Ontology",
1160
+ kind="bar",
1161
+ palette=palette,
1162
+ ax=ax,
1163
+ **kwargs
1164
+ )
1165
+ return ax, df2plot
1166
+
1167
+ def get_library_name():
1168
+ return gp.get_library_name()
1169
+
1170
+ def get_gsva(
1171
+ data_gene_samples: pd.DataFrame, # index(gene),columns(samples)
1172
+ gene_sets: str,
1173
+ species:str="Human",
1174
+ dir_save:str="./",
1175
+ plot_:bool=False,
1176
+ n_top:int=30,
1177
+ check_shared:bool=True,
1178
+ cmap="coolwarm",
1179
+ min_size=1,
1180
+ max_size=1000,
1181
+ kcdf="Gaussian",# 'Gaussian' for continuous data
1182
+ method='gsva',
1183
+ seed=1,
1184
+ **kwargs,
1185
+ ):
1186
+ kws_figsets = {}
1187
+ for k_arg, v_arg in kwargs.items():
1188
+ if "figset" in k_arg:
1189
+ kws_figsets = v_arg
1190
+ kwargs.pop(k_arg, None)
1191
+ break
1192
+ species_org = species
1193
+ # organism (str) – Select one from { ‘Human’, ‘Mouse’, ‘Yeast’, ‘Fly’, ‘Fish’, ‘Worm’ }
1194
+ organisms = ["Human", "Mouse", "Yeast", "Fly", "Fish", "Worm"]
1195
+ species = ips.strcmp(species, organisms)[0]
1196
+ if species_org.lower() != species.lower():
1197
+ print(f"species was corrected to {species}, becasue only support {organisms}")
1198
+ if os.path.isfile(gene_sets):
1199
+ gene_sets_name = os.path.basename(gene_sets)
1200
+ gene_sets = ips.fload(gene_sets)
1201
+ else:
1202
+ lib_support_names = gp.get_library_name()
1203
+ # correct input gene_set name
1204
+ gene_sets_name = ips.strcmp(gene_sets, lib_support_names)[0]
1205
+ # download it
1206
+ gene_sets = gp.get_library(name=gene_sets_name, organism=species)
1207
+ print(f"gene_sets get ready: {gene_sets_name}")
1208
+
1209
+ # gene symbols are uppercase
1210
+ gene_symbol_list = [str(i).upper() for i in data_gene_samples.index]
1211
+ data_gene_samples.index=gene_symbol_list
1212
+ # display(data_gene_samples.head(3))
1213
+ # # check how shared genes
1214
+ if check_shared:
1215
+ ips.shared(
1216
+ ips.flatten(gene_symbol_list, verbose=False),
1217
+ ips.flatten(gene_sets, verbose=False),
1218
+ verbose=False
1219
+ )
1220
+ gsva_results = gp.gsva(
1221
+ data=data_gene_samples, # matrix should have genes as rows and samples as columns
1222
+ gene_sets=gene_sets,
1223
+ outdir=None,
1224
+ kcdf=kcdf, # 'Gaussian' for continuous data
1225
+ min_size=min_size,
1226
+ method=method,
1227
+ max_size=max_size,
1228
+ verbose=True,
1229
+ seed=seed,
1230
+ # no_plot=False,
1231
+ )
1232
+ gsva_res = gsva_results.res2d.copy()
1233
+ gsva_res["ES_abs"] = gsva_res["ES"].apply(np.abs)
1234
+ gsva_res = gsva_res.sort_values(by="ES_abs", ascending=False)
1235
+ gsva_res = (
1236
+ gsva_res.drop_duplicates(subset="Term").drop(columns="ES_abs")
1237
+ # .iloc[:80, :]
1238
+ .reset_index(drop=True)
1239
+ )
1240
+ gsva_res = gsva_res.sort_values(by="ES", ascending=False)
1241
+ if plot_:
1242
+ if gsva_res.shape[0]>=2*n_top:
1243
+ gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
1244
+ else:
1245
+ gsva_res_plot = gsva_res
1246
+ if isinstance(cmap,str):
1247
+ palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
1248
+ elif isinstance(cmap,list):
1249
+ if len(cmap)==2:
1250
+ palette = [cmap[0]]*n_top+[cmap[1]]*n_top
1251
+ else:
1252
+ palette=cmap
1253
+ # ! barplot
1254
+ if n_top < 5:
1255
+ height_ = 3
1256
+ elif 5 <= n_top < 10:
1257
+ height_ = 4
1258
+ elif 10 <= n_top < 15:
1259
+ height_ = 5
1260
+ elif 15 <= n_top < 20:
1261
+ height_ = 6
1262
+ elif 20 <= n_top < 30:
1263
+ height_ = 7
1264
+ elif 30 <= n_top < 40:
1265
+ height_ = int(n_top / 3.5)
1266
+ else:
1267
+ height_ = int(n_top / 3)
1268
+ plt.figure(figsize=[10, height_])
1269
+ ax2 = plot.plotxy(
1270
+ data=gsva_res_plot,
1271
+ x="ES",
1272
+ y="Term",
1273
+ hue="Term",
1274
+ palette=palette,
1275
+ kind=["bar"],
1276
+ figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
1277
+ )
1278
+ # 改变labels的位置
1279
+ for i, bar in enumerate(ax2.patches):
1280
+ term = gsva_res_plot.iloc[i]["Term"]
1281
+ es_value = gsva_res_plot.iloc[i]["ES"]
1282
+
1283
+ # Positive ES values: Align y-labels to the left
1284
+ if es_value > 0:
1285
+ ax2.annotate(
1286
+ term,
1287
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1288
+ xytext=(-5, 0), # Move to the left
1289
+ textcoords="offset points",
1290
+ ha="right",
1291
+ va="center", # Align labels to the right
1292
+ fontsize=10,
1293
+ color="black",
1294
+ )
1295
+ # Negative ES values: Align y-labels to the right
1296
+ else:
1297
+ ax2.annotate(
1298
+ term,
1299
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1300
+ xytext=(5, 0), # Move to the right
1301
+ textcoords="offset points",
1302
+ ha="left",
1303
+ va="center", # Align labels to the left
1304
+ fontsize=10,
1305
+ color="black",
1306
+ )
1307
+ plot.figsets(ax=ax2, **kws_figsets)
1308
+ if dir_save:
1309
+ ips.figsave(dir_save + f"GSVA_{gene_sets_name}.pdf")
1310
+ plt.show()
1311
+ return gsva_res.reset_index(drop=True)
1312
+
1313
+ def plot_gsva(gsva_res, # output from bio.get_gsva()
1314
+ n_top=10,
1315
+ ax=None,
1316
+ x="ES",
1317
+ y="Term",
1318
+ hue="Term",
1319
+ cmap="coolwarm",
1320
+ **kwargs
1321
+ ):
1322
+ kws_figsets = {}
1323
+ for k_arg, v_arg in kwargs.items():
1324
+ if "figset" in k_arg:
1325
+ kws_figsets = v_arg
1326
+ kwargs.pop(k_arg, None)
1327
+ break
1328
+ # ! barplot
1329
+ if n_top < 5:
1330
+ height_ = 4
1331
+ elif 5 <= n_top < 10:
1332
+ height_ = 5
1333
+ elif 10 <= n_top < 15:
1334
+ height_ = 6
1335
+ elif 15 <= n_top < 20:
1336
+ height_ = 7
1337
+ elif 20 <= n_top < 30:
1338
+ height_ = 8
1339
+ elif 30 <= n_top < 40:
1340
+ height_ = int(n_top / 3.5)
1341
+ else:
1342
+ height_ = int(n_top / 3)
1343
+ if ax is None:
1344
+ _,ax=plt.subplots(1,1,figsize=[10, height_])
1345
+ gsva_res = gsva_res.sort_values(by=x, ascending=False)
1346
+
1347
+ if gsva_res.shape[0]>=2*n_top:
1348
+ gsva_res_plot=pd.concat([gsva_res.head(n_top),gsva_res.tail(n_top)])
1349
+ else:
1350
+ gsva_res_plot = gsva_res
1351
+ if isinstance(cmap,str):
1352
+ palette = plot.get_color(n_top*2, cmap=cmap)[::-1]
1353
+ elif isinstance(cmap,list):
1354
+ if len(cmap)==2:
1355
+ palette = [cmap[0]]*n_top+[cmap[1]]*n_top
1356
+ else:
1357
+ palette=cmap
1358
+
1359
+ ax = plot.plotxy(
1360
+ ax=ax,
1361
+ data=gsva_res_plot,
1362
+ x=x,
1363
+ y=y,
1364
+ hue=hue,
1365
+ palette=palette,
1366
+ kind=["bar"],
1367
+ figsets=dict(yticklabel=[], ticksloc="b", boxloc="b", ylabel=None),
1368
+ )
1369
+ # 改变labels的位置
1370
+ for i, bar in enumerate(ax.patches):
1371
+ term = gsva_res_plot.iloc[i]["Term"]
1372
+ es_value = gsva_res_plot.iloc[i]["ES"]
1373
+
1374
+ # Positive ES values: Align y-labels to the left
1375
+ if es_value > 0:
1376
+ ax.annotate(
1377
+ term,
1378
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1379
+ xytext=(-5, 0), # Move to the left
1380
+ textcoords="offset points",
1381
+ ha="right",
1382
+ va="center", # Align labels to the right
1383
+ fontsize=10,
1384
+ color="black",
1385
+ )
1386
+ # Negative ES values: Align y-labels to the right
1387
+ else:
1388
+ ax.annotate(
1389
+ term,
1390
+ xy=(0, bar.get_y() + bar.get_height() / 2),
1391
+ xytext=(5, 0), # Move to the right
1392
+ textcoords="offset points",
1393
+ ha="left",
1394
+ va="center", # Align labels to the left
1395
+ fontsize=10,
1396
+ color="black",
1397
+ )
1398
+ plot.figsets(ax=ax, **kws_figsets)
1399
+ return ax
1400
+
1401
+
1402
+ #! https://string-db.org/help/api/
1403
+
1404
+ import pandas as pd
1405
+ import requests
1406
+ import networkx as nx
1407
+ import matplotlib.pyplot as plt
1408
+ from io import StringIO
1409
+ from py2ls import ips
1410
+
1411
+
1412
+ def get_ppi(
1413
+ target_genes:list,
1414
+ species:int=9606, # "human"
1415
+ ci:float=0.1, # int 1~1000
1416
+ max_nodes:int=50,
1417
+ base_url:str="https://string-db.org",
1418
+ gene_mapping_api:str="/api/json/get_string_ids?",
1419
+ interaction_api:str="/api/tsv/network?",
1420
+ ):
1421
+ """
1422
+ Generate a Protein-Protein Interaction (PPI) network using STRINGdb data.
1423
+
1424
+ return:
1425
+ the STRING protein-protein interaction (PPI) data, which contains information about
1426
+ predicted and experimentally validated associations between proteins.
1427
+
1428
+ stringId_A and stringId_B: Unique identifiers for the interacting proteins based on the
1429
+ STRING database.
1430
+ preferredName_A and preferredName_B: Standard gene names for the interacting proteins.
1431
+ ncbiTaxonId: The taxon ID (9606 for humans).
1432
+ score: A combined score reflecting the overall confidence of the interaction, which aggregates different sources of evidence.
1433
+
1434
+ nscore, fscore, pscore, ascore, escore, dscore, tscore: These are sub-scores representing the confidence in the interaction based on various evidence types:
1435
+ - nscore: Neighborhood score, based on genes located near each other in the genome.
1436
+ - fscore: Fusion score, based on gene fusions in other genomes.
1437
+ - pscore: Phylogenetic profile score, based on co-occurrence across different species.
1438
+ - ascore: Coexpression score, reflecting the likelihood of coexpression.
1439
+ - escore: Experimental score, based on experimental evidence.
1440
+ - dscore: Database score, from curated databases.
1441
+ - tscore: Text-mining score, from literature co-occurrence.
1442
+
1443
+ Higher score values (closer to 1) indicate stronger evidence for an interaction.
1444
+ - Combined score: Useful for ranking interactions based on overall confidence. A score >0.7 is typically considered high-confidence.
1445
+ - Sub-scores: Interpret the types of evidence supporting the interaction. For instance:
1446
+ - High ascore indicates strong evidence of coexpression.
1447
+ - High escore suggests experimental validation.
1448
+
1449
+ """
1450
+ print("check api: https://string-db.org/help/api/")
1451
+
1452
+ # 将species转化为taxon_id
1453
+ if isinstance(species,str):
1454
+ print(species)
1455
+ species=list(get_taxon_id(species).values())[0]
1456
+ print(species)
1457
+
1458
+
1459
+ string_api_url = base_url + gene_mapping_api
1460
+ interaction_api_url = base_url + interaction_api
1461
+ # Map gene symbols to STRING IDs
1462
+ mapped_genes = {}
1463
+ for gene in target_genes:
1464
+ params = {"identifiers": gene, "species": species, "limit": 1}
1465
+ response = requests.get(string_api_url, params=params)
1466
+ if response.status_code == 200:
1467
+ try:
1468
+ json_data = response.json()
1469
+ if json_data:
1470
+ mapped_genes[gene] = json_data[0]["stringId"]
1471
+ except ValueError:
1472
+ print(
1473
+ f"Failed to decode JSON for gene {gene}. Response: {response.text}"
1474
+ )
1475
+ else:
1476
+ print(
1477
+ f"Failed to fetch data for gene {gene}. Status code: {response.status_code}"
1478
+ )
1479
+ if not mapped_genes:
1480
+ print("No mapped genes found in STRING database.")
1481
+ return None
1482
+
1483
+ # Retrieve PPI data from STRING API
1484
+ string_ids = "%0d".join(mapped_genes.values())
1485
+ params = {
1486
+ "identifiers": string_ids,
1487
+ "species": species,
1488
+ "required_score": int(ci * 1000),
1489
+ "limit": max_nodes,
1490
+ }
1491
+ response = requests.get(interaction_api_url, params=params)
1492
+
1493
+ if response.status_code == 200:
1494
+ try:
1495
+ interactions = pd.read_csv(StringIO(response.text), sep="\t")
1496
+ except Exception as e:
1497
+ print("Error reading the interaction data:", e)
1498
+ print("Response content:", response.text)
1499
+ return None
1500
+ else:
1501
+ print(
1502
+ f"Failed to retrieve interaction data. Status code: {response.status_code}"
1503
+ )
1504
+ print("Response content:", response.text)
1505
+ return None
1506
+ display(interactions.head())
1507
+ # Filter interactions by ci score
1508
+ if "score" in interactions.columns:
1509
+ interactions = interactions[interactions["score"] >= ci]
1510
+ if interactions.empty:
1511
+ print("No interactions found with the specified confidence.")
1512
+ return None
1513
+ else:
1514
+ print("The 'score' column is missing from the retrieved data. Unable to filter by confidence interval.")
1515
+ if "fdr" in interactions.columns:
1516
+ interactions=interactions.sort_values(by="fdr",ascending=False)
1517
+ return interactions
1518
+ # * usage
1519
+ # interactions = get_ppi(target_genes, ci=0.0001)
1520
+
1521
+ def plot_ppi(
1522
+ interactions,
1523
+ player1="preferredName_A",
1524
+ player2="preferredName_B",
1525
+ weight="score",
1526
+ n_layers=None, # Number of concentric layers
1527
+ n_rank=[5, 10], # Nodes in each rank for the concentric layout
1528
+ dist_node = 10, # Distance between each rank of circles
1529
+ layout="degree",
1530
+ size='auto',#700,
1531
+ facecolor="skyblue",
1532
+ cmap='coolwarm',
1533
+ edgecolor="k",
1534
+ edgelinewidth=1.5,
1535
+ alpha=.5,
1536
+ marker="o",
1537
+ node_hideticks=True,
1538
+ linecolor="gray",
1539
+ linewidth=1.5,
1540
+ linestyle="-",
1541
+ line_arrowstyle='-',
1542
+ fontsize=10,
1543
+ fontcolor="k",
1544
+ ha:str="center",
1545
+ va:str="center",
1546
+ figsize=(12, 10),
1547
+ k_value=0.3,
1548
+ bgcolor="w",
1549
+ dir_save="./ppi_network.html",
1550
+ physics=True,
1551
+ notebook=False,
1552
+ scale=1,
1553
+ ax=None,
1554
+ **kwargs
1555
+ ):
1556
+ """
1557
+ Plot a Protein-Protein Interaction (PPI) network with adjustable appearance.
1558
+ """
1559
+ from pyvis.network import Network
1560
+ import networkx as nx
1561
+ from IPython.display import IFrame
1562
+ from matplotlib.colors import Normalize
1563
+ from matplotlib import cm
1564
+ # Check for required columns in the DataFrame
1565
+ for col in [player1, player2, weight]:
1566
+ if col not in interactions.columns:
1567
+ raise ValueError(f"Column '{col}' is missing from the interactions DataFrame.")
1568
+
1569
+ # Initialize Pyvis network
1570
+ net = Network(height="750px", width="100%", bgcolor=bgcolor, font_color=fontcolor)
1571
+ net.force_atlas_2based(
1572
+ gravity=-50, central_gravity=0.01, spring_length=100, spring_strength=0.1
1573
+ )
1574
+ net.toggle_physics(physics)
1575
+
1576
+ kws_figsets = {}
1577
+ for k_arg, v_arg in kwargs.items():
1578
+ if "figset" in k_arg:
1579
+ kws_figsets = v_arg
1580
+ kwargs.pop(k_arg, None)
1581
+ break
1582
+
1583
+ # Create a NetworkX graph from the interaction data
1584
+ G = nx.Graph()
1585
+ for _, row in interactions.iterrows():
1586
+ G.add_edge(row[player1], row[player2], weight=row[weight])
1587
+
1588
+ # Calculate node degrees
1589
+ degrees = dict(G.degree())
1590
+ norm = Normalize(vmin=min(degrees.values()), vmax=max(degrees.values()))
1591
+ colormap = cm.get_cmap(cmap) # Get the 'coolwarm' colormap
1592
+
1593
+ # Set properties based on degrees
1594
+ if not isinstance(size, (int,float,list)):
1595
+ size = [deg * 50 for deg in degrees.values()] # Scale sizes
1596
+ if not ips.isa(facecolor, 'color'):
1597
+ facecolor = [colormap(norm(deg)) for deg in degrees.values()] # Use colormap
1598
+ if size is None:
1599
+ size = [700] * G.number_of_nodes() # Default size for all nodes
1600
+ elif isinstance(size, (int, float)):
1601
+ size = [size] * G.number_of_nodes() # If a scalar, apply to all nodes
1602
+ # else:
1603
+ # size = size.tolist() # Ensure size is a list
1604
+ if len(size)>G.number_of_nodes():
1605
+ size=size[:G.number_of_nodes()]
1606
+
1607
+ for node in G.nodes():
1608
+ net.add_node(
1609
+ node,
1610
+ label=node,
1611
+ size=size[list(G.nodes()).index(node)] if isinstance(size,list) else size[0],
1612
+ color=facecolor[list(G.nodes()).index(node)] if isinstance(facecolor,list) else facecolor,
1613
+ font={"size": fontsize, "color": fontcolor},
1614
+ )
1615
+
1616
+ for edge in G.edges(data=True):
1617
+ net.add_edge(
1618
+ edge[0],
1619
+ edge[1],
1620
+ weight=edge[2]["weight"],
1621
+ color=edgecolor,
1622
+ width=edgelinewidth * edge[2]["weight"],
1623
+ )
1624
+ layouts = [
1625
+ "spring",
1626
+ "circular",
1627
+ "kamada_kawai",
1628
+ "random",
1629
+ "shell",
1630
+ "planar",
1631
+ "spiral",
1632
+ "degree"
1633
+ ]
1634
+ layout = ips.strcmp(layout, layouts)[0]
1635
+ print(layout)
1636
+ # Choose layout
1637
+ if layout == "spring":
1638
+ pos = nx.spring_layout(G, k=k_value)
1639
+ elif layout == "circular":
1640
+ pos = nx.circular_layout(G)
1641
+ elif layout == "kamada_kawai":
1642
+ pos = nx.kamada_kawai_layout(G)
1643
+ elif layout == "spectral":
1644
+ pos = nx.spectral_layout(G)
1645
+ elif layout == "random":
1646
+ pos = nx.random_layout(G)
1647
+ elif layout == "shell":
1648
+ pos = nx.shell_layout(G)
1649
+ elif layout == "planar":
1650
+ if nx.check_planarity(G)[0]:
1651
+ pos = nx.planar_layout(G)
1652
+ else:
1653
+ print("Graph is not planar; switching to spring layout.")
1654
+ pos = nx.spring_layout(G, k=k_value)
1655
+ elif layout == "spiral":
1656
+ pos = nx.spiral_layout(G)
1657
+ elif layout=='degree':
1658
+ # Calculate node degrees and sort nodes by degree
1659
+ degrees = dict(G.degree())
1660
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
1661
+
1662
+ # Create positions for concentric circles based on n_layers and n_rank
1663
+ pos = {}
1664
+ n_layers=len(n_rank)+1 if n_layers is None else n_layers
1665
+ for rank_index in range(n_layers):
1666
+ if rank_index < len(n_rank):
1667
+ nodes_per_rank = n_rank[rank_index]
1668
+ rank_nodes = sorted_nodes[sum(n_rank[:rank_index]): sum(n_rank[:rank_index + 1])]
1669
+ else:
1670
+ # 随机打乱剩余节点的顺序
1671
+ remaining_nodes = sorted_nodes[sum(n_rank[:rank_index]):]
1672
+ random_indices = np.random.permutation(len(remaining_nodes))
1673
+ rank_nodes = [remaining_nodes[i] for i in random_indices]
1674
+
1675
+ radius = (rank_index + 1) * dist_node # Radius for this rank
1676
+
1677
+ # Arrange nodes in a circle for the current rank
1678
+ for i, (node, degree) in enumerate(rank_nodes):
1679
+ angle = (i / len(rank_nodes)) * 2 * np.pi # Distribute around circle
1680
+ pos[node] = (radius * np.cos(angle), radius * np.sin(angle))
1681
+
1682
+ else:
1683
+ print(f"Unknown layout '{layout}', defaulting to 'spring',or可以用这些: {layouts}")
1684
+ pos = nx.spring_layout(G, k=k_value)
1685
+
1686
+ for node, (x, y) in pos.items():
1687
+ net.get_node(node)["x"] = x * scale
1688
+ net.get_node(node)["y"] = y * scale
1689
+
1690
+ # If ax is None, use plt.gca()
1691
+ if ax is None:
1692
+ fig, ax = plt.subplots(1,1,figsize=figsize)
1693
+
1694
+ # Draw nodes, edges, and labels with customization options
1695
+ nx.draw_networkx_nodes(
1696
+ G,
1697
+ pos,
1698
+ ax=ax,
1699
+ node_size=size,
1700
+ node_color=facecolor,
1701
+ linewidths=edgelinewidth,
1702
+ edgecolors=edgecolor,
1703
+ alpha=alpha,
1704
+ hide_ticks=node_hideticks,
1705
+ node_shape=marker
1706
+ )
1707
+ nx.draw_networkx_edges(
1708
+ G,
1709
+ pos,
1710
+ ax=ax,
1711
+ edge_color=linecolor,
1712
+ width=linewidth,
1713
+ style=linestyle,
1714
+ arrowstyle=line_arrowstyle,
1715
+ alpha=0.7
1716
+ )
1717
+ nx.draw_networkx_labels(
1718
+ G, pos, ax=ax, font_size=fontsize, font_color=fontcolor,horizontalalignment=ha,verticalalignment=va
1719
+ )
1720
+ plot.figsets(ax=ax,**kws_figsets)
1721
+ ax.axis("off")
1722
+ if dir_save:
1723
+ if not os.path.basename(dir_save):
1724
+ dir_save="_.html"
1725
+ net.write_html(dir_save)
1726
+ nx.write_graphml(G, dir_save.replace(".html",".graphml")) # Export to GraphML
1727
+ print(f"could be edited in Cytoscape \n{dir_save.replace(".html",".graphml")}")
1728
+ return G,ax
1729
+
1730
+
1731
+ # * usage:
1732
+ # G, ax = bio.plot_ppi(
1733
+ # interactions,
1734
+ # player1="preferredName_A",
1735
+ # player2="preferredName_B",
1736
+ # weight="score",
1737
+ # # size="auto",
1738
+ # # size=interactions["score"].tolist(),
1739
+ # # layout="circ",
1740
+ # n_rank=[5, 10, 15],
1741
+ # dist_node=100,
1742
+ # alpha=0.6,
1743
+ # linecolor="0.8",
1744
+ # linewidth=1,
1745
+ # figsize=(8, 8.5),
1746
+ # cmap="jet",
1747
+ # edgelinewidth=0.5,
1748
+ # # facecolor="#FF5F57",
1749
+ # fontsize=10,
1750
+ # # fontcolor="b",
1751
+ # # edgecolor="r",
1752
+ # # scale=100,
1753
+ # # physics=False,
1754
+ # figsets=dict(title="ppi networks"),
1755
+ # )
1756
+ # figsave("./ppi_network.pdf")
1757
+
1758
+ def top_ppi(interactions, n_top=10):
1759
+ """
1760
+ Analyzes protein-protein interactions (PPIs) to identify key proteins based on
1761
+ degree and betweenness centrality.
1762
+
1763
+ Parameters:
1764
+ interactions (pd.DataFrame): DataFrame containing PPI data with columns
1765
+ ['preferredName_A', 'preferredName_B', 'score'].
1766
+
1767
+ Returns:
1768
+ dict: A dictionary containing the top key proteins by degree and betweenness centrality.
1769
+ """
1770
+
1771
+ # Create a NetworkX graph from the interaction data
1772
+ G = nx.Graph()
1773
+ for _, row in interactions.iterrows():
1774
+ G.add_edge(row["preferredName_A"], row["preferredName_B"], weight=row["score"])
1775
+
1776
+ # Calculate Degree Centrality
1777
+ degree_centrality = G.degree()
1778
+ key_proteins_degree = sorted(degree_centrality, key=lambda x: x[1], reverse=True)
1779
+
1780
+ # Calculate Betweenness Centrality
1781
+ betweenness_centrality = nx.betweenness_centrality(G)
1782
+ key_proteins_betweenness = sorted(
1783
+ betweenness_centrality.items(), key=lambda x: x[1], reverse=True
1784
+ )
1785
+ print(
1786
+ {
1787
+ "Top 10 Key Proteins by Degree Centrality": key_proteins_degree[:10],
1788
+ "Top 10 Key Proteins by Betweenness Centrality": key_proteins_betweenness[
1789
+ :10
1790
+ ],
1791
+ }
1792
+ )
1793
+ # Return the top n_top key proteins
1794
+ if n_top == "all":
1795
+ return key_proteins_degree, key_proteins_betweenness
1796
+ else:
1797
+ return key_proteins_degree[:n_top], key_proteins_betweenness[:n_top]
1798
+
1799
+
1800
+ # * usage: top_ppi(interactions)
1801
+ # top_ppi(interactions, n_top="all")
1802
+ # top_ppi(interactions, n_top=10)
1803
+
1804
+
1805
+
1806
+ species_dict = {
1807
+ "Human": "Homo sapiens",
1808
+ "House mouse": "Mus musculus",
1809
+ "Zebrafish": "Danio rerio",
1810
+ "Norway rat": "Rattus norvegicus",
1811
+ "Fruit fly": "Drosophila melanogaster",
1812
+ "Baker's yeast": "Saccharomyces cerevisiae",
1813
+ "Nematode": "Caenorhabditis elegans",
1814
+ "Chicken": "Gallus gallus",
1815
+ "Cattle": "Bos taurus",
1816
+ "Rice": "Oryza sativa",
1817
+ "Thale cress": "Arabidopsis thaliana",
1818
+ "Guinea pig": "Cavia porcellus",
1819
+ "Domestic dog": "Canis lupus familiaris",
1820
+ "Domestic cat": "Felis catus",
1821
+ "Horse": "Equus caballus",
1822
+ "Domestic pig": "Sus scrofa",
1823
+ "African clawed frog": "Xenopus laevis",
1824
+ "Great white shark": "Carcharodon carcharias",
1825
+ "Common chimpanzee": "Pan troglodytes",
1826
+ "Rhesus macaque": "Macaca mulatta",
1827
+ "Water buffalo": "Bubalus bubalis",
1828
+ "Lettuce": "Lactuca sativa",
1829
+ "Tomato": "Solanum lycopersicum",
1830
+ "Maize": "Zea mays",
1831
+ "Cucumber": "Cucumis sativus",
1832
+ "Common grape vine": "Vitis vinifera",
1833
+ "Scots pine": "Pinus sylvestris",
1834
+ }
1835
+
1836
+
1837
+ def get_taxon_id(species_list):
1838
+ """
1839
+ Convert species names to their corresponding taxon ID codes.
1840
+
1841
+ Parameters:
1842
+ - species_list: List of species names (strings).
1843
+
1844
+ Returns:
1845
+ - dict: A dictionary with species names as keys and their taxon IDs as values.
1846
+ """
1847
+ from Bio import Entrez
1848
+
1849
+ if not isinstance(species_list, list):
1850
+ species_list = [species_list]
1851
+ species_list = [
1852
+ species_dict[ips.strcmp(i, ips.flatten(list(species_dict.keys())))[0]]
1853
+ for i in species_list
1854
+ ]
1855
+ taxon_dict = {}
1856
+
1857
+ for species in species_list:
1858
+ try:
1859
+ search_handle = Entrez.esearch(db="taxonomy", term=species)
1860
+ search_results = Entrez.read(search_handle)
1861
+ search_handle.close()
1862
+
1863
+ # Get the taxon ID
1864
+ if search_results["IdList"]:
1865
+ taxon_id = search_results["IdList"][0]
1866
+ taxon_dict[species] = int(taxon_id)
1867
+ else:
1868
+ taxon_dict[species] = None # Not found
1869
+ except Exception as e:
1870
+ print(f"Error occurred for species '{species}': {e}")
1871
+ taxon_dict[species] = None # Error in processing
1872
+ return taxon_dict
1873
+
1874
+
1875
+ # # * usage: get_taxon_id("human")
1876
+ # species_names = ["human", "nouse", "rat"]
1877
+ # taxon_ids = get_taxon_id(species_names)
1878
+ # print(taxon_ids)