gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1832 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Report Data Preparation Module
|
|
3
|
+
|
|
4
|
+
This module prepares data for the Alpine.js + Tailwind CSS + Plotly.js report.
|
|
5
|
+
All data is exported as JS files with window global variables to bypass CORS restrictions
|
|
6
|
+
when opening index.html via file:// protocol.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import gc
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import shutil
|
|
13
|
+
import threading
|
|
14
|
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
import anndata as ad
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
import plotly
|
|
21
|
+
import scanpy as sc
|
|
22
|
+
import scipy.sparse as sp
|
|
23
|
+
|
|
24
|
+
from gsMap.config import QuickModeConfig
|
|
25
|
+
from gsMap.config.latent2gene_config import DatasetType
|
|
26
|
+
from gsMap.find_latent.st_process import setup_data_layer
|
|
27
|
+
from gsMap.report.diagnosis import filter_snps, load_gwas_data
|
|
28
|
+
from gsMap.report.three_d_plot.three_d_plots import three_d_plot, three_d_plot_save
|
|
29
|
+
from gsMap.report.visualize import (
|
|
30
|
+
estimate_plotly_point_size,
|
|
31
|
+
)
|
|
32
|
+
from gsMap.spatial_ldsc.io import load_marker_scores_memmap_format
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ReportDataManager:
|
|
38
|
+
def __init__(self, report_config: QuickModeConfig):
|
|
39
|
+
self.report_config = report_config
|
|
40
|
+
|
|
41
|
+
# Two separate directories: data files vs web report
|
|
42
|
+
self.report_data_dir = report_config.report_data_dir
|
|
43
|
+
self.web_report_dir = report_config.web_report_dir
|
|
44
|
+
self.js_data_dir = self.web_report_dir / "js_data"
|
|
45
|
+
|
|
46
|
+
# Create directories
|
|
47
|
+
self.report_data_dir.mkdir(parents=True, exist_ok=True)
|
|
48
|
+
self.web_report_dir.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
self.js_data_dir.mkdir(exist_ok=True)
|
|
50
|
+
|
|
51
|
+
if self.report_config.sample_h5ad_dict is None:
|
|
52
|
+
self.report_config._resolve_h5ad_inputs()
|
|
53
|
+
|
|
54
|
+
self.force_re_run = getattr(report_config, 'force_report_re_run', False)
|
|
55
|
+
|
|
56
|
+
# Internal state
|
|
57
|
+
self.ldsc_results = None
|
|
58
|
+
self.traits = []
|
|
59
|
+
self.sample_names = []
|
|
60
|
+
self.coords = None
|
|
61
|
+
self.is_3d = False
|
|
62
|
+
self.z_axis = None
|
|
63
|
+
self.common_spots = None
|
|
64
|
+
self.analysis_spots = None
|
|
65
|
+
self.gene_stats = None
|
|
66
|
+
self.metadata = None
|
|
67
|
+
self.gss_adata = None
|
|
68
|
+
self.chrom_tick_positions = None
|
|
69
|
+
self.umap_info = None
|
|
70
|
+
self.spatial_3d_html = None
|
|
71
|
+
|
|
72
|
+
def close(self):
|
|
73
|
+
"""Release resources and close file handles."""
|
|
74
|
+
if self.gss_adata is not None:
|
|
75
|
+
logger.info("Closing GSS AnnData and memory maps...")
|
|
76
|
+
if 'memmap_manager' in self.gss_adata.uns:
|
|
77
|
+
self.gss_adata.uns['memmap_manager'].close()
|
|
78
|
+
self.gss_adata = None
|
|
79
|
+
gc.collect()
|
|
80
|
+
|
|
81
|
+
def _is_step_complete(self, files: list[Path]) -> bool:
|
|
82
|
+
if self.force_re_run:
|
|
83
|
+
return False
|
|
84
|
+
return all(f.exists() for f in files)
|
|
85
|
+
|
|
86
|
+
def run(self):
|
|
87
|
+
"""Orchestrate the report data preparation."""
|
|
88
|
+
logger.info("Starting report data preparation...")
|
|
89
|
+
try:
|
|
90
|
+
# 1. Base Metadata
|
|
91
|
+
self.prepare_base_metadata()
|
|
92
|
+
|
|
93
|
+
# 2. GSS Statistics (PCC)
|
|
94
|
+
self.prepare_gss_stats()
|
|
95
|
+
|
|
96
|
+
# 3. Spot Metadata (depends on GSS stats for gene_list.csv)
|
|
97
|
+
self.prepare_spot_metadata()
|
|
98
|
+
|
|
99
|
+
# 4. Manhattan Data
|
|
100
|
+
self.prepare_manhattan_data()
|
|
101
|
+
|
|
102
|
+
# 5. Static Plots
|
|
103
|
+
self.render_static_plots()
|
|
104
|
+
|
|
105
|
+
# 6. Cauchy Results
|
|
106
|
+
self.collect_cauchy_results()
|
|
107
|
+
|
|
108
|
+
# 7. UMAP Data
|
|
109
|
+
self.prepare_umap_data()
|
|
110
|
+
|
|
111
|
+
# 8. 3D Visualization
|
|
112
|
+
self.prepare_3d_visualization()
|
|
113
|
+
|
|
114
|
+
# 9. Finalize Metadata and JS Assets
|
|
115
|
+
self.finalize_report()
|
|
116
|
+
|
|
117
|
+
logger.info("Report data preparation complete.")
|
|
118
|
+
logger.info(f" Data files: {self.report_data_dir}")
|
|
119
|
+
logger.info(f" Web report: {self.web_report_dir}")
|
|
120
|
+
finally:
|
|
121
|
+
self.close()
|
|
122
|
+
|
|
123
|
+
return self.web_report_dir
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def prepare_base_metadata(self):
|
|
127
|
+
"""Load LDSC results and coordinates."""
|
|
128
|
+
# 1. Load LDSC results
|
|
129
|
+
if self.ldsc_results is None:
|
|
130
|
+
self.ldsc_results, self.traits, self.sample_names = _load_ldsc_results(self.report_config)
|
|
131
|
+
|
|
132
|
+
# 2. Load coordinates
|
|
133
|
+
if self.coords is None:
|
|
134
|
+
self.coords, self.is_3d, self.z_axis = _load_coordinates(self.report_config)
|
|
135
|
+
|
|
136
|
+
# Export base metadata as JS if needed (usually handled in finalizing or per-sample)
|
|
137
|
+
# However, _export_per_sample_spatial_js is what we use now.
|
|
138
|
+
|
|
139
|
+
def prepare_gss_stats(self):
|
|
140
|
+
"""Load GSS data, calculate PCC per trait, and split results."""
|
|
141
|
+
gss_dir = self.report_data_dir / "gss_stats"
|
|
142
|
+
gss_dir.mkdir(exist_ok=True)
|
|
143
|
+
|
|
144
|
+
gene_list_file = self.report_data_dir / "gene_list.csv"
|
|
145
|
+
|
|
146
|
+
# Determine which traits need PCC calculation
|
|
147
|
+
traits_to_run = []
|
|
148
|
+
for trait in self.traits:
|
|
149
|
+
csv_path = self.report_config.get_gene_diagnostic_info_save_path(trait)
|
|
150
|
+
js_path = self.js_data_dir / "gss_stats" / f"gene_trait_correlation_{trait}.js"
|
|
151
|
+
if not self._is_step_complete([csv_path, js_path]):
|
|
152
|
+
traits_to_run.append(trait)
|
|
153
|
+
|
|
154
|
+
if not traits_to_run and self._is_step_complete([gene_list_file]):
|
|
155
|
+
logger.info("GSS statistics (PCC) already complete for all traits. Skipping.")
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
logger.info(f"Processing GSS statistics for {len(traits_to_run)} traits...")
|
|
159
|
+
|
|
160
|
+
# Load GSS and common spots
|
|
161
|
+
if self.gss_adata is None:
|
|
162
|
+
self.common_spots, self.gss_adata, self.gene_stats, self.analysis_spots = _load_gss_and_calculate_stats_base(
|
|
163
|
+
self.report_config, self.ldsc_results, self.coords, self.report_data_dir
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Pre-filter to high expression genes and pre-calculate centered matrix
|
|
167
|
+
exp_frac = pd.read_parquet(self.report_config.mean_frac_path)
|
|
168
|
+
high_expr_genes = exp_frac[exp_frac['frac'] > 0.01].index.tolist()
|
|
169
|
+
|
|
170
|
+
# Use analysis_spots (subsampled) for PCC calculation
|
|
171
|
+
gss_adata_sub = self.gss_adata[self.analysis_spots, high_expr_genes]
|
|
172
|
+
gss_matrix = gss_adata_sub.X
|
|
173
|
+
if hasattr(gss_matrix, 'toarray'):
|
|
174
|
+
gss_matrix = gss_matrix.toarray()
|
|
175
|
+
|
|
176
|
+
gss_mean = gss_matrix.mean(axis=0).astype(np.float32)
|
|
177
|
+
gss_centered = (gss_matrix - gss_mean).astype(np.float32)
|
|
178
|
+
gss_ssq = np.sum(gss_centered ** 2, axis=0)
|
|
179
|
+
gene_names = gss_adata_sub.var_names.tolist()
|
|
180
|
+
|
|
181
|
+
# Calculate PCC for each missing trait
|
|
182
|
+
all_pcc = []
|
|
183
|
+
for trait in self.traits:
|
|
184
|
+
# We check both the CSV in gss_stats and the JS in js_data/gss_stats
|
|
185
|
+
csv_path = self.report_config.get_gene_diagnostic_info_save_path(trait)
|
|
186
|
+
js_path = self.js_data_dir / "gss_stats" / f"gene_trait_correlation_{trait}.js"
|
|
187
|
+
|
|
188
|
+
if self._is_step_complete([csv_path, js_path]):
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
trait_pcc = _calculate_pcc_for_single_trait_fast(
|
|
192
|
+
trait, self.ldsc_results, self.analysis_spots, gss_centered, gss_ssq, gene_names, self.gene_stats, self.report_config, self.report_data_dir, gss_dir
|
|
193
|
+
)
|
|
194
|
+
if trait_pcc is not None:
|
|
195
|
+
all_pcc.append(trait_pcc)
|
|
196
|
+
# Export to JS immediately after CSV creation
|
|
197
|
+
self._export_single_pcc_js(trait, trait_pcc)
|
|
198
|
+
|
|
199
|
+
def prepare_spot_metadata(self):
|
|
200
|
+
"""Save spot metadata and coordinates."""
|
|
201
|
+
metadata_file = self.report_data_dir / "spot_metadata.csv"
|
|
202
|
+
if self._is_step_complete([metadata_file]):
|
|
203
|
+
logger.info("Spot metadata already exists. Skipping.")
|
|
204
|
+
self.metadata = pd.read_csv(metadata_file)
|
|
205
|
+
if self.common_spots is None:
|
|
206
|
+
self.common_spots = self.metadata['spot'].values
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
self.metadata = _save_metadata(self.ldsc_results, self.coords, self.report_data_dir)
|
|
210
|
+
if self.common_spots is None:
|
|
211
|
+
self.common_spots = self.metadata['spot'].values
|
|
212
|
+
|
|
213
|
+
def _export_single_pcc_js(self, trait, df):
|
|
214
|
+
"""Export single trait PCC results to JS in gss_stats subfolder."""
|
|
215
|
+
gss_js_dir = self.js_data_dir / "gss_stats"
|
|
216
|
+
gss_js_dir.mkdir(exist_ok=True)
|
|
217
|
+
|
|
218
|
+
data_json = df.to_json(orient='records')
|
|
219
|
+
var_name = f"GSMAP_GENE_TRAIT_CORRELATION_{"".join(c if c.isalnum() else "_" for c in trait)}"
|
|
220
|
+
js_content = f"window.{var_name} = {data_json};"
|
|
221
|
+
with open(gss_js_dir / f"gene_trait_correlation_{trait}.js", "w", encoding='utf-8') as f:
|
|
222
|
+
f.write(js_content)
|
|
223
|
+
def prepare_manhattan_data(self):
|
|
224
|
+
"""Prepare Manhattan data for all traits."""
|
|
225
|
+
manhattan_dir = self.report_data_dir / "manhattan_data"
|
|
226
|
+
manhattan_dir.mkdir(exist_ok=True)
|
|
227
|
+
|
|
228
|
+
# Determine which traits need Manhattan data
|
|
229
|
+
traits_to_run = []
|
|
230
|
+
for trait in self.traits:
|
|
231
|
+
csv_path = manhattan_dir / f"{trait}_manhattan.csv"
|
|
232
|
+
js_path = self.js_data_dir / f"manhattan_{trait}.js"
|
|
233
|
+
if not self._is_step_complete([csv_path, js_path]):
|
|
234
|
+
traits_to_run.append(trait)
|
|
235
|
+
|
|
236
|
+
if not traits_to_run:
|
|
237
|
+
logger.info("Manhattan data already complete for all traits. Skipping.")
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
logger.info(f"Processing Manhattan data for {len(traits_to_run)} traits...")
|
|
241
|
+
|
|
242
|
+
# Load weights
|
|
243
|
+
logger.info(f"Loading weights from {self.report_config.snp_gene_weight_adata_path}")
|
|
244
|
+
weight_adata = ad.read_h5ad(self.report_config.snp_gene_weight_adata_path)
|
|
245
|
+
|
|
246
|
+
# Load gene ref
|
|
247
|
+
genes_file = self.report_data_dir / "gene_list.csv"
|
|
248
|
+
gene_names_ref = pd.read_csv(genes_file)['gene'].tolist() if genes_file.exists() else []
|
|
249
|
+
|
|
250
|
+
chrom_tick_positions = {}
|
|
251
|
+
for trait in traits_to_run:
|
|
252
|
+
try:
|
|
253
|
+
# Load trait-specific PCC data
|
|
254
|
+
trait_pcc_file = self.report_config.get_gene_diagnostic_info_save_path(trait)
|
|
255
|
+
trait_pcc_df = pd.read_csv(trait_pcc_file) if trait_pcc_file.exists() else None
|
|
256
|
+
|
|
257
|
+
chrom_ticks = _prepare_manhattan_for_trait(
|
|
258
|
+
self.report_config, trait, weight_adata, trait_pcc_df, gene_names_ref, manhattan_dir
|
|
259
|
+
)
|
|
260
|
+
chrom_tick_positions[trait] = chrom_ticks
|
|
261
|
+
# Export to JS
|
|
262
|
+
_export_single_manhattan_js(trait, manhattan_dir / f"{trait}_manhattan.csv", self.js_data_dir)
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.warning(f"Failed to prepare Manhattan data for {trait}: {e}")
|
|
265
|
+
|
|
266
|
+
self.chrom_tick_positions = chrom_tick_positions
|
|
267
|
+
|
|
268
|
+
def render_static_plots(self):
|
|
269
|
+
"""Render static plots (LDSC, Annotation, Gene Diagnostics)."""
|
|
270
|
+
dataset_type = getattr(self.report_config, 'dataset_type', DatasetType.SPATIAL_2D)
|
|
271
|
+
is_spatial = dataset_type in (DatasetType.SPATIAL_2D, DatasetType.SPATIAL_3D)
|
|
272
|
+
|
|
273
|
+
if not is_spatial:
|
|
274
|
+
logger.info("Skipping static plot rendering for non-spatial dataset type")
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
from .visualize import VisualizeRunner
|
|
278
|
+
visualizer = VisualizeRunner(self.report_config)
|
|
279
|
+
|
|
280
|
+
n_samples = len(self.sample_names)
|
|
281
|
+
n_cols = min(4, n_samples)
|
|
282
|
+
n_rows = (n_samples + n_cols - 1) // n_cols
|
|
283
|
+
|
|
284
|
+
obs_data = self.metadata.copy()
|
|
285
|
+
if 'sample' not in obs_data.columns:
|
|
286
|
+
obs_data['sample'] = obs_data['sample_name']
|
|
287
|
+
|
|
288
|
+
if self.report_config.generate_multi_sample_plots:
|
|
289
|
+
# Render LDSC plots
|
|
290
|
+
spatial_plot_dir = self.web_report_dir / "spatial_plots"
|
|
291
|
+
spatial_plot_dir.mkdir(exist_ok=True)
|
|
292
|
+
for trait in self.traits:
|
|
293
|
+
trait_plot_path = spatial_plot_dir / f"ldsc_{trait}.png"
|
|
294
|
+
if not self._is_step_complete([trait_plot_path]):
|
|
295
|
+
logger.info(f"Rendering LDSC plot for {trait}...")
|
|
296
|
+
visualizer._create_single_trait_multi_sample_matplotlib_plot(
|
|
297
|
+
obs_ldsc_merged=obs_data,
|
|
298
|
+
trait_abbreviation=trait,
|
|
299
|
+
sample_name_list=self.sample_names,
|
|
300
|
+
output_png_path=trait_plot_path,
|
|
301
|
+
n_rows=n_rows, n_cols=n_cols,
|
|
302
|
+
subplot_width_inches=5.0
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Render annotation plots
|
|
306
|
+
anno_dir = self.web_report_dir / "annotation_plots"
|
|
307
|
+
anno_dir.mkdir(exist_ok=True)
|
|
308
|
+
for anno in self.report_config.annotation_list:
|
|
309
|
+
anno_plot_path = anno_dir / f"anno_{anno}.png"
|
|
310
|
+
if not self._is_step_complete([anno_plot_path]):
|
|
311
|
+
logger.info(f"Rendering plot for annotation: {anno}...")
|
|
312
|
+
fig = visualizer._create_multi_sample_annotation_plot(
|
|
313
|
+
obs_ldsc_merged=obs_data,
|
|
314
|
+
annotation=anno,
|
|
315
|
+
sample_names_list=self.sample_names,
|
|
316
|
+
output_dir=None,
|
|
317
|
+
n_rows=n_rows, n_cols=n_cols,
|
|
318
|
+
fig_width=5 * n_cols, fig_height=5 * n_rows
|
|
319
|
+
)
|
|
320
|
+
import matplotlib.pyplot as plt
|
|
321
|
+
fig.savefig(anno_plot_path, dpi=300, bbox_inches='tight')
|
|
322
|
+
plt.close(fig)
|
|
323
|
+
|
|
324
|
+
# Gene diagnostic plots
|
|
325
|
+
# Ensure base metadata exists (needed for common_spots and ldsc_results)
|
|
326
|
+
if self.ldsc_results is None or self.coords is None:
|
|
327
|
+
self.prepare_base_metadata()
|
|
328
|
+
|
|
329
|
+
if self.metadata is None:
|
|
330
|
+
self.prepare_spot_metadata()
|
|
331
|
+
|
|
332
|
+
# Ensure GSS data is loaded if we need to plot
|
|
333
|
+
if self.gss_adata is None:
|
|
334
|
+
self.common_spots, self.gss_adata, self.gene_stats, self.analysis_spots = _load_gss_and_calculate_stats_base(
|
|
335
|
+
self.report_config, self.ldsc_results, self.coords, self.report_data_dir
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
_render_gene_diagnostic_plots_refactored(self.report_config, self.metadata, self.common_spots, self.gss_adata, n_rows, n_cols, self.web_report_dir, self.report_data_dir, self._is_step_complete)
|
|
339
|
+
|
|
340
|
+
def collect_cauchy_results(self):
|
|
341
|
+
"""Collect and save Cauchy combination results."""
|
|
342
|
+
cauchy_file = self.report_data_dir / "cauchy_results.csv"
|
|
343
|
+
cauchy_js = self.js_data_dir / "cauchy_results.js"
|
|
344
|
+
# Since this combines all traits/annotations, we check for the final csv
|
|
345
|
+
if self._is_step_complete([cauchy_file, cauchy_js]):
|
|
346
|
+
logger.info("Cauchy results already collected. Skipping.")
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
_collect_cauchy_results(self.report_config, self.traits, self.report_data_dir)
|
|
350
|
+
_export_cauchy_js(self.report_data_dir, self.js_data_dir)
|
|
351
|
+
|
|
352
|
+
def prepare_umap_data(self):
|
|
353
|
+
"""Prepare UMAP data from embeddings."""
|
|
354
|
+
umap_file = self.report_data_dir / "umap_data.csv"
|
|
355
|
+
umap_js = self.js_data_dir / "umap_data.js"
|
|
356
|
+
if self._is_step_complete([umap_file,umap_js]):
|
|
357
|
+
logger.info("UMAP data already exists. Skipping.")
|
|
358
|
+
return
|
|
359
|
+
|
|
360
|
+
self.umap_info = _prepare_umap_data(self.report_config, self.metadata, self.report_data_dir)
|
|
361
|
+
_export_umap_js(self.report_data_dir, self.js_data_dir, {'umap_info': self.umap_info})
|
|
362
|
+
|
|
363
|
+
def prepare_3d_visualization(self):
|
|
364
|
+
"""Prepare 3D visualization if applicable."""
|
|
365
|
+
if not self.is_3d:
|
|
366
|
+
self.spatial_3d_html = None
|
|
367
|
+
return
|
|
368
|
+
|
|
369
|
+
self.spatial_3d_html = _prepare_3d_visualization(self.report_config, self.metadata, self.traits, self.report_data_dir, self.web_report_dir)
|
|
370
|
+
|
|
371
|
+
def finalize_report(self):
|
|
372
|
+
"""Save report metadata and final JS assets."""
|
|
373
|
+
# 1. Save standard report metadata
|
|
374
|
+
_save_report_metadata(
|
|
375
|
+
self.report_config, self.traits, self.sample_names, self.web_report_dir,
|
|
376
|
+
getattr(self, 'chrom_tick_positions', None),
|
|
377
|
+
getattr(self, 'umap_info', None),
|
|
378
|
+
is_3d=self.is_3d,
|
|
379
|
+
spatial_3d_path=getattr(self, 'spatial_3d_html', None)
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# 2. Copy JS libraries
|
|
383
|
+
_copy_js_assets(self.web_report_dir)
|
|
384
|
+
|
|
385
|
+
# 3. Export per-sample spatial JS (highly important for performance)
|
|
386
|
+
with open(self.web_report_dir / "report_meta.json") as f:
|
|
387
|
+
meta = json.load(f)
|
|
388
|
+
_export_per_sample_spatial_js(self.report_data_dir, self.js_data_dir, meta)
|
|
389
|
+
_export_report_meta_js(self.web_report_dir, self.js_data_dir, meta)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def _render_single_sample_gene_plot_task(task_data: dict):
|
|
394
|
+
"""
|
|
395
|
+
Parallel task to render single-sample Expression or GSS plot.
|
|
396
|
+
Creates a single plot for one gene in one sample.
|
|
397
|
+
"""
|
|
398
|
+
try:
|
|
399
|
+
import gc
|
|
400
|
+
from pathlib import Path
|
|
401
|
+
|
|
402
|
+
import matplotlib
|
|
403
|
+
import matplotlib.colors as mcolors
|
|
404
|
+
import matplotlib.pyplot as plt
|
|
405
|
+
import numpy as np
|
|
406
|
+
|
|
407
|
+
matplotlib.use('Agg')
|
|
408
|
+
|
|
409
|
+
gene = task_data['gene']
|
|
410
|
+
sample_name = task_data['sample_name']
|
|
411
|
+
plot_type = task_data['plot_type']
|
|
412
|
+
coords = task_data['coords']
|
|
413
|
+
values = task_data['values']
|
|
414
|
+
output_path = Path(task_data['output_path'])
|
|
415
|
+
fig_width = task_data.get('fig_width', 6.0)
|
|
416
|
+
dpi = task_data.get('dpi', 150)
|
|
417
|
+
|
|
418
|
+
if coords is None or values is None or len(coords) == 0:
|
|
419
|
+
return f"No data for {gene} in {sample_name}"
|
|
420
|
+
|
|
421
|
+
# Custom colormap
|
|
422
|
+
custom_colors = [
|
|
423
|
+
'#313695', '#4575b4', '#74add1', '#abd9e9', '#e0f3f8',
|
|
424
|
+
'#fee090', '#fdae61', '#f46d43', '#d73027'
|
|
425
|
+
]
|
|
426
|
+
custom_cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', custom_colors)
|
|
427
|
+
|
|
428
|
+
fig, ax = plt.subplots(figsize=(fig_width, fig_width))
|
|
429
|
+
|
|
430
|
+
# Calculate point size based on data density
|
|
431
|
+
from gsMap.report.visualize import estimate_matplotlib_scatter_marker_size
|
|
432
|
+
point_size = estimate_matplotlib_scatter_marker_size(ax, coords)
|
|
433
|
+
point_size = min(max(point_size, 1), 200)
|
|
434
|
+
|
|
435
|
+
# Color scale
|
|
436
|
+
vmin = 0
|
|
437
|
+
with np.errstate(all='ignore'):
|
|
438
|
+
non_nan_values = values[np.isfinite(values)] if len(values) > 0 else np.array([])
|
|
439
|
+
vmax = np.percentile(non_nan_values, 99.5) if len(non_nan_values) > 0 else 1.0
|
|
440
|
+
|
|
441
|
+
if not np.isfinite(vmax) or vmax <= vmin:
|
|
442
|
+
vmax = vmin + 1.0
|
|
443
|
+
|
|
444
|
+
scatter = ax.scatter(
|
|
445
|
+
coords[:, 0], coords[:, 1],
|
|
446
|
+
c=values, cmap=custom_cmap,
|
|
447
|
+
s=point_size, vmin=vmin, vmax=vmax,
|
|
448
|
+
marker='o', edgecolors='none', rasterized=True
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
if task_data.get('plot_origin', 'upper') == 'upper':
|
|
452
|
+
ax.invert_yaxis()
|
|
453
|
+
|
|
454
|
+
ax.axis('off')
|
|
455
|
+
ax.set_aspect('equal')
|
|
456
|
+
|
|
457
|
+
# Add colorbar
|
|
458
|
+
cbar = fig.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
|
|
459
|
+
cbar.set_label('Expression' if plot_type == 'exp' else 'GSS', fontsize=10)
|
|
460
|
+
|
|
461
|
+
# Title
|
|
462
|
+
title_text = f"{gene} - {'Expression' if plot_type == 'exp' else 'GSS'}"
|
|
463
|
+
ax.set_title(title_text, fontsize=12, fontweight='bold')
|
|
464
|
+
|
|
465
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
466
|
+
plt.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
|
|
467
|
+
plt.close(fig)
|
|
468
|
+
gc.collect()
|
|
469
|
+
|
|
470
|
+
return True
|
|
471
|
+
except Exception as e:
|
|
472
|
+
import traceback
|
|
473
|
+
return f"{str(e)}\n{traceback.format_exc()}"
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
# =============================================================================
|
|
477
|
+
# UMAP Calculation Functions
|
|
478
|
+
# =============================================================================
|
|
479
|
+
|
|
480
|
+
def _detect_z_axis(coords_3d: np.ndarray, sample_names: pd.Series) -> int:
|
|
481
|
+
"""
|
|
482
|
+
Detect which dimension is the Z axis (stacking dimension) in 3D coordinates.
|
|
483
|
+
|
|
484
|
+
The Z axis is identified as the dimension with the least within-sample variance,
|
|
485
|
+
since samples are stacked along Z and should have minimal variation in that dimension.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
coords_3d: 3D coordinates array (n_spots, 3)
|
|
489
|
+
sample_names: Series of sample names for each spot
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Index of the Z axis (0, 1, or 2)
|
|
493
|
+
"""
|
|
494
|
+
# Calculate within-sample variance for each dimension
|
|
495
|
+
within_sample_vars = []
|
|
496
|
+
|
|
497
|
+
for dim in range(3):
|
|
498
|
+
dim_values = coords_3d[:, dim]
|
|
499
|
+
|
|
500
|
+
# Calculate variance within each sample
|
|
501
|
+
sample_vars = []
|
|
502
|
+
for sample in sample_names.unique():
|
|
503
|
+
mask = sample_names == sample
|
|
504
|
+
sample_values = dim_values[mask]
|
|
505
|
+
if len(sample_values) > 1:
|
|
506
|
+
sample_vars.append(np.var(sample_values))
|
|
507
|
+
|
|
508
|
+
# Average within-sample variance for this dimension
|
|
509
|
+
avg_within_var = np.mean(sample_vars) if sample_vars else 0
|
|
510
|
+
within_sample_vars.append(avg_within_var)
|
|
511
|
+
|
|
512
|
+
# Z axis has the least within-sample variance
|
|
513
|
+
z_axis = int(np.argmin(within_sample_vars))
|
|
514
|
+
logger.info(f"Detected Z axis: dimension {z_axis} (within-sample variances: {within_sample_vars})")
|
|
515
|
+
|
|
516
|
+
return z_axis
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _reorder_coords_for_3d(coords_3d: np.ndarray, z_axis: int) -> tuple[np.ndarray, list[str]]:
|
|
520
|
+
"""
|
|
521
|
+
Reorder 3D coordinates so that Z axis is the last dimension.
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
Tuple of (reordered coords, column names)
|
|
525
|
+
"""
|
|
526
|
+
# Create axis order: put z_axis last
|
|
527
|
+
axis_order = [i for i in range(3) if i != z_axis] + [z_axis]
|
|
528
|
+
reordered = coords_3d[:, axis_order]
|
|
529
|
+
|
|
530
|
+
# Column names reflect the reordering
|
|
531
|
+
col_names = ['3d_x', '3d_y', '3d_z']
|
|
532
|
+
|
|
533
|
+
return reordered, col_names
|
|
534
|
+
|
|
535
|
+
def _stratified_subsample(
|
|
536
|
+
spot_names: np.ndarray,
|
|
537
|
+
sample_names: pd.Series,
|
|
538
|
+
n_samples: int,
|
|
539
|
+
random_state: int = 42
|
|
540
|
+
) -> np.ndarray:
|
|
541
|
+
"""
|
|
542
|
+
Stratified subsampling that ensures representation from all samples.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
spot_names: Array of spot identifiers
|
|
546
|
+
sample_names: Series mapping spot names to sample names
|
|
547
|
+
n_samples: Target number of samples
|
|
548
|
+
random_state: Random seed for reproducibility
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
Array of selected spot names
|
|
552
|
+
"""
|
|
553
|
+
np.random.seed(random_state)
|
|
554
|
+
|
|
555
|
+
# Get sample counts
|
|
556
|
+
sample_counts = sample_names.value_counts()
|
|
557
|
+
n_samples_total = len(spot_names)
|
|
558
|
+
|
|
559
|
+
if n_samples >= n_samples_total:
|
|
560
|
+
return spot_names
|
|
561
|
+
|
|
562
|
+
# Calculate proportional samples per group
|
|
563
|
+
selected_spots = []
|
|
564
|
+
for sample_name, count in sample_counts.items():
|
|
565
|
+
sample_spots = spot_names[sample_names == sample_name]
|
|
566
|
+
# Proportional allocation
|
|
567
|
+
n_select = max(1, int(np.ceil(n_samples * count / n_samples_total)))
|
|
568
|
+
n_select = min(n_select, len(sample_spots))
|
|
569
|
+
|
|
570
|
+
selected = np.random.choice(sample_spots, n_select, replace=False)
|
|
571
|
+
selected_spots.extend(selected)
|
|
572
|
+
|
|
573
|
+
# If we have too many, randomly remove some
|
|
574
|
+
selected_spots = np.array(selected_spots)
|
|
575
|
+
if len(selected_spots) > n_samples:
|
|
576
|
+
selected_spots = np.random.choice(selected_spots, n_samples, replace=False)
|
|
577
|
+
|
|
578
|
+
return selected_spots
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def _calculate_umap_from_embeddings(
|
|
582
|
+
adata: ad.AnnData,
|
|
583
|
+
embedding_key: str,
|
|
584
|
+
) -> np.ndarray:
|
|
585
|
+
logger.info(f"Calculating UMAP for {embedding_key} with {adata.n_obs} spots using scanpy...")
|
|
586
|
+
|
|
587
|
+
sc.pp.neighbors(adata, use_rep=embedding_key)
|
|
588
|
+
sc.tl.umap(adata)
|
|
589
|
+
|
|
590
|
+
return adata.obsm['X_umap']
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _prepare_umap_data(
|
|
594
|
+
report_config: QuickModeConfig,
|
|
595
|
+
metadata: pd.DataFrame,
|
|
596
|
+
report_dir: Path
|
|
597
|
+
) -> dict | None:
|
|
598
|
+
"""
|
|
599
|
+
Prepare UMAP data from cell and niche embeddings.
|
|
600
|
+
|
|
601
|
+
Returns dict with umap_cell, umap_niche (optional), and metadata for visualization.
|
|
602
|
+
"""
|
|
603
|
+
logger.info("Preparing UMAP data from embeddings...")
|
|
604
|
+
|
|
605
|
+
# Load concatenated adata
|
|
606
|
+
adata_path = report_config.concatenated_latent_adata_path
|
|
607
|
+
if not adata_path.exists():
|
|
608
|
+
logger.warning(f"Concatenated adata not found at {adata_path}")
|
|
609
|
+
return None
|
|
610
|
+
|
|
611
|
+
adata = ad.read_h5ad(adata_path)
|
|
612
|
+
|
|
613
|
+
# Get embedding keys from config
|
|
614
|
+
cell_emb_key = getattr(report_config, 'latent_representation_cell', 'emb_cell')
|
|
615
|
+
niche_emb_key = getattr(report_config, 'latent_representation_niche', None)
|
|
616
|
+
|
|
617
|
+
# Check if embeddings exist
|
|
618
|
+
if cell_emb_key not in adata.obsm:
|
|
619
|
+
logger.warning(f"Cell embedding '{cell_emb_key}' not found in adata.obsm")
|
|
620
|
+
return None
|
|
621
|
+
|
|
622
|
+
has_niche = niche_emb_key is not None and niche_emb_key in adata.obsm
|
|
623
|
+
|
|
624
|
+
# Stratified subsampling
|
|
625
|
+
n_subsample = getattr(report_config, 'downsampling_n_spots', 20000)
|
|
626
|
+
spot_names = adata.obs_names.values
|
|
627
|
+
sample_names = adata.obs['sample_name']
|
|
628
|
+
|
|
629
|
+
if len(spot_names) > n_subsample:
|
|
630
|
+
logger.info(f"Stratified subsampling from {len(spot_names)} to {n_subsample} spots...")
|
|
631
|
+
selected_spots = _stratified_subsample(spot_names, sample_names, n_subsample)
|
|
632
|
+
adata_sub = adata[selected_spots].copy()
|
|
633
|
+
else:
|
|
634
|
+
adata_sub = adata.copy()
|
|
635
|
+
selected_spots = spot_names
|
|
636
|
+
|
|
637
|
+
logger.info(f"Using {len(selected_spots)} spots for UMAP calculation")
|
|
638
|
+
|
|
639
|
+
# Calculate UMAPs
|
|
640
|
+
umap_cell = _calculate_umap_from_embeddings(adata_sub, cell_emb_key)
|
|
641
|
+
|
|
642
|
+
# Estimate point size for UMAP cell
|
|
643
|
+
_, point_size_cell = estimate_plotly_point_size(umap_cell, DEFAULT_PIXEL_WIDTH=600)
|
|
644
|
+
|
|
645
|
+
umap_niche = None
|
|
646
|
+
point_size_niche = None
|
|
647
|
+
if has_niche:
|
|
648
|
+
umap_niche = _calculate_umap_from_embeddings(adata_sub, niche_emb_key)
|
|
649
|
+
_, point_size_niche = estimate_plotly_point_size(umap_niche, DEFAULT_PIXEL_WIDTH=600)
|
|
650
|
+
|
|
651
|
+
# Prepare metadata for the subsampled spots
|
|
652
|
+
umap_metadata = pd.DataFrame({
|
|
653
|
+
'spot': adata_sub.obs_names,
|
|
654
|
+
'sample_name': adata_sub.obs['sample_name'].values,
|
|
655
|
+
'umap_cell_x': umap_cell[:, 0],
|
|
656
|
+
'umap_cell_y': umap_cell[:, 1],
|
|
657
|
+
})
|
|
658
|
+
|
|
659
|
+
if umap_niche is not None:
|
|
660
|
+
umap_metadata['umap_niche_x'] = umap_niche[:, 0]
|
|
661
|
+
umap_metadata['umap_niche_y'] = umap_niche[:, 1]
|
|
662
|
+
|
|
663
|
+
# Add all annotation columns
|
|
664
|
+
for anno in report_config.annotation_list:
|
|
665
|
+
if anno in adata_sub.obs.columns:
|
|
666
|
+
# Fill NaN with 'NaN' and convert to string to avoid sorting errors
|
|
667
|
+
umap_metadata[anno] = adata_sub.obs[anno].astype(str).fillna('NaN').values
|
|
668
|
+
|
|
669
|
+
# Add trait -log10(p) values from metadata if available (vectorized join)
|
|
670
|
+
traits = report_config.trait_name_list
|
|
671
|
+
available_traits = [t for t in traits if t in metadata.columns]
|
|
672
|
+
if available_traits:
|
|
673
|
+
|
|
674
|
+
# Prepare trait data: ensure 'spot' is a column and is string type
|
|
675
|
+
trait_data = metadata[available_traits].copy()
|
|
676
|
+
trait_data['spot'] = metadata['spot'].astype(str)
|
|
677
|
+
umap_metadata = umap_metadata.merge(trait_data, on='spot', how='left')
|
|
678
|
+
|
|
679
|
+
# Keep 1 decimal precision for trait values and handle non-numeric data
|
|
680
|
+
for trait in available_traits:
|
|
681
|
+
if trait in umap_metadata.columns:
|
|
682
|
+
# Convert to numeric in case values were strings, then round
|
|
683
|
+
umap_metadata[trait] = pd.to_numeric(umap_metadata[trait], errors='coerce').round(1)
|
|
684
|
+
logger.info(f"Added trait {trait} to UMAP data with 1 decimal precision")
|
|
685
|
+
|
|
686
|
+
# Save to CSV
|
|
687
|
+
umap_metadata.to_csv(report_dir / "umap_data.csv", index=False)
|
|
688
|
+
logger.info(f"UMAP data saved with {len(umap_metadata)} points")
|
|
689
|
+
|
|
690
|
+
return {
|
|
691
|
+
'has_niche': has_niche,
|
|
692
|
+
'cell_emb_key': cell_emb_key,
|
|
693
|
+
'niche_emb_key': niche_emb_key,
|
|
694
|
+
'n_points': len(umap_metadata),
|
|
695
|
+
'point_size_cell': float(point_size_cell),
|
|
696
|
+
'point_size_niche': float(point_size_niche) if point_size_niche is not None else None,
|
|
697
|
+
'default_opacity': 0.7
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _prepare_3d_visualization(
|
|
703
|
+
report_config: QuickModeConfig,
|
|
704
|
+
metadata: pd.DataFrame,
|
|
705
|
+
traits: list[str],
|
|
706
|
+
report_data_dir: Path,
|
|
707
|
+
web_report_dir: Path
|
|
708
|
+
) -> str | None:
|
|
709
|
+
"""
|
|
710
|
+
Create 3D visualization widget using spatialvista and save as HTML.
|
|
711
|
+
|
|
712
|
+
Args:
|
|
713
|
+
report_config: QuickModeConfig with annotation_list and other settings
|
|
714
|
+
metadata: DataFrame with 3d_x, 3d_y, 3d_z coordinates and annotations
|
|
715
|
+
traits: List of trait names (for continuous values)
|
|
716
|
+
report_data_dir: Directory for data files (h5ad)
|
|
717
|
+
web_report_dir: Directory for web report files (HTML)
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
Path to saved HTML file, or None if failed
|
|
721
|
+
"""
|
|
722
|
+
try:
|
|
723
|
+
import pyvista
|
|
724
|
+
except ImportError:
|
|
725
|
+
logger.warning("pyvista not installed. Skipping 3D visualization.")
|
|
726
|
+
return None
|
|
727
|
+
|
|
728
|
+
# Check if we have 3D coordinates
|
|
729
|
+
if '3d_x' not in metadata.columns or '3d_y' not in metadata.columns or '3d_z' not in metadata.columns:
|
|
730
|
+
logger.warning("3D coordinates not found in metadata. Skipping 3D visualization.")
|
|
731
|
+
return None
|
|
732
|
+
|
|
733
|
+
logger.info("Creating 3D visualization...")
|
|
734
|
+
|
|
735
|
+
# 1. Create adata_vis using ALL spots
|
|
736
|
+
n_spots_all = len(metadata)
|
|
737
|
+
adata_vis_full = ad.AnnData(
|
|
738
|
+
X=sp.csr_matrix((n_spots_all, 1), dtype=np.float32),
|
|
739
|
+
obs=metadata.set_index('spot')
|
|
740
|
+
)
|
|
741
|
+
# Set coordinates
|
|
742
|
+
adata_vis_full.obsm['spatial_3d'] = metadata[['3d_x', '3d_y', '3d_z']].values
|
|
743
|
+
adata_vis_full.obsm['spatial'] = adata_vis_full.obsm['spatial_3d']
|
|
744
|
+
if 'sx' in metadata.columns and 'sy' in metadata.columns:
|
|
745
|
+
adata_vis_full.obsm['spatial_2d'] = metadata[['sx', 'sy']].values
|
|
746
|
+
|
|
747
|
+
# Ensure trait columns are float32
|
|
748
|
+
adata_vis_full.obs = adata_vis_full.obs.astype({
|
|
749
|
+
trait: np.float32 for trait in traits if trait in adata_vis_full.obs.columns
|
|
750
|
+
})
|
|
751
|
+
|
|
752
|
+
# Create 3D visualization directories
|
|
753
|
+
three_d_data_dir = report_data_dir / "spatial_3d"
|
|
754
|
+
three_d_data_dir.mkdir(exist_ok=True)
|
|
755
|
+
three_d_web_dir = web_report_dir / "spatial_3d"
|
|
756
|
+
three_d_web_dir.mkdir(exist_ok=True)
|
|
757
|
+
|
|
758
|
+
# 2. Save the full adata_vis to data directory
|
|
759
|
+
h5ad_path = three_d_data_dir / "spatial_3d.h5ad"
|
|
760
|
+
adata_vis_full.write_h5ad(h5ad_path)
|
|
761
|
+
logger.info(f"Full 3D visualization data saved to {h5ad_path}")
|
|
762
|
+
|
|
763
|
+
# 3. Stratified subsampling for HTML visualization (limit to reasonable size)
|
|
764
|
+
n_max_points = getattr(report_config, 'downsampling_n_spots_3d', 1000000)
|
|
765
|
+
if len(metadata) > n_max_points:
|
|
766
|
+
sample_names = metadata['sample_name']
|
|
767
|
+
selected_idx = _stratified_subsample(
|
|
768
|
+
metadata.index.values, sample_names, n_max_points
|
|
769
|
+
)
|
|
770
|
+
adata_vis = adata_vis_full[selected_idx].copy()
|
|
771
|
+
logger.info(f"Subsampled to {len(adata_vis)} spots for HTML 3D visualization")
|
|
772
|
+
else:
|
|
773
|
+
adata_vis = adata_vis_full.copy()
|
|
774
|
+
|
|
775
|
+
# Add traits as continuous values
|
|
776
|
+
continuous_cols = traits
|
|
777
|
+
|
|
778
|
+
# Add annotations (categorical)
|
|
779
|
+
annotation_cols = []
|
|
780
|
+
for anno in report_config.annotation_list:
|
|
781
|
+
if anno in adata_vis.obs.columns:
|
|
782
|
+
annotation_cols.append(anno)
|
|
783
|
+
|
|
784
|
+
logger.info(f"3D visualization: annotations={annotation_cols}, traits={continuous_cols}")
|
|
785
|
+
|
|
786
|
+
# Create 3D plots
|
|
787
|
+
try:
|
|
788
|
+
from gsMap.report.visualize import _create_color_map
|
|
789
|
+
|
|
790
|
+
# P-value color scale (Red-Blue)
|
|
791
|
+
P_COLOR = ['#313695', '#4575b4', '#74add1', '#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
|
|
792
|
+
|
|
793
|
+
# Calculate appropriate point size based on coverage ratio (20-30%)
|
|
794
|
+
# Formula: S = sqrt(W * H * k / N)
|
|
795
|
+
win_w, win_h = 1200, 1008
|
|
796
|
+
k_coverage = 0.25
|
|
797
|
+
n_points = len(adata_vis)
|
|
798
|
+
point_size = np.sqrt((win_w * win_h * k_coverage) / n_points)
|
|
799
|
+
point_size = max(0.5, min(5.0, point_size)) # Clamp between 0.5 and 5.0
|
|
800
|
+
logger.info(f"Estimated 3D point size: {point_size:.2f} for {n_points} spots")
|
|
801
|
+
|
|
802
|
+
# Shared plotting settings
|
|
803
|
+
text_kwargs = dict(text_font_size=15, text_loc="upper_edge")
|
|
804
|
+
|
|
805
|
+
# 1. Save all Annotation 3D plots
|
|
806
|
+
categorical_legend_kwargs = dict(categorical_legend_loc="center right")
|
|
807
|
+
for anno in annotation_cols:
|
|
808
|
+
logger.info(f"Generating 3D plot for annotation: {anno}")
|
|
809
|
+
|
|
810
|
+
# Use same colormap logic as distribution plots
|
|
811
|
+
if pd.api.types.is_numeric_dtype(adata_vis.obs[anno]):
|
|
812
|
+
color_map = P_COLOR
|
|
813
|
+
else:
|
|
814
|
+
unique_vals = adata_vis.obs[anno].unique()
|
|
815
|
+
color_map = _create_color_map(unique_vals, hex=True, rng=42)
|
|
816
|
+
|
|
817
|
+
safe_anno = "".join(c if c.isalnum() else "_" for c in anno)
|
|
818
|
+
anno_plot_name = three_d_web_dir / f"spatial_3d_anno_{safe_anno}"
|
|
819
|
+
|
|
820
|
+
plotter_anno = three_d_plot(
|
|
821
|
+
adata=adata_vis,
|
|
822
|
+
spatial_key='spatial',
|
|
823
|
+
keys=[anno],
|
|
824
|
+
cmaps=[color_map],
|
|
825
|
+
point_size=point_size,
|
|
826
|
+
texts=[anno],
|
|
827
|
+
off_screen=True,
|
|
828
|
+
jupyter=False,
|
|
829
|
+
background='white',
|
|
830
|
+
legend_kwargs=categorical_legend_kwargs,
|
|
831
|
+
text_kwargs=text_kwargs,
|
|
832
|
+
window_size=(win_w, win_h)
|
|
833
|
+
)
|
|
834
|
+
three_d_plot_save(plotter_anno, filename=str(anno_plot_name))
|
|
835
|
+
|
|
836
|
+
# 2. Save each Trait 3D plot
|
|
837
|
+
legend_kwargs = dict(scalar_bar_title_size=30, scalar_bar_label_size=30, fmt="%.1e")
|
|
838
|
+
|
|
839
|
+
for trait in continuous_cols:
|
|
840
|
+
logger.info(f"Generating 3D plot for trait: {trait}")
|
|
841
|
+
|
|
842
|
+
# Calculate opacity based on LogP (exponential scaling)
|
|
843
|
+
trait_values = adata_vis.obs[trait].fillna(0).values
|
|
844
|
+
bins = np.linspace(trait_values.min(), trait_values.max(), 5)
|
|
845
|
+
# Handle case where min == max to avoid bins error
|
|
846
|
+
if bins[0] == bins[-1]:
|
|
847
|
+
opacity_show = 1.0
|
|
848
|
+
else:
|
|
849
|
+
alpha = np.exp(np.linspace(0.1, 1.0, num=(len(bins)-1))) - 1
|
|
850
|
+
alpha = alpha / max(alpha)
|
|
851
|
+
opacity_show = pd.cut(trait_values, bins=bins, labels=alpha, include_lowest=True).tolist()
|
|
852
|
+
|
|
853
|
+
# Set the clim (median of top 20 spots)
|
|
854
|
+
sorted_vals = np.sort(trait_values)[::-1]
|
|
855
|
+
max_v = np.round(np.median(sorted_vals[0:20])) if len(sorted_vals) >= 20 else sorted_vals[0]
|
|
856
|
+
if max_v <= 0: max_v = 1.0
|
|
857
|
+
|
|
858
|
+
safe_trait = "".join(c if c.isalnum() else "_" for c in trait)
|
|
859
|
+
trait_plot_name = three_d_web_dir / f"spatial_3d_trait_{safe_trait}"
|
|
860
|
+
|
|
861
|
+
plotter_trait = three_d_plot(
|
|
862
|
+
adata=adata_vis,
|
|
863
|
+
spatial_key='spatial',
|
|
864
|
+
keys=[trait],
|
|
865
|
+
cmaps=[P_COLOR],
|
|
866
|
+
point_size=point_size,
|
|
867
|
+
opacity=opacity_show,
|
|
868
|
+
clim=[0, max_v],
|
|
869
|
+
scalar_bar_titles=["-log10(p)"],
|
|
870
|
+
texts=[trait],
|
|
871
|
+
off_screen=True,
|
|
872
|
+
jupyter=False,
|
|
873
|
+
background='white',
|
|
874
|
+
legend_kwargs=legend_kwargs,
|
|
875
|
+
text_kwargs=text_kwargs,
|
|
876
|
+
window_size=(win_w,win_h)
|
|
877
|
+
)
|
|
878
|
+
three_d_plot_save(plotter_trait, filename=str(trait_plot_name))
|
|
879
|
+
|
|
880
|
+
# Return the relative path of the first available plot
|
|
881
|
+
if annotation_cols:
|
|
882
|
+
safe_first = "".join(c if c.isalnum() else "_" for c in annotation_cols[0])
|
|
883
|
+
return f"spatial_3d/spatial_3d_anno_{safe_first}.html"
|
|
884
|
+
elif continuous_cols:
|
|
885
|
+
safe_first_trait = "".join(c if c.isalnum() else "_" for c in continuous_cols[0])
|
|
886
|
+
return f"spatial_3d/spatial_3d_trait_{safe_first_trait}.html"
|
|
887
|
+
|
|
888
|
+
return None
|
|
889
|
+
|
|
890
|
+
except Exception as e:
|
|
891
|
+
logger.warning(f"Failed to create 3D visualization: {e}")
|
|
892
|
+
import traceback
|
|
893
|
+
traceback.print_exc()
|
|
894
|
+
return None
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
# =============================================================================
|
|
898
|
+
# Data Loading Functions
|
|
899
|
+
# =============================================================================
|
|
900
|
+
|
|
901
|
+
def _load_ldsc_results(report_config: QuickModeConfig) -> tuple[pd.DataFrame, list[str], list[str]]:
|
|
902
|
+
"""Load combined LDSC results and extract traits/samples."""
|
|
903
|
+
logger.info(f"Loading combined LDSC results from {report_config.ldsc_combined_parquet_path}")
|
|
904
|
+
|
|
905
|
+
if not report_config.ldsc_combined_parquet_path.exists():
|
|
906
|
+
raise FileNotFoundError(
|
|
907
|
+
f"Combined LDSC parquet not found at {report_config.ldsc_combined_parquet_path}. "
|
|
908
|
+
"Please run Cauchy combination first."
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
|
|
912
|
+
ldsc_results = pd.read_parquet(report_config.ldsc_combined_parquet_path)
|
|
913
|
+
if 'spot' in ldsc_results.columns:
|
|
914
|
+
ldsc_results.set_index('spot', inplace=True)
|
|
915
|
+
|
|
916
|
+
assert 'sample_name' in ldsc_results.columns, "LDSC combined results must have 'sample_name' column."
|
|
917
|
+
|
|
918
|
+
traits = report_config.trait_name_list
|
|
919
|
+
|
|
920
|
+
# Explicitly use sample order from report_config.sample_h5ad_dict
|
|
921
|
+
sample_names = list(report_config.sample_h5ad_dict.keys())
|
|
922
|
+
actual_samples_in_data = set(ldsc_results['sample_name'].unique())
|
|
923
|
+
|
|
924
|
+
# Assert all samples in config are actually in the data
|
|
925
|
+
missing = [s for s in sample_names if s not in actual_samples_in_data]
|
|
926
|
+
assert not missing, f"Samples {missing} in config are not present in LDSC results."
|
|
927
|
+
|
|
928
|
+
return ldsc_results, traits, sample_names
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def _load_coordinates(report_config: QuickModeConfig) -> tuple[pd.DataFrame, bool, int | None]:
|
|
932
|
+
"""
|
|
933
|
+
Load spatial coordinates from concatenated adata.
|
|
934
|
+
|
|
935
|
+
Returns:
|
|
936
|
+
Tuple of (coords DataFrame, is_3d flag, z_axis index if 3D)
|
|
937
|
+
"""
|
|
938
|
+
from gsMap.config import DatasetType
|
|
939
|
+
|
|
940
|
+
logger.info(f"Loading coordinates from {report_config.concatenated_latent_adata_path}")
|
|
941
|
+
|
|
942
|
+
adata_concat = ad.read_h5ad(report_config.concatenated_latent_adata_path, backed='r')
|
|
943
|
+
|
|
944
|
+
assert report_config.spatial_key in adata_concat.obsm
|
|
945
|
+
coords_data = adata_concat.obsm[report_config.spatial_key]
|
|
946
|
+
sample_info = adata_concat.obs[['sample_name']]
|
|
947
|
+
|
|
948
|
+
# Check if dataset type is 3D
|
|
949
|
+
is_3d_type = (
|
|
950
|
+
hasattr(report_config, 'dataset_type') and
|
|
951
|
+
report_config.dataset_type == DatasetType.SPATIAL_3D
|
|
952
|
+
)
|
|
953
|
+
has_3d_coords = coords_data.shape[1] >= 3
|
|
954
|
+
|
|
955
|
+
z_axis = None
|
|
956
|
+
if is_3d_type and has_3d_coords:
|
|
957
|
+
# True 3D coordinates
|
|
958
|
+
logger.info("Detected 3D spatial coordinates")
|
|
959
|
+
coords_3d = coords_data[:, :3]
|
|
960
|
+
|
|
961
|
+
# Detect Z axis
|
|
962
|
+
z_axis = _detect_z_axis(coords_3d, adata_concat.obs['sample_name'])
|
|
963
|
+
|
|
964
|
+
# Reorder coordinates so Z is last
|
|
965
|
+
reordered_coords, col_names = _reorder_coords_for_3d(coords_3d, z_axis)
|
|
966
|
+
coords = pd.DataFrame(reordered_coords, columns=col_names, index=adata_concat.obs_names)
|
|
967
|
+
|
|
968
|
+
# Also keep 2D coords for compatibility (use x and y from reordered)
|
|
969
|
+
coords['sx'] = coords['3d_x']
|
|
970
|
+
coords['sy'] = coords['3d_y']
|
|
971
|
+
|
|
972
|
+
elif is_3d_type and not has_3d_coords:
|
|
973
|
+
# 3D dataset type but only 2D coordinates - create pseudo Z axis
|
|
974
|
+
logger.info("3D dataset type with 2D coordinates - creating pseudo Z axis based on sample_name")
|
|
975
|
+
coords_2d = coords_data[:, :2]
|
|
976
|
+
coords = pd.DataFrame(coords_2d, columns=['sx', 'sy'], index=adata_concat.obs_names)
|
|
977
|
+
|
|
978
|
+
# Calculate the span for pseudo Z axis
|
|
979
|
+
sx_span = coords['sx'].max() - coords['sx'].min()
|
|
980
|
+
sy_span = coords['sy'].max() - coords['sy'].min()
|
|
981
|
+
z_span = max(sx_span, sy_span)
|
|
982
|
+
|
|
983
|
+
# Assign Z values based on sample order from report_config
|
|
984
|
+
sample_names_all = adata_concat.obs['sample_name']
|
|
985
|
+
actual_samples_in_data = set(sample_names_all.unique())
|
|
986
|
+
|
|
987
|
+
# Explicitly use report_config order
|
|
988
|
+
ordered_samples = list(report_config.sample_h5ad_dict.keys())
|
|
989
|
+
missing = [s for s in ordered_samples if s not in actual_samples_in_data]
|
|
990
|
+
assert not missing, f"Samples {missing} in config are not present in coordinates data."
|
|
991
|
+
|
|
992
|
+
n_samples = len(ordered_samples)
|
|
993
|
+
|
|
994
|
+
# Create evenly spaced Z values for each sample in the specific order
|
|
995
|
+
if n_samples > 1:
|
|
996
|
+
z_values_per_sample = {
|
|
997
|
+
sample: z_span * i / (n_samples - 1)
|
|
998
|
+
for i, sample in enumerate(ordered_samples)
|
|
999
|
+
}
|
|
1000
|
+
else:
|
|
1001
|
+
z_values_per_sample = {ordered_samples[0]: 0.0}
|
|
1002
|
+
|
|
1003
|
+
# Assign Z values to each spot
|
|
1004
|
+
pseudo_z = sample_names_all.map(z_values_per_sample).values
|
|
1005
|
+
|
|
1006
|
+
coords['3d_x'] = coords['sx']
|
|
1007
|
+
coords['3d_y'] = coords['sy']
|
|
1008
|
+
coords['3d_z'] = pseudo_z
|
|
1009
|
+
z_axis = 2 # Pseudo Z is always the last dimension
|
|
1010
|
+
|
|
1011
|
+
logger.info(f"Created pseudo Z axis with span {z_span:.2f} for {n_samples} samples")
|
|
1012
|
+
|
|
1013
|
+
else:
|
|
1014
|
+
# 2D coordinates (non-3D dataset type)
|
|
1015
|
+
coords_2d = coords_data[:, :2]
|
|
1016
|
+
coords = pd.DataFrame(coords_2d, columns=['sx', 'sy'], index=adata_concat.obs_names)
|
|
1017
|
+
|
|
1018
|
+
coords = pd.concat([coords, sample_info], axis=1)
|
|
1019
|
+
|
|
1020
|
+
# Close backed file
|
|
1021
|
+
if adata_concat.isbacked:
|
|
1022
|
+
adata_concat.file.close()
|
|
1023
|
+
|
|
1024
|
+
# Return is_3d as True if dataset type is 3D (regardless of coord dimensions)
|
|
1025
|
+
return coords, is_3d_type, z_axis
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
def _load_gss_and_calculate_stats_base(
|
|
1029
|
+
report_config: QuickModeConfig,
|
|
1030
|
+
ldsc_results: pd.DataFrame,
|
|
1031
|
+
coords: pd.DataFrame,
|
|
1032
|
+
report_dir: Path
|
|
1033
|
+
) -> tuple[np.ndarray, ad.AnnData, pd.DataFrame, np.ndarray]:
|
|
1034
|
+
"""Load GSS data, calculate general stats, and return analysis results."""
|
|
1035
|
+
logger.info("Loading GSS data...")
|
|
1036
|
+
|
|
1037
|
+
gss_adata = load_marker_scores_memmap_format(report_config)
|
|
1038
|
+
common_spots = np.intersect1d(gss_adata.obs_names, ldsc_results.index)
|
|
1039
|
+
logger.info(f"Common spots (gss & ldsc): {len(common_spots)}")
|
|
1040
|
+
assert len(common_spots) > 0, "No common spots found between GSS and LDSC results."
|
|
1041
|
+
|
|
1042
|
+
# Stratified subsample if requested
|
|
1043
|
+
analysis_spots = common_spots
|
|
1044
|
+
if report_config.downsampling_n_spots_pcc and len(common_spots) > report_config.downsampling_n_spots_pcc:
|
|
1045
|
+
sample_names = ldsc_results.loc[common_spots, 'sample_name']
|
|
1046
|
+
analysis_spots = _stratified_subsample(
|
|
1047
|
+
common_spots, sample_names, report_config.downsampling_n_spots_pcc
|
|
1048
|
+
)
|
|
1049
|
+
logger.info(f"Stratified subsampled to {len(analysis_spots)} spots for PCC calculation.")
|
|
1050
|
+
|
|
1051
|
+
# Filter to high expression genes
|
|
1052
|
+
exp_frac = pd.read_parquet(report_config.mean_frac_path)
|
|
1053
|
+
high_expr_genes = exp_frac[exp_frac['frac'] > 0.01].index.tolist()
|
|
1054
|
+
logger.info(f"Using {len(high_expr_genes)} high expression genes for PCC calculation.")
|
|
1055
|
+
|
|
1056
|
+
gss_adata_sub = gss_adata[analysis_spots, high_expr_genes]
|
|
1057
|
+
gss_matrix = gss_adata_sub.X
|
|
1058
|
+
gene_names = gss_adata_sub.var_names.tolist()
|
|
1059
|
+
pd.DataFrame({'gene': gene_names}).to_csv(report_dir / "gene_list.csv", index=False)
|
|
1060
|
+
|
|
1061
|
+
# Calculate gene annotation stats
|
|
1062
|
+
adata_concat = ad.read_h5ad(report_config.concatenated_latent_adata_path, backed='r')
|
|
1063
|
+
anno_col = report_config.annotation_list[0]
|
|
1064
|
+
annotations = adata_concat.obs.loc[analysis_spots, anno_col]
|
|
1065
|
+
|
|
1066
|
+
if hasattr(gss_matrix, 'toarray'):
|
|
1067
|
+
gss_matrix = gss_matrix.toarray()
|
|
1068
|
+
|
|
1069
|
+
gss_df_temp = pd.DataFrame(gss_matrix, index=analysis_spots, columns=gene_names)
|
|
1070
|
+
grouped_gss = gss_df_temp.groupby(annotations).median()
|
|
1071
|
+
|
|
1072
|
+
gene_stats = pd.DataFrame({
|
|
1073
|
+
'gene': grouped_gss.idxmax().index,
|
|
1074
|
+
'Annotation': grouped_gss.idxmax().values,
|
|
1075
|
+
'Median_GSS': grouped_gss.max().values
|
|
1076
|
+
})
|
|
1077
|
+
gene_stats.dropna(subset=['Median_GSS'], inplace=True)
|
|
1078
|
+
|
|
1079
|
+
return common_spots, gss_adata, gene_stats, analysis_spots
|
|
1080
|
+
|
|
1081
|
+
|
|
1082
|
+
def _calculate_pcc_for_single_trait_fast(
|
|
1083
|
+
trait: str,
|
|
1084
|
+
ldsc_results: pd.DataFrame,
|
|
1085
|
+
analysis_spots: np.ndarray,
|
|
1086
|
+
gss_centered: np.ndarray,
|
|
1087
|
+
gss_ssq: np.ndarray,
|
|
1088
|
+
gene_names: list[str],
|
|
1089
|
+
gene_stats: pd.DataFrame,
|
|
1090
|
+
report_config: QuickModeConfig,
|
|
1091
|
+
report_dir: Path,
|
|
1092
|
+
gss_dir: Path
|
|
1093
|
+
) -> pd.DataFrame | None:
|
|
1094
|
+
"""Calculate PCC for a single trait using pre-calculated centered GSS matrix."""
|
|
1095
|
+
if trait not in ldsc_results.columns:
|
|
1096
|
+
logger.warning(f"Trait {trait} not found in LDSC combined results. Skipping PCC calculation.")
|
|
1097
|
+
return None
|
|
1098
|
+
|
|
1099
|
+
def fast_corr(centered_matrix, ssq_matrix, vector):
|
|
1100
|
+
v_centered = vector - vector.mean()
|
|
1101
|
+
numerator = np.dot(v_centered, centered_matrix)
|
|
1102
|
+
denominator = np.sqrt(np.sum(v_centered ** 2) * ssq_matrix)
|
|
1103
|
+
return numerator / (denominator + 1e-12)
|
|
1104
|
+
|
|
1105
|
+
logger.info(f"Processing PCC for trait: {trait}")
|
|
1106
|
+
|
|
1107
|
+
logp_vec = ldsc_results.loc[analysis_spots, trait].values.astype(np.float32)
|
|
1108
|
+
assert not np.any(np.isnan(logp_vec)), f"NaN values found in LDSC results for trait {trait}."
|
|
1109
|
+
pccs = fast_corr(gss_centered, gss_ssq, logp_vec)
|
|
1110
|
+
|
|
1111
|
+
trait_pcc = pd.DataFrame({
|
|
1112
|
+
'gene': gene_names,
|
|
1113
|
+
'PCC': pccs,
|
|
1114
|
+
'trait': trait
|
|
1115
|
+
})
|
|
1116
|
+
|
|
1117
|
+
if gene_stats is not None:
|
|
1118
|
+
trait_pcc = trait_pcc.merge(gene_stats, on='gene', how='left')
|
|
1119
|
+
|
|
1120
|
+
trait_pcc_sorted = trait_pcc.sort_values('PCC', ascending=False)
|
|
1121
|
+
|
|
1122
|
+
# Save to gss_stats subfolder
|
|
1123
|
+
# Save result CSV
|
|
1124
|
+
diag_info_path = report_config.get_gene_diagnostic_info_save_path(trait)
|
|
1125
|
+
trait_pcc_sorted.to_csv(diag_info_path, index=False)
|
|
1126
|
+
|
|
1127
|
+
return trait_pcc_sorted
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
def _export_single_manhattan_js(trait: str, csv_path: Path, js_data_dir: Path):
|
|
1131
|
+
"""Export single trait Manhattan data to JS."""
|
|
1132
|
+
try:
|
|
1133
|
+
df = pd.read_csv(csv_path)
|
|
1134
|
+
if 'P' in df.columns and 'logp' not in df.columns:
|
|
1135
|
+
df['logp'] = -np.log10(df['P'] + 1e-300)
|
|
1136
|
+
|
|
1137
|
+
data_struct = {
|
|
1138
|
+
'x': df['BP_cum'].tolist(),
|
|
1139
|
+
'y': df['logp'].tolist(),
|
|
1140
|
+
'gene': df['GENE'].fillna("").tolist(),
|
|
1141
|
+
'chr': df['CHR'].astype(int).tolist(),
|
|
1142
|
+
'snp': df['SNP'].tolist() if 'SNP' in df.columns else [],
|
|
1143
|
+
'is_top': df['is_top_pcc'].astype(int).tolist() if 'is_top_pcc' in df.columns else [],
|
|
1144
|
+
'bp': df['BP'].astype(int).tolist() if 'BP' in df.columns else []
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
json_str = json.dumps(data_struct, separators=(',', ':'))
|
|
1148
|
+
safe_trait = "".join(c if c.isalnum() else "_" for c in trait)
|
|
1149
|
+
js_content = f"window.GSMAP_MANHATTAN_{safe_trait} = {json_str};"
|
|
1150
|
+
|
|
1151
|
+
with open(js_data_dir / f"manhattan_{trait}.js", "w", encoding='utf-8') as f:
|
|
1152
|
+
f.write(js_content)
|
|
1153
|
+
except Exception as e:
|
|
1154
|
+
logger.warning(f"Failed to export Manhattan JS for {trait}: {e}")
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
# =============================================================================
|
|
1158
|
+
# Data Preparation Functions
|
|
1159
|
+
# =============================================================================
|
|
1160
|
+
|
|
1161
|
+
def _save_metadata(
|
|
1162
|
+
ldsc_results: pd.DataFrame,
|
|
1163
|
+
coords: pd.DataFrame,
|
|
1164
|
+
report_dir: Path
|
|
1165
|
+
) -> pd.DataFrame:
|
|
1166
|
+
"""Save metadata and coordinates to CSV."""
|
|
1167
|
+
logger.info("Saving metadata and coordinates...")
|
|
1168
|
+
|
|
1169
|
+
common_indices = ldsc_results.index.intersection(coords.index)
|
|
1170
|
+
ldsc_subset = ldsc_results.loc[common_indices]
|
|
1171
|
+
cols_to_use = ldsc_subset.columns.difference(coords.columns)
|
|
1172
|
+
metadata = pd.concat([coords.loc[common_indices], ldsc_subset[cols_to_use]], axis=1)
|
|
1173
|
+
|
|
1174
|
+
metadata.index.name = 'spot'
|
|
1175
|
+
metadata = metadata.reset_index()
|
|
1176
|
+
metadata = metadata.loc[:, ~metadata.columns.duplicated()]
|
|
1177
|
+
metadata.to_csv(report_dir / "spot_metadata.csv", index=False)
|
|
1178
|
+
|
|
1179
|
+
return metadata
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def _prepare_manhattan_for_trait(
|
|
1185
|
+
report_config: QuickModeConfig,
|
|
1186
|
+
trait: str,
|
|
1187
|
+
weight_adata: ad.AnnData,
|
|
1188
|
+
trait_pcc_df: pd.DataFrame | None,
|
|
1189
|
+
gene_names_ref: list[str],
|
|
1190
|
+
manhattan_dir: Path
|
|
1191
|
+
) -> dict:
|
|
1192
|
+
"""Prepare Manhattan data for a single trait."""
|
|
1193
|
+
from gsMap.utils.manhattan_plot import _ManhattanPlot
|
|
1194
|
+
|
|
1195
|
+
sumstats_file = report_config.sumstats_config_dict.get(trait)
|
|
1196
|
+
if not sumstats_file or not Path(sumstats_file).exists():
|
|
1197
|
+
return {}
|
|
1198
|
+
|
|
1199
|
+
logger.info(f"Processing Manhattan for {trait}...")
|
|
1200
|
+
gwas_data = load_gwas_data(sumstats_file)
|
|
1201
|
+
|
|
1202
|
+
common_snps = weight_adata.obs_names[weight_adata.obs_names.isin(gwas_data["SNP"])]
|
|
1203
|
+
gwas_subset = gwas_data.set_index("SNP").loc[common_snps].reset_index()
|
|
1204
|
+
|
|
1205
|
+
gwas_subset = gwas_subset.drop(columns=[c for c in ["CHR", "BP"] if c in gwas_subset.columns])
|
|
1206
|
+
gwas_subset = gwas_subset.set_index("SNP").join(weight_adata.obs[["CHR", "BP"]]).reset_index()
|
|
1207
|
+
|
|
1208
|
+
snps2plot_ids = filter_snps(gwas_subset.sort_values("P"), SUBSAMPLE_SNP_NUMBER=50000)
|
|
1209
|
+
gwas_plot_data = gwas_subset[gwas_subset["SNP"].isin(snps2plot_ids)].copy()
|
|
1210
|
+
|
|
1211
|
+
gwas_plot_data["CHR"] = pd.to_numeric(gwas_plot_data["CHR"], errors='coerce')
|
|
1212
|
+
gwas_plot_data["BP"] = pd.to_numeric(gwas_plot_data["BP"], errors='coerce')
|
|
1213
|
+
gwas_plot_data = gwas_plot_data.dropna(subset=["CHR", "BP"])
|
|
1214
|
+
|
|
1215
|
+
# Gene assignment
|
|
1216
|
+
target_genes = [g for g in weight_adata.var_names if g in gene_names_ref and g != "unmapped"]
|
|
1217
|
+
if target_genes:
|
|
1218
|
+
sub_weight = weight_adata[gwas_plot_data["SNP"], target_genes].to_memory()
|
|
1219
|
+
weights_matrix = sub_weight.X
|
|
1220
|
+
|
|
1221
|
+
if sp.issparse(weights_matrix):
|
|
1222
|
+
max_idx = np.array(weights_matrix.argmax(axis=1)).ravel()
|
|
1223
|
+
max_val = np.array(weights_matrix.max(axis=1).toarray()).ravel()
|
|
1224
|
+
else:
|
|
1225
|
+
max_idx = np.argmax(weights_matrix, axis=1)
|
|
1226
|
+
max_val = np.max(weights_matrix, axis=1)
|
|
1227
|
+
|
|
1228
|
+
gene_map = np.where(max_val > 1, np.array(target_genes)[max_idx], "None")
|
|
1229
|
+
gwas_plot_data["GENE"] = gene_map
|
|
1230
|
+
|
|
1231
|
+
if trait_pcc_df is not None:
|
|
1232
|
+
top_n = getattr(report_config, 'top_corr_genes', 50)
|
|
1233
|
+
trait_top_genes = trait_pcc_df.sort_values('PCC', ascending=False).head(top_n)['gene'].tolist()
|
|
1234
|
+
gwas_plot_data["is_top_pcc"] = gwas_plot_data["GENE"].isin(trait_top_genes)
|
|
1235
|
+
else:
|
|
1236
|
+
gwas_plot_data["is_top_pcc"] = False
|
|
1237
|
+
|
|
1238
|
+
if "GENE" not in gwas_plot_data.columns:
|
|
1239
|
+
gwas_plot_data["GENE"] = "None"
|
|
1240
|
+
if "is_top_pcc" not in gwas_plot_data.columns:
|
|
1241
|
+
gwas_plot_data["is_top_pcc"] = False
|
|
1242
|
+
|
|
1243
|
+
# Calculate cumulative positions
|
|
1244
|
+
chrom_ticks = {}
|
|
1245
|
+
try:
|
|
1246
|
+
mp_helper = _ManhattanPlot(gwas_plot_data)
|
|
1247
|
+
gwas_plot_data["BP_cum"] = mp_helper.data["POSITION"].values
|
|
1248
|
+
gwas_plot_data["CHR_INDEX"] = mp_helper.data["INDEX"].values
|
|
1249
|
+
|
|
1250
|
+
chrom_groups = gwas_plot_data.groupby("CHR")["BP_cum"]
|
|
1251
|
+
chrom_ticks = {int(chrom): float(positions.median()) for chrom, positions in chrom_groups}
|
|
1252
|
+
except Exception as e:
|
|
1253
|
+
logger.warning(f"Failed to calculate Manhattan coordinates: {e}")
|
|
1254
|
+
gwas_plot_data["BP_cum"] = np.arange(len(gwas_plot_data))
|
|
1255
|
+
gwas_plot_data["CHR_INDEX"] = gwas_plot_data["CHR"] % 2
|
|
1256
|
+
|
|
1257
|
+
gwas_plot_data.to_csv(manhattan_dir / f"{trait}_manhattan.csv", index=False)
|
|
1258
|
+
return chrom_ticks
|
|
1259
|
+
|
|
1260
|
+
|
|
1261
|
+
# =============================================================================
|
|
1262
|
+
# Plot Rendering Functions
|
|
1263
|
+
# =============================================================================
|
|
1264
|
+
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def _render_gene_diagnostic_plots_refactored(
|
|
1268
|
+
report_config: QuickModeConfig,
|
|
1269
|
+
metadata: pd.DataFrame,
|
|
1270
|
+
common_spots: np.ndarray,
|
|
1271
|
+
gss_adata: ad.AnnData,
|
|
1272
|
+
n_rows: int,
|
|
1273
|
+
n_cols: int,
|
|
1274
|
+
web_report_dir: Path,
|
|
1275
|
+
report_data_dir: Path,
|
|
1276
|
+
is_step_complete_func
|
|
1277
|
+
):
|
|
1278
|
+
"""Render gene expression and GSS diagnostic plots with completeness check."""
|
|
1279
|
+
gene_plot_dir = web_report_dir / "gene_diagnostic_plots"
|
|
1280
|
+
gene_plot_dir.mkdir(exist_ok=True)
|
|
1281
|
+
|
|
1282
|
+
if not report_config.sample_h5ad_dict:
|
|
1283
|
+
logger.warning("Skipping gene diagnostic plots: missing h5ad dict")
|
|
1284
|
+
return
|
|
1285
|
+
|
|
1286
|
+
top_n = report_config.top_corr_genes
|
|
1287
|
+
# 1. Identify all genes that need to be plotted from per-trait PCC files
|
|
1288
|
+
gss_stats_dir = report_data_dir / "gss_stats"
|
|
1289
|
+
trait_top_genes = {}
|
|
1290
|
+
all_top_genes_set = set()
|
|
1291
|
+
|
|
1292
|
+
# We need to know which traits we have
|
|
1293
|
+
traits = [p.name.replace("gene_trait_correlation_", "").replace(".csv", "")
|
|
1294
|
+
for p in gss_stats_dir.glob("gene_trait_correlation_*.csv")]
|
|
1295
|
+
|
|
1296
|
+
for trait in traits:
|
|
1297
|
+
pcc_path = report_config.get_gene_diagnostic_info_save_path(trait)
|
|
1298
|
+
if pcc_path.exists():
|
|
1299
|
+
group = pd.read_csv(pcc_path)
|
|
1300
|
+
genes = group.sort_values('PCC', ascending=False).head(top_n)['gene'].tolist()
|
|
1301
|
+
trait_top_genes[trait] = genes
|
|
1302
|
+
all_top_genes_set.update(genes)
|
|
1303
|
+
|
|
1304
|
+
# Also collect all available traits for completeness check
|
|
1305
|
+
sorted(all_top_genes_set)
|
|
1306
|
+
sample_names_sorted = list(report_config.sample_h5ad_dict.keys())
|
|
1307
|
+
|
|
1308
|
+
# Pre-filter metadata for common spots once
|
|
1309
|
+
metadata_common = metadata[metadata['spot'].isin(common_spots)]
|
|
1310
|
+
|
|
1311
|
+
logger.info("Checking completeness for diagnostic plots...")
|
|
1312
|
+
|
|
1313
|
+
# Filter out combinations that already exist
|
|
1314
|
+
tasks_to_run = []
|
|
1315
|
+
|
|
1316
|
+
for sample_name in sample_names_sorted:
|
|
1317
|
+
safe_sample = "".join(c if c.isalnum() else "_" for c in sample_name)
|
|
1318
|
+
for trait in traits:
|
|
1319
|
+
if trait not in trait_top_genes: continue
|
|
1320
|
+
for gene in trait_top_genes[trait]:
|
|
1321
|
+
exp_path = gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_exp.png"
|
|
1322
|
+
gss_path = gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_gss.png"
|
|
1323
|
+
if not is_step_complete_func([exp_path, gss_path]):
|
|
1324
|
+
tasks_to_run.append((sample_name, trait, gene))
|
|
1325
|
+
|
|
1326
|
+
if not tasks_to_run:
|
|
1327
|
+
logger.info("All gene diagnostic plots already exist. Skipping.")
|
|
1328
|
+
return
|
|
1329
|
+
|
|
1330
|
+
logger.info(f"Rendering {len(tasks_to_run)} diagnostic plots...")
|
|
1331
|
+
|
|
1332
|
+
# Process sample by sample to save memory
|
|
1333
|
+
all_futures = []
|
|
1334
|
+
futures_lock = threading.Lock()
|
|
1335
|
+
max_workers = 20
|
|
1336
|
+
|
|
1337
|
+
# Use semaphore to limit the number of pending tasks in the executor queue.
|
|
1338
|
+
max_pending_tasks = max_workers * 4 # e.g., 80 pending tasks
|
|
1339
|
+
task_semaphore = threading.Semaphore(max_pending_tasks)
|
|
1340
|
+
|
|
1341
|
+
max_loading_threads = min(4, len(sample_names_sorted))
|
|
1342
|
+
|
|
1343
|
+
# Group tasks by sample for efficient processing
|
|
1344
|
+
tasks_by_sample = {}
|
|
1345
|
+
for sample, trait, gene in tasks_to_run:
|
|
1346
|
+
if sample not in tasks_by_sample: tasks_by_sample[sample] = []
|
|
1347
|
+
tasks_by_sample[sample].append((trait, gene))
|
|
1348
|
+
|
|
1349
|
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
1350
|
+
def process_sample(sample_name):
|
|
1351
|
+
if sample_name not in tasks_by_sample: return
|
|
1352
|
+
|
|
1353
|
+
h5ad_path = report_config.sample_h5ad_dict[sample_name]
|
|
1354
|
+
sample_metadata = metadata_common[metadata_common['sample_name'] == sample_name]
|
|
1355
|
+
if sample_metadata.empty: return
|
|
1356
|
+
|
|
1357
|
+
sample_spots = sample_metadata['spot'].values
|
|
1358
|
+
coords = sample_metadata[['sx', 'sy']].values
|
|
1359
|
+
|
|
1360
|
+
try:
|
|
1361
|
+
adata_rep = ad.read_h5ad(h5ad_path)
|
|
1362
|
+
suffix = f"|{sample_name}"
|
|
1363
|
+
if not str(adata_rep.obs_names[0]).endswith(suffix):
|
|
1364
|
+
adata_rep.obs_names = adata_rep.obs_names.astype(str) + suffix
|
|
1365
|
+
|
|
1366
|
+
if 'log1p' in adata_rep.uns and adata_rep.X is not None:
|
|
1367
|
+
is_count = False
|
|
1368
|
+
else:
|
|
1369
|
+
is_count, _ = setup_data_layer(adata_rep, report_config.data_layer, verbose=False)
|
|
1370
|
+
|
|
1371
|
+
if is_count:
|
|
1372
|
+
sc.pp.normalize_total(adata_rep, target_sum=1e4)
|
|
1373
|
+
sc.pp.log1p(adata_rep)
|
|
1374
|
+
|
|
1375
|
+
sample_genes = [g for trait, g in tasks_by_sample[sample_name]]
|
|
1376
|
+
unique_sample_genes = list(set(sample_genes))
|
|
1377
|
+
available_genes = [g for g in unique_sample_genes if g in adata_rep.var_names]
|
|
1378
|
+
|
|
1379
|
+
if not available_genes:
|
|
1380
|
+
del adata_rep
|
|
1381
|
+
return
|
|
1382
|
+
|
|
1383
|
+
# Extract to memory and then delete the large AnnData object to save RAM
|
|
1384
|
+
adata_sample_exp = adata_rep[sample_spots, available_genes].copy()
|
|
1385
|
+
adata_sample_gss = gss_adata[sample_spots, available_genes].to_memory()
|
|
1386
|
+
|
|
1387
|
+
del adata_rep
|
|
1388
|
+
gc.collect()
|
|
1389
|
+
|
|
1390
|
+
except Exception as e:
|
|
1391
|
+
logger.error(f"Failed to load data for {sample_name}: {e}")
|
|
1392
|
+
return
|
|
1393
|
+
|
|
1394
|
+
safe_sample = "".join(c if c.isalnum() else "_" for c in sample_name)
|
|
1395
|
+
|
|
1396
|
+
for trait, gene in tasks_by_sample[sample_name]:
|
|
1397
|
+
if gene not in available_genes: continue
|
|
1398
|
+
|
|
1399
|
+
exp_vals = np.ravel(adata_sample_exp[:, gene].X.toarray() if sp.issparse(adata_sample_exp[:, gene].X) else adata_sample_exp[:, gene].X).astype(np.float32)
|
|
1400
|
+
gss_vals = np.ravel(adata_sample_gss[:, gene].X.toarray() if sp.issparse(adata_sample_gss[:, gene].X) else adata_sample_gss[:, gene].X).astype(np.float32)
|
|
1401
|
+
|
|
1402
|
+
# Submit Expression task with semaphore throttling
|
|
1403
|
+
task_semaphore.acquire()
|
|
1404
|
+
try:
|
|
1405
|
+
f_exp = executor.submit(_render_single_sample_gene_plot_task, {
|
|
1406
|
+
'gene': gene, 'sample_name': sample_name, 'trait': trait,
|
|
1407
|
+
'plot_type': 'exp', 'coords': coords.copy(), 'values': exp_vals,
|
|
1408
|
+
'output_path': gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_exp.png",
|
|
1409
|
+
'plot_origin': report_config.plot_origin, 'fig_width': 6.0, 'dpi': 150,
|
|
1410
|
+
})
|
|
1411
|
+
f_exp.add_done_callback(lambda _: task_semaphore.release())
|
|
1412
|
+
with futures_lock:
|
|
1413
|
+
all_futures.append(f_exp)
|
|
1414
|
+
except Exception:
|
|
1415
|
+
task_semaphore.release()
|
|
1416
|
+
raise
|
|
1417
|
+
|
|
1418
|
+
# Submit GSS task with semaphore throttling
|
|
1419
|
+
task_semaphore.acquire()
|
|
1420
|
+
try:
|
|
1421
|
+
f_gss = executor.submit(_render_single_sample_gene_plot_task, {
|
|
1422
|
+
'gene': gene, 'sample_name': sample_name, 'trait': trait,
|
|
1423
|
+
'plot_type': 'gss', 'coords': coords.copy(), 'values': gss_vals,
|
|
1424
|
+
'output_path': gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_gss.png",
|
|
1425
|
+
'plot_origin': report_config.plot_origin, 'fig_width': 6.0, 'dpi': 150,
|
|
1426
|
+
})
|
|
1427
|
+
f_gss.add_done_callback(lambda _: task_semaphore.release())
|
|
1428
|
+
with futures_lock:
|
|
1429
|
+
all_futures.append(f_gss)
|
|
1430
|
+
except Exception:
|
|
1431
|
+
task_semaphore.release()
|
|
1432
|
+
raise
|
|
1433
|
+
|
|
1434
|
+
# Cleanup sample slices
|
|
1435
|
+
del adata_sample_exp, adata_sample_gss
|
|
1436
|
+
gc.collect()
|
|
1437
|
+
|
|
1438
|
+
with ThreadPoolExecutor(max_workers=max_loading_threads) as loader:
|
|
1439
|
+
loader.map(process_sample, tasks_by_sample.keys())
|
|
1440
|
+
|
|
1441
|
+
# Wait for all
|
|
1442
|
+
logger.info(f"Waiting for {len(all_futures)} diagnostic plot tasks to complete...")
|
|
1443
|
+
for future in as_completed(all_futures):
|
|
1444
|
+
future.result()
|
|
1445
|
+
|
|
1446
|
+
logger.info("Successfully finished rendering diagnostic plots.")
|
|
1447
|
+
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
# =============================================================================
|
|
1451
|
+
# Results Collection Functions
|
|
1452
|
+
# =============================================================================
|
|
1453
|
+
|
|
1454
|
+
def _collect_cauchy_results(
|
|
1455
|
+
report_config: QuickModeConfig,
|
|
1456
|
+
traits: list[str],
|
|
1457
|
+
report_dir: Path
|
|
1458
|
+
):
|
|
1459
|
+
"""Collect and save Cauchy combination results."""
|
|
1460
|
+
logger.info("Collecting Cauchy combination results...")
|
|
1461
|
+
all_cauchy = []
|
|
1462
|
+
|
|
1463
|
+
for trait in traits:
|
|
1464
|
+
for annotation in report_config.annotation_list:
|
|
1465
|
+
# Aggregated results
|
|
1466
|
+
cauchy_file_all = report_config.get_cauchy_result_file(trait, annotation=annotation, all_samples=True)
|
|
1467
|
+
if cauchy_file_all.exists():
|
|
1468
|
+
try:
|
|
1469
|
+
df = pd.read_csv(cauchy_file_all)
|
|
1470
|
+
df['trait'] = trait
|
|
1471
|
+
df['annotation_name'] = annotation
|
|
1472
|
+
df['type'] = 'aggregated'
|
|
1473
|
+
if 'sample_name' not in df.columns:
|
|
1474
|
+
df['sample_name'] = 'All Samples'
|
|
1475
|
+
|
|
1476
|
+
# Convert to -log10 scale for report
|
|
1477
|
+
df['mlog10_p_cauchy'] = -np.log10(df['p_cauchy'].clip(lower=1e-300))
|
|
1478
|
+
df['mlog10_p_median'] = -np.log10(df['p_median'].clip(lower=1e-300))
|
|
1479
|
+
|
|
1480
|
+
all_cauchy.append(df)
|
|
1481
|
+
except Exception as e:
|
|
1482
|
+
logger.warning(f"Failed to load aggregated Cauchy result {cauchy_file_all}: {e}")
|
|
1483
|
+
|
|
1484
|
+
# Per-sample results
|
|
1485
|
+
cauchy_file = report_config.get_cauchy_result_file(trait, annotation=annotation, all_samples=False)
|
|
1486
|
+
if cauchy_file.exists():
|
|
1487
|
+
try:
|
|
1488
|
+
df = pd.read_csv(cauchy_file)
|
|
1489
|
+
df['trait'] = trait
|
|
1490
|
+
df['annotation_name'] = annotation
|
|
1491
|
+
df['type'] = 'sample'
|
|
1492
|
+
|
|
1493
|
+
# Convert to -log10 scale for report
|
|
1494
|
+
df['mlog10_p_cauchy'] = -np.log10(df['p_cauchy'].clip(lower=1e-300))
|
|
1495
|
+
df['mlog10_p_median'] = -np.log10(df['p_median'].clip(lower=1e-300))
|
|
1496
|
+
|
|
1497
|
+
all_cauchy.append(df)
|
|
1498
|
+
except Exception as e:
|
|
1499
|
+
logger.warning(f"Failed to load sample Cauchy result {cauchy_file}: {e}")
|
|
1500
|
+
|
|
1501
|
+
if all_cauchy:
|
|
1502
|
+
combined_cauchy = pd.concat(all_cauchy, ignore_index=True)
|
|
1503
|
+
if 'sample' in combined_cauchy.columns and 'sample_name' not in combined_cauchy.columns:
|
|
1504
|
+
combined_cauchy = combined_cauchy.rename(columns={'sample': 'sample_name'})
|
|
1505
|
+
|
|
1506
|
+
cauchy_save_path = report_dir / "cauchy_results.csv"
|
|
1507
|
+
combined_cauchy.to_csv(cauchy_save_path, index=False)
|
|
1508
|
+
logger.info(f"Saved {len(combined_cauchy)} Cauchy results to {cauchy_save_path}")
|
|
1509
|
+
else:
|
|
1510
|
+
pd.DataFrame(columns=['trait', 'annotation_name', 'mlog10_p_cauchy', 'mlog10_p_median', 'top_95_quantile', 'type', 'sample_name']).to_csv(
|
|
1511
|
+
report_dir / "cauchy_results.csv", index=False
|
|
1512
|
+
)
|
|
1513
|
+
logger.warning("No Cauchy results found to save.")
|
|
1514
|
+
|
|
1515
|
+
|
|
1516
|
+
def _save_report_metadata(
|
|
1517
|
+
report_config: QuickModeConfig,
|
|
1518
|
+
traits: list[str],
|
|
1519
|
+
sample_names: list[str],
|
|
1520
|
+
report_dir: Path,
|
|
1521
|
+
chrom_tick_positions: dict | None = None,
|
|
1522
|
+
umap_info: dict | None = None,
|
|
1523
|
+
is_3d: bool = False,
|
|
1524
|
+
spatial_3d_path: str | None = None
|
|
1525
|
+
):
|
|
1526
|
+
"""Save report configuration metadata as JSON."""
|
|
1527
|
+
logger.info("Saving report configuration metadata...")
|
|
1528
|
+
report_meta = report_config.to_dict_with_paths_as_strings()
|
|
1529
|
+
|
|
1530
|
+
report_meta['traits'] = traits
|
|
1531
|
+
report_meta['samples'] = sample_names
|
|
1532
|
+
report_meta['annotations'] = report_config.annotation_list
|
|
1533
|
+
report_meta['top_corr_genes'] = report_config.top_corr_genes
|
|
1534
|
+
report_meta['plot_origin'] = report_config.plot_origin
|
|
1535
|
+
report_meta['legend_marker_size'] = report_config.legend_marker_size
|
|
1536
|
+
report_meta['downsampling_n_spots_2d'] = getattr(report_config, 'downsampling_n_spots_2d', 250000)
|
|
1537
|
+
|
|
1538
|
+
# Add chromosome tick positions for Manhattan plot
|
|
1539
|
+
if chrom_tick_positions:
|
|
1540
|
+
report_meta['chrom_tick_positions'] = chrom_tick_positions
|
|
1541
|
+
|
|
1542
|
+
# Add UMAP info
|
|
1543
|
+
if umap_info:
|
|
1544
|
+
report_meta['umap_info'] = umap_info
|
|
1545
|
+
|
|
1546
|
+
# Add 3D visualization info
|
|
1547
|
+
report_meta['is_3d'] = is_3d
|
|
1548
|
+
report_meta['has_3d_widget'] = spatial_3d_path is not None
|
|
1549
|
+
report_meta['spatial_3d_widget_path'] = spatial_3d_path
|
|
1550
|
+
|
|
1551
|
+
# Add dataset type info for conditional section rendering
|
|
1552
|
+
dataset_type_value = report_config.dataset_type
|
|
1553
|
+
if hasattr(dataset_type_value, 'value'):
|
|
1554
|
+
dataset_type_value = dataset_type_value.value
|
|
1555
|
+
report_meta['dataset_type'] = dataset_type_value
|
|
1556
|
+
report_meta['dataset_type_label'] = {
|
|
1557
|
+
'spatial3D': 'Spatial 3D',
|
|
1558
|
+
'spatial2D': 'Spatial 2D',
|
|
1559
|
+
'scRNA': 'Single Cell RNA-seq'
|
|
1560
|
+
}.get(dataset_type_value, dataset_type_value)
|
|
1561
|
+
|
|
1562
|
+
with open(report_dir / "report_meta.json", "w") as f:
|
|
1563
|
+
json.dump(report_meta, f)
|
|
1564
|
+
|
|
1565
|
+
|
|
1566
|
+
def _copy_js_assets(report_dir: Path):
|
|
1567
|
+
"""Copy bundled JS assets for local usage (no network required)."""
|
|
1568
|
+
logger.info("Copying JS assets for local usage...")
|
|
1569
|
+
js_lib_dir = report_dir / "js_lib"
|
|
1570
|
+
js_lib_dir.mkdir(exist_ok=True)
|
|
1571
|
+
|
|
1572
|
+
# Copy bundled Alpine.js and Tailwind.js from static folder
|
|
1573
|
+
static_js_dir = Path(__file__).parent / "static" / "js_lib"
|
|
1574
|
+
if static_js_dir.exists():
|
|
1575
|
+
for js_file in static_js_dir.glob("*.js"):
|
|
1576
|
+
dest = js_lib_dir / js_file.name
|
|
1577
|
+
if not dest.exists():
|
|
1578
|
+
shutil.copy2(js_file, dest)
|
|
1579
|
+
logger.info(f"Copied {js_file.name}")
|
|
1580
|
+
|
|
1581
|
+
# Copy plotly.min.js from installed plotly Python package
|
|
1582
|
+
plotly_js_src = Path(plotly.__file__).parent / "package_data" / "plotly.min.js"
|
|
1583
|
+
plotly_dest = js_lib_dir / "plotly.min.js"
|
|
1584
|
+
if plotly_js_src.exists() and not plotly_dest.exists():
|
|
1585
|
+
shutil.copy2(plotly_js_src, plotly_dest)
|
|
1586
|
+
logger.info("Copied plotly.min.js from plotly package")
|
|
1587
|
+
|
|
1588
|
+
|
|
1589
|
+
# =============================================================================
|
|
1590
|
+
# JS Export Functions
|
|
1591
|
+
# =============================================================================
|
|
1592
|
+
|
|
1593
|
+
def export_data_as_js_modules(data_dir: Path):
|
|
1594
|
+
"""
|
|
1595
|
+
Convert the CSV data in the report directory into JavaScript modules (.js files)
|
|
1596
|
+
that assign the data to window global variables.
|
|
1597
|
+
This allows loading data via <script> tags locally without CORS issues.
|
|
1598
|
+
"""
|
|
1599
|
+
logger.info("Exporting data as JS modules...")
|
|
1600
|
+
js_data_dir = data_dir / "js_data"
|
|
1601
|
+
js_data_dir.mkdir(exist_ok=True)
|
|
1602
|
+
|
|
1603
|
+
meta = {}
|
|
1604
|
+
meta_file = data_dir / "report_meta.json"
|
|
1605
|
+
if meta_file.exists():
|
|
1606
|
+
try:
|
|
1607
|
+
with open(meta_file) as f:
|
|
1608
|
+
meta = json.load(f)
|
|
1609
|
+
except Exception as e:
|
|
1610
|
+
logger.warning(f"Failed to load report_meta.json: {e}")
|
|
1611
|
+
|
|
1612
|
+
# Use per-sample export for efficient on-demand loading (instead of monolithic _export_metadata_js)
|
|
1613
|
+
_export_per_sample_spatial_js(data_dir, js_data_dir, meta)
|
|
1614
|
+
_export_cauchy_js(data_dir, js_data_dir)
|
|
1615
|
+
_export_manhattan_js(data_dir, js_data_dir)
|
|
1616
|
+
_export_umap_js(data_dir, js_data_dir, meta)
|
|
1617
|
+
_export_report_meta_js(data_dir, js_data_dir, meta)
|
|
1618
|
+
|
|
1619
|
+
logger.info(f"JS modules exported to {js_data_dir}")
|
|
1620
|
+
|
|
1621
|
+
|
|
1622
|
+
|
|
1623
|
+
|
|
1624
|
+
def _export_per_sample_spatial_js(data_dir: Path, js_data_dir: Path, meta: dict):
|
|
1625
|
+
"""
|
|
1626
|
+
Export spatial metadata as per-sample JS files for efficient on-demand loading.
|
|
1627
|
+
|
|
1628
|
+
This creates:
|
|
1629
|
+
- sample_index.js: Lightweight index mapping sample names to file info
|
|
1630
|
+
- sample_{name}_spatial.js: Per-sample data with coordinates, traits, annotations
|
|
1631
|
+
|
|
1632
|
+
This approach is much more efficient than loading all samples at once,
|
|
1633
|
+
especially for datasets with millions of spots.
|
|
1634
|
+
"""
|
|
1635
|
+
metadata_file = data_dir / "spot_metadata.csv"
|
|
1636
|
+
if not metadata_file.exists():
|
|
1637
|
+
logger.warning("spot_metadata.csv not found, skipping per-sample spatial export")
|
|
1638
|
+
return
|
|
1639
|
+
|
|
1640
|
+
logger.info("Exporting per-sample spatial data as JS modules...")
|
|
1641
|
+
df = pd.read_csv(metadata_file)
|
|
1642
|
+
|
|
1643
|
+
samples = meta.get('samples', [])
|
|
1644
|
+
traits = meta.get('traits', [])
|
|
1645
|
+
annotations = meta.get('annotations', [])
|
|
1646
|
+
|
|
1647
|
+
if not samples:
|
|
1648
|
+
# Fallback to unique sample names from data
|
|
1649
|
+
samples = df['sample_name'].unique().tolist() if 'sample_name' in df.columns else []
|
|
1650
|
+
|
|
1651
|
+
sample_index = {}
|
|
1652
|
+
max_spots_2d = meta.get('downsampling_n_spots_2d', 250000)
|
|
1653
|
+
|
|
1654
|
+
for sample_name in samples:
|
|
1655
|
+
sample_df = df[df['sample_name'] == sample_name]
|
|
1656
|
+
|
|
1657
|
+
if len(sample_df) == 0:
|
|
1658
|
+
logger.warning(f"No data found for sample: {sample_name}")
|
|
1659
|
+
continue
|
|
1660
|
+
|
|
1661
|
+
# Track original count before downsampling
|
|
1662
|
+
n_spots_original = len(sample_df)
|
|
1663
|
+
|
|
1664
|
+
# Downsample if needed for 2D visualization performance
|
|
1665
|
+
if len(sample_df) > max_spots_2d:
|
|
1666
|
+
logger.info(f"Downsampling {sample_name} from {len(sample_df):,} to {max_spots_2d:,} spots for 2D visualization")
|
|
1667
|
+
sample_df = sample_df.sample(n=max_spots_2d, random_state=42)
|
|
1668
|
+
|
|
1669
|
+
# Calculate point size for this sample
|
|
1670
|
+
_, point_size = estimate_plotly_point_size(sample_df[['sx','sy']], DEFAULT_PIXEL_WIDTH=560)
|
|
1671
|
+
|
|
1672
|
+
# Build columnar data structure for efficient ScatterGL rendering
|
|
1673
|
+
data_struct = {
|
|
1674
|
+
'spot': sample_df['spot'].tolist(),
|
|
1675
|
+
'sx': sample_df['sx'].tolist(),
|
|
1676
|
+
'sy': sample_df['sy'].tolist(),
|
|
1677
|
+
'point_size': round(point_size, 2),
|
|
1678
|
+
}
|
|
1679
|
+
|
|
1680
|
+
# Add 3D coordinates if present
|
|
1681
|
+
for coord in ['3d_x', '3d_y', '3d_z']:
|
|
1682
|
+
if coord in sample_df.columns:
|
|
1683
|
+
data_struct[coord] = sample_df[coord].tolist()
|
|
1684
|
+
|
|
1685
|
+
# Add all annotation columns
|
|
1686
|
+
for anno in annotations:
|
|
1687
|
+
if anno in sample_df.columns:
|
|
1688
|
+
data_struct[anno] = sample_df[anno].tolist()
|
|
1689
|
+
|
|
1690
|
+
# Add all trait columns (round to 1 decimal to reduce file size)
|
|
1691
|
+
for trait in traits:
|
|
1692
|
+
if trait in sample_df.columns:
|
|
1693
|
+
data_struct[trait] = [
|
|
1694
|
+
round(v, 1) if pd.notna(v) else None
|
|
1695
|
+
for v in sample_df[trait]
|
|
1696
|
+
]
|
|
1697
|
+
|
|
1698
|
+
# Create safe filename (replace non-alphanumeric chars)
|
|
1699
|
+
safe_name = "".join(c if c.isalnum() else "_" for c in sample_name)
|
|
1700
|
+
var_name = f"GSMAP_SAMPLE_{safe_name}"
|
|
1701
|
+
file_name = f"sample_{safe_name}_spatial.js"
|
|
1702
|
+
|
|
1703
|
+
# Write per-sample JS file
|
|
1704
|
+
js_content = f"window.{var_name} = {json.dumps(data_struct, separators=(',', ':'))};"
|
|
1705
|
+
with open(js_data_dir / file_name, "w", encoding='utf-8') as f:
|
|
1706
|
+
f.write(js_content)
|
|
1707
|
+
|
|
1708
|
+
# Track in sample index
|
|
1709
|
+
sample_index[sample_name] = {
|
|
1710
|
+
'var_name': var_name,
|
|
1711
|
+
'file': file_name,
|
|
1712
|
+
'n_spots': len(sample_df),
|
|
1713
|
+
'n_spots_original': n_spots_original,
|
|
1714
|
+
'point_size': round(point_size, 2)
|
|
1715
|
+
}
|
|
1716
|
+
|
|
1717
|
+
if n_spots_original > len(sample_df):
|
|
1718
|
+
logger.info(f" Exported {sample_name}: {len(sample_df):,} spots (downsampled from {n_spots_original:,}) -> {file_name}")
|
|
1719
|
+
else:
|
|
1720
|
+
logger.info(f" Exported {sample_name}: {len(sample_df):,} spots -> {file_name}")
|
|
1721
|
+
|
|
1722
|
+
# Export lightweight sample index (loaded upfront)
|
|
1723
|
+
js_content = f"window.GSMAP_SAMPLE_INDEX = {json.dumps(sample_index, separators=(',', ':'))};"
|
|
1724
|
+
with open(js_data_dir / "sample_index.js", "w", encoding='utf-8') as f:
|
|
1725
|
+
f.write(js_content)
|
|
1726
|
+
|
|
1727
|
+
logger.info(f"Exported {len(sample_index)} per-sample spatial JS files + sample_index.js")
|
|
1728
|
+
|
|
1729
|
+
|
|
1730
|
+
def _export_cauchy_js(data_dir: Path, js_data_dir: Path):
|
|
1731
|
+
"""Export Cauchy results as JS module."""
|
|
1732
|
+
cauchy_file = data_dir / "cauchy_results.csv"
|
|
1733
|
+
if cauchy_file.exists():
|
|
1734
|
+
df = pd.read_csv(cauchy_file)
|
|
1735
|
+
data_json = df.to_json(orient='records')
|
|
1736
|
+
js_content = f"window.GSMAP_CAUCHY = {data_json};"
|
|
1737
|
+
with open(js_data_dir / "cauchy_results.js", "w", encoding='utf-8') as f:
|
|
1738
|
+
f.write(js_content)
|
|
1739
|
+
|
|
1740
|
+
|
|
1741
|
+
def _export_manhattan_js(data_dir: Path, js_data_dir: Path):
|
|
1742
|
+
"""Export Manhattan data as JS modules (one per trait)."""
|
|
1743
|
+
manhattan_dir = data_dir / "manhattan_data"
|
|
1744
|
+
if not manhattan_dir.exists():
|
|
1745
|
+
return
|
|
1746
|
+
|
|
1747
|
+
for csv_file in manhattan_dir.glob("*_manhattan.csv"):
|
|
1748
|
+
trait = csv_file.name.replace("_manhattan.csv", "")
|
|
1749
|
+
try:
|
|
1750
|
+
df = pd.read_csv(csv_file)
|
|
1751
|
+
if 'P' in df.columns and 'logp' not in df.columns:
|
|
1752
|
+
df['logp'] = -np.log10(df['P'] + 1e-300)
|
|
1753
|
+
|
|
1754
|
+
data_struct = {
|
|
1755
|
+
'x': df['BP_cum'].tolist(),
|
|
1756
|
+
'y': df['logp'].tolist(),
|
|
1757
|
+
'gene': df['GENE'].fillna("").tolist(),
|
|
1758
|
+
'chr': df['CHR'].astype(int).tolist(),
|
|
1759
|
+
'snp': df['SNP'].tolist() if 'SNP' in df.columns else [],
|
|
1760
|
+
'is_top': df['is_top_pcc'].astype(int).tolist() if 'is_top_pcc' in df.columns else [],
|
|
1761
|
+
'bp': df['BP'].astype(int).tolist() if 'BP' in df.columns else []
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
json_str = json.dumps(data_struct, separators=(',', ':'))
|
|
1765
|
+
safe_trait = "".join(c if c.isalnum() else "_" for c in trait)
|
|
1766
|
+
js_content = f"window.GSMAP_MANHATTAN_{safe_trait} = {json_str};"
|
|
1767
|
+
|
|
1768
|
+
with open(js_data_dir / f"manhattan_{trait}.js", "w", encoding='utf-8') as f:
|
|
1769
|
+
f.write(js_content)
|
|
1770
|
+
|
|
1771
|
+
except Exception as e:
|
|
1772
|
+
logger.warning(f"Failed to export Manhattan JS for {trait}: {e}")
|
|
1773
|
+
|
|
1774
|
+
|
|
1775
|
+
def _export_report_meta_js(data_dir: Path, js_data_dir: Path, meta: dict):
|
|
1776
|
+
"""Export report metadata as JS module."""
|
|
1777
|
+
if meta:
|
|
1778
|
+
js_content = f"window.GSMAP_REPORT_META = {json.dumps(meta, separators=(',', ':'))};"
|
|
1779
|
+
with open(js_data_dir / "report_meta.js", "w", encoding='utf-8') as f:
|
|
1780
|
+
f.write(js_content)
|
|
1781
|
+
|
|
1782
|
+
|
|
1783
|
+
def _export_umap_js(data_dir: Path, js_data_dir: Path, meta: dict):
|
|
1784
|
+
"""Export UMAP data as JS module."""
|
|
1785
|
+
umap_file = data_dir / "umap_data.csv"
|
|
1786
|
+
if umap_file.exists():
|
|
1787
|
+
df = pd.read_csv(umap_file)
|
|
1788
|
+
|
|
1789
|
+
# Build efficient columnar structure for ScatterGL
|
|
1790
|
+
data_struct = {
|
|
1791
|
+
'spot': df['spot'].tolist(),
|
|
1792
|
+
'sample_name': df['sample_name'].tolist(),
|
|
1793
|
+
'umap_cell_x': df['umap_cell_x'].tolist(),
|
|
1794
|
+
'umap_cell_y': df['umap_cell_y'].tolist(),
|
|
1795
|
+
'point_size_cell': meta.get('umap_info', {}).get('point_size_cell', 4),
|
|
1796
|
+
'default_opacity': meta.get('umap_info', {}).get('default_opacity', 0.8)
|
|
1797
|
+
}
|
|
1798
|
+
|
|
1799
|
+
# Add niche UMAP if available
|
|
1800
|
+
if 'umap_niche_x' in df.columns:
|
|
1801
|
+
data_struct['umap_niche_x'] = df['umap_niche_x'].tolist()
|
|
1802
|
+
data_struct['umap_niche_y'] = df['umap_niche_y'].tolist()
|
|
1803
|
+
data_struct['point_size_niche'] = meta.get('umap_info', {}).get('point_size_niche', 4)
|
|
1804
|
+
data_struct['has_niche'] = True
|
|
1805
|
+
else:
|
|
1806
|
+
data_struct['has_niche'] = False
|
|
1807
|
+
|
|
1808
|
+
# Add all annotation columns (excluding known columns)
|
|
1809
|
+
known_cols = {'spot', 'sample_name', 'umap_cell_x', 'umap_cell_y', 'umap_niche_x', 'umap_niche_y'}
|
|
1810
|
+
annotation_cols = [c for c in df.columns if c not in known_cols]
|
|
1811
|
+
data_struct['annotation_columns'] = annotation_cols
|
|
1812
|
+
|
|
1813
|
+
from gsMap.report.visualize import _create_color_map
|
|
1814
|
+
P_COLOR = ['#313695', '#4575b4', '#74add1', '#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
|
|
1815
|
+
color_maps = {}
|
|
1816
|
+
|
|
1817
|
+
for col in annotation_cols:
|
|
1818
|
+
data_struct[col] = df[col].tolist()
|
|
1819
|
+
# Generate color map for each column
|
|
1820
|
+
if pd.api.types.is_numeric_dtype(df[col]):
|
|
1821
|
+
color_maps[col] = P_COLOR
|
|
1822
|
+
else:
|
|
1823
|
+
# Convert to string before sorting to avoid TypeError with mixed float/str (e.g. NaN)
|
|
1824
|
+
unique_vals = sorted(df[col].astype(str).unique().tolist())
|
|
1825
|
+
color_maps[col] = _create_color_map(unique_vals, hex=True, rng=42)
|
|
1826
|
+
|
|
1827
|
+
data_struct['color_maps'] = color_maps
|
|
1828
|
+
|
|
1829
|
+
js_content = f"window.GSMAP_UMAP = {json.dumps(data_struct, separators=(',', ':'))};"
|
|
1830
|
+
with open(js_data_dir / "umap_data.js", "w", encoding='utf-8') as f:
|
|
1831
|
+
f.write(js_content)
|
|
1832
|
+
logger.info(f"Exported UMAP data with {len(df)} points")
|