gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1832 @@
1
+ """
2
+ Report Data Preparation Module
3
+
4
+ This module prepares data for the Alpine.js + Tailwind CSS + Plotly.js report.
5
+ All data is exported as JS files with window global variables to bypass CORS restrictions
6
+ when opening index.html via file:// protocol.
7
+ """
8
+
9
+ import gc
10
+ import json
11
+ import logging
12
+ import shutil
13
+ import threading
14
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
15
+ from pathlib import Path
16
+
17
+ import anndata as ad
18
+ import numpy as np
19
+ import pandas as pd
20
+ import plotly
21
+ import scanpy as sc
22
+ import scipy.sparse as sp
23
+
24
+ from gsMap.config import QuickModeConfig
25
+ from gsMap.config.latent2gene_config import DatasetType
26
+ from gsMap.find_latent.st_process import setup_data_layer
27
+ from gsMap.report.diagnosis import filter_snps, load_gwas_data
28
+ from gsMap.report.three_d_plot.three_d_plots import three_d_plot, three_d_plot_save
29
+ from gsMap.report.visualize import (
30
+ estimate_plotly_point_size,
31
+ )
32
+ from gsMap.spatial_ldsc.io import load_marker_scores_memmap_format
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class ReportDataManager:
38
+ def __init__(self, report_config: QuickModeConfig):
39
+ self.report_config = report_config
40
+
41
+ # Two separate directories: data files vs web report
42
+ self.report_data_dir = report_config.report_data_dir
43
+ self.web_report_dir = report_config.web_report_dir
44
+ self.js_data_dir = self.web_report_dir / "js_data"
45
+
46
+ # Create directories
47
+ self.report_data_dir.mkdir(parents=True, exist_ok=True)
48
+ self.web_report_dir.mkdir(parents=True, exist_ok=True)
49
+ self.js_data_dir.mkdir(exist_ok=True)
50
+
51
+ if self.report_config.sample_h5ad_dict is None:
52
+ self.report_config._resolve_h5ad_inputs()
53
+
54
+ self.force_re_run = getattr(report_config, 'force_report_re_run', False)
55
+
56
+ # Internal state
57
+ self.ldsc_results = None
58
+ self.traits = []
59
+ self.sample_names = []
60
+ self.coords = None
61
+ self.is_3d = False
62
+ self.z_axis = None
63
+ self.common_spots = None
64
+ self.analysis_spots = None
65
+ self.gene_stats = None
66
+ self.metadata = None
67
+ self.gss_adata = None
68
+ self.chrom_tick_positions = None
69
+ self.umap_info = None
70
+ self.spatial_3d_html = None
71
+
72
+ def close(self):
73
+ """Release resources and close file handles."""
74
+ if self.gss_adata is not None:
75
+ logger.info("Closing GSS AnnData and memory maps...")
76
+ if 'memmap_manager' in self.gss_adata.uns:
77
+ self.gss_adata.uns['memmap_manager'].close()
78
+ self.gss_adata = None
79
+ gc.collect()
80
+
81
+ def _is_step_complete(self, files: list[Path]) -> bool:
82
+ if self.force_re_run:
83
+ return False
84
+ return all(f.exists() for f in files)
85
+
86
+ def run(self):
87
+ """Orchestrate the report data preparation."""
88
+ logger.info("Starting report data preparation...")
89
+ try:
90
+ # 1. Base Metadata
91
+ self.prepare_base_metadata()
92
+
93
+ # 2. GSS Statistics (PCC)
94
+ self.prepare_gss_stats()
95
+
96
+ # 3. Spot Metadata (depends on GSS stats for gene_list.csv)
97
+ self.prepare_spot_metadata()
98
+
99
+ # 4. Manhattan Data
100
+ self.prepare_manhattan_data()
101
+
102
+ # 5. Static Plots
103
+ self.render_static_plots()
104
+
105
+ # 6. Cauchy Results
106
+ self.collect_cauchy_results()
107
+
108
+ # 7. UMAP Data
109
+ self.prepare_umap_data()
110
+
111
+ # 8. 3D Visualization
112
+ self.prepare_3d_visualization()
113
+
114
+ # 9. Finalize Metadata and JS Assets
115
+ self.finalize_report()
116
+
117
+ logger.info("Report data preparation complete.")
118
+ logger.info(f" Data files: {self.report_data_dir}")
119
+ logger.info(f" Web report: {self.web_report_dir}")
120
+ finally:
121
+ self.close()
122
+
123
+ return self.web_report_dir
124
+
125
+
126
+ def prepare_base_metadata(self):
127
+ """Load LDSC results and coordinates."""
128
+ # 1. Load LDSC results
129
+ if self.ldsc_results is None:
130
+ self.ldsc_results, self.traits, self.sample_names = _load_ldsc_results(self.report_config)
131
+
132
+ # 2. Load coordinates
133
+ if self.coords is None:
134
+ self.coords, self.is_3d, self.z_axis = _load_coordinates(self.report_config)
135
+
136
+ # Export base metadata as JS if needed (usually handled in finalizing or per-sample)
137
+ # However, _export_per_sample_spatial_js is what we use now.
138
+
139
+ def prepare_gss_stats(self):
140
+ """Load GSS data, calculate PCC per trait, and split results."""
141
+ gss_dir = self.report_data_dir / "gss_stats"
142
+ gss_dir.mkdir(exist_ok=True)
143
+
144
+ gene_list_file = self.report_data_dir / "gene_list.csv"
145
+
146
+ # Determine which traits need PCC calculation
147
+ traits_to_run = []
148
+ for trait in self.traits:
149
+ csv_path = self.report_config.get_gene_diagnostic_info_save_path(trait)
150
+ js_path = self.js_data_dir / "gss_stats" / f"gene_trait_correlation_{trait}.js"
151
+ if not self._is_step_complete([csv_path, js_path]):
152
+ traits_to_run.append(trait)
153
+
154
+ if not traits_to_run and self._is_step_complete([gene_list_file]):
155
+ logger.info("GSS statistics (PCC) already complete for all traits. Skipping.")
156
+ return
157
+
158
+ logger.info(f"Processing GSS statistics for {len(traits_to_run)} traits...")
159
+
160
+ # Load GSS and common spots
161
+ if self.gss_adata is None:
162
+ self.common_spots, self.gss_adata, self.gene_stats, self.analysis_spots = _load_gss_and_calculate_stats_base(
163
+ self.report_config, self.ldsc_results, self.coords, self.report_data_dir
164
+ )
165
+
166
+ # Pre-filter to high expression genes and pre-calculate centered matrix
167
+ exp_frac = pd.read_parquet(self.report_config.mean_frac_path)
168
+ high_expr_genes = exp_frac[exp_frac['frac'] > 0.01].index.tolist()
169
+
170
+ # Use analysis_spots (subsampled) for PCC calculation
171
+ gss_adata_sub = self.gss_adata[self.analysis_spots, high_expr_genes]
172
+ gss_matrix = gss_adata_sub.X
173
+ if hasattr(gss_matrix, 'toarray'):
174
+ gss_matrix = gss_matrix.toarray()
175
+
176
+ gss_mean = gss_matrix.mean(axis=0).astype(np.float32)
177
+ gss_centered = (gss_matrix - gss_mean).astype(np.float32)
178
+ gss_ssq = np.sum(gss_centered ** 2, axis=0)
179
+ gene_names = gss_adata_sub.var_names.tolist()
180
+
181
+ # Calculate PCC for each missing trait
182
+ all_pcc = []
183
+ for trait in self.traits:
184
+ # We check both the CSV in gss_stats and the JS in js_data/gss_stats
185
+ csv_path = self.report_config.get_gene_diagnostic_info_save_path(trait)
186
+ js_path = self.js_data_dir / "gss_stats" / f"gene_trait_correlation_{trait}.js"
187
+
188
+ if self._is_step_complete([csv_path, js_path]):
189
+ continue
190
+
191
+ trait_pcc = _calculate_pcc_for_single_trait_fast(
192
+ trait, self.ldsc_results, self.analysis_spots, gss_centered, gss_ssq, gene_names, self.gene_stats, self.report_config, self.report_data_dir, gss_dir
193
+ )
194
+ if trait_pcc is not None:
195
+ all_pcc.append(trait_pcc)
196
+ # Export to JS immediately after CSV creation
197
+ self._export_single_pcc_js(trait, trait_pcc)
198
+
199
+ def prepare_spot_metadata(self):
200
+ """Save spot metadata and coordinates."""
201
+ metadata_file = self.report_data_dir / "spot_metadata.csv"
202
+ if self._is_step_complete([metadata_file]):
203
+ logger.info("Spot metadata already exists. Skipping.")
204
+ self.metadata = pd.read_csv(metadata_file)
205
+ if self.common_spots is None:
206
+ self.common_spots = self.metadata['spot'].values
207
+ return
208
+
209
+ self.metadata = _save_metadata(self.ldsc_results, self.coords, self.report_data_dir)
210
+ if self.common_spots is None:
211
+ self.common_spots = self.metadata['spot'].values
212
+
213
+ def _export_single_pcc_js(self, trait, df):
214
+ """Export single trait PCC results to JS in gss_stats subfolder."""
215
+ gss_js_dir = self.js_data_dir / "gss_stats"
216
+ gss_js_dir.mkdir(exist_ok=True)
217
+
218
+ data_json = df.to_json(orient='records')
219
+ var_name = f"GSMAP_GENE_TRAIT_CORRELATION_{"".join(c if c.isalnum() else "_" for c in trait)}"
220
+ js_content = f"window.{var_name} = {data_json};"
221
+ with open(gss_js_dir / f"gene_trait_correlation_{trait}.js", "w", encoding='utf-8') as f:
222
+ f.write(js_content)
223
+ def prepare_manhattan_data(self):
224
+ """Prepare Manhattan data for all traits."""
225
+ manhattan_dir = self.report_data_dir / "manhattan_data"
226
+ manhattan_dir.mkdir(exist_ok=True)
227
+
228
+ # Determine which traits need Manhattan data
229
+ traits_to_run = []
230
+ for trait in self.traits:
231
+ csv_path = manhattan_dir / f"{trait}_manhattan.csv"
232
+ js_path = self.js_data_dir / f"manhattan_{trait}.js"
233
+ if not self._is_step_complete([csv_path, js_path]):
234
+ traits_to_run.append(trait)
235
+
236
+ if not traits_to_run:
237
+ logger.info("Manhattan data already complete for all traits. Skipping.")
238
+ return
239
+
240
+ logger.info(f"Processing Manhattan data for {len(traits_to_run)} traits...")
241
+
242
+ # Load weights
243
+ logger.info(f"Loading weights from {self.report_config.snp_gene_weight_adata_path}")
244
+ weight_adata = ad.read_h5ad(self.report_config.snp_gene_weight_adata_path)
245
+
246
+ # Load gene ref
247
+ genes_file = self.report_data_dir / "gene_list.csv"
248
+ gene_names_ref = pd.read_csv(genes_file)['gene'].tolist() if genes_file.exists() else []
249
+
250
+ chrom_tick_positions = {}
251
+ for trait in traits_to_run:
252
+ try:
253
+ # Load trait-specific PCC data
254
+ trait_pcc_file = self.report_config.get_gene_diagnostic_info_save_path(trait)
255
+ trait_pcc_df = pd.read_csv(trait_pcc_file) if trait_pcc_file.exists() else None
256
+
257
+ chrom_ticks = _prepare_manhattan_for_trait(
258
+ self.report_config, trait, weight_adata, trait_pcc_df, gene_names_ref, manhattan_dir
259
+ )
260
+ chrom_tick_positions[trait] = chrom_ticks
261
+ # Export to JS
262
+ _export_single_manhattan_js(trait, manhattan_dir / f"{trait}_manhattan.csv", self.js_data_dir)
263
+ except Exception as e:
264
+ logger.warning(f"Failed to prepare Manhattan data for {trait}: {e}")
265
+
266
+ self.chrom_tick_positions = chrom_tick_positions
267
+
268
+ def render_static_plots(self):
269
+ """Render static plots (LDSC, Annotation, Gene Diagnostics)."""
270
+ dataset_type = getattr(self.report_config, 'dataset_type', DatasetType.SPATIAL_2D)
271
+ is_spatial = dataset_type in (DatasetType.SPATIAL_2D, DatasetType.SPATIAL_3D)
272
+
273
+ if not is_spatial:
274
+ logger.info("Skipping static plot rendering for non-spatial dataset type")
275
+ return
276
+
277
+ from .visualize import VisualizeRunner
278
+ visualizer = VisualizeRunner(self.report_config)
279
+
280
+ n_samples = len(self.sample_names)
281
+ n_cols = min(4, n_samples)
282
+ n_rows = (n_samples + n_cols - 1) // n_cols
283
+
284
+ obs_data = self.metadata.copy()
285
+ if 'sample' not in obs_data.columns:
286
+ obs_data['sample'] = obs_data['sample_name']
287
+
288
+ if self.report_config.generate_multi_sample_plots:
289
+ # Render LDSC plots
290
+ spatial_plot_dir = self.web_report_dir / "spatial_plots"
291
+ spatial_plot_dir.mkdir(exist_ok=True)
292
+ for trait in self.traits:
293
+ trait_plot_path = spatial_plot_dir / f"ldsc_{trait}.png"
294
+ if not self._is_step_complete([trait_plot_path]):
295
+ logger.info(f"Rendering LDSC plot for {trait}...")
296
+ visualizer._create_single_trait_multi_sample_matplotlib_plot(
297
+ obs_ldsc_merged=obs_data,
298
+ trait_abbreviation=trait,
299
+ sample_name_list=self.sample_names,
300
+ output_png_path=trait_plot_path,
301
+ n_rows=n_rows, n_cols=n_cols,
302
+ subplot_width_inches=5.0
303
+ )
304
+
305
+ # Render annotation plots
306
+ anno_dir = self.web_report_dir / "annotation_plots"
307
+ anno_dir.mkdir(exist_ok=True)
308
+ for anno in self.report_config.annotation_list:
309
+ anno_plot_path = anno_dir / f"anno_{anno}.png"
310
+ if not self._is_step_complete([anno_plot_path]):
311
+ logger.info(f"Rendering plot for annotation: {anno}...")
312
+ fig = visualizer._create_multi_sample_annotation_plot(
313
+ obs_ldsc_merged=obs_data,
314
+ annotation=anno,
315
+ sample_names_list=self.sample_names,
316
+ output_dir=None,
317
+ n_rows=n_rows, n_cols=n_cols,
318
+ fig_width=5 * n_cols, fig_height=5 * n_rows
319
+ )
320
+ import matplotlib.pyplot as plt
321
+ fig.savefig(anno_plot_path, dpi=300, bbox_inches='tight')
322
+ plt.close(fig)
323
+
324
+ # Gene diagnostic plots
325
+ # Ensure base metadata exists (needed for common_spots and ldsc_results)
326
+ if self.ldsc_results is None or self.coords is None:
327
+ self.prepare_base_metadata()
328
+
329
+ if self.metadata is None:
330
+ self.prepare_spot_metadata()
331
+
332
+ # Ensure GSS data is loaded if we need to plot
333
+ if self.gss_adata is None:
334
+ self.common_spots, self.gss_adata, self.gene_stats, self.analysis_spots = _load_gss_and_calculate_stats_base(
335
+ self.report_config, self.ldsc_results, self.coords, self.report_data_dir
336
+ )
337
+
338
+ _render_gene_diagnostic_plots_refactored(self.report_config, self.metadata, self.common_spots, self.gss_adata, n_rows, n_cols, self.web_report_dir, self.report_data_dir, self._is_step_complete)
339
+
340
+ def collect_cauchy_results(self):
341
+ """Collect and save Cauchy combination results."""
342
+ cauchy_file = self.report_data_dir / "cauchy_results.csv"
343
+ cauchy_js = self.js_data_dir / "cauchy_results.js"
344
+ # Since this combines all traits/annotations, we check for the final csv
345
+ if self._is_step_complete([cauchy_file, cauchy_js]):
346
+ logger.info("Cauchy results already collected. Skipping.")
347
+ return
348
+
349
+ _collect_cauchy_results(self.report_config, self.traits, self.report_data_dir)
350
+ _export_cauchy_js(self.report_data_dir, self.js_data_dir)
351
+
352
+ def prepare_umap_data(self):
353
+ """Prepare UMAP data from embeddings."""
354
+ umap_file = self.report_data_dir / "umap_data.csv"
355
+ umap_js = self.js_data_dir / "umap_data.js"
356
+ if self._is_step_complete([umap_file,umap_js]):
357
+ logger.info("UMAP data already exists. Skipping.")
358
+ return
359
+
360
+ self.umap_info = _prepare_umap_data(self.report_config, self.metadata, self.report_data_dir)
361
+ _export_umap_js(self.report_data_dir, self.js_data_dir, {'umap_info': self.umap_info})
362
+
363
+ def prepare_3d_visualization(self):
364
+ """Prepare 3D visualization if applicable."""
365
+ if not self.is_3d:
366
+ self.spatial_3d_html = None
367
+ return
368
+
369
+ self.spatial_3d_html = _prepare_3d_visualization(self.report_config, self.metadata, self.traits, self.report_data_dir, self.web_report_dir)
370
+
371
+ def finalize_report(self):
372
+ """Save report metadata and final JS assets."""
373
+ # 1. Save standard report metadata
374
+ _save_report_metadata(
375
+ self.report_config, self.traits, self.sample_names, self.web_report_dir,
376
+ getattr(self, 'chrom_tick_positions', None),
377
+ getattr(self, 'umap_info', None),
378
+ is_3d=self.is_3d,
379
+ spatial_3d_path=getattr(self, 'spatial_3d_html', None)
380
+ )
381
+
382
+ # 2. Copy JS libraries
383
+ _copy_js_assets(self.web_report_dir)
384
+
385
+ # 3. Export per-sample spatial JS (highly important for performance)
386
+ with open(self.web_report_dir / "report_meta.json") as f:
387
+ meta = json.load(f)
388
+ _export_per_sample_spatial_js(self.report_data_dir, self.js_data_dir, meta)
389
+ _export_report_meta_js(self.web_report_dir, self.js_data_dir, meta)
390
+
391
+
392
+
393
+ def _render_single_sample_gene_plot_task(task_data: dict):
394
+ """
395
+ Parallel task to render single-sample Expression or GSS plot.
396
+ Creates a single plot for one gene in one sample.
397
+ """
398
+ try:
399
+ import gc
400
+ from pathlib import Path
401
+
402
+ import matplotlib
403
+ import matplotlib.colors as mcolors
404
+ import matplotlib.pyplot as plt
405
+ import numpy as np
406
+
407
+ matplotlib.use('Agg')
408
+
409
+ gene = task_data['gene']
410
+ sample_name = task_data['sample_name']
411
+ plot_type = task_data['plot_type']
412
+ coords = task_data['coords']
413
+ values = task_data['values']
414
+ output_path = Path(task_data['output_path'])
415
+ fig_width = task_data.get('fig_width', 6.0)
416
+ dpi = task_data.get('dpi', 150)
417
+
418
+ if coords is None or values is None or len(coords) == 0:
419
+ return f"No data for {gene} in {sample_name}"
420
+
421
+ # Custom colormap
422
+ custom_colors = [
423
+ '#313695', '#4575b4', '#74add1', '#abd9e9', '#e0f3f8',
424
+ '#fee090', '#fdae61', '#f46d43', '#d73027'
425
+ ]
426
+ custom_cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', custom_colors)
427
+
428
+ fig, ax = plt.subplots(figsize=(fig_width, fig_width))
429
+
430
+ # Calculate point size based on data density
431
+ from gsMap.report.visualize import estimate_matplotlib_scatter_marker_size
432
+ point_size = estimate_matplotlib_scatter_marker_size(ax, coords)
433
+ point_size = min(max(point_size, 1), 200)
434
+
435
+ # Color scale
436
+ vmin = 0
437
+ with np.errstate(all='ignore'):
438
+ non_nan_values = values[np.isfinite(values)] if len(values) > 0 else np.array([])
439
+ vmax = np.percentile(non_nan_values, 99.5) if len(non_nan_values) > 0 else 1.0
440
+
441
+ if not np.isfinite(vmax) or vmax <= vmin:
442
+ vmax = vmin + 1.0
443
+
444
+ scatter = ax.scatter(
445
+ coords[:, 0], coords[:, 1],
446
+ c=values, cmap=custom_cmap,
447
+ s=point_size, vmin=vmin, vmax=vmax,
448
+ marker='o', edgecolors='none', rasterized=True
449
+ )
450
+
451
+ if task_data.get('plot_origin', 'upper') == 'upper':
452
+ ax.invert_yaxis()
453
+
454
+ ax.axis('off')
455
+ ax.set_aspect('equal')
456
+
457
+ # Add colorbar
458
+ cbar = fig.colorbar(scatter, ax=ax, fraction=0.046, pad=0.04)
459
+ cbar.set_label('Expression' if plot_type == 'exp' else 'GSS', fontsize=10)
460
+
461
+ # Title
462
+ title_text = f"{gene} - {'Expression' if plot_type == 'exp' else 'GSS'}"
463
+ ax.set_title(title_text, fontsize=12, fontweight='bold')
464
+
465
+ output_path.parent.mkdir(parents=True, exist_ok=True)
466
+ plt.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
467
+ plt.close(fig)
468
+ gc.collect()
469
+
470
+ return True
471
+ except Exception as e:
472
+ import traceback
473
+ return f"{str(e)}\n{traceback.format_exc()}"
474
+
475
+
476
+ # =============================================================================
477
+ # UMAP Calculation Functions
478
+ # =============================================================================
479
+
480
+ def _detect_z_axis(coords_3d: np.ndarray, sample_names: pd.Series) -> int:
481
+ """
482
+ Detect which dimension is the Z axis (stacking dimension) in 3D coordinates.
483
+
484
+ The Z axis is identified as the dimension with the least within-sample variance,
485
+ since samples are stacked along Z and should have minimal variation in that dimension.
486
+
487
+ Args:
488
+ coords_3d: 3D coordinates array (n_spots, 3)
489
+ sample_names: Series of sample names for each spot
490
+
491
+ Returns:
492
+ Index of the Z axis (0, 1, or 2)
493
+ """
494
+ # Calculate within-sample variance for each dimension
495
+ within_sample_vars = []
496
+
497
+ for dim in range(3):
498
+ dim_values = coords_3d[:, dim]
499
+
500
+ # Calculate variance within each sample
501
+ sample_vars = []
502
+ for sample in sample_names.unique():
503
+ mask = sample_names == sample
504
+ sample_values = dim_values[mask]
505
+ if len(sample_values) > 1:
506
+ sample_vars.append(np.var(sample_values))
507
+
508
+ # Average within-sample variance for this dimension
509
+ avg_within_var = np.mean(sample_vars) if sample_vars else 0
510
+ within_sample_vars.append(avg_within_var)
511
+
512
+ # Z axis has the least within-sample variance
513
+ z_axis = int(np.argmin(within_sample_vars))
514
+ logger.info(f"Detected Z axis: dimension {z_axis} (within-sample variances: {within_sample_vars})")
515
+
516
+ return z_axis
517
+
518
+
519
+ def _reorder_coords_for_3d(coords_3d: np.ndarray, z_axis: int) -> tuple[np.ndarray, list[str]]:
520
+ """
521
+ Reorder 3D coordinates so that Z axis is the last dimension.
522
+
523
+ Returns:
524
+ Tuple of (reordered coords, column names)
525
+ """
526
+ # Create axis order: put z_axis last
527
+ axis_order = [i for i in range(3) if i != z_axis] + [z_axis]
528
+ reordered = coords_3d[:, axis_order]
529
+
530
+ # Column names reflect the reordering
531
+ col_names = ['3d_x', '3d_y', '3d_z']
532
+
533
+ return reordered, col_names
534
+
535
+ def _stratified_subsample(
536
+ spot_names: np.ndarray,
537
+ sample_names: pd.Series,
538
+ n_samples: int,
539
+ random_state: int = 42
540
+ ) -> np.ndarray:
541
+ """
542
+ Stratified subsampling that ensures representation from all samples.
543
+
544
+ Args:
545
+ spot_names: Array of spot identifiers
546
+ sample_names: Series mapping spot names to sample names
547
+ n_samples: Target number of samples
548
+ random_state: Random seed for reproducibility
549
+
550
+ Returns:
551
+ Array of selected spot names
552
+ """
553
+ np.random.seed(random_state)
554
+
555
+ # Get sample counts
556
+ sample_counts = sample_names.value_counts()
557
+ n_samples_total = len(spot_names)
558
+
559
+ if n_samples >= n_samples_total:
560
+ return spot_names
561
+
562
+ # Calculate proportional samples per group
563
+ selected_spots = []
564
+ for sample_name, count in sample_counts.items():
565
+ sample_spots = spot_names[sample_names == sample_name]
566
+ # Proportional allocation
567
+ n_select = max(1, int(np.ceil(n_samples * count / n_samples_total)))
568
+ n_select = min(n_select, len(sample_spots))
569
+
570
+ selected = np.random.choice(sample_spots, n_select, replace=False)
571
+ selected_spots.extend(selected)
572
+
573
+ # If we have too many, randomly remove some
574
+ selected_spots = np.array(selected_spots)
575
+ if len(selected_spots) > n_samples:
576
+ selected_spots = np.random.choice(selected_spots, n_samples, replace=False)
577
+
578
+ return selected_spots
579
+
580
+
581
+ def _calculate_umap_from_embeddings(
582
+ adata: ad.AnnData,
583
+ embedding_key: str,
584
+ ) -> np.ndarray:
585
+ logger.info(f"Calculating UMAP for {embedding_key} with {adata.n_obs} spots using scanpy...")
586
+
587
+ sc.pp.neighbors(adata, use_rep=embedding_key)
588
+ sc.tl.umap(adata)
589
+
590
+ return adata.obsm['X_umap']
591
+
592
+
593
+ def _prepare_umap_data(
594
+ report_config: QuickModeConfig,
595
+ metadata: pd.DataFrame,
596
+ report_dir: Path
597
+ ) -> dict | None:
598
+ """
599
+ Prepare UMAP data from cell and niche embeddings.
600
+
601
+ Returns dict with umap_cell, umap_niche (optional), and metadata for visualization.
602
+ """
603
+ logger.info("Preparing UMAP data from embeddings...")
604
+
605
+ # Load concatenated adata
606
+ adata_path = report_config.concatenated_latent_adata_path
607
+ if not adata_path.exists():
608
+ logger.warning(f"Concatenated adata not found at {adata_path}")
609
+ return None
610
+
611
+ adata = ad.read_h5ad(adata_path)
612
+
613
+ # Get embedding keys from config
614
+ cell_emb_key = getattr(report_config, 'latent_representation_cell', 'emb_cell')
615
+ niche_emb_key = getattr(report_config, 'latent_representation_niche', None)
616
+
617
+ # Check if embeddings exist
618
+ if cell_emb_key not in adata.obsm:
619
+ logger.warning(f"Cell embedding '{cell_emb_key}' not found in adata.obsm")
620
+ return None
621
+
622
+ has_niche = niche_emb_key is not None and niche_emb_key in adata.obsm
623
+
624
+ # Stratified subsampling
625
+ n_subsample = getattr(report_config, 'downsampling_n_spots', 20000)
626
+ spot_names = adata.obs_names.values
627
+ sample_names = adata.obs['sample_name']
628
+
629
+ if len(spot_names) > n_subsample:
630
+ logger.info(f"Stratified subsampling from {len(spot_names)} to {n_subsample} spots...")
631
+ selected_spots = _stratified_subsample(spot_names, sample_names, n_subsample)
632
+ adata_sub = adata[selected_spots].copy()
633
+ else:
634
+ adata_sub = adata.copy()
635
+ selected_spots = spot_names
636
+
637
+ logger.info(f"Using {len(selected_spots)} spots for UMAP calculation")
638
+
639
+ # Calculate UMAPs
640
+ umap_cell = _calculate_umap_from_embeddings(adata_sub, cell_emb_key)
641
+
642
+ # Estimate point size for UMAP cell
643
+ _, point_size_cell = estimate_plotly_point_size(umap_cell, DEFAULT_PIXEL_WIDTH=600)
644
+
645
+ umap_niche = None
646
+ point_size_niche = None
647
+ if has_niche:
648
+ umap_niche = _calculate_umap_from_embeddings(adata_sub, niche_emb_key)
649
+ _, point_size_niche = estimate_plotly_point_size(umap_niche, DEFAULT_PIXEL_WIDTH=600)
650
+
651
+ # Prepare metadata for the subsampled spots
652
+ umap_metadata = pd.DataFrame({
653
+ 'spot': adata_sub.obs_names,
654
+ 'sample_name': adata_sub.obs['sample_name'].values,
655
+ 'umap_cell_x': umap_cell[:, 0],
656
+ 'umap_cell_y': umap_cell[:, 1],
657
+ })
658
+
659
+ if umap_niche is not None:
660
+ umap_metadata['umap_niche_x'] = umap_niche[:, 0]
661
+ umap_metadata['umap_niche_y'] = umap_niche[:, 1]
662
+
663
+ # Add all annotation columns
664
+ for anno in report_config.annotation_list:
665
+ if anno in adata_sub.obs.columns:
666
+ # Fill NaN with 'NaN' and convert to string to avoid sorting errors
667
+ umap_metadata[anno] = adata_sub.obs[anno].astype(str).fillna('NaN').values
668
+
669
+ # Add trait -log10(p) values from metadata if available (vectorized join)
670
+ traits = report_config.trait_name_list
671
+ available_traits = [t for t in traits if t in metadata.columns]
672
+ if available_traits:
673
+
674
+ # Prepare trait data: ensure 'spot' is a column and is string type
675
+ trait_data = metadata[available_traits].copy()
676
+ trait_data['spot'] = metadata['spot'].astype(str)
677
+ umap_metadata = umap_metadata.merge(trait_data, on='spot', how='left')
678
+
679
+ # Keep 1 decimal precision for trait values and handle non-numeric data
680
+ for trait in available_traits:
681
+ if trait in umap_metadata.columns:
682
+ # Convert to numeric in case values were strings, then round
683
+ umap_metadata[trait] = pd.to_numeric(umap_metadata[trait], errors='coerce').round(1)
684
+ logger.info(f"Added trait {trait} to UMAP data with 1 decimal precision")
685
+
686
+ # Save to CSV
687
+ umap_metadata.to_csv(report_dir / "umap_data.csv", index=False)
688
+ logger.info(f"UMAP data saved with {len(umap_metadata)} points")
689
+
690
+ return {
691
+ 'has_niche': has_niche,
692
+ 'cell_emb_key': cell_emb_key,
693
+ 'niche_emb_key': niche_emb_key,
694
+ 'n_points': len(umap_metadata),
695
+ 'point_size_cell': float(point_size_cell),
696
+ 'point_size_niche': float(point_size_niche) if point_size_niche is not None else None,
697
+ 'default_opacity': 0.7
698
+ }
699
+
700
+
701
+
702
+ def _prepare_3d_visualization(
703
+ report_config: QuickModeConfig,
704
+ metadata: pd.DataFrame,
705
+ traits: list[str],
706
+ report_data_dir: Path,
707
+ web_report_dir: Path
708
+ ) -> str | None:
709
+ """
710
+ Create 3D visualization widget using spatialvista and save as HTML.
711
+
712
+ Args:
713
+ report_config: QuickModeConfig with annotation_list and other settings
714
+ metadata: DataFrame with 3d_x, 3d_y, 3d_z coordinates and annotations
715
+ traits: List of trait names (for continuous values)
716
+ report_data_dir: Directory for data files (h5ad)
717
+ web_report_dir: Directory for web report files (HTML)
718
+
719
+ Returns:
720
+ Path to saved HTML file, or None if failed
721
+ """
722
+ try:
723
+ import pyvista
724
+ except ImportError:
725
+ logger.warning("pyvista not installed. Skipping 3D visualization.")
726
+ return None
727
+
728
+ # Check if we have 3D coordinates
729
+ if '3d_x' not in metadata.columns or '3d_y' not in metadata.columns or '3d_z' not in metadata.columns:
730
+ logger.warning("3D coordinates not found in metadata. Skipping 3D visualization.")
731
+ return None
732
+
733
+ logger.info("Creating 3D visualization...")
734
+
735
+ # 1. Create adata_vis using ALL spots
736
+ n_spots_all = len(metadata)
737
+ adata_vis_full = ad.AnnData(
738
+ X=sp.csr_matrix((n_spots_all, 1), dtype=np.float32),
739
+ obs=metadata.set_index('spot')
740
+ )
741
+ # Set coordinates
742
+ adata_vis_full.obsm['spatial_3d'] = metadata[['3d_x', '3d_y', '3d_z']].values
743
+ adata_vis_full.obsm['spatial'] = adata_vis_full.obsm['spatial_3d']
744
+ if 'sx' in metadata.columns and 'sy' in metadata.columns:
745
+ adata_vis_full.obsm['spatial_2d'] = metadata[['sx', 'sy']].values
746
+
747
+ # Ensure trait columns are float32
748
+ adata_vis_full.obs = adata_vis_full.obs.astype({
749
+ trait: np.float32 for trait in traits if trait in adata_vis_full.obs.columns
750
+ })
751
+
752
+ # Create 3D visualization directories
753
+ three_d_data_dir = report_data_dir / "spatial_3d"
754
+ three_d_data_dir.mkdir(exist_ok=True)
755
+ three_d_web_dir = web_report_dir / "spatial_3d"
756
+ three_d_web_dir.mkdir(exist_ok=True)
757
+
758
+ # 2. Save the full adata_vis to data directory
759
+ h5ad_path = three_d_data_dir / "spatial_3d.h5ad"
760
+ adata_vis_full.write_h5ad(h5ad_path)
761
+ logger.info(f"Full 3D visualization data saved to {h5ad_path}")
762
+
763
+ # 3. Stratified subsampling for HTML visualization (limit to reasonable size)
764
+ n_max_points = getattr(report_config, 'downsampling_n_spots_3d', 1000000)
765
+ if len(metadata) > n_max_points:
766
+ sample_names = metadata['sample_name']
767
+ selected_idx = _stratified_subsample(
768
+ metadata.index.values, sample_names, n_max_points
769
+ )
770
+ adata_vis = adata_vis_full[selected_idx].copy()
771
+ logger.info(f"Subsampled to {len(adata_vis)} spots for HTML 3D visualization")
772
+ else:
773
+ adata_vis = adata_vis_full.copy()
774
+
775
+ # Add traits as continuous values
776
+ continuous_cols = traits
777
+
778
+ # Add annotations (categorical)
779
+ annotation_cols = []
780
+ for anno in report_config.annotation_list:
781
+ if anno in adata_vis.obs.columns:
782
+ annotation_cols.append(anno)
783
+
784
+ logger.info(f"3D visualization: annotations={annotation_cols}, traits={continuous_cols}")
785
+
786
+ # Create 3D plots
787
+ try:
788
+ from gsMap.report.visualize import _create_color_map
789
+
790
+ # P-value color scale (Red-Blue)
791
+ P_COLOR = ['#313695', '#4575b4', '#74add1', '#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
792
+
793
+ # Calculate appropriate point size based on coverage ratio (20-30%)
794
+ # Formula: S = sqrt(W * H * k / N)
795
+ win_w, win_h = 1200, 1008
796
+ k_coverage = 0.25
797
+ n_points = len(adata_vis)
798
+ point_size = np.sqrt((win_w * win_h * k_coverage) / n_points)
799
+ point_size = max(0.5, min(5.0, point_size)) # Clamp between 0.5 and 5.0
800
+ logger.info(f"Estimated 3D point size: {point_size:.2f} for {n_points} spots")
801
+
802
+ # Shared plotting settings
803
+ text_kwargs = dict(text_font_size=15, text_loc="upper_edge")
804
+
805
+ # 1. Save all Annotation 3D plots
806
+ categorical_legend_kwargs = dict(categorical_legend_loc="center right")
807
+ for anno in annotation_cols:
808
+ logger.info(f"Generating 3D plot for annotation: {anno}")
809
+
810
+ # Use same colormap logic as distribution plots
811
+ if pd.api.types.is_numeric_dtype(adata_vis.obs[anno]):
812
+ color_map = P_COLOR
813
+ else:
814
+ unique_vals = adata_vis.obs[anno].unique()
815
+ color_map = _create_color_map(unique_vals, hex=True, rng=42)
816
+
817
+ safe_anno = "".join(c if c.isalnum() else "_" for c in anno)
818
+ anno_plot_name = three_d_web_dir / f"spatial_3d_anno_{safe_anno}"
819
+
820
+ plotter_anno = three_d_plot(
821
+ adata=adata_vis,
822
+ spatial_key='spatial',
823
+ keys=[anno],
824
+ cmaps=[color_map],
825
+ point_size=point_size,
826
+ texts=[anno],
827
+ off_screen=True,
828
+ jupyter=False,
829
+ background='white',
830
+ legend_kwargs=categorical_legend_kwargs,
831
+ text_kwargs=text_kwargs,
832
+ window_size=(win_w, win_h)
833
+ )
834
+ three_d_plot_save(plotter_anno, filename=str(anno_plot_name))
835
+
836
+ # 2. Save each Trait 3D plot
837
+ legend_kwargs = dict(scalar_bar_title_size=30, scalar_bar_label_size=30, fmt="%.1e")
838
+
839
+ for trait in continuous_cols:
840
+ logger.info(f"Generating 3D plot for trait: {trait}")
841
+
842
+ # Calculate opacity based on LogP (exponential scaling)
843
+ trait_values = adata_vis.obs[trait].fillna(0).values
844
+ bins = np.linspace(trait_values.min(), trait_values.max(), 5)
845
+ # Handle case where min == max to avoid bins error
846
+ if bins[0] == bins[-1]:
847
+ opacity_show = 1.0
848
+ else:
849
+ alpha = np.exp(np.linspace(0.1, 1.0, num=(len(bins)-1))) - 1
850
+ alpha = alpha / max(alpha)
851
+ opacity_show = pd.cut(trait_values, bins=bins, labels=alpha, include_lowest=True).tolist()
852
+
853
+ # Set the clim (median of top 20 spots)
854
+ sorted_vals = np.sort(trait_values)[::-1]
855
+ max_v = np.round(np.median(sorted_vals[0:20])) if len(sorted_vals) >= 20 else sorted_vals[0]
856
+ if max_v <= 0: max_v = 1.0
857
+
858
+ safe_trait = "".join(c if c.isalnum() else "_" for c in trait)
859
+ trait_plot_name = three_d_web_dir / f"spatial_3d_trait_{safe_trait}"
860
+
861
+ plotter_trait = three_d_plot(
862
+ adata=adata_vis,
863
+ spatial_key='spatial',
864
+ keys=[trait],
865
+ cmaps=[P_COLOR],
866
+ point_size=point_size,
867
+ opacity=opacity_show,
868
+ clim=[0, max_v],
869
+ scalar_bar_titles=["-log10(p)"],
870
+ texts=[trait],
871
+ off_screen=True,
872
+ jupyter=False,
873
+ background='white',
874
+ legend_kwargs=legend_kwargs,
875
+ text_kwargs=text_kwargs,
876
+ window_size=(win_w,win_h)
877
+ )
878
+ three_d_plot_save(plotter_trait, filename=str(trait_plot_name))
879
+
880
+ # Return the relative path of the first available plot
881
+ if annotation_cols:
882
+ safe_first = "".join(c if c.isalnum() else "_" for c in annotation_cols[0])
883
+ return f"spatial_3d/spatial_3d_anno_{safe_first}.html"
884
+ elif continuous_cols:
885
+ safe_first_trait = "".join(c if c.isalnum() else "_" for c in continuous_cols[0])
886
+ return f"spatial_3d/spatial_3d_trait_{safe_first_trait}.html"
887
+
888
+ return None
889
+
890
+ except Exception as e:
891
+ logger.warning(f"Failed to create 3D visualization: {e}")
892
+ import traceback
893
+ traceback.print_exc()
894
+ return None
895
+
896
+
897
+ # =============================================================================
898
+ # Data Loading Functions
899
+ # =============================================================================
900
+
901
+ def _load_ldsc_results(report_config: QuickModeConfig) -> tuple[pd.DataFrame, list[str], list[str]]:
902
+ """Load combined LDSC results and extract traits/samples."""
903
+ logger.info(f"Loading combined LDSC results from {report_config.ldsc_combined_parquet_path}")
904
+
905
+ if not report_config.ldsc_combined_parquet_path.exists():
906
+ raise FileNotFoundError(
907
+ f"Combined LDSC parquet not found at {report_config.ldsc_combined_parquet_path}. "
908
+ "Please run Cauchy combination first."
909
+ )
910
+
911
+
912
+ ldsc_results = pd.read_parquet(report_config.ldsc_combined_parquet_path)
913
+ if 'spot' in ldsc_results.columns:
914
+ ldsc_results.set_index('spot', inplace=True)
915
+
916
+ assert 'sample_name' in ldsc_results.columns, "LDSC combined results must have 'sample_name' column."
917
+
918
+ traits = report_config.trait_name_list
919
+
920
+ # Explicitly use sample order from report_config.sample_h5ad_dict
921
+ sample_names = list(report_config.sample_h5ad_dict.keys())
922
+ actual_samples_in_data = set(ldsc_results['sample_name'].unique())
923
+
924
+ # Assert all samples in config are actually in the data
925
+ missing = [s for s in sample_names if s not in actual_samples_in_data]
926
+ assert not missing, f"Samples {missing} in config are not present in LDSC results."
927
+
928
+ return ldsc_results, traits, sample_names
929
+
930
+
931
+ def _load_coordinates(report_config: QuickModeConfig) -> tuple[pd.DataFrame, bool, int | None]:
932
+ """
933
+ Load spatial coordinates from concatenated adata.
934
+
935
+ Returns:
936
+ Tuple of (coords DataFrame, is_3d flag, z_axis index if 3D)
937
+ """
938
+ from gsMap.config import DatasetType
939
+
940
+ logger.info(f"Loading coordinates from {report_config.concatenated_latent_adata_path}")
941
+
942
+ adata_concat = ad.read_h5ad(report_config.concatenated_latent_adata_path, backed='r')
943
+
944
+ assert report_config.spatial_key in adata_concat.obsm
945
+ coords_data = adata_concat.obsm[report_config.spatial_key]
946
+ sample_info = adata_concat.obs[['sample_name']]
947
+
948
+ # Check if dataset type is 3D
949
+ is_3d_type = (
950
+ hasattr(report_config, 'dataset_type') and
951
+ report_config.dataset_type == DatasetType.SPATIAL_3D
952
+ )
953
+ has_3d_coords = coords_data.shape[1] >= 3
954
+
955
+ z_axis = None
956
+ if is_3d_type and has_3d_coords:
957
+ # True 3D coordinates
958
+ logger.info("Detected 3D spatial coordinates")
959
+ coords_3d = coords_data[:, :3]
960
+
961
+ # Detect Z axis
962
+ z_axis = _detect_z_axis(coords_3d, adata_concat.obs['sample_name'])
963
+
964
+ # Reorder coordinates so Z is last
965
+ reordered_coords, col_names = _reorder_coords_for_3d(coords_3d, z_axis)
966
+ coords = pd.DataFrame(reordered_coords, columns=col_names, index=adata_concat.obs_names)
967
+
968
+ # Also keep 2D coords for compatibility (use x and y from reordered)
969
+ coords['sx'] = coords['3d_x']
970
+ coords['sy'] = coords['3d_y']
971
+
972
+ elif is_3d_type and not has_3d_coords:
973
+ # 3D dataset type but only 2D coordinates - create pseudo Z axis
974
+ logger.info("3D dataset type with 2D coordinates - creating pseudo Z axis based on sample_name")
975
+ coords_2d = coords_data[:, :2]
976
+ coords = pd.DataFrame(coords_2d, columns=['sx', 'sy'], index=adata_concat.obs_names)
977
+
978
+ # Calculate the span for pseudo Z axis
979
+ sx_span = coords['sx'].max() - coords['sx'].min()
980
+ sy_span = coords['sy'].max() - coords['sy'].min()
981
+ z_span = max(sx_span, sy_span)
982
+
983
+ # Assign Z values based on sample order from report_config
984
+ sample_names_all = adata_concat.obs['sample_name']
985
+ actual_samples_in_data = set(sample_names_all.unique())
986
+
987
+ # Explicitly use report_config order
988
+ ordered_samples = list(report_config.sample_h5ad_dict.keys())
989
+ missing = [s for s in ordered_samples if s not in actual_samples_in_data]
990
+ assert not missing, f"Samples {missing} in config are not present in coordinates data."
991
+
992
+ n_samples = len(ordered_samples)
993
+
994
+ # Create evenly spaced Z values for each sample in the specific order
995
+ if n_samples > 1:
996
+ z_values_per_sample = {
997
+ sample: z_span * i / (n_samples - 1)
998
+ for i, sample in enumerate(ordered_samples)
999
+ }
1000
+ else:
1001
+ z_values_per_sample = {ordered_samples[0]: 0.0}
1002
+
1003
+ # Assign Z values to each spot
1004
+ pseudo_z = sample_names_all.map(z_values_per_sample).values
1005
+
1006
+ coords['3d_x'] = coords['sx']
1007
+ coords['3d_y'] = coords['sy']
1008
+ coords['3d_z'] = pseudo_z
1009
+ z_axis = 2 # Pseudo Z is always the last dimension
1010
+
1011
+ logger.info(f"Created pseudo Z axis with span {z_span:.2f} for {n_samples} samples")
1012
+
1013
+ else:
1014
+ # 2D coordinates (non-3D dataset type)
1015
+ coords_2d = coords_data[:, :2]
1016
+ coords = pd.DataFrame(coords_2d, columns=['sx', 'sy'], index=adata_concat.obs_names)
1017
+
1018
+ coords = pd.concat([coords, sample_info], axis=1)
1019
+
1020
+ # Close backed file
1021
+ if adata_concat.isbacked:
1022
+ adata_concat.file.close()
1023
+
1024
+ # Return is_3d as True if dataset type is 3D (regardless of coord dimensions)
1025
+ return coords, is_3d_type, z_axis
1026
+
1027
+
1028
+ def _load_gss_and_calculate_stats_base(
1029
+ report_config: QuickModeConfig,
1030
+ ldsc_results: pd.DataFrame,
1031
+ coords: pd.DataFrame,
1032
+ report_dir: Path
1033
+ ) -> tuple[np.ndarray, ad.AnnData, pd.DataFrame, np.ndarray]:
1034
+ """Load GSS data, calculate general stats, and return analysis results."""
1035
+ logger.info("Loading GSS data...")
1036
+
1037
+ gss_adata = load_marker_scores_memmap_format(report_config)
1038
+ common_spots = np.intersect1d(gss_adata.obs_names, ldsc_results.index)
1039
+ logger.info(f"Common spots (gss & ldsc): {len(common_spots)}")
1040
+ assert len(common_spots) > 0, "No common spots found between GSS and LDSC results."
1041
+
1042
+ # Stratified subsample if requested
1043
+ analysis_spots = common_spots
1044
+ if report_config.downsampling_n_spots_pcc and len(common_spots) > report_config.downsampling_n_spots_pcc:
1045
+ sample_names = ldsc_results.loc[common_spots, 'sample_name']
1046
+ analysis_spots = _stratified_subsample(
1047
+ common_spots, sample_names, report_config.downsampling_n_spots_pcc
1048
+ )
1049
+ logger.info(f"Stratified subsampled to {len(analysis_spots)} spots for PCC calculation.")
1050
+
1051
+ # Filter to high expression genes
1052
+ exp_frac = pd.read_parquet(report_config.mean_frac_path)
1053
+ high_expr_genes = exp_frac[exp_frac['frac'] > 0.01].index.tolist()
1054
+ logger.info(f"Using {len(high_expr_genes)} high expression genes for PCC calculation.")
1055
+
1056
+ gss_adata_sub = gss_adata[analysis_spots, high_expr_genes]
1057
+ gss_matrix = gss_adata_sub.X
1058
+ gene_names = gss_adata_sub.var_names.tolist()
1059
+ pd.DataFrame({'gene': gene_names}).to_csv(report_dir / "gene_list.csv", index=False)
1060
+
1061
+ # Calculate gene annotation stats
1062
+ adata_concat = ad.read_h5ad(report_config.concatenated_latent_adata_path, backed='r')
1063
+ anno_col = report_config.annotation_list[0]
1064
+ annotations = adata_concat.obs.loc[analysis_spots, anno_col]
1065
+
1066
+ if hasattr(gss_matrix, 'toarray'):
1067
+ gss_matrix = gss_matrix.toarray()
1068
+
1069
+ gss_df_temp = pd.DataFrame(gss_matrix, index=analysis_spots, columns=gene_names)
1070
+ grouped_gss = gss_df_temp.groupby(annotations).median()
1071
+
1072
+ gene_stats = pd.DataFrame({
1073
+ 'gene': grouped_gss.idxmax().index,
1074
+ 'Annotation': grouped_gss.idxmax().values,
1075
+ 'Median_GSS': grouped_gss.max().values
1076
+ })
1077
+ gene_stats.dropna(subset=['Median_GSS'], inplace=True)
1078
+
1079
+ return common_spots, gss_adata, gene_stats, analysis_spots
1080
+
1081
+
1082
+ def _calculate_pcc_for_single_trait_fast(
1083
+ trait: str,
1084
+ ldsc_results: pd.DataFrame,
1085
+ analysis_spots: np.ndarray,
1086
+ gss_centered: np.ndarray,
1087
+ gss_ssq: np.ndarray,
1088
+ gene_names: list[str],
1089
+ gene_stats: pd.DataFrame,
1090
+ report_config: QuickModeConfig,
1091
+ report_dir: Path,
1092
+ gss_dir: Path
1093
+ ) -> pd.DataFrame | None:
1094
+ """Calculate PCC for a single trait using pre-calculated centered GSS matrix."""
1095
+ if trait not in ldsc_results.columns:
1096
+ logger.warning(f"Trait {trait} not found in LDSC combined results. Skipping PCC calculation.")
1097
+ return None
1098
+
1099
+ def fast_corr(centered_matrix, ssq_matrix, vector):
1100
+ v_centered = vector - vector.mean()
1101
+ numerator = np.dot(v_centered, centered_matrix)
1102
+ denominator = np.sqrt(np.sum(v_centered ** 2) * ssq_matrix)
1103
+ return numerator / (denominator + 1e-12)
1104
+
1105
+ logger.info(f"Processing PCC for trait: {trait}")
1106
+
1107
+ logp_vec = ldsc_results.loc[analysis_spots, trait].values.astype(np.float32)
1108
+ assert not np.any(np.isnan(logp_vec)), f"NaN values found in LDSC results for trait {trait}."
1109
+ pccs = fast_corr(gss_centered, gss_ssq, logp_vec)
1110
+
1111
+ trait_pcc = pd.DataFrame({
1112
+ 'gene': gene_names,
1113
+ 'PCC': pccs,
1114
+ 'trait': trait
1115
+ })
1116
+
1117
+ if gene_stats is not None:
1118
+ trait_pcc = trait_pcc.merge(gene_stats, on='gene', how='left')
1119
+
1120
+ trait_pcc_sorted = trait_pcc.sort_values('PCC', ascending=False)
1121
+
1122
+ # Save to gss_stats subfolder
1123
+ # Save result CSV
1124
+ diag_info_path = report_config.get_gene_diagnostic_info_save_path(trait)
1125
+ trait_pcc_sorted.to_csv(diag_info_path, index=False)
1126
+
1127
+ return trait_pcc_sorted
1128
+
1129
+
1130
+ def _export_single_manhattan_js(trait: str, csv_path: Path, js_data_dir: Path):
1131
+ """Export single trait Manhattan data to JS."""
1132
+ try:
1133
+ df = pd.read_csv(csv_path)
1134
+ if 'P' in df.columns and 'logp' not in df.columns:
1135
+ df['logp'] = -np.log10(df['P'] + 1e-300)
1136
+
1137
+ data_struct = {
1138
+ 'x': df['BP_cum'].tolist(),
1139
+ 'y': df['logp'].tolist(),
1140
+ 'gene': df['GENE'].fillna("").tolist(),
1141
+ 'chr': df['CHR'].astype(int).tolist(),
1142
+ 'snp': df['SNP'].tolist() if 'SNP' in df.columns else [],
1143
+ 'is_top': df['is_top_pcc'].astype(int).tolist() if 'is_top_pcc' in df.columns else [],
1144
+ 'bp': df['BP'].astype(int).tolist() if 'BP' in df.columns else []
1145
+ }
1146
+
1147
+ json_str = json.dumps(data_struct, separators=(',', ':'))
1148
+ safe_trait = "".join(c if c.isalnum() else "_" for c in trait)
1149
+ js_content = f"window.GSMAP_MANHATTAN_{safe_trait} = {json_str};"
1150
+
1151
+ with open(js_data_dir / f"manhattan_{trait}.js", "w", encoding='utf-8') as f:
1152
+ f.write(js_content)
1153
+ except Exception as e:
1154
+ logger.warning(f"Failed to export Manhattan JS for {trait}: {e}")
1155
+
1156
+
1157
+ # =============================================================================
1158
+ # Data Preparation Functions
1159
+ # =============================================================================
1160
+
1161
+ def _save_metadata(
1162
+ ldsc_results: pd.DataFrame,
1163
+ coords: pd.DataFrame,
1164
+ report_dir: Path
1165
+ ) -> pd.DataFrame:
1166
+ """Save metadata and coordinates to CSV."""
1167
+ logger.info("Saving metadata and coordinates...")
1168
+
1169
+ common_indices = ldsc_results.index.intersection(coords.index)
1170
+ ldsc_subset = ldsc_results.loc[common_indices]
1171
+ cols_to_use = ldsc_subset.columns.difference(coords.columns)
1172
+ metadata = pd.concat([coords.loc[common_indices], ldsc_subset[cols_to_use]], axis=1)
1173
+
1174
+ metadata.index.name = 'spot'
1175
+ metadata = metadata.reset_index()
1176
+ metadata = metadata.loc[:, ~metadata.columns.duplicated()]
1177
+ metadata.to_csv(report_dir / "spot_metadata.csv", index=False)
1178
+
1179
+ return metadata
1180
+
1181
+
1182
+
1183
+
1184
+ def _prepare_manhattan_for_trait(
1185
+ report_config: QuickModeConfig,
1186
+ trait: str,
1187
+ weight_adata: ad.AnnData,
1188
+ trait_pcc_df: pd.DataFrame | None,
1189
+ gene_names_ref: list[str],
1190
+ manhattan_dir: Path
1191
+ ) -> dict:
1192
+ """Prepare Manhattan data for a single trait."""
1193
+ from gsMap.utils.manhattan_plot import _ManhattanPlot
1194
+
1195
+ sumstats_file = report_config.sumstats_config_dict.get(trait)
1196
+ if not sumstats_file or not Path(sumstats_file).exists():
1197
+ return {}
1198
+
1199
+ logger.info(f"Processing Manhattan for {trait}...")
1200
+ gwas_data = load_gwas_data(sumstats_file)
1201
+
1202
+ common_snps = weight_adata.obs_names[weight_adata.obs_names.isin(gwas_data["SNP"])]
1203
+ gwas_subset = gwas_data.set_index("SNP").loc[common_snps].reset_index()
1204
+
1205
+ gwas_subset = gwas_subset.drop(columns=[c for c in ["CHR", "BP"] if c in gwas_subset.columns])
1206
+ gwas_subset = gwas_subset.set_index("SNP").join(weight_adata.obs[["CHR", "BP"]]).reset_index()
1207
+
1208
+ snps2plot_ids = filter_snps(gwas_subset.sort_values("P"), SUBSAMPLE_SNP_NUMBER=50000)
1209
+ gwas_plot_data = gwas_subset[gwas_subset["SNP"].isin(snps2plot_ids)].copy()
1210
+
1211
+ gwas_plot_data["CHR"] = pd.to_numeric(gwas_plot_data["CHR"], errors='coerce')
1212
+ gwas_plot_data["BP"] = pd.to_numeric(gwas_plot_data["BP"], errors='coerce')
1213
+ gwas_plot_data = gwas_plot_data.dropna(subset=["CHR", "BP"])
1214
+
1215
+ # Gene assignment
1216
+ target_genes = [g for g in weight_adata.var_names if g in gene_names_ref and g != "unmapped"]
1217
+ if target_genes:
1218
+ sub_weight = weight_adata[gwas_plot_data["SNP"], target_genes].to_memory()
1219
+ weights_matrix = sub_weight.X
1220
+
1221
+ if sp.issparse(weights_matrix):
1222
+ max_idx = np.array(weights_matrix.argmax(axis=1)).ravel()
1223
+ max_val = np.array(weights_matrix.max(axis=1).toarray()).ravel()
1224
+ else:
1225
+ max_idx = np.argmax(weights_matrix, axis=1)
1226
+ max_val = np.max(weights_matrix, axis=1)
1227
+
1228
+ gene_map = np.where(max_val > 1, np.array(target_genes)[max_idx], "None")
1229
+ gwas_plot_data["GENE"] = gene_map
1230
+
1231
+ if trait_pcc_df is not None:
1232
+ top_n = getattr(report_config, 'top_corr_genes', 50)
1233
+ trait_top_genes = trait_pcc_df.sort_values('PCC', ascending=False).head(top_n)['gene'].tolist()
1234
+ gwas_plot_data["is_top_pcc"] = gwas_plot_data["GENE"].isin(trait_top_genes)
1235
+ else:
1236
+ gwas_plot_data["is_top_pcc"] = False
1237
+
1238
+ if "GENE" not in gwas_plot_data.columns:
1239
+ gwas_plot_data["GENE"] = "None"
1240
+ if "is_top_pcc" not in gwas_plot_data.columns:
1241
+ gwas_plot_data["is_top_pcc"] = False
1242
+
1243
+ # Calculate cumulative positions
1244
+ chrom_ticks = {}
1245
+ try:
1246
+ mp_helper = _ManhattanPlot(gwas_plot_data)
1247
+ gwas_plot_data["BP_cum"] = mp_helper.data["POSITION"].values
1248
+ gwas_plot_data["CHR_INDEX"] = mp_helper.data["INDEX"].values
1249
+
1250
+ chrom_groups = gwas_plot_data.groupby("CHR")["BP_cum"]
1251
+ chrom_ticks = {int(chrom): float(positions.median()) for chrom, positions in chrom_groups}
1252
+ except Exception as e:
1253
+ logger.warning(f"Failed to calculate Manhattan coordinates: {e}")
1254
+ gwas_plot_data["BP_cum"] = np.arange(len(gwas_plot_data))
1255
+ gwas_plot_data["CHR_INDEX"] = gwas_plot_data["CHR"] % 2
1256
+
1257
+ gwas_plot_data.to_csv(manhattan_dir / f"{trait}_manhattan.csv", index=False)
1258
+ return chrom_ticks
1259
+
1260
+
1261
+ # =============================================================================
1262
+ # Plot Rendering Functions
1263
+ # =============================================================================
1264
+
1265
+
1266
+
1267
+ def _render_gene_diagnostic_plots_refactored(
1268
+ report_config: QuickModeConfig,
1269
+ metadata: pd.DataFrame,
1270
+ common_spots: np.ndarray,
1271
+ gss_adata: ad.AnnData,
1272
+ n_rows: int,
1273
+ n_cols: int,
1274
+ web_report_dir: Path,
1275
+ report_data_dir: Path,
1276
+ is_step_complete_func
1277
+ ):
1278
+ """Render gene expression and GSS diagnostic plots with completeness check."""
1279
+ gene_plot_dir = web_report_dir / "gene_diagnostic_plots"
1280
+ gene_plot_dir.mkdir(exist_ok=True)
1281
+
1282
+ if not report_config.sample_h5ad_dict:
1283
+ logger.warning("Skipping gene diagnostic plots: missing h5ad dict")
1284
+ return
1285
+
1286
+ top_n = report_config.top_corr_genes
1287
+ # 1. Identify all genes that need to be plotted from per-trait PCC files
1288
+ gss_stats_dir = report_data_dir / "gss_stats"
1289
+ trait_top_genes = {}
1290
+ all_top_genes_set = set()
1291
+
1292
+ # We need to know which traits we have
1293
+ traits = [p.name.replace("gene_trait_correlation_", "").replace(".csv", "")
1294
+ for p in gss_stats_dir.glob("gene_trait_correlation_*.csv")]
1295
+
1296
+ for trait in traits:
1297
+ pcc_path = report_config.get_gene_diagnostic_info_save_path(trait)
1298
+ if pcc_path.exists():
1299
+ group = pd.read_csv(pcc_path)
1300
+ genes = group.sort_values('PCC', ascending=False).head(top_n)['gene'].tolist()
1301
+ trait_top_genes[trait] = genes
1302
+ all_top_genes_set.update(genes)
1303
+
1304
+ # Also collect all available traits for completeness check
1305
+ sorted(all_top_genes_set)
1306
+ sample_names_sorted = list(report_config.sample_h5ad_dict.keys())
1307
+
1308
+ # Pre-filter metadata for common spots once
1309
+ metadata_common = metadata[metadata['spot'].isin(common_spots)]
1310
+
1311
+ logger.info("Checking completeness for diagnostic plots...")
1312
+
1313
+ # Filter out combinations that already exist
1314
+ tasks_to_run = []
1315
+
1316
+ for sample_name in sample_names_sorted:
1317
+ safe_sample = "".join(c if c.isalnum() else "_" for c in sample_name)
1318
+ for trait in traits:
1319
+ if trait not in trait_top_genes: continue
1320
+ for gene in trait_top_genes[trait]:
1321
+ exp_path = gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_exp.png"
1322
+ gss_path = gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_gss.png"
1323
+ if not is_step_complete_func([exp_path, gss_path]):
1324
+ tasks_to_run.append((sample_name, trait, gene))
1325
+
1326
+ if not tasks_to_run:
1327
+ logger.info("All gene diagnostic plots already exist. Skipping.")
1328
+ return
1329
+
1330
+ logger.info(f"Rendering {len(tasks_to_run)} diagnostic plots...")
1331
+
1332
+ # Process sample by sample to save memory
1333
+ all_futures = []
1334
+ futures_lock = threading.Lock()
1335
+ max_workers = 20
1336
+
1337
+ # Use semaphore to limit the number of pending tasks in the executor queue.
1338
+ max_pending_tasks = max_workers * 4 # e.g., 80 pending tasks
1339
+ task_semaphore = threading.Semaphore(max_pending_tasks)
1340
+
1341
+ max_loading_threads = min(4, len(sample_names_sorted))
1342
+
1343
+ # Group tasks by sample for efficient processing
1344
+ tasks_by_sample = {}
1345
+ for sample, trait, gene in tasks_to_run:
1346
+ if sample not in tasks_by_sample: tasks_by_sample[sample] = []
1347
+ tasks_by_sample[sample].append((trait, gene))
1348
+
1349
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
1350
+ def process_sample(sample_name):
1351
+ if sample_name not in tasks_by_sample: return
1352
+
1353
+ h5ad_path = report_config.sample_h5ad_dict[sample_name]
1354
+ sample_metadata = metadata_common[metadata_common['sample_name'] == sample_name]
1355
+ if sample_metadata.empty: return
1356
+
1357
+ sample_spots = sample_metadata['spot'].values
1358
+ coords = sample_metadata[['sx', 'sy']].values
1359
+
1360
+ try:
1361
+ adata_rep = ad.read_h5ad(h5ad_path)
1362
+ suffix = f"|{sample_name}"
1363
+ if not str(adata_rep.obs_names[0]).endswith(suffix):
1364
+ adata_rep.obs_names = adata_rep.obs_names.astype(str) + suffix
1365
+
1366
+ if 'log1p' in adata_rep.uns and adata_rep.X is not None:
1367
+ is_count = False
1368
+ else:
1369
+ is_count, _ = setup_data_layer(adata_rep, report_config.data_layer, verbose=False)
1370
+
1371
+ if is_count:
1372
+ sc.pp.normalize_total(adata_rep, target_sum=1e4)
1373
+ sc.pp.log1p(adata_rep)
1374
+
1375
+ sample_genes = [g for trait, g in tasks_by_sample[sample_name]]
1376
+ unique_sample_genes = list(set(sample_genes))
1377
+ available_genes = [g for g in unique_sample_genes if g in adata_rep.var_names]
1378
+
1379
+ if not available_genes:
1380
+ del adata_rep
1381
+ return
1382
+
1383
+ # Extract to memory and then delete the large AnnData object to save RAM
1384
+ adata_sample_exp = adata_rep[sample_spots, available_genes].copy()
1385
+ adata_sample_gss = gss_adata[sample_spots, available_genes].to_memory()
1386
+
1387
+ del adata_rep
1388
+ gc.collect()
1389
+
1390
+ except Exception as e:
1391
+ logger.error(f"Failed to load data for {sample_name}: {e}")
1392
+ return
1393
+
1394
+ safe_sample = "".join(c if c.isalnum() else "_" for c in sample_name)
1395
+
1396
+ for trait, gene in tasks_by_sample[sample_name]:
1397
+ if gene not in available_genes: continue
1398
+
1399
+ exp_vals = np.ravel(adata_sample_exp[:, gene].X.toarray() if sp.issparse(adata_sample_exp[:, gene].X) else adata_sample_exp[:, gene].X).astype(np.float32)
1400
+ gss_vals = np.ravel(adata_sample_gss[:, gene].X.toarray() if sp.issparse(adata_sample_gss[:, gene].X) else adata_sample_gss[:, gene].X).astype(np.float32)
1401
+
1402
+ # Submit Expression task with semaphore throttling
1403
+ task_semaphore.acquire()
1404
+ try:
1405
+ f_exp = executor.submit(_render_single_sample_gene_plot_task, {
1406
+ 'gene': gene, 'sample_name': sample_name, 'trait': trait,
1407
+ 'plot_type': 'exp', 'coords': coords.copy(), 'values': exp_vals,
1408
+ 'output_path': gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_exp.png",
1409
+ 'plot_origin': report_config.plot_origin, 'fig_width': 6.0, 'dpi': 150,
1410
+ })
1411
+ f_exp.add_done_callback(lambda _: task_semaphore.release())
1412
+ with futures_lock:
1413
+ all_futures.append(f_exp)
1414
+ except Exception:
1415
+ task_semaphore.release()
1416
+ raise
1417
+
1418
+ # Submit GSS task with semaphore throttling
1419
+ task_semaphore.acquire()
1420
+ try:
1421
+ f_gss = executor.submit(_render_single_sample_gene_plot_task, {
1422
+ 'gene': gene, 'sample_name': sample_name, 'trait': trait,
1423
+ 'plot_type': 'gss', 'coords': coords.copy(), 'values': gss_vals,
1424
+ 'output_path': gene_plot_dir / f"gene_{trait}_{gene}_{safe_sample}_gss.png",
1425
+ 'plot_origin': report_config.plot_origin, 'fig_width': 6.0, 'dpi': 150,
1426
+ })
1427
+ f_gss.add_done_callback(lambda _: task_semaphore.release())
1428
+ with futures_lock:
1429
+ all_futures.append(f_gss)
1430
+ except Exception:
1431
+ task_semaphore.release()
1432
+ raise
1433
+
1434
+ # Cleanup sample slices
1435
+ del adata_sample_exp, adata_sample_gss
1436
+ gc.collect()
1437
+
1438
+ with ThreadPoolExecutor(max_workers=max_loading_threads) as loader:
1439
+ loader.map(process_sample, tasks_by_sample.keys())
1440
+
1441
+ # Wait for all
1442
+ logger.info(f"Waiting for {len(all_futures)} diagnostic plot tasks to complete...")
1443
+ for future in as_completed(all_futures):
1444
+ future.result()
1445
+
1446
+ logger.info("Successfully finished rendering diagnostic plots.")
1447
+
1448
+
1449
+
1450
+ # =============================================================================
1451
+ # Results Collection Functions
1452
+ # =============================================================================
1453
+
1454
+ def _collect_cauchy_results(
1455
+ report_config: QuickModeConfig,
1456
+ traits: list[str],
1457
+ report_dir: Path
1458
+ ):
1459
+ """Collect and save Cauchy combination results."""
1460
+ logger.info("Collecting Cauchy combination results...")
1461
+ all_cauchy = []
1462
+
1463
+ for trait in traits:
1464
+ for annotation in report_config.annotation_list:
1465
+ # Aggregated results
1466
+ cauchy_file_all = report_config.get_cauchy_result_file(trait, annotation=annotation, all_samples=True)
1467
+ if cauchy_file_all.exists():
1468
+ try:
1469
+ df = pd.read_csv(cauchy_file_all)
1470
+ df['trait'] = trait
1471
+ df['annotation_name'] = annotation
1472
+ df['type'] = 'aggregated'
1473
+ if 'sample_name' not in df.columns:
1474
+ df['sample_name'] = 'All Samples'
1475
+
1476
+ # Convert to -log10 scale for report
1477
+ df['mlog10_p_cauchy'] = -np.log10(df['p_cauchy'].clip(lower=1e-300))
1478
+ df['mlog10_p_median'] = -np.log10(df['p_median'].clip(lower=1e-300))
1479
+
1480
+ all_cauchy.append(df)
1481
+ except Exception as e:
1482
+ logger.warning(f"Failed to load aggregated Cauchy result {cauchy_file_all}: {e}")
1483
+
1484
+ # Per-sample results
1485
+ cauchy_file = report_config.get_cauchy_result_file(trait, annotation=annotation, all_samples=False)
1486
+ if cauchy_file.exists():
1487
+ try:
1488
+ df = pd.read_csv(cauchy_file)
1489
+ df['trait'] = trait
1490
+ df['annotation_name'] = annotation
1491
+ df['type'] = 'sample'
1492
+
1493
+ # Convert to -log10 scale for report
1494
+ df['mlog10_p_cauchy'] = -np.log10(df['p_cauchy'].clip(lower=1e-300))
1495
+ df['mlog10_p_median'] = -np.log10(df['p_median'].clip(lower=1e-300))
1496
+
1497
+ all_cauchy.append(df)
1498
+ except Exception as e:
1499
+ logger.warning(f"Failed to load sample Cauchy result {cauchy_file}: {e}")
1500
+
1501
+ if all_cauchy:
1502
+ combined_cauchy = pd.concat(all_cauchy, ignore_index=True)
1503
+ if 'sample' in combined_cauchy.columns and 'sample_name' not in combined_cauchy.columns:
1504
+ combined_cauchy = combined_cauchy.rename(columns={'sample': 'sample_name'})
1505
+
1506
+ cauchy_save_path = report_dir / "cauchy_results.csv"
1507
+ combined_cauchy.to_csv(cauchy_save_path, index=False)
1508
+ logger.info(f"Saved {len(combined_cauchy)} Cauchy results to {cauchy_save_path}")
1509
+ else:
1510
+ pd.DataFrame(columns=['trait', 'annotation_name', 'mlog10_p_cauchy', 'mlog10_p_median', 'top_95_quantile', 'type', 'sample_name']).to_csv(
1511
+ report_dir / "cauchy_results.csv", index=False
1512
+ )
1513
+ logger.warning("No Cauchy results found to save.")
1514
+
1515
+
1516
+ def _save_report_metadata(
1517
+ report_config: QuickModeConfig,
1518
+ traits: list[str],
1519
+ sample_names: list[str],
1520
+ report_dir: Path,
1521
+ chrom_tick_positions: dict | None = None,
1522
+ umap_info: dict | None = None,
1523
+ is_3d: bool = False,
1524
+ spatial_3d_path: str | None = None
1525
+ ):
1526
+ """Save report configuration metadata as JSON."""
1527
+ logger.info("Saving report configuration metadata...")
1528
+ report_meta = report_config.to_dict_with_paths_as_strings()
1529
+
1530
+ report_meta['traits'] = traits
1531
+ report_meta['samples'] = sample_names
1532
+ report_meta['annotations'] = report_config.annotation_list
1533
+ report_meta['top_corr_genes'] = report_config.top_corr_genes
1534
+ report_meta['plot_origin'] = report_config.plot_origin
1535
+ report_meta['legend_marker_size'] = report_config.legend_marker_size
1536
+ report_meta['downsampling_n_spots_2d'] = getattr(report_config, 'downsampling_n_spots_2d', 250000)
1537
+
1538
+ # Add chromosome tick positions for Manhattan plot
1539
+ if chrom_tick_positions:
1540
+ report_meta['chrom_tick_positions'] = chrom_tick_positions
1541
+
1542
+ # Add UMAP info
1543
+ if umap_info:
1544
+ report_meta['umap_info'] = umap_info
1545
+
1546
+ # Add 3D visualization info
1547
+ report_meta['is_3d'] = is_3d
1548
+ report_meta['has_3d_widget'] = spatial_3d_path is not None
1549
+ report_meta['spatial_3d_widget_path'] = spatial_3d_path
1550
+
1551
+ # Add dataset type info for conditional section rendering
1552
+ dataset_type_value = report_config.dataset_type
1553
+ if hasattr(dataset_type_value, 'value'):
1554
+ dataset_type_value = dataset_type_value.value
1555
+ report_meta['dataset_type'] = dataset_type_value
1556
+ report_meta['dataset_type_label'] = {
1557
+ 'spatial3D': 'Spatial 3D',
1558
+ 'spatial2D': 'Spatial 2D',
1559
+ 'scRNA': 'Single Cell RNA-seq'
1560
+ }.get(dataset_type_value, dataset_type_value)
1561
+
1562
+ with open(report_dir / "report_meta.json", "w") as f:
1563
+ json.dump(report_meta, f)
1564
+
1565
+
1566
+ def _copy_js_assets(report_dir: Path):
1567
+ """Copy bundled JS assets for local usage (no network required)."""
1568
+ logger.info("Copying JS assets for local usage...")
1569
+ js_lib_dir = report_dir / "js_lib"
1570
+ js_lib_dir.mkdir(exist_ok=True)
1571
+
1572
+ # Copy bundled Alpine.js and Tailwind.js from static folder
1573
+ static_js_dir = Path(__file__).parent / "static" / "js_lib"
1574
+ if static_js_dir.exists():
1575
+ for js_file in static_js_dir.glob("*.js"):
1576
+ dest = js_lib_dir / js_file.name
1577
+ if not dest.exists():
1578
+ shutil.copy2(js_file, dest)
1579
+ logger.info(f"Copied {js_file.name}")
1580
+
1581
+ # Copy plotly.min.js from installed plotly Python package
1582
+ plotly_js_src = Path(plotly.__file__).parent / "package_data" / "plotly.min.js"
1583
+ plotly_dest = js_lib_dir / "plotly.min.js"
1584
+ if plotly_js_src.exists() and not plotly_dest.exists():
1585
+ shutil.copy2(plotly_js_src, plotly_dest)
1586
+ logger.info("Copied plotly.min.js from plotly package")
1587
+
1588
+
1589
+ # =============================================================================
1590
+ # JS Export Functions
1591
+ # =============================================================================
1592
+
1593
+ def export_data_as_js_modules(data_dir: Path):
1594
+ """
1595
+ Convert the CSV data in the report directory into JavaScript modules (.js files)
1596
+ that assign the data to window global variables.
1597
+ This allows loading data via <script> tags locally without CORS issues.
1598
+ """
1599
+ logger.info("Exporting data as JS modules...")
1600
+ js_data_dir = data_dir / "js_data"
1601
+ js_data_dir.mkdir(exist_ok=True)
1602
+
1603
+ meta = {}
1604
+ meta_file = data_dir / "report_meta.json"
1605
+ if meta_file.exists():
1606
+ try:
1607
+ with open(meta_file) as f:
1608
+ meta = json.load(f)
1609
+ except Exception as e:
1610
+ logger.warning(f"Failed to load report_meta.json: {e}")
1611
+
1612
+ # Use per-sample export for efficient on-demand loading (instead of monolithic _export_metadata_js)
1613
+ _export_per_sample_spatial_js(data_dir, js_data_dir, meta)
1614
+ _export_cauchy_js(data_dir, js_data_dir)
1615
+ _export_manhattan_js(data_dir, js_data_dir)
1616
+ _export_umap_js(data_dir, js_data_dir, meta)
1617
+ _export_report_meta_js(data_dir, js_data_dir, meta)
1618
+
1619
+ logger.info(f"JS modules exported to {js_data_dir}")
1620
+
1621
+
1622
+
1623
+
1624
+ def _export_per_sample_spatial_js(data_dir: Path, js_data_dir: Path, meta: dict):
1625
+ """
1626
+ Export spatial metadata as per-sample JS files for efficient on-demand loading.
1627
+
1628
+ This creates:
1629
+ - sample_index.js: Lightweight index mapping sample names to file info
1630
+ - sample_{name}_spatial.js: Per-sample data with coordinates, traits, annotations
1631
+
1632
+ This approach is much more efficient than loading all samples at once,
1633
+ especially for datasets with millions of spots.
1634
+ """
1635
+ metadata_file = data_dir / "spot_metadata.csv"
1636
+ if not metadata_file.exists():
1637
+ logger.warning("spot_metadata.csv not found, skipping per-sample spatial export")
1638
+ return
1639
+
1640
+ logger.info("Exporting per-sample spatial data as JS modules...")
1641
+ df = pd.read_csv(metadata_file)
1642
+
1643
+ samples = meta.get('samples', [])
1644
+ traits = meta.get('traits', [])
1645
+ annotations = meta.get('annotations', [])
1646
+
1647
+ if not samples:
1648
+ # Fallback to unique sample names from data
1649
+ samples = df['sample_name'].unique().tolist() if 'sample_name' in df.columns else []
1650
+
1651
+ sample_index = {}
1652
+ max_spots_2d = meta.get('downsampling_n_spots_2d', 250000)
1653
+
1654
+ for sample_name in samples:
1655
+ sample_df = df[df['sample_name'] == sample_name]
1656
+
1657
+ if len(sample_df) == 0:
1658
+ logger.warning(f"No data found for sample: {sample_name}")
1659
+ continue
1660
+
1661
+ # Track original count before downsampling
1662
+ n_spots_original = len(sample_df)
1663
+
1664
+ # Downsample if needed for 2D visualization performance
1665
+ if len(sample_df) > max_spots_2d:
1666
+ logger.info(f"Downsampling {sample_name} from {len(sample_df):,} to {max_spots_2d:,} spots for 2D visualization")
1667
+ sample_df = sample_df.sample(n=max_spots_2d, random_state=42)
1668
+
1669
+ # Calculate point size for this sample
1670
+ _, point_size = estimate_plotly_point_size(sample_df[['sx','sy']], DEFAULT_PIXEL_WIDTH=560)
1671
+
1672
+ # Build columnar data structure for efficient ScatterGL rendering
1673
+ data_struct = {
1674
+ 'spot': sample_df['spot'].tolist(),
1675
+ 'sx': sample_df['sx'].tolist(),
1676
+ 'sy': sample_df['sy'].tolist(),
1677
+ 'point_size': round(point_size, 2),
1678
+ }
1679
+
1680
+ # Add 3D coordinates if present
1681
+ for coord in ['3d_x', '3d_y', '3d_z']:
1682
+ if coord in sample_df.columns:
1683
+ data_struct[coord] = sample_df[coord].tolist()
1684
+
1685
+ # Add all annotation columns
1686
+ for anno in annotations:
1687
+ if anno in sample_df.columns:
1688
+ data_struct[anno] = sample_df[anno].tolist()
1689
+
1690
+ # Add all trait columns (round to 1 decimal to reduce file size)
1691
+ for trait in traits:
1692
+ if trait in sample_df.columns:
1693
+ data_struct[trait] = [
1694
+ round(v, 1) if pd.notna(v) else None
1695
+ for v in sample_df[trait]
1696
+ ]
1697
+
1698
+ # Create safe filename (replace non-alphanumeric chars)
1699
+ safe_name = "".join(c if c.isalnum() else "_" for c in sample_name)
1700
+ var_name = f"GSMAP_SAMPLE_{safe_name}"
1701
+ file_name = f"sample_{safe_name}_spatial.js"
1702
+
1703
+ # Write per-sample JS file
1704
+ js_content = f"window.{var_name} = {json.dumps(data_struct, separators=(',', ':'))};"
1705
+ with open(js_data_dir / file_name, "w", encoding='utf-8') as f:
1706
+ f.write(js_content)
1707
+
1708
+ # Track in sample index
1709
+ sample_index[sample_name] = {
1710
+ 'var_name': var_name,
1711
+ 'file': file_name,
1712
+ 'n_spots': len(sample_df),
1713
+ 'n_spots_original': n_spots_original,
1714
+ 'point_size': round(point_size, 2)
1715
+ }
1716
+
1717
+ if n_spots_original > len(sample_df):
1718
+ logger.info(f" Exported {sample_name}: {len(sample_df):,} spots (downsampled from {n_spots_original:,}) -> {file_name}")
1719
+ else:
1720
+ logger.info(f" Exported {sample_name}: {len(sample_df):,} spots -> {file_name}")
1721
+
1722
+ # Export lightweight sample index (loaded upfront)
1723
+ js_content = f"window.GSMAP_SAMPLE_INDEX = {json.dumps(sample_index, separators=(',', ':'))};"
1724
+ with open(js_data_dir / "sample_index.js", "w", encoding='utf-8') as f:
1725
+ f.write(js_content)
1726
+
1727
+ logger.info(f"Exported {len(sample_index)} per-sample spatial JS files + sample_index.js")
1728
+
1729
+
1730
+ def _export_cauchy_js(data_dir: Path, js_data_dir: Path):
1731
+ """Export Cauchy results as JS module."""
1732
+ cauchy_file = data_dir / "cauchy_results.csv"
1733
+ if cauchy_file.exists():
1734
+ df = pd.read_csv(cauchy_file)
1735
+ data_json = df.to_json(orient='records')
1736
+ js_content = f"window.GSMAP_CAUCHY = {data_json};"
1737
+ with open(js_data_dir / "cauchy_results.js", "w", encoding='utf-8') as f:
1738
+ f.write(js_content)
1739
+
1740
+
1741
+ def _export_manhattan_js(data_dir: Path, js_data_dir: Path):
1742
+ """Export Manhattan data as JS modules (one per trait)."""
1743
+ manhattan_dir = data_dir / "manhattan_data"
1744
+ if not manhattan_dir.exists():
1745
+ return
1746
+
1747
+ for csv_file in manhattan_dir.glob("*_manhattan.csv"):
1748
+ trait = csv_file.name.replace("_manhattan.csv", "")
1749
+ try:
1750
+ df = pd.read_csv(csv_file)
1751
+ if 'P' in df.columns and 'logp' not in df.columns:
1752
+ df['logp'] = -np.log10(df['P'] + 1e-300)
1753
+
1754
+ data_struct = {
1755
+ 'x': df['BP_cum'].tolist(),
1756
+ 'y': df['logp'].tolist(),
1757
+ 'gene': df['GENE'].fillna("").tolist(),
1758
+ 'chr': df['CHR'].astype(int).tolist(),
1759
+ 'snp': df['SNP'].tolist() if 'SNP' in df.columns else [],
1760
+ 'is_top': df['is_top_pcc'].astype(int).tolist() if 'is_top_pcc' in df.columns else [],
1761
+ 'bp': df['BP'].astype(int).tolist() if 'BP' in df.columns else []
1762
+ }
1763
+
1764
+ json_str = json.dumps(data_struct, separators=(',', ':'))
1765
+ safe_trait = "".join(c if c.isalnum() else "_" for c in trait)
1766
+ js_content = f"window.GSMAP_MANHATTAN_{safe_trait} = {json_str};"
1767
+
1768
+ with open(js_data_dir / f"manhattan_{trait}.js", "w", encoding='utf-8') as f:
1769
+ f.write(js_content)
1770
+
1771
+ except Exception as e:
1772
+ logger.warning(f"Failed to export Manhattan JS for {trait}: {e}")
1773
+
1774
+
1775
+ def _export_report_meta_js(data_dir: Path, js_data_dir: Path, meta: dict):
1776
+ """Export report metadata as JS module."""
1777
+ if meta:
1778
+ js_content = f"window.GSMAP_REPORT_META = {json.dumps(meta, separators=(',', ':'))};"
1779
+ with open(js_data_dir / "report_meta.js", "w", encoding='utf-8') as f:
1780
+ f.write(js_content)
1781
+
1782
+
1783
+ def _export_umap_js(data_dir: Path, js_data_dir: Path, meta: dict):
1784
+ """Export UMAP data as JS module."""
1785
+ umap_file = data_dir / "umap_data.csv"
1786
+ if umap_file.exists():
1787
+ df = pd.read_csv(umap_file)
1788
+
1789
+ # Build efficient columnar structure for ScatterGL
1790
+ data_struct = {
1791
+ 'spot': df['spot'].tolist(),
1792
+ 'sample_name': df['sample_name'].tolist(),
1793
+ 'umap_cell_x': df['umap_cell_x'].tolist(),
1794
+ 'umap_cell_y': df['umap_cell_y'].tolist(),
1795
+ 'point_size_cell': meta.get('umap_info', {}).get('point_size_cell', 4),
1796
+ 'default_opacity': meta.get('umap_info', {}).get('default_opacity', 0.8)
1797
+ }
1798
+
1799
+ # Add niche UMAP if available
1800
+ if 'umap_niche_x' in df.columns:
1801
+ data_struct['umap_niche_x'] = df['umap_niche_x'].tolist()
1802
+ data_struct['umap_niche_y'] = df['umap_niche_y'].tolist()
1803
+ data_struct['point_size_niche'] = meta.get('umap_info', {}).get('point_size_niche', 4)
1804
+ data_struct['has_niche'] = True
1805
+ else:
1806
+ data_struct['has_niche'] = False
1807
+
1808
+ # Add all annotation columns (excluding known columns)
1809
+ known_cols = {'spot', 'sample_name', 'umap_cell_x', 'umap_cell_y', 'umap_niche_x', 'umap_niche_y'}
1810
+ annotation_cols = [c for c in df.columns if c not in known_cols]
1811
+ data_struct['annotation_columns'] = annotation_cols
1812
+
1813
+ from gsMap.report.visualize import _create_color_map
1814
+ P_COLOR = ['#313695', '#4575b4', '#74add1', '#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
1815
+ color_maps = {}
1816
+
1817
+ for col in annotation_cols:
1818
+ data_struct[col] = df[col].tolist()
1819
+ # Generate color map for each column
1820
+ if pd.api.types.is_numeric_dtype(df[col]):
1821
+ color_maps[col] = P_COLOR
1822
+ else:
1823
+ # Convert to string before sorting to avoid TypeError with mixed float/str (e.g. NaN)
1824
+ unique_vals = sorted(df[col].astype(str).unique().tolist())
1825
+ color_maps[col] = _create_color_map(unique_vals, hex=True, rng=42)
1826
+
1827
+ data_struct['color_maps'] = color_maps
1828
+
1829
+ js_content = f"window.GSMAP_UMAP = {json.dumps(data_struct, separators=(',', ':'))};"
1830
+ with open(js_data_dir / "umap_data.js", "w", encoding='utf-8') as f:
1831
+ f.write(js_content)
1832
+ logger.info(f"Exported UMAP data with {len(df)} points")