gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,134 @@
1
+
2
+ import logging
3
+ import time
4
+
5
+ from gsMap.cauchy_combination_test import run_Cauchy_combination
6
+ from gsMap.config import QuickModeConfig
7
+ from gsMap.config.cauchy_config import check_cauchy_done
8
+ from gsMap.config.find_latent_config import check_find_latent_done
9
+ from gsMap.config.latent2gene_config import check_latent2gene_done
10
+ from gsMap.config.quick_mode_config import check_report_done
11
+ from gsMap.config.spatial_ldsc_config import check_spatial_ldsc_done
12
+ from gsMap.find_latent import run_find_latent_representation
13
+ from gsMap.latent2gene import run_latent_to_gene
14
+ from gsMap.report import run_report
15
+ from gsMap.spatial_ldsc.spatial_ldsc_jax import run_spatial_ldsc_jax
16
+
17
+ logger = logging.getLogger("gsMap.pipeline")
18
+
19
+ def format_duration(seconds):
20
+ hours = int(seconds // 3600)
21
+ minutes = int((seconds % 3600) // 60)
22
+ if hours > 0:
23
+ return f"{hours}h {minutes}m"
24
+ else:
25
+ return f"{minutes}m {int(seconds % 60)}s"
26
+
27
+
28
+ def run_quick_mode(config: QuickModeConfig):
29
+ """
30
+ Run the Quick Mode pipeline.
31
+ """
32
+ logger.info("Starting Quick Mode pipeline")
33
+ pipeline_start_time = time.time()
34
+
35
+ steps = ["find_latent", "latent2gene", "spatial_ldsc", "cauchy", "report"]
36
+ try:
37
+ start_idx = steps.index(config.start_step)
38
+ except ValueError:
39
+ raise ValueError(f"Invalid start_step: {config.start_step}. Must be one of {steps}")
40
+
41
+ stop_idx = len(steps) - 1
42
+ if config.stop_step:
43
+ try:
44
+ stop_idx = steps.index(config.stop_step)
45
+ except ValueError:
46
+ raise ValueError(f"Invalid stop_step: {config.stop_step}. Must be one of {steps}")
47
+
48
+ if start_idx > stop_idx:
49
+ raise ValueError(f"start_step ({config.start_step}) must be before or equal to stop_step ({config.stop_step})")
50
+
51
+ # Step 1: Find Latent Representations
52
+ if start_idx <= 0 <= stop_idx:
53
+ logger.info("=== Step 1: Find Latent Representations ===")
54
+ start_time = time.time()
55
+
56
+ if check_find_latent_done(config):
57
+ logger.info(f"Find latent representations already done (verified via {config.find_latent_metadata_path}). Skipping...")
58
+ else:
59
+ run_find_latent_representation(config.find_latent_config)
60
+
61
+ logger.info(f"Step 1 completed in {format_duration(time.time() - start_time)}")
62
+
63
+ # Step 2: Latent to Gene
64
+ if start_idx <= 1 <= stop_idx:
65
+ logger.info("=== Step 2: Latent to Gene Mapping ===")
66
+ start_time = time.time()
67
+
68
+ if check_latent2gene_done(config):
69
+ logger.info("Latent to gene mapping already done. Skipping...")
70
+ else:
71
+ run_latent_to_gene(config.latent2gene_config)
72
+
73
+ logger.info(f"Step 2 completed in {format_duration(time.time() - start_time)}")
74
+
75
+ # Get lists of traits to process
76
+ if not config.sumstats_config_dict:
77
+ # Check if we should warn? Only if running step 3,4,5
78
+ if start_idx <= 4 and stop_idx >= 2:
79
+ logger.warning("No summary statistics provided. Steps requiring GWAS data (Spatial LDSC, Cauchy, Report) may fail or do nothing if relying on them.")
80
+
81
+ traits_to_process = config.sumstats_config_dict
82
+
83
+ # Step 3: Spatial LDSC
84
+ if start_idx <= 2 <= stop_idx:
85
+ logger.info("=== Step 3: Spatial LDSC ===")
86
+ start_time = time.time()
87
+
88
+
89
+ traits_remaining = {}
90
+ for trait_name, sumstats_path in traits_to_process.items():
91
+ if check_spatial_ldsc_done(config, trait_name):
92
+ logger.info(f"Spatial LDSC result already exists for {trait_name}. Skipping...")
93
+ else:
94
+ traits_remaining[trait_name] = sumstats_path
95
+
96
+ if not traits_remaining:
97
+ logger.info("All traits have been processed for Spatial LDSC. Skipping step...")
98
+ else:
99
+ sldsc_config = config.spatial_ldsc_config
100
+ # Update config to run only remaining traits
101
+ sldsc_config.sumstats_config_dict = traits_remaining
102
+ run_spatial_ldsc_jax(sldsc_config)
103
+
104
+
105
+ logger.info(f"Step 3 completed in {format_duration(time.time() - start_time)}")
106
+
107
+ # Step 4: Cauchy Combination
108
+ if start_idx <= 3 <= stop_idx:
109
+ logger.info("=== Step 4: Cauchy Combination ===")
110
+ start_time = time.time()
111
+
112
+ cauchy_check_list = [check_cauchy_done(config, trait_name) for trait_name in traits_to_process]
113
+ if all(cauchy_check_list):
114
+ logger.info("Cauchy combination already done for all traits. Skipping...")
115
+ else:
116
+ cauchy_config = config.cauchy_config
117
+ run_Cauchy_combination(cauchy_config)
118
+
119
+ logger.info(f"Step 4 completed in {format_duration(time.time() - start_time)}")
120
+
121
+ # Step 5: Report
122
+ if start_idx <= 4 <= stop_idx:
123
+ logger.info("=== Step 5: Generate Report ===")
124
+ start_time = time.time()
125
+
126
+ if check_report_done(config, verbose=True):
127
+ logger.info("Report already exists. Skipping...")
128
+ else:
129
+
130
+ run_report(config,)
131
+
132
+ logger.info(f"Step 5 completed in {format_duration(time.time() - start_time)}")
133
+
134
+ logger.info(f"Pipeline completed successfully in {format_duration(time.time() - pipeline_start_time)}")
@@ -0,0 +1,2 @@
1
+ from .report import run_report
2
+ from .report_data import ReportDataManager
@@ -0,0 +1,375 @@
1
+ import logging
2
+ import multiprocessing
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import anndata as ad
8
+ import numpy as np
9
+ import pandas as pd
10
+ from scipy.stats import norm
11
+
12
+ from gsMap.config import DiagnosisConfig
13
+ from gsMap.utils.manhattan_plot import ManhattanPlot
14
+ from gsMap.utils.regression_read import _read_chr_files
15
+
16
+ from .visualize import draw_scatter, estimate_plotly_point_size, load_ldsc, load_st_coord
17
+
18
+ warnings.filterwarnings("ignore", category=FutureWarning)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def convert_z_to_p(gwas_data):
23
+ """Convert Z-scores to P-values."""
24
+ gwas_data["P"] = norm.sf(abs(gwas_data["Z"])) * 2
25
+ min_p_value = 1e-300
26
+ gwas_data["P"] = gwas_data["P"].clip(lower=min_p_value)
27
+ return gwas_data
28
+
29
+
30
+ def load_gene_diagnostic_info(config: DiagnosisConfig, adata: ad.AnnData | None = None):
31
+ """Load or compute gene diagnostic info."""
32
+ gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
33
+ if gene_diagnostic_info_save_path.exists():
34
+ logger.info(
35
+ f"Loading gene diagnostic information from {gene_diagnostic_info_save_path}..."
36
+ )
37
+ return pd.read_csv(gene_diagnostic_info_save_path)
38
+ else:
39
+ logger.info(
40
+ "Gene diagnostic information not found. Calculating gene diagnostic information..."
41
+ )
42
+ return compute_gene_diagnostic_info(config, adata=adata)
43
+
44
+
45
+ def compute_gene_diagnostic_info(config: DiagnosisConfig, adata: ad.AnnData | None = None):
46
+ """Calculate gene diagnostic info and save it to adata."""
47
+ logger.info("Loading ST data and LDSC results...")
48
+
49
+ if adata is None:
50
+ adata = ad.read_h5ad(config.hdf5_with_latent_path)
51
+
52
+ mk_score = pd.read_feather(config.mkscore_feather_path)
53
+ mk_score.set_index("HUMAN_GENE_SYM", inplace=True)
54
+ mk_score = mk_score.T
55
+ trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
56
+
57
+ # Align marker scores with trait LDSC results
58
+ mk_score = mk_score.loc[trait_ldsc_result.index]
59
+
60
+ # Filter out genes with no variation
61
+ has_variation = (~mk_score.eq(mk_score.iloc[0], axis=1)).any()
62
+ mk_score = mk_score.loc[:, has_variation]
63
+
64
+ logger.info("Calculating correlation between gene marker scores and trait logp-values...")
65
+ corr = mk_score.corrwith(trait_ldsc_result["logp"])
66
+ corr.name = "PCC"
67
+
68
+ grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
69
+ max_annotations = grouped_mk_score.idxmax()
70
+
71
+ high_GSS_Gene_annotation_pair = pd.DataFrame(
72
+ {
73
+ "Gene": max_annotations.index,
74
+ "Annotation": max_annotations.values,
75
+ "Median_GSS": grouped_mk_score.max().values,
76
+ }
77
+ )
78
+
79
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(
80
+ corr, left_on="Gene", right_index=True
81
+ )
82
+
83
+ # Prepare the final gene diagnostic info dataframe
84
+ gene_diagnostic_info_cols = ["Gene", "Annotation", "Median_GSS", "PCC"]
85
+ gene_diagnostic_info = (
86
+ high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols]
87
+ .drop_duplicates()
88
+ .dropna(subset=["Gene"])
89
+ )
90
+ gene_diagnostic_info.sort_values("PCC", ascending=False, inplace=True)
91
+
92
+ # Save gene diagnostic info to a file
93
+ gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
94
+ gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
95
+ logger.info(f"Gene diagnostic information saved to {gene_diagnostic_info_save_path}.")
96
+
97
+ return gene_diagnostic_info.reset_index()
98
+
99
+
100
+ def load_gwas_data(sumstats_file):
101
+ """Load and process GWAS data."""
102
+ logger.info("Loading and processing GWAS data...")
103
+ gwas_data = pd.read_csv(sumstats_file, compression="gzip", sep="\t")
104
+ return convert_z_to_p(gwas_data)
105
+
106
+
107
+ def load_snp_gene_pairs(config: DiagnosisConfig):
108
+ """Load SNP-gene pairs from multiple chromosomes."""
109
+ ldscore_save_dir = Path(config.ldscore_save_dir)
110
+ snp_gene_pair_file_prefix = ldscore_save_dir / "SNP_gene_pair/SNP_gene_pair_chr"
111
+ return pd.concat(
112
+ [
113
+ pd.read_feather(file)
114
+ for file in _read_chr_files(snp_gene_pair_file_prefix.as_posix(), suffix=".feather")
115
+ ]
116
+ )
117
+
118
+
119
+ def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
120
+ """Filter the SNPs based on significance levels."""
121
+ pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort["P"] < 1e-5
122
+ pass_suggestive_line_number = pass_suggestive_line_mask.sum()
123
+
124
+ if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
125
+ snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
126
+ logger.info(
127
+ f"To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted."
128
+ )
129
+ else:
130
+ snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
131
+ logger.info(
132
+ f"To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted."
133
+ )
134
+
135
+ return snps2plot
136
+
137
+
138
+ def generate_manhattan_plot(config: DiagnosisConfig, adata: ad.AnnData | None = None):
139
+ """Generate Manhattan plot."""
140
+ # report_save_dir = config.get_report_dir(config.trait_name)
141
+ gwas_data = load_gwas_data(config.sumstats_file)
142
+ snp_gene_pair = load_snp_gene_pairs(config)
143
+ gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on="SNP", how="inner").rename(
144
+ columns={"gene_name": "GENE"}
145
+ )
146
+ gene_diagnostic_info = load_gene_diagnostic_info(config, adata=adata)
147
+ gwas_data_with_gene_annotation = gwas_data_with_gene.merge(
148
+ gene_diagnostic_info, left_on="GENE", right_on="Gene", how="left"
149
+ )
150
+
151
+ gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
152
+ ~gwas_data_with_gene_annotation["Annotation"].isna()
153
+ ]
154
+ gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values("P")
155
+
156
+ snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
157
+ gwas_data_to_plot = gwas_data_with_gene_annotation[
158
+ gwas_data_with_gene_annotation["SNP"].isin(snps2plot)
159
+ ].reset_index(drop=True)
160
+ gwas_data_to_plot["Annotation_text"] = (
161
+ "PCC: "
162
+ + gwas_data_to_plot["PCC"].round(2).astype(str)
163
+ + "<br>"
164
+ + "Annotation: "
165
+ + gwas_data_to_plot["Annotation"].astype(str)
166
+ )
167
+
168
+ # Verify data integrity
169
+ if gwas_data_with_gene_annotation_sort.empty:
170
+ raise ValueError("Filtered GWAS data is empty, cannot create Manhattan plot")
171
+
172
+ if len(gwas_data_to_plot) == 0:
173
+ raise ValueError("No SNPs passed filtering criteria for Manhattan plot")
174
+
175
+ # Log some diagnostic information
176
+ logger.info(f"Creating Manhattan plot with {len(gwas_data_to_plot)} SNPs")
177
+ logger.info(f"Chromosome column values: {gwas_data_to_plot['CHR'].unique()}")
178
+
179
+ fig = ManhattanPlot(
180
+ dataframe=gwas_data_to_plot,
181
+ title="gsMap Diagnosis Manhattan Plot",
182
+ point_size=3,
183
+ highlight_gene_list=config.selected_genes
184
+ or gene_diagnostic_info.Gene.iloc[: config.top_corr_genes].tolist(),
185
+ suggestiveline_value=-np.log10(1e-5),
186
+ annotation="Annotation_text",
187
+ )
188
+
189
+ save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
190
+ fig.write_html(save_manhattan_plot_path)
191
+ logger.info(f"Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.")
192
+
193
+
194
+ def generate_GSS_distribution(config: DiagnosisConfig, adata: ad.AnnData):
195
+ """Generate GSS distribution plots."""
196
+ # mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM").T
197
+ # We should avoid loading large files inside workers if possible, or use memmap.
198
+ # For now, let's load it once here.
199
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM").T
200
+
201
+ plot_genes = (
202
+ config.selected_genes
203
+ or load_gene_diagnostic_info(config, adata=adata).Gene.iloc[: config.top_corr_genes].tolist()
204
+ )
205
+ if config.selected_genes is not None:
206
+ logger.info(
207
+ f"Generating GSS & Expression distribution plot for selected genes: {plot_genes}..."
208
+ )
209
+ else:
210
+ logger.info(
211
+ f"Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes..."
212
+ )
213
+
214
+ if config.customize_fig:
215
+ pixel_width, pixel_height, point_size = (
216
+ config.fig_width,
217
+ config.fig_height,
218
+ config.point_size,
219
+ )
220
+ else:
221
+ (pixel_width, pixel_height), point_size = estimate_plotly_point_size(
222
+ adata.obsm["spatial"]
223
+ )
224
+ sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
225
+
226
+ # save plot gene list
227
+ config.get_GSS_plot_select_gene_file(config.trait_name).write_text("\n".join(plot_genes))
228
+
229
+ paralleized_params = []
230
+ for selected_gene in plot_genes:
231
+ expression_series = pd.Series(
232
+ adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index, name="Expression"
233
+ )
234
+ threshold = np.quantile(expression_series, 0.9999)
235
+ expression_series[expression_series > threshold] = threshold
236
+
237
+ paralleized_params.append(
238
+ (
239
+ adata,
240
+ mk_score[[selected_gene]], # Pass only needed gene to save memory
241
+ expression_series,
242
+ selected_gene,
243
+ point_size,
244
+ pixel_width,
245
+ pixel_height,
246
+ sub_fig_save_dir,
247
+ config.project_name,
248
+ config.annotation,
249
+ config.plot_origin,
250
+ )
251
+ )
252
+
253
+ with multiprocessing.Pool(os.cpu_count() // 2) as pool:
254
+ pool.starmap(generate_and_save_plots, paralleized_params)
255
+ pool.close()
256
+ pool.join()
257
+
258
+
259
+ def generate_and_save_plots(
260
+ adata,
261
+ mk_score,
262
+ expression_series,
263
+ selected_gene,
264
+ point_size,
265
+ pixel_width,
266
+ pixel_height,
267
+ sub_fig_save_dir,
268
+ sample_name,
269
+ annotation,
270
+ plot_origin: str = "upper",
271
+ ):
272
+ """Generate and save the plots."""
273
+ select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
274
+ sub_fig_1 = draw_scatter(
275
+ select_gene_expression_with_space_coord,
276
+ title=f"{selected_gene} (Expression)",
277
+ annotation="annotation",
278
+ color_by="Expression",
279
+ point_size=point_size,
280
+ width=pixel_width,
281
+ height=pixel_height,
282
+ plot_origin=plot_origin,
283
+ )
284
+ save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, "Expression")
285
+
286
+ select_gene_GSS_with_space_coord = load_st_coord(
287
+ adata, mk_score[selected_gene].rename("GSS"), annotation
288
+ )
289
+ sub_fig_2 = draw_scatter(
290
+ select_gene_GSS_with_space_coord,
291
+ title=f"{selected_gene} (GSS)",
292
+ annotation="annotation",
293
+ color_by="GSS",
294
+ point_size=point_size,
295
+ width=pixel_width,
296
+ height=pixel_height,
297
+ plot_origin=plot_origin,
298
+ )
299
+ save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, "GSS")
300
+
301
+
302
+ def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
303
+ """Save the plot to HTML and PNG."""
304
+ save_sub_fig_path = (
305
+ sub_fig_save_dir / f"{sample_name}_{selected_gene}_{plot_type}_Distribution.png"
306
+ )
307
+ # sub_fig.write_html(str(save_sub_fig_path))
308
+ sub_fig.update_layout(showlegend=False)
309
+ sub_fig.write_image(save_sub_fig_path)
310
+ assert save_sub_fig_path.exists(), f"Failed to save {plot_type} plot for {selected_gene}."
311
+
312
+
313
+ def generate_gsMap_plot(config: DiagnosisConfig, adata: ad.AnnData):
314
+ """Generate gsMap plot."""
315
+ logger.info("Creating gsMap plot...")
316
+
317
+ trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
318
+ space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
319
+
320
+ if config.customize_fig:
321
+ pixel_width, pixel_height, point_size = (
322
+ config.fig_width,
323
+ config.fig_height,
324
+ config.point_size,
325
+ )
326
+ else:
327
+ (pixel_width, pixel_height), point_size = estimate_plotly_point_size(
328
+ adata.obsm["spatial"]
329
+ )
330
+ fig = draw_scatter(
331
+ space_coord_concat,
332
+ title=f"{config.trait_name} (gsMap)",
333
+ point_size=point_size,
334
+ width=pixel_width,
335
+ height=pixel_height,
336
+ annotation=config.annotation,
337
+ plot_origin=config.plot_origin,
338
+ )
339
+
340
+ output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
341
+ output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
342
+ output_file_png = output_file_html.with_suffix(".png")
343
+ output_file_csv = output_file_html.with_suffix(".csv")
344
+
345
+ fig.write_html(output_file_html)
346
+ fig.write_image(output_file_png)
347
+ space_coord_concat.to_csv(output_file_csv)
348
+
349
+ logger.info(f"gsMap plot created and saved in {output_dir}.")
350
+
351
+
352
+ def run_Diagnosis(config: DiagnosisConfig):
353
+ """Main function to run the diagnostic plot generation."""
354
+ adata = ad.read_h5ad(config.hdf5_with_latent_path)
355
+ if "pcc" not in adata.var.columns:
356
+ # Manual normalization and log1p to avoid scanpy dependency/warnings
357
+ if hasattr(adata.X, 'toarray'):
358
+ x_dense = adata.X.toarray()
359
+ else:
360
+ x_dense = adata.X
361
+
362
+ # Normalize to target sum 1e4
363
+ row_sums = x_dense.sum(axis=1)
364
+ row_sums[row_sums == 0] = 1 # Avoid division by zero
365
+ x_norm = (x_dense / row_sums.reshape(-1, 1)) * 1e4
366
+
367
+ # Log transformation
368
+ adata.X = np.log1p(x_norm)
369
+
370
+ if config.plot_type in ["gsMap", "all"]:
371
+ generate_gsMap_plot(config, adata=adata)
372
+ if config.plot_type in ["manhattan", "all"]:
373
+ generate_manhattan_plot(config, adata=adata)
374
+ if config.plot_type in ["GSS", "all"]:
375
+ generate_GSS_distribution(config, adata=adata)
gsMap/report/report.py ADDED
@@ -0,0 +1,100 @@
1
+ import logging
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import gsMap
6
+ from gsMap.config import QuickModeConfig
7
+
8
+ from .report_data import ReportDataManager
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def run_report(config: QuickModeConfig, run_parameters: dict = None):
13
+ """
14
+ Main entry point for report generation.
15
+ Prepares data and saves the interactive report as a standalone Alpine+Tailwind HTML folder.
16
+
17
+ Output structure:
18
+ project_dir/
19
+ ├── report_data/ # Data files (CSV, h5ad)
20
+ │ ├── spot_metadata.csv
21
+ │ ├── cauchy_results.csv
22
+ │ ├── umap_data.csv
23
+ │ ├── gene_list.csv
24
+ │ ├── gss_stats/
25
+ │ │ └── gene_trait_correlation_{trait}.csv
26
+ │ ├── manhattan_data/
27
+ │ │ └── {trait}_manhattan.csv
28
+ │ └── spatial_3d/
29
+ │ └── spatial_3d.h5ad
30
+
31
+ └── gsmap_web_report/ # Web report (self-contained)
32
+ ├── index.html
33
+ ├── report_meta.json
34
+ ├── execution_summary.yaml
35
+ ├── spatial_plots/
36
+ │ └── ldsc_{trait}.png
37
+ ├── gene_diagnostic_plots/
38
+ ├── annotation_plots/
39
+ ├── spatial_3d/
40
+ │ └── *.html
41
+ ├── js_lib/
42
+ └── js_data/
43
+ ├── gss_stats/
44
+ ├── sample_index.js
45
+ ├── sample_{name}_spatial.js
46
+ └── ... (other JS modules)
47
+ """
48
+ logger.info("Running gsMap Report Module (Alpine.js + Tailwind based)")
49
+
50
+ # 1. Use ReportDataManager to prepare all data and JS assets
51
+ manager = ReportDataManager(config)
52
+ web_report_dir = manager.run()
53
+
54
+ # 2. Save run_parameters for future reference
55
+ if run_parameters:
56
+ import yaml
57
+ with open(web_report_dir / "execution_summary.yaml", "w") as f:
58
+ yaml.dump(run_parameters, f)
59
+
60
+ # 3. Render the Jinja2 template
61
+ template_path = Path(__file__).parent / "static" / "template.html"
62
+ if not template_path.exists():
63
+ logger.error(f"Template file not found at {template_path}")
64
+ return
65
+
66
+ try:
67
+ from jinja2 import Template
68
+ with open(template_path, encoding="utf-8") as f:
69
+ template = Template(f.read())
70
+
71
+ # Prepare context
72
+ context = {
73
+ "title": f"gsMap Report - {config.project_name}",
74
+ "project_name": config.project_name,
75
+ "generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
76
+ "gsmap_version": getattr(gsMap, "__version__", "unknown"),
77
+ }
78
+
79
+ rendered_html = template.render(**context)
80
+
81
+ report_file = web_report_dir / "index.html"
82
+ with open(report_file, "w", encoding="utf-8") as f:
83
+ f.write(rendered_html)
84
+
85
+ from rich import print as rprint
86
+ rprint("\n[bold green]Report generated successfully![/bold green]")
87
+ rprint(f"Web report directory: [cyan]{web_report_dir}[/cyan]")
88
+ rprint(f"Data files directory: [cyan]{config.report_data_dir}[/cyan]\n")
89
+
90
+ rprint("[bold]Ways to view the interactive report:[/bold]")
91
+ rprint("1. [bold white]Remote Server:[/bold white] Run the command below to start a temporary web server:")
92
+ rprint(f" [bold cyan]gsmap report-view {web_report_dir} --port 8080 --no-browser[/bold cyan]")
93
+ rprint(f"\n2. [bold white]Local PC:[/bold white] Copy the [cyan]{web_report_dir.name}[/cyan] folder to your machine and open [cyan]index.html[/cyan].\n")
94
+
95
+ except ImportError:
96
+ logger.error("Jinja2 not found. Please install it with 'pip install jinja2'.")
97
+ except Exception as e:
98
+ logger.error(f"Failed to render report: {e}")
99
+ import traceback
100
+ traceback.print_exc()