gsMap 1.67__py3-none-any.whl → 1.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsMap/diagnosis.py CHANGED
@@ -1,273 +1,273 @@
1
- import logging
2
- import warnings
3
- from pathlib import Path
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import scanpy as sc
8
- from scipy.stats import norm
9
-
10
- from gsMap.config import DiagnosisConfig
11
- from gsMap.utils.manhattan_plot import ManhattanPlot
12
- from gsMap.visualize import draw_scatter, load_st_coord, estimate_point_size_for_plot
13
-
14
-
15
- warnings.filterwarnings("ignore", category=FutureWarning)
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- def convert_z_to_p(gwas_data):
20
- """Convert Z-scores to P-values."""
21
- gwas_data['P'] = norm.sf(abs(gwas_data['Z'])) * 2
22
- min_p_value = 1e-300
23
- gwas_data['P'] = gwas_data['P'].clip(lower=min_p_value)
24
- return gwas_data
25
-
26
-
27
- def load_ldsc(ldsc_input_file):
28
- """Load LDSC data and calculate logp."""
29
- ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
30
- ldsc['spot'] = ldsc['spot'].astype(str).replace('\.0', '', regex=True)
31
- ldsc.set_index('spot', inplace=True)
32
- ldsc['logp'] = -np.log10(ldsc['p'])
33
- return ldsc
34
-
35
-
36
- def load_gene_diagnostic_info(config:DiagnosisConfig):
37
- """Load or compute gene diagnostic info."""
38
- gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
39
- if gene_diagnostic_info_save_path.exists():
40
- logger.info(f'Loading gene diagnostic information from {gene_diagnostic_info_save_path}...')
41
- return pd.read_csv(gene_diagnostic_info_save_path)
42
- else:
43
- logger.info('Gene diagnostic information not found. Calculating gene diagnostic information...')
44
- return compute_gene_diagnostic_info(config)
45
-
46
-
47
- def compute_gene_diagnostic_info(config: DiagnosisConfig):
48
- """Calculate gene diagnostic info and save it to adata."""
49
- logger.info('Loading ST data and LDSC results...')
50
- # adata = sc.read_h5ad(config.hdf5_with_latent_path, backed='r')
51
- mk_score = pd.read_feather(config.mkscore_feather_path)
52
- mk_score.set_index('HUMAN_GENE_SYM', inplace=True)
53
- mk_score = mk_score.T
54
- trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
55
-
56
- # Align marker scores with trait LDSC results
57
- mk_score = mk_score.loc[trait_ldsc_result.index]
58
- mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
59
-
60
- logger.info('Calculating correlation between gene marker scores and trait logp-values...')
61
- corr = mk_score.corrwith(trait_ldsc_result['logp'])
62
- corr.name = 'PCC'
63
-
64
- grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
65
- max_annotations = grouped_mk_score.idxmax()
66
-
67
- high_GSS_Gene_annotation_pair = pd.DataFrame({
68
- 'Gene': max_annotations.index,
69
- 'Annotation': max_annotations.values,
70
- 'Median_GSS': grouped_mk_score.max().values
71
- })
72
-
73
- # Filter based on median GSS score
74
- high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[high_GSS_Gene_annotation_pair['Median_GSS'] >= 1.0]
75
- high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(corr, left_on='Gene', right_index=True)
76
-
77
- # Prepare the final gene diagnostic info dataframe
78
- gene_diagnostic_info_cols = ['Gene', 'Annotation', 'Median_GSS', 'PCC']
79
- gene_diagnostic_info = high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols].drop_duplicates().dropna(
80
- subset=['Gene'])
81
- gene_diagnostic_info.sort_values('PCC', ascending=False, inplace=True)
82
-
83
- # Save gene diagnostic info to a file
84
- gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
85
- gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
86
- logger.info(f'Gene diagnostic information saved to {gene_diagnostic_info_save_path}.')
87
-
88
- # Save to adata.var with the trait_name prefix
89
- logger.info('Saving gene diagnostic info to adata.var...')
90
- gene_diagnostic_info.set_index('Gene', inplace=True) # Use 'Gene' as the index to align with adata.var
91
- adata.var[f'{config.trait_name}_Annotation'] = gene_diagnostic_info['Annotation']
92
- adata.var[f'{config.trait_name}_Median_GSS'] = gene_diagnostic_info['Median_GSS']
93
- adata.var[f'{config.trait_name}_PCC'] = gene_diagnostic_info['PCC']
94
-
95
- # Save trait_ldsc_result to adata.obs
96
- logger.info(f'Saving trait LDSC results to adata.obs as gsMap_{config.trait_name}_p_value...')
97
- adata.obs[f'gsMap_{config.trait_name}_p_value'] = trait_ldsc_result['p']
98
- adata.write(config.hdf5_with_latent_path, )
99
-
100
- return gene_diagnostic_info.reset_index()
101
-
102
-
103
- def load_gwas_data(config:DiagnosisConfig):
104
- """Load and process GWAS data."""
105
- logger.info('Loading and processing GWAS data...')
106
- gwas_data = pd.read_csv(config.sumstats_file, compression='gzip', sep='\t')
107
- return convert_z_to_p(gwas_data)
108
-
109
-
110
- def load_snp_gene_pairs(config:DiagnosisConfig):
111
- """Load SNP-gene pairs from multiple chromosomes."""
112
- ldscore_save_dir = Path(config.ldscore_save_dir)
113
- return pd.concat([
114
- pd.read_feather(ldscore_save_dir / f'SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather')
115
- for chrom in range(1, 23)
116
- ])
117
-
118
-
119
- def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
120
- """Filter the SNPs based on significance levels."""
121
- pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort['P'] < 1e-5
122
- pass_suggestive_line_number = pass_suggestive_line_mask.sum()
123
-
124
- if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
125
- snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
126
- logger.info(f'To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted.')
127
- else:
128
- snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
129
- logger.info(
130
- f'To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted.')
131
-
132
- return snps2plot
133
-
134
-
135
- def generate_manhattan_plot(config: DiagnosisConfig):
136
- """Generate Manhattan plot."""
137
- report_save_dir = config.get_report_dir(config.trait_name)
138
- gwas_data = load_gwas_data(config)
139
- snp_gene_pair = load_snp_gene_pairs(config)
140
- gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on='SNP', how='inner').rename(columns={'gene_name': 'GENE'})
141
- gene_diagnostic_info = load_gene_diagnostic_info(config)
142
- gwas_data_with_gene_annotation = gwas_data_with_gene.merge(gene_diagnostic_info, left_on='GENE', right_on='Gene',
143
- how='left')
144
-
145
- gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
146
- ~gwas_data_with_gene_annotation['Annotation'].isna()]
147
- gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values('P')
148
-
149
- snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
150
- gwas_data_to_plot = gwas_data_with_gene_annotation[
151
- gwas_data_with_gene_annotation['SNP'].isin(snps2plot)].reset_index(drop=True)
152
- gwas_data_to_plot['Annotation_text'] = 'PCC: ' + gwas_data_to_plot['PCC'].round(2).astype(
153
- str) + '<br>' + 'Annotation: ' + gwas_data_to_plot['Annotation'].astype(str)
154
-
155
- fig = ManhattanPlot(
156
- dataframe=gwas_data_to_plot,
157
- title='gsMap Diagnosis Manhattan Plot',
158
- point_size=3,
159
- highlight_gene_list=config.selected_genes or gene_diagnostic_info.Gene.iloc[:config.top_corr_genes].tolist(),
160
- suggestiveline_value=-np.log10(1e-5),
161
- annotation='Annotation_text',
162
- )
163
-
164
- save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
165
- fig.write_html(save_manhattan_plot_path)
166
- logger.info(f'Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.')
167
-
168
-
169
- def generate_GSS_distribution(config: DiagnosisConfig):
170
- """Generate GSS distribution plots."""
171
- # logger.info('Loading ST data...')
172
- # adata = sc.read_h5ad(config.hdf5_with_latent_path)
173
- mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM').T
174
-
175
- plot_genes = config.selected_genes or load_gene_diagnostic_info(config).Gene.iloc[:config.top_corr_genes].tolist()
176
- if config.selected_genes is not None:
177
- logger.info(f'Generating GSS & Expression distribution plot for selected genes: {plot_genes}...')
178
- else:
179
- logger.info(f'Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes...')
180
-
181
- if config.customize_fig:
182
- pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
183
- else:
184
- (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
185
- sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
186
-
187
- # save plot gene list
188
- config.get_GSS_plot_select_gene_file(config.trait_name).write_text('\n'.join(plot_genes))
189
-
190
- for selected_gene in plot_genes:
191
- expression_series = pd.Series(adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index,name='Expression')
192
- threshold = np.quantile(expression_series,0.9999)
193
- expression_series[expression_series > threshold] = threshold
194
- generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width,
195
- pixel_height, sub_fig_save_dir, config.sample_name, config.annotation)
196
-
197
-
198
- def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width, pixel_height,
199
- sub_fig_save_dir, sample_name, annotation):
200
- """Generate and save the plots."""
201
- select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
202
- sub_fig_1 = draw_scatter(select_gene_expression_with_space_coord, title=f'{selected_gene} (Expression)',
203
- annotation='annotation', color_by='Expression', point_size=point_size, width=pixel_width,
204
- height=pixel_height)
205
- save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, 'Expression')
206
-
207
- select_gene_GSS_with_space_coord = load_st_coord(adata, mk_score[selected_gene].rename('GSS'), annotation)
208
- sub_fig_2 = draw_scatter(select_gene_GSS_with_space_coord, title=f'{selected_gene} (GSS)', annotation='annotation',
209
- color_by='GSS', point_size=point_size, width=pixel_width, height=pixel_height)
210
- save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, 'GSS')
211
-
212
- # combined_fig = make_subplots(rows=1, cols=2,
213
- # subplot_titles=(f'{selected_gene} (Expression)', f'{selected_gene} (GSS)'))
214
- # for trace in sub_fig_1.data:
215
- # combined_fig.add_trace(trace, row=1, col=1)
216
- # for trace in sub_fig_2.data:
217
- # combined_fig.add_trace(trace, row=1, col=2)
218
- #
219
-
220
- def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
221
- """Save the plot to HTML and PNG."""
222
- save_sub_fig_path = sub_fig_save_dir / f'{sample_name}_{selected_gene}_{plot_type}_Distribution.html'
223
- # sub_fig.write_html(str(save_sub_fig_path))
224
- sub_fig.update_layout(showlegend=False)
225
- sub_fig.write_image(str(save_sub_fig_path).replace('.html', '.png'))
226
-
227
-
228
- def generate_gsMap_plot(config: DiagnosisConfig):
229
- """Generate gsMap plot."""
230
- logger.info('Creating gsMap plot...')
231
-
232
- trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
233
- space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
234
-
235
- if config.customize_fig:
236
- pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
237
- else:
238
- (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
239
- fig = draw_scatter(space_coord_concat,
240
- title=f'{config.trait_name} (gsMap)',
241
- point_size=point_size,
242
- width=pixel_width,
243
- height=pixel_height,
244
- annotation=config.annotation
245
- )
246
-
247
- output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
248
- output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
249
- output_file_png = output_file_html.with_suffix('.png')
250
- output_file_csv = output_file_html.with_suffix('.csv')
251
-
252
- fig.write_html(output_file_html)
253
- fig.write_image(output_file_png)
254
- space_coord_concat.to_csv(output_file_csv)
255
-
256
- logger.info(f'gsMap plot created and saved in {output_dir}.')
257
-
258
-
259
- def run_Diagnosis(config: DiagnosisConfig):
260
- """Main function to run the diagnostic plot generation."""
261
- global adata
262
- adata = sc.read_h5ad(config.hdf5_with_latent_path)
263
- if 'log1p' not in adata.uns.keys() and adata.X.max() > 14:
264
- sc.pp.normalize_total(adata, target_sum=1e4)
265
- sc.pp.log1p(adata)
266
-
267
- if config.plot_type in ['manhattan', 'all']:
268
- generate_manhattan_plot(config)
269
- if config.plot_type in ['GSS', 'all']:
270
- generate_GSS_distribution(config)
271
- if config.plot_type in ['gsMap', 'all']:
272
- generate_gsMap_plot(config)
273
-
1
+ import logging
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ from scipy.stats import norm
9
+
10
+ from gsMap.config import DiagnosisConfig
11
+ from gsMap.utils.manhattan_plot import ManhattanPlot
12
+ from gsMap.visualize import draw_scatter, load_st_coord, estimate_point_size_for_plot
13
+
14
+
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def convert_z_to_p(gwas_data):
20
+ """Convert Z-scores to P-values."""
21
+ gwas_data['P'] = norm.sf(abs(gwas_data['Z'])) * 2
22
+ min_p_value = 1e-300
23
+ gwas_data['P'] = gwas_data['P'].clip(lower=min_p_value)
24
+ return gwas_data
25
+
26
+
27
+ def load_ldsc(ldsc_input_file):
28
+ """Load LDSC data and calculate logp."""
29
+ ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
30
+ ldsc['spot'] = ldsc['spot'].astype(str).replace('\.0', '', regex=True)
31
+ ldsc.set_index('spot', inplace=True)
32
+ ldsc['logp'] = -np.log10(ldsc['p'])
33
+ return ldsc
34
+
35
+
36
+ def load_gene_diagnostic_info(config:DiagnosisConfig):
37
+ """Load or compute gene diagnostic info."""
38
+ gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
39
+ if gene_diagnostic_info_save_path.exists():
40
+ logger.info(f'Loading gene diagnostic information from {gene_diagnostic_info_save_path}...')
41
+ return pd.read_csv(gene_diagnostic_info_save_path)
42
+ else:
43
+ logger.info('Gene diagnostic information not found. Calculating gene diagnostic information...')
44
+ return compute_gene_diagnostic_info(config)
45
+
46
+
47
+ def compute_gene_diagnostic_info(config: DiagnosisConfig):
48
+ """Calculate gene diagnostic info and save it to adata."""
49
+ logger.info('Loading ST data and LDSC results...')
50
+ # adata = sc.read_h5ad(config.hdf5_with_latent_path, backed='r')
51
+ mk_score = pd.read_feather(config.mkscore_feather_path)
52
+ mk_score.set_index('HUMAN_GENE_SYM', inplace=True)
53
+ mk_score = mk_score.T
54
+ trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
55
+
56
+ # Align marker scores with trait LDSC results
57
+ mk_score = mk_score.loc[trait_ldsc_result.index]
58
+ mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
59
+
60
+ logger.info('Calculating correlation between gene marker scores and trait logp-values...')
61
+ corr = mk_score.corrwith(trait_ldsc_result['logp'])
62
+ corr.name = 'PCC'
63
+
64
+ grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
65
+ max_annotations = grouped_mk_score.idxmax()
66
+
67
+ high_GSS_Gene_annotation_pair = pd.DataFrame({
68
+ 'Gene': max_annotations.index,
69
+ 'Annotation': max_annotations.values,
70
+ 'Median_GSS': grouped_mk_score.max().values
71
+ })
72
+
73
+ # Filter based on median GSS score
74
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[high_GSS_Gene_annotation_pair['Median_GSS'] >= 1.0]
75
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(corr, left_on='Gene', right_index=True)
76
+
77
+ # Prepare the final gene diagnostic info dataframe
78
+ gene_diagnostic_info_cols = ['Gene', 'Annotation', 'Median_GSS', 'PCC']
79
+ gene_diagnostic_info = high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols].drop_duplicates().dropna(
80
+ subset=['Gene'])
81
+ gene_diagnostic_info.sort_values('PCC', ascending=False, inplace=True)
82
+
83
+ # Save gene diagnostic info to a file
84
+ gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
85
+ gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
86
+ logger.info(f'Gene diagnostic information saved to {gene_diagnostic_info_save_path}.')
87
+
88
+ # Save to adata.var with the trait_name prefix
89
+ logger.info('Saving gene diagnostic info to adata.var...')
90
+ gene_diagnostic_info.set_index('Gene', inplace=True) # Use 'Gene' as the index to align with adata.var
91
+ adata.var[f'{config.trait_name}_Annotation'] = gene_diagnostic_info['Annotation']
92
+ adata.var[f'{config.trait_name}_Median_GSS'] = gene_diagnostic_info['Median_GSS']
93
+ adata.var[f'{config.trait_name}_PCC'] = gene_diagnostic_info['PCC']
94
+
95
+ # Save trait_ldsc_result to adata.obs
96
+ logger.info(f'Saving trait LDSC results to adata.obs as gsMap_{config.trait_name}_p_value...')
97
+ adata.obs[f'gsMap_{config.trait_name}_p_value'] = trait_ldsc_result['p']
98
+ adata.write(config.hdf5_with_latent_path, )
99
+
100
+ return gene_diagnostic_info.reset_index()
101
+
102
+
103
+ def load_gwas_data(config:DiagnosisConfig):
104
+ """Load and process GWAS data."""
105
+ logger.info('Loading and processing GWAS data...')
106
+ gwas_data = pd.read_csv(config.sumstats_file, compression='gzip', sep='\t')
107
+ return convert_z_to_p(gwas_data)
108
+
109
+
110
+ def load_snp_gene_pairs(config:DiagnosisConfig):
111
+ """Load SNP-gene pairs from multiple chromosomes."""
112
+ ldscore_save_dir = Path(config.ldscore_save_dir)
113
+ return pd.concat([
114
+ pd.read_feather(ldscore_save_dir / f'SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather')
115
+ for chrom in range(1, 23)
116
+ ])
117
+
118
+
119
+ def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
120
+ """Filter the SNPs based on significance levels."""
121
+ pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort['P'] < 1e-5
122
+ pass_suggestive_line_number = pass_suggestive_line_mask.sum()
123
+
124
+ if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
125
+ snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
126
+ logger.info(f'To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted.')
127
+ else:
128
+ snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
129
+ logger.info(
130
+ f'To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted.')
131
+
132
+ return snps2plot
133
+
134
+
135
+ def generate_manhattan_plot(config: DiagnosisConfig):
136
+ """Generate Manhattan plot."""
137
+ report_save_dir = config.get_report_dir(config.trait_name)
138
+ gwas_data = load_gwas_data(config)
139
+ snp_gene_pair = load_snp_gene_pairs(config)
140
+ gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on='SNP', how='inner').rename(columns={'gene_name': 'GENE'})
141
+ gene_diagnostic_info = load_gene_diagnostic_info(config)
142
+ gwas_data_with_gene_annotation = gwas_data_with_gene.merge(gene_diagnostic_info, left_on='GENE', right_on='Gene',
143
+ how='left')
144
+
145
+ gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
146
+ ~gwas_data_with_gene_annotation['Annotation'].isna()]
147
+ gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values('P')
148
+
149
+ snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
150
+ gwas_data_to_plot = gwas_data_with_gene_annotation[
151
+ gwas_data_with_gene_annotation['SNP'].isin(snps2plot)].reset_index(drop=True)
152
+ gwas_data_to_plot['Annotation_text'] = 'PCC: ' + gwas_data_to_plot['PCC'].round(2).astype(
153
+ str) + '<br>' + 'Annotation: ' + gwas_data_to_plot['Annotation'].astype(str)
154
+
155
+ fig = ManhattanPlot(
156
+ dataframe=gwas_data_to_plot,
157
+ title='gsMap Diagnosis Manhattan Plot',
158
+ point_size=3,
159
+ highlight_gene_list=config.selected_genes or gene_diagnostic_info.Gene.iloc[:config.top_corr_genes].tolist(),
160
+ suggestiveline_value=-np.log10(1e-5),
161
+ annotation='Annotation_text',
162
+ )
163
+
164
+ save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
165
+ fig.write_html(save_manhattan_plot_path)
166
+ logger.info(f'Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.')
167
+
168
+
169
+ def generate_GSS_distribution(config: DiagnosisConfig):
170
+ """Generate GSS distribution plots."""
171
+ # logger.info('Loading ST data...')
172
+ # adata = sc.read_h5ad(config.hdf5_with_latent_path)
173
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM').T
174
+
175
+ plot_genes = config.selected_genes or load_gene_diagnostic_info(config).Gene.iloc[:config.top_corr_genes].tolist()
176
+ if config.selected_genes is not None:
177
+ logger.info(f'Generating GSS & Expression distribution plot for selected genes: {plot_genes}...')
178
+ else:
179
+ logger.info(f'Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes...')
180
+
181
+ if config.customize_fig:
182
+ pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
183
+ else:
184
+ (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
185
+ sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
186
+
187
+ # save plot gene list
188
+ config.get_GSS_plot_select_gene_file(config.trait_name).write_text('\n'.join(plot_genes))
189
+
190
+ for selected_gene in plot_genes:
191
+ expression_series = pd.Series(adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index,name='Expression')
192
+ threshold = np.quantile(expression_series,0.9999)
193
+ expression_series[expression_series > threshold] = threshold
194
+ generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width,
195
+ pixel_height, sub_fig_save_dir, config.sample_name, config.annotation)
196
+
197
+
198
+ def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width, pixel_height,
199
+ sub_fig_save_dir, sample_name, annotation):
200
+ """Generate and save the plots."""
201
+ select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
202
+ sub_fig_1 = draw_scatter(select_gene_expression_with_space_coord, title=f'{selected_gene} (Expression)',
203
+ annotation='annotation', color_by='Expression', point_size=point_size, width=pixel_width,
204
+ height=pixel_height)
205
+ save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, 'Expression')
206
+
207
+ select_gene_GSS_with_space_coord = load_st_coord(adata, mk_score[selected_gene].rename('GSS'), annotation)
208
+ sub_fig_2 = draw_scatter(select_gene_GSS_with_space_coord, title=f'{selected_gene} (GSS)', annotation='annotation',
209
+ color_by='GSS', point_size=point_size, width=pixel_width, height=pixel_height)
210
+ save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, 'GSS')
211
+
212
+ # combined_fig = make_subplots(rows=1, cols=2,
213
+ # subplot_titles=(f'{selected_gene} (Expression)', f'{selected_gene} (GSS)'))
214
+ # for trace in sub_fig_1.data:
215
+ # combined_fig.add_trace(trace, row=1, col=1)
216
+ # for trace in sub_fig_2.data:
217
+ # combined_fig.add_trace(trace, row=1, col=2)
218
+ #
219
+
220
+ def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
221
+ """Save the plot to HTML and PNG."""
222
+ save_sub_fig_path = sub_fig_save_dir / f'{sample_name}_{selected_gene}_{plot_type}_Distribution.html'
223
+ # sub_fig.write_html(str(save_sub_fig_path))
224
+ sub_fig.update_layout(showlegend=False)
225
+ sub_fig.write_image(str(save_sub_fig_path).replace('.html', '.png'))
226
+
227
+
228
+ def generate_gsMap_plot(config: DiagnosisConfig):
229
+ """Generate gsMap plot."""
230
+ logger.info('Creating gsMap plot...')
231
+
232
+ trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
233
+ space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
234
+
235
+ if config.customize_fig:
236
+ pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
237
+ else:
238
+ (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
239
+ fig = draw_scatter(space_coord_concat,
240
+ title=f'{config.trait_name} (gsMap)',
241
+ point_size=point_size,
242
+ width=pixel_width,
243
+ height=pixel_height,
244
+ annotation=config.annotation
245
+ )
246
+
247
+ output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
248
+ output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
249
+ output_file_png = output_file_html.with_suffix('.png')
250
+ output_file_csv = output_file_html.with_suffix('.csv')
251
+
252
+ fig.write_html(output_file_html)
253
+ fig.write_image(output_file_png)
254
+ space_coord_concat.to_csv(output_file_csv)
255
+
256
+ logger.info(f'gsMap plot created and saved in {output_dir}.')
257
+
258
+
259
+ def run_Diagnosis(config: DiagnosisConfig):
260
+ """Main function to run the diagnostic plot generation."""
261
+ global adata
262
+ adata = sc.read_h5ad(config.hdf5_with_latent_path)
263
+ if 'log1p' not in adata.uns.keys() and adata.X.max() > 14:
264
+ sc.pp.normalize_total(adata, target_sum=1e4)
265
+ sc.pp.log1p(adata)
266
+
267
+ if config.plot_type in ['manhattan', 'all']:
268
+ generate_manhattan_plot(config)
269
+ if config.plot_type in ['GSS', 'all']:
270
+ generate_GSS_distribution(config)
271
+ if config.plot_type in ['gsMap', 'all']:
272
+ generate_gsMap_plot(config)
273
+