gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
from gsMap.cauchy_combination_test import run_Cauchy_combination
|
|
6
|
+
from gsMap.config import QuickModeConfig
|
|
7
|
+
from gsMap.config.cauchy_config import check_cauchy_done
|
|
8
|
+
from gsMap.config.find_latent_config import check_find_latent_done
|
|
9
|
+
from gsMap.config.latent2gene_config import check_latent2gene_done
|
|
10
|
+
from gsMap.config.quick_mode_config import check_report_done
|
|
11
|
+
from gsMap.config.spatial_ldsc_config import check_spatial_ldsc_done
|
|
12
|
+
from gsMap.find_latent import run_find_latent_representation
|
|
13
|
+
from gsMap.latent2gene import run_latent_to_gene
|
|
14
|
+
from gsMap.report import run_report
|
|
15
|
+
from gsMap.spatial_ldsc.spatial_ldsc_jax import run_spatial_ldsc_jax
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("gsMap.pipeline")
|
|
18
|
+
|
|
19
|
+
def format_duration(seconds):
|
|
20
|
+
hours = int(seconds // 3600)
|
|
21
|
+
minutes = int((seconds % 3600) // 60)
|
|
22
|
+
if hours > 0:
|
|
23
|
+
return f"{hours}h {minutes}m"
|
|
24
|
+
else:
|
|
25
|
+
return f"{minutes}m {int(seconds % 60)}s"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def run_quick_mode(config: QuickModeConfig):
|
|
29
|
+
"""
|
|
30
|
+
Run the Quick Mode pipeline.
|
|
31
|
+
"""
|
|
32
|
+
logger.info("Starting Quick Mode pipeline")
|
|
33
|
+
pipeline_start_time = time.time()
|
|
34
|
+
|
|
35
|
+
steps = ["find_latent", "latent2gene", "spatial_ldsc", "cauchy", "report"]
|
|
36
|
+
try:
|
|
37
|
+
start_idx = steps.index(config.start_step)
|
|
38
|
+
except ValueError:
|
|
39
|
+
raise ValueError(f"Invalid start_step: {config.start_step}. Must be one of {steps}")
|
|
40
|
+
|
|
41
|
+
stop_idx = len(steps) - 1
|
|
42
|
+
if config.stop_step:
|
|
43
|
+
try:
|
|
44
|
+
stop_idx = steps.index(config.stop_step)
|
|
45
|
+
except ValueError:
|
|
46
|
+
raise ValueError(f"Invalid stop_step: {config.stop_step}. Must be one of {steps}")
|
|
47
|
+
|
|
48
|
+
if start_idx > stop_idx:
|
|
49
|
+
raise ValueError(f"start_step ({config.start_step}) must be before or equal to stop_step ({config.stop_step})")
|
|
50
|
+
|
|
51
|
+
# Step 1: Find Latent Representations
|
|
52
|
+
if start_idx <= 0 <= stop_idx:
|
|
53
|
+
logger.info("=== Step 1: Find Latent Representations ===")
|
|
54
|
+
start_time = time.time()
|
|
55
|
+
|
|
56
|
+
if check_find_latent_done(config):
|
|
57
|
+
logger.info(f"Find latent representations already done (verified via {config.find_latent_metadata_path}). Skipping...")
|
|
58
|
+
else:
|
|
59
|
+
run_find_latent_representation(config.find_latent_config)
|
|
60
|
+
|
|
61
|
+
logger.info(f"Step 1 completed in {format_duration(time.time() - start_time)}")
|
|
62
|
+
|
|
63
|
+
# Step 2: Latent to Gene
|
|
64
|
+
if start_idx <= 1 <= stop_idx:
|
|
65
|
+
logger.info("=== Step 2: Latent to Gene Mapping ===")
|
|
66
|
+
start_time = time.time()
|
|
67
|
+
|
|
68
|
+
if check_latent2gene_done(config):
|
|
69
|
+
logger.info("Latent to gene mapping already done. Skipping...")
|
|
70
|
+
else:
|
|
71
|
+
run_latent_to_gene(config.latent2gene_config)
|
|
72
|
+
|
|
73
|
+
logger.info(f"Step 2 completed in {format_duration(time.time() - start_time)}")
|
|
74
|
+
|
|
75
|
+
# Get lists of traits to process
|
|
76
|
+
if not config.sumstats_config_dict:
|
|
77
|
+
# Check if we should warn? Only if running step 3,4,5
|
|
78
|
+
if start_idx <= 4 and stop_idx >= 2:
|
|
79
|
+
logger.warning("No summary statistics provided. Steps requiring GWAS data (Spatial LDSC, Cauchy, Report) may fail or do nothing if relying on them.")
|
|
80
|
+
|
|
81
|
+
traits_to_process = config.sumstats_config_dict
|
|
82
|
+
|
|
83
|
+
# Step 3: Spatial LDSC
|
|
84
|
+
if start_idx <= 2 <= stop_idx:
|
|
85
|
+
logger.info("=== Step 3: Spatial LDSC ===")
|
|
86
|
+
start_time = time.time()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
traits_remaining = {}
|
|
90
|
+
for trait_name, sumstats_path in traits_to_process.items():
|
|
91
|
+
if check_spatial_ldsc_done(config, trait_name):
|
|
92
|
+
logger.info(f"Spatial LDSC result already exists for {trait_name}. Skipping...")
|
|
93
|
+
else:
|
|
94
|
+
traits_remaining[trait_name] = sumstats_path
|
|
95
|
+
|
|
96
|
+
if not traits_remaining:
|
|
97
|
+
logger.info("All traits have been processed for Spatial LDSC. Skipping step...")
|
|
98
|
+
else:
|
|
99
|
+
sldsc_config = config.spatial_ldsc_config
|
|
100
|
+
# Update config to run only remaining traits
|
|
101
|
+
sldsc_config.sumstats_config_dict = traits_remaining
|
|
102
|
+
run_spatial_ldsc_jax(sldsc_config)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
logger.info(f"Step 3 completed in {format_duration(time.time() - start_time)}")
|
|
106
|
+
|
|
107
|
+
# Step 4: Cauchy Combination
|
|
108
|
+
if start_idx <= 3 <= stop_idx:
|
|
109
|
+
logger.info("=== Step 4: Cauchy Combination ===")
|
|
110
|
+
start_time = time.time()
|
|
111
|
+
|
|
112
|
+
cauchy_check_list = [check_cauchy_done(config, trait_name) for trait_name in traits_to_process]
|
|
113
|
+
if all(cauchy_check_list):
|
|
114
|
+
logger.info("Cauchy combination already done for all traits. Skipping...")
|
|
115
|
+
else:
|
|
116
|
+
cauchy_config = config.cauchy_config
|
|
117
|
+
run_Cauchy_combination(cauchy_config)
|
|
118
|
+
|
|
119
|
+
logger.info(f"Step 4 completed in {format_duration(time.time() - start_time)}")
|
|
120
|
+
|
|
121
|
+
# Step 5: Report
|
|
122
|
+
if start_idx <= 4 <= stop_idx:
|
|
123
|
+
logger.info("=== Step 5: Generate Report ===")
|
|
124
|
+
start_time = time.time()
|
|
125
|
+
|
|
126
|
+
if check_report_done(config, verbose=True):
|
|
127
|
+
logger.info("Report already exists. Skipping...")
|
|
128
|
+
else:
|
|
129
|
+
|
|
130
|
+
run_report(config,)
|
|
131
|
+
|
|
132
|
+
logger.info(f"Step 5 completed in {format_duration(time.time() - start_time)}")
|
|
133
|
+
|
|
134
|
+
logger.info(f"Pipeline completed successfully in {format_duration(time.time() - pipeline_start_time)}")
|
gsMap/report/__init__.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import multiprocessing
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import anndata as ad
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from scipy.stats import norm
|
|
11
|
+
|
|
12
|
+
from gsMap.config import DiagnosisConfig
|
|
13
|
+
from gsMap.utils.manhattan_plot import ManhattanPlot
|
|
14
|
+
from gsMap.utils.regression_read import _read_chr_files
|
|
15
|
+
|
|
16
|
+
from .visualize import draw_scatter, estimate_plotly_point_size, load_ldsc, load_st_coord
|
|
17
|
+
|
|
18
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def convert_z_to_p(gwas_data):
|
|
23
|
+
"""Convert Z-scores to P-values."""
|
|
24
|
+
gwas_data["P"] = norm.sf(abs(gwas_data["Z"])) * 2
|
|
25
|
+
min_p_value = 1e-300
|
|
26
|
+
gwas_data["P"] = gwas_data["P"].clip(lower=min_p_value)
|
|
27
|
+
return gwas_data
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def load_gene_diagnostic_info(config: DiagnosisConfig, adata: ad.AnnData | None = None):
|
|
31
|
+
"""Load or compute gene diagnostic info."""
|
|
32
|
+
gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
|
|
33
|
+
if gene_diagnostic_info_save_path.exists():
|
|
34
|
+
logger.info(
|
|
35
|
+
f"Loading gene diagnostic information from {gene_diagnostic_info_save_path}..."
|
|
36
|
+
)
|
|
37
|
+
return pd.read_csv(gene_diagnostic_info_save_path)
|
|
38
|
+
else:
|
|
39
|
+
logger.info(
|
|
40
|
+
"Gene diagnostic information not found. Calculating gene diagnostic information..."
|
|
41
|
+
)
|
|
42
|
+
return compute_gene_diagnostic_info(config, adata=adata)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def compute_gene_diagnostic_info(config: DiagnosisConfig, adata: ad.AnnData | None = None):
|
|
46
|
+
"""Calculate gene diagnostic info and save it to adata."""
|
|
47
|
+
logger.info("Loading ST data and LDSC results...")
|
|
48
|
+
|
|
49
|
+
if adata is None:
|
|
50
|
+
adata = ad.read_h5ad(config.hdf5_with_latent_path)
|
|
51
|
+
|
|
52
|
+
mk_score = pd.read_feather(config.mkscore_feather_path)
|
|
53
|
+
mk_score.set_index("HUMAN_GENE_SYM", inplace=True)
|
|
54
|
+
mk_score = mk_score.T
|
|
55
|
+
trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
|
|
56
|
+
|
|
57
|
+
# Align marker scores with trait LDSC results
|
|
58
|
+
mk_score = mk_score.loc[trait_ldsc_result.index]
|
|
59
|
+
|
|
60
|
+
# Filter out genes with no variation
|
|
61
|
+
has_variation = (~mk_score.eq(mk_score.iloc[0], axis=1)).any()
|
|
62
|
+
mk_score = mk_score.loc[:, has_variation]
|
|
63
|
+
|
|
64
|
+
logger.info("Calculating correlation between gene marker scores and trait logp-values...")
|
|
65
|
+
corr = mk_score.corrwith(trait_ldsc_result["logp"])
|
|
66
|
+
corr.name = "PCC"
|
|
67
|
+
|
|
68
|
+
grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
|
|
69
|
+
max_annotations = grouped_mk_score.idxmax()
|
|
70
|
+
|
|
71
|
+
high_GSS_Gene_annotation_pair = pd.DataFrame(
|
|
72
|
+
{
|
|
73
|
+
"Gene": max_annotations.index,
|
|
74
|
+
"Annotation": max_annotations.values,
|
|
75
|
+
"Median_GSS": grouped_mk_score.max().values,
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(
|
|
80
|
+
corr, left_on="Gene", right_index=True
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Prepare the final gene diagnostic info dataframe
|
|
84
|
+
gene_diagnostic_info_cols = ["Gene", "Annotation", "Median_GSS", "PCC"]
|
|
85
|
+
gene_diagnostic_info = (
|
|
86
|
+
high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols]
|
|
87
|
+
.drop_duplicates()
|
|
88
|
+
.dropna(subset=["Gene"])
|
|
89
|
+
)
|
|
90
|
+
gene_diagnostic_info.sort_values("PCC", ascending=False, inplace=True)
|
|
91
|
+
|
|
92
|
+
# Save gene diagnostic info to a file
|
|
93
|
+
gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
|
|
94
|
+
gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
|
|
95
|
+
logger.info(f"Gene diagnostic information saved to {gene_diagnostic_info_save_path}.")
|
|
96
|
+
|
|
97
|
+
return gene_diagnostic_info.reset_index()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_gwas_data(sumstats_file):
|
|
101
|
+
"""Load and process GWAS data."""
|
|
102
|
+
logger.info("Loading and processing GWAS data...")
|
|
103
|
+
gwas_data = pd.read_csv(sumstats_file, compression="gzip", sep="\t")
|
|
104
|
+
return convert_z_to_p(gwas_data)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_snp_gene_pairs(config: DiagnosisConfig):
|
|
108
|
+
"""Load SNP-gene pairs from multiple chromosomes."""
|
|
109
|
+
ldscore_save_dir = Path(config.ldscore_save_dir)
|
|
110
|
+
snp_gene_pair_file_prefix = ldscore_save_dir / "SNP_gene_pair/SNP_gene_pair_chr"
|
|
111
|
+
return pd.concat(
|
|
112
|
+
[
|
|
113
|
+
pd.read_feather(file)
|
|
114
|
+
for file in _read_chr_files(snp_gene_pair_file_prefix.as_posix(), suffix=".feather")
|
|
115
|
+
]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
|
|
120
|
+
"""Filter the SNPs based on significance levels."""
|
|
121
|
+
pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort["P"] < 1e-5
|
|
122
|
+
pass_suggestive_line_number = pass_suggestive_line_mask.sum()
|
|
123
|
+
|
|
124
|
+
if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
|
|
125
|
+
snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
|
|
126
|
+
logger.info(
|
|
127
|
+
f"To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted."
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
|
|
131
|
+
logger.info(
|
|
132
|
+
f"To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return snps2plot
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def generate_manhattan_plot(config: DiagnosisConfig, adata: ad.AnnData | None = None):
|
|
139
|
+
"""Generate Manhattan plot."""
|
|
140
|
+
# report_save_dir = config.get_report_dir(config.trait_name)
|
|
141
|
+
gwas_data = load_gwas_data(config.sumstats_file)
|
|
142
|
+
snp_gene_pair = load_snp_gene_pairs(config)
|
|
143
|
+
gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on="SNP", how="inner").rename(
|
|
144
|
+
columns={"gene_name": "GENE"}
|
|
145
|
+
)
|
|
146
|
+
gene_diagnostic_info = load_gene_diagnostic_info(config, adata=adata)
|
|
147
|
+
gwas_data_with_gene_annotation = gwas_data_with_gene.merge(
|
|
148
|
+
gene_diagnostic_info, left_on="GENE", right_on="Gene", how="left"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
|
|
152
|
+
~gwas_data_with_gene_annotation["Annotation"].isna()
|
|
153
|
+
]
|
|
154
|
+
gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values("P")
|
|
155
|
+
|
|
156
|
+
snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
|
|
157
|
+
gwas_data_to_plot = gwas_data_with_gene_annotation[
|
|
158
|
+
gwas_data_with_gene_annotation["SNP"].isin(snps2plot)
|
|
159
|
+
].reset_index(drop=True)
|
|
160
|
+
gwas_data_to_plot["Annotation_text"] = (
|
|
161
|
+
"PCC: "
|
|
162
|
+
+ gwas_data_to_plot["PCC"].round(2).astype(str)
|
|
163
|
+
+ "<br>"
|
|
164
|
+
+ "Annotation: "
|
|
165
|
+
+ gwas_data_to_plot["Annotation"].astype(str)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Verify data integrity
|
|
169
|
+
if gwas_data_with_gene_annotation_sort.empty:
|
|
170
|
+
raise ValueError("Filtered GWAS data is empty, cannot create Manhattan plot")
|
|
171
|
+
|
|
172
|
+
if len(gwas_data_to_plot) == 0:
|
|
173
|
+
raise ValueError("No SNPs passed filtering criteria for Manhattan plot")
|
|
174
|
+
|
|
175
|
+
# Log some diagnostic information
|
|
176
|
+
logger.info(f"Creating Manhattan plot with {len(gwas_data_to_plot)} SNPs")
|
|
177
|
+
logger.info(f"Chromosome column values: {gwas_data_to_plot['CHR'].unique()}")
|
|
178
|
+
|
|
179
|
+
fig = ManhattanPlot(
|
|
180
|
+
dataframe=gwas_data_to_plot,
|
|
181
|
+
title="gsMap Diagnosis Manhattan Plot",
|
|
182
|
+
point_size=3,
|
|
183
|
+
highlight_gene_list=config.selected_genes
|
|
184
|
+
or gene_diagnostic_info.Gene.iloc[: config.top_corr_genes].tolist(),
|
|
185
|
+
suggestiveline_value=-np.log10(1e-5),
|
|
186
|
+
annotation="Annotation_text",
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
|
|
190
|
+
fig.write_html(save_manhattan_plot_path)
|
|
191
|
+
logger.info(f"Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def generate_GSS_distribution(config: DiagnosisConfig, adata: ad.AnnData):
|
|
195
|
+
"""Generate GSS distribution plots."""
|
|
196
|
+
# mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM").T
|
|
197
|
+
# We should avoid loading large files inside workers if possible, or use memmap.
|
|
198
|
+
# For now, let's load it once here.
|
|
199
|
+
mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM").T
|
|
200
|
+
|
|
201
|
+
plot_genes = (
|
|
202
|
+
config.selected_genes
|
|
203
|
+
or load_gene_diagnostic_info(config, adata=adata).Gene.iloc[: config.top_corr_genes].tolist()
|
|
204
|
+
)
|
|
205
|
+
if config.selected_genes is not None:
|
|
206
|
+
logger.info(
|
|
207
|
+
f"Generating GSS & Expression distribution plot for selected genes: {plot_genes}..."
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
logger.info(
|
|
211
|
+
f"Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes..."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if config.customize_fig:
|
|
215
|
+
pixel_width, pixel_height, point_size = (
|
|
216
|
+
config.fig_width,
|
|
217
|
+
config.fig_height,
|
|
218
|
+
config.point_size,
|
|
219
|
+
)
|
|
220
|
+
else:
|
|
221
|
+
(pixel_width, pixel_height), point_size = estimate_plotly_point_size(
|
|
222
|
+
adata.obsm["spatial"]
|
|
223
|
+
)
|
|
224
|
+
sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
|
|
225
|
+
|
|
226
|
+
# save plot gene list
|
|
227
|
+
config.get_GSS_plot_select_gene_file(config.trait_name).write_text("\n".join(plot_genes))
|
|
228
|
+
|
|
229
|
+
paralleized_params = []
|
|
230
|
+
for selected_gene in plot_genes:
|
|
231
|
+
expression_series = pd.Series(
|
|
232
|
+
adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index, name="Expression"
|
|
233
|
+
)
|
|
234
|
+
threshold = np.quantile(expression_series, 0.9999)
|
|
235
|
+
expression_series[expression_series > threshold] = threshold
|
|
236
|
+
|
|
237
|
+
paralleized_params.append(
|
|
238
|
+
(
|
|
239
|
+
adata,
|
|
240
|
+
mk_score[[selected_gene]], # Pass only needed gene to save memory
|
|
241
|
+
expression_series,
|
|
242
|
+
selected_gene,
|
|
243
|
+
point_size,
|
|
244
|
+
pixel_width,
|
|
245
|
+
pixel_height,
|
|
246
|
+
sub_fig_save_dir,
|
|
247
|
+
config.project_name,
|
|
248
|
+
config.annotation,
|
|
249
|
+
config.plot_origin,
|
|
250
|
+
)
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
with multiprocessing.Pool(os.cpu_count() // 2) as pool:
|
|
254
|
+
pool.starmap(generate_and_save_plots, paralleized_params)
|
|
255
|
+
pool.close()
|
|
256
|
+
pool.join()
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def generate_and_save_plots(
|
|
260
|
+
adata,
|
|
261
|
+
mk_score,
|
|
262
|
+
expression_series,
|
|
263
|
+
selected_gene,
|
|
264
|
+
point_size,
|
|
265
|
+
pixel_width,
|
|
266
|
+
pixel_height,
|
|
267
|
+
sub_fig_save_dir,
|
|
268
|
+
sample_name,
|
|
269
|
+
annotation,
|
|
270
|
+
plot_origin: str = "upper",
|
|
271
|
+
):
|
|
272
|
+
"""Generate and save the plots."""
|
|
273
|
+
select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
|
|
274
|
+
sub_fig_1 = draw_scatter(
|
|
275
|
+
select_gene_expression_with_space_coord,
|
|
276
|
+
title=f"{selected_gene} (Expression)",
|
|
277
|
+
annotation="annotation",
|
|
278
|
+
color_by="Expression",
|
|
279
|
+
point_size=point_size,
|
|
280
|
+
width=pixel_width,
|
|
281
|
+
height=pixel_height,
|
|
282
|
+
plot_origin=plot_origin,
|
|
283
|
+
)
|
|
284
|
+
save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, "Expression")
|
|
285
|
+
|
|
286
|
+
select_gene_GSS_with_space_coord = load_st_coord(
|
|
287
|
+
adata, mk_score[selected_gene].rename("GSS"), annotation
|
|
288
|
+
)
|
|
289
|
+
sub_fig_2 = draw_scatter(
|
|
290
|
+
select_gene_GSS_with_space_coord,
|
|
291
|
+
title=f"{selected_gene} (GSS)",
|
|
292
|
+
annotation="annotation",
|
|
293
|
+
color_by="GSS",
|
|
294
|
+
point_size=point_size,
|
|
295
|
+
width=pixel_width,
|
|
296
|
+
height=pixel_height,
|
|
297
|
+
plot_origin=plot_origin,
|
|
298
|
+
)
|
|
299
|
+
save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, "GSS")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
|
|
303
|
+
"""Save the plot to HTML and PNG."""
|
|
304
|
+
save_sub_fig_path = (
|
|
305
|
+
sub_fig_save_dir / f"{sample_name}_{selected_gene}_{plot_type}_Distribution.png"
|
|
306
|
+
)
|
|
307
|
+
# sub_fig.write_html(str(save_sub_fig_path))
|
|
308
|
+
sub_fig.update_layout(showlegend=False)
|
|
309
|
+
sub_fig.write_image(save_sub_fig_path)
|
|
310
|
+
assert save_sub_fig_path.exists(), f"Failed to save {plot_type} plot for {selected_gene}."
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def generate_gsMap_plot(config: DiagnosisConfig, adata: ad.AnnData):
|
|
314
|
+
"""Generate gsMap plot."""
|
|
315
|
+
logger.info("Creating gsMap plot...")
|
|
316
|
+
|
|
317
|
+
trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
|
|
318
|
+
space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
|
|
319
|
+
|
|
320
|
+
if config.customize_fig:
|
|
321
|
+
pixel_width, pixel_height, point_size = (
|
|
322
|
+
config.fig_width,
|
|
323
|
+
config.fig_height,
|
|
324
|
+
config.point_size,
|
|
325
|
+
)
|
|
326
|
+
else:
|
|
327
|
+
(pixel_width, pixel_height), point_size = estimate_plotly_point_size(
|
|
328
|
+
adata.obsm["spatial"]
|
|
329
|
+
)
|
|
330
|
+
fig = draw_scatter(
|
|
331
|
+
space_coord_concat,
|
|
332
|
+
title=f"{config.trait_name} (gsMap)",
|
|
333
|
+
point_size=point_size,
|
|
334
|
+
width=pixel_width,
|
|
335
|
+
height=pixel_height,
|
|
336
|
+
annotation=config.annotation,
|
|
337
|
+
plot_origin=config.plot_origin,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
|
|
341
|
+
output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
|
|
342
|
+
output_file_png = output_file_html.with_suffix(".png")
|
|
343
|
+
output_file_csv = output_file_html.with_suffix(".csv")
|
|
344
|
+
|
|
345
|
+
fig.write_html(output_file_html)
|
|
346
|
+
fig.write_image(output_file_png)
|
|
347
|
+
space_coord_concat.to_csv(output_file_csv)
|
|
348
|
+
|
|
349
|
+
logger.info(f"gsMap plot created and saved in {output_dir}.")
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def run_Diagnosis(config: DiagnosisConfig):
|
|
353
|
+
"""Main function to run the diagnostic plot generation."""
|
|
354
|
+
adata = ad.read_h5ad(config.hdf5_with_latent_path)
|
|
355
|
+
if "pcc" not in adata.var.columns:
|
|
356
|
+
# Manual normalization and log1p to avoid scanpy dependency/warnings
|
|
357
|
+
if hasattr(adata.X, 'toarray'):
|
|
358
|
+
x_dense = adata.X.toarray()
|
|
359
|
+
else:
|
|
360
|
+
x_dense = adata.X
|
|
361
|
+
|
|
362
|
+
# Normalize to target sum 1e4
|
|
363
|
+
row_sums = x_dense.sum(axis=1)
|
|
364
|
+
row_sums[row_sums == 0] = 1 # Avoid division by zero
|
|
365
|
+
x_norm = (x_dense / row_sums.reshape(-1, 1)) * 1e4
|
|
366
|
+
|
|
367
|
+
# Log transformation
|
|
368
|
+
adata.X = np.log1p(x_norm)
|
|
369
|
+
|
|
370
|
+
if config.plot_type in ["gsMap", "all"]:
|
|
371
|
+
generate_gsMap_plot(config, adata=adata)
|
|
372
|
+
if config.plot_type in ["manhattan", "all"]:
|
|
373
|
+
generate_manhattan_plot(config, adata=adata)
|
|
374
|
+
if config.plot_type in ["GSS", "all"]:
|
|
375
|
+
generate_GSS_distribution(config, adata=adata)
|
gsMap/report/report.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import gsMap
|
|
6
|
+
from gsMap.config import QuickModeConfig
|
|
7
|
+
|
|
8
|
+
from .report_data import ReportDataManager
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
def run_report(config: QuickModeConfig, run_parameters: dict = None):
|
|
13
|
+
"""
|
|
14
|
+
Main entry point for report generation.
|
|
15
|
+
Prepares data and saves the interactive report as a standalone Alpine+Tailwind HTML folder.
|
|
16
|
+
|
|
17
|
+
Output structure:
|
|
18
|
+
project_dir/
|
|
19
|
+
├── report_data/ # Data files (CSV, h5ad)
|
|
20
|
+
│ ├── spot_metadata.csv
|
|
21
|
+
│ ├── cauchy_results.csv
|
|
22
|
+
│ ├── umap_data.csv
|
|
23
|
+
│ ├── gene_list.csv
|
|
24
|
+
│ ├── gss_stats/
|
|
25
|
+
│ │ └── gene_trait_correlation_{trait}.csv
|
|
26
|
+
│ ├── manhattan_data/
|
|
27
|
+
│ │ └── {trait}_manhattan.csv
|
|
28
|
+
│ └── spatial_3d/
|
|
29
|
+
│ └── spatial_3d.h5ad
|
|
30
|
+
│
|
|
31
|
+
└── gsmap_web_report/ # Web report (self-contained)
|
|
32
|
+
├── index.html
|
|
33
|
+
├── report_meta.json
|
|
34
|
+
├── execution_summary.yaml
|
|
35
|
+
├── spatial_plots/
|
|
36
|
+
│ └── ldsc_{trait}.png
|
|
37
|
+
├── gene_diagnostic_plots/
|
|
38
|
+
├── annotation_plots/
|
|
39
|
+
├── spatial_3d/
|
|
40
|
+
│ └── *.html
|
|
41
|
+
├── js_lib/
|
|
42
|
+
└── js_data/
|
|
43
|
+
├── gss_stats/
|
|
44
|
+
├── sample_index.js
|
|
45
|
+
├── sample_{name}_spatial.js
|
|
46
|
+
└── ... (other JS modules)
|
|
47
|
+
"""
|
|
48
|
+
logger.info("Running gsMap Report Module (Alpine.js + Tailwind based)")
|
|
49
|
+
|
|
50
|
+
# 1. Use ReportDataManager to prepare all data and JS assets
|
|
51
|
+
manager = ReportDataManager(config)
|
|
52
|
+
web_report_dir = manager.run()
|
|
53
|
+
|
|
54
|
+
# 2. Save run_parameters for future reference
|
|
55
|
+
if run_parameters:
|
|
56
|
+
import yaml
|
|
57
|
+
with open(web_report_dir / "execution_summary.yaml", "w") as f:
|
|
58
|
+
yaml.dump(run_parameters, f)
|
|
59
|
+
|
|
60
|
+
# 3. Render the Jinja2 template
|
|
61
|
+
template_path = Path(__file__).parent / "static" / "template.html"
|
|
62
|
+
if not template_path.exists():
|
|
63
|
+
logger.error(f"Template file not found at {template_path}")
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
from jinja2 import Template
|
|
68
|
+
with open(template_path, encoding="utf-8") as f:
|
|
69
|
+
template = Template(f.read())
|
|
70
|
+
|
|
71
|
+
# Prepare context
|
|
72
|
+
context = {
|
|
73
|
+
"title": f"gsMap Report - {config.project_name}",
|
|
74
|
+
"project_name": config.project_name,
|
|
75
|
+
"generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
76
|
+
"gsmap_version": getattr(gsMap, "__version__", "unknown"),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
rendered_html = template.render(**context)
|
|
80
|
+
|
|
81
|
+
report_file = web_report_dir / "index.html"
|
|
82
|
+
with open(report_file, "w", encoding="utf-8") as f:
|
|
83
|
+
f.write(rendered_html)
|
|
84
|
+
|
|
85
|
+
from rich import print as rprint
|
|
86
|
+
rprint("\n[bold green]Report generated successfully![/bold green]")
|
|
87
|
+
rprint(f"Web report directory: [cyan]{web_report_dir}[/cyan]")
|
|
88
|
+
rprint(f"Data files directory: [cyan]{config.report_data_dir}[/cyan]\n")
|
|
89
|
+
|
|
90
|
+
rprint("[bold]Ways to view the interactive report:[/bold]")
|
|
91
|
+
rprint("1. [bold white]Remote Server:[/bold white] Run the command below to start a temporary web server:")
|
|
92
|
+
rprint(f" [bold cyan]gsmap report-view {web_report_dir} --port 8080 --no-browser[/bold cyan]")
|
|
93
|
+
rprint(f"\n2. [bold white]Local PC:[/bold white] Copy the [cyan]{web_report_dir.name}[/cyan] folder to your machine and open [cyan]index.html[/cyan].\n")
|
|
94
|
+
|
|
95
|
+
except ImportError:
|
|
96
|
+
logger.error("Jinja2 not found. Please install it with 'pip install jinja2'.")
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.error(f"Failed to render report: {e}")
|
|
99
|
+
import traceback
|
|
100
|
+
traceback.print_exc()
|