gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,154 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import anndata
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ import zarr
9
+ from scipy.stats import rankdata
10
+ from tqdm import tqdm
11
+
12
+ from gsMap.config import CreateSliceMeanConfig
13
+
14
+ # %% Helper functions
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def get_common_genes(h5ad_files, config: CreateSliceMeanConfig):
19
+ """
20
+ Get common genes from a list of h5ad files.
21
+ """
22
+ common_genes = None
23
+ for file in tqdm(h5ad_files, desc="Finding common genes"):
24
+ adata = sc.read_h5ad(file)
25
+ adata.var_names_make_unique()
26
+ if common_genes is None:
27
+ common_genes = adata.var_names
28
+ else:
29
+ common_genes = common_genes.intersection(adata.var_names)
30
+ # sort
31
+
32
+ if config.species is not None:
33
+ homologs = pd.read_csv(config.homolog_file, sep="\t")
34
+ if homologs.shape[1] < 2:
35
+ raise ValueError(
36
+ "Homologs file must have at least two columns: one for the species and one for the human gene symbol."
37
+ )
38
+ homologs.columns = [config.species, "HUMAN_GENE_SYM"]
39
+ homologs.set_index(config.species, inplace=True)
40
+ common_genes = np.intersect1d(common_genes, homologs.index)
41
+
42
+ common_genes = sorted(common_genes)
43
+ return common_genes
44
+
45
+
46
+ def calculate_one_slice_mean(
47
+ sample_name, file_path: Path, common_genes, zarr_group_path, data_layer
48
+ ):
49
+ """
50
+ Calculate the geometric mean (using log trick) of gene expressions for a single slice and store it in a Zarr group.
51
+ """
52
+ # file_name = file_path.name
53
+ gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
54
+ adata = anndata.read_h5ad(file_path)
55
+
56
+ if data_layer in adata.layers.keys():
57
+ adata.X = adata.layers[data_layer]
58
+ elif data_layer == "X":
59
+ pass
60
+ else:
61
+ raise ValueError(f"Data layer {data_layer} not found in {file_path}")
62
+
63
+ adata = adata[:, common_genes].copy()
64
+ n_cells = adata.shape[0]
65
+ log_ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
66
+ # Compute log of ranks to avoid overflow when computing geometric mean
67
+ for i in tqdm(range(n_cells), desc=f"Computing log ranks for {sample_name}"):
68
+ data = adata.X[i, :].toarray().flatten()
69
+ ranks = rankdata(data, method="average")
70
+ log_ranks[i, :] = np.log(ranks) # Adding small value to avoid log(0)
71
+
72
+ # Calculate geometric mean via log trick: exp(mean(log(values)))
73
+ gmean = (np.exp(np.mean(log_ranks, axis=0))).reshape(-1, 1)
74
+
75
+ # Calculate the expression fractio
76
+ adata_X_bool = adata.X.astype(bool)
77
+ frac = (np.asarray(adata_X_bool.sum(axis=0)).flatten()).reshape(-1, 1)
78
+
79
+ # Save to zarr group
80
+ gmean_frac = np.concatenate([gmean, frac], axis=1)
81
+ s1_zarr = gmean_zarr_group.array(sample_name, data=gmean_frac, chunks=None, dtype="f4")
82
+ s1_zarr.attrs["spot_number"] = adata.shape[0]
83
+
84
+
85
+ def merge_zarr_means(zarr_group_path, output_file, common_genes):
86
+ """
87
+ Merge all Zarr arrays into a weighted geometric mean and save to a Parquet file.
88
+ Instead of calculating the mean, it sums the logs and applies the exponential.
89
+ """
90
+ gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
91
+ log_sum = None
92
+ frac_sum = None
93
+ total_spot_number = 0
94
+ for key in tqdm(gmean_zarr_group.array_keys(), desc="Merging Zarr arrays"):
95
+ s1 = gmean_zarr_group[key]
96
+ s1_array_gmean = s1[:][:, 0]
97
+ s1_array_frac = s1[:][:, 1]
98
+ n = s1.attrs["spot_number"]
99
+
100
+ if log_sum is None:
101
+ log_sum = np.log(s1_array_gmean) * n
102
+ frac_sum = s1_array_frac
103
+ else:
104
+ log_sum += np.log(s1_array_gmean) * n
105
+ frac_sum += s1_array_frac
106
+
107
+ total_spot_number += n
108
+
109
+ # Apply the geometric mean via exponentiation of the averaged logs
110
+ final_mean = np.exp(log_sum / total_spot_number)
111
+ final_frac = frac_sum / total_spot_number
112
+
113
+ # Save the final mean to a Parquet file
114
+ gene_names = common_genes
115
+ final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_mean, "frac": final_frac})
116
+ final_df.set_index("gene", inplace=True)
117
+ final_df.to_parquet(output_file)
118
+ return final_df
119
+
120
+
121
+ def run_create_slice_mean(config: CreateSliceMeanConfig):
122
+ """
123
+ Main entrypoint to create slice means.
124
+ Now works with a config that can accept either:
125
+ 1. An h5ad_yaml file.
126
+ 2. Direct lists of sample names and h5ad files.
127
+ """
128
+ h5ad_files = list(config.h5ad_dict.values())
129
+
130
+ # Step 2: Get common genes from the h5ad files
131
+ common_genes = get_common_genes(h5ad_files, config)
132
+ logger.info(f"Found {len(common_genes)} common genes across all files.")
133
+
134
+ # Step 3: Initialize the Zarr group
135
+ zarr_group_path = config.slice_mean_output_file.with_suffix(".zarr")
136
+
137
+ for sample_name, h5ad_file in config.h5ad_dict.items():
138
+ # Step 4: Process each file to calculate the slice means
139
+ if zarr_group_path.exists():
140
+ zarr_group = zarr.open(zarr_group_path.as_posix(), mode="r")
141
+ # Check if the slice mean for this file already exists
142
+ if sample_name in zarr_group.array_keys():
143
+ logger.info(f"Skipping {sample_name}, already processed.")
144
+ continue
145
+
146
+ calculate_one_slice_mean(
147
+ sample_name, h5ad_file, common_genes, zarr_group_path, config.data_layer
148
+ )
149
+
150
+ output_file = config.slice_mean_output_file
151
+ final_df = merge_zarr_means(zarr_group_path, output_file, common_genes)
152
+
153
+ logger.info(f"Final slice mean and expression fraction saved to {output_file}")
154
+ return final_df
gsMap/diagnosis.py CHANGED
@@ -9,8 +9,7 @@ from scipy.stats import norm
9
9
 
10
10
  from gsMap.config import DiagnosisConfig
11
11
  from gsMap.utils.manhattan_plot import ManhattanPlot
12
- from gsMap.visualize import draw_scatter, load_st_coord, estimate_point_size_for_plot
13
-
12
+ from gsMap.visualize import draw_scatter, estimate_point_size_for_plot, load_ldsc, load_st_coord
14
13
 
15
14
  warnings.filterwarnings("ignore", category=FutureWarning)
16
15
  logger = logging.getLogger(__name__)
@@ -18,38 +17,33 @@ logger = logging.getLogger(__name__)
18
17
 
19
18
  def convert_z_to_p(gwas_data):
20
19
  """Convert Z-scores to P-values."""
21
- gwas_data['P'] = norm.sf(abs(gwas_data['Z'])) * 2
20
+ gwas_data["P"] = norm.sf(abs(gwas_data["Z"])) * 2
22
21
  min_p_value = 1e-300
23
- gwas_data['P'] = gwas_data['P'].clip(lower=min_p_value)
22
+ gwas_data["P"] = gwas_data["P"].clip(lower=min_p_value)
24
23
  return gwas_data
25
24
 
26
25
 
27
- def load_ldsc(ldsc_input_file):
28
- """Load LDSC data and calculate logp."""
29
- ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
30
- ldsc['spot'] = ldsc['spot'].astype(str).replace('\.0', '', regex=True)
31
- ldsc.set_index('spot', inplace=True)
32
- ldsc['logp'] = -np.log10(ldsc['p'])
33
- return ldsc
34
-
35
-
36
- def load_gene_diagnostic_info(config:DiagnosisConfig):
26
+ def load_gene_diagnostic_info(config: DiagnosisConfig):
37
27
  """Load or compute gene diagnostic info."""
38
28
  gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
39
29
  if gene_diagnostic_info_save_path.exists():
40
- logger.info(f'Loading gene diagnostic information from {gene_diagnostic_info_save_path}...')
30
+ logger.info(
31
+ f"Loading gene diagnostic information from {gene_diagnostic_info_save_path}..."
32
+ )
41
33
  return pd.read_csv(gene_diagnostic_info_save_path)
42
34
  else:
43
- logger.info('Gene diagnostic information not found. Calculating gene diagnostic information...')
35
+ logger.info(
36
+ "Gene diagnostic information not found. Calculating gene diagnostic information..."
37
+ )
44
38
  return compute_gene_diagnostic_info(config)
45
39
 
46
40
 
47
41
  def compute_gene_diagnostic_info(config: DiagnosisConfig):
48
42
  """Calculate gene diagnostic info and save it to adata."""
49
- logger.info('Loading ST data and LDSC results...')
43
+ logger.info("Loading ST data and LDSC results...")
50
44
  # adata = sc.read_h5ad(config.hdf5_with_latent_path, backed='r')
51
45
  mk_score = pd.read_feather(config.mkscore_feather_path)
52
- mk_score.set_index('HUMAN_GENE_SYM', inplace=True)
46
+ mk_score.set_index("HUMAN_GENE_SYM", inplace=True)
53
47
  mk_score = mk_score.T
54
48
  trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
55
49
 
@@ -57,33 +51,42 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
57
51
  mk_score = mk_score.loc[trait_ldsc_result.index]
58
52
  mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
59
53
 
60
- logger.info('Calculating correlation between gene marker scores and trait logp-values...')
61
- corr = mk_score.corrwith(trait_ldsc_result['logp'])
62
- corr.name = 'PCC'
54
+ logger.info("Calculating correlation between gene marker scores and trait logp-values...")
55
+ corr = mk_score.corrwith(trait_ldsc_result["logp"])
56
+ corr.name = "PCC"
63
57
 
64
58
  grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
65
59
  max_annotations = grouped_mk_score.idxmax()
66
60
 
67
- high_GSS_Gene_annotation_pair = pd.DataFrame({
68
- 'Gene': max_annotations.index,
69
- 'Annotation': max_annotations.values,
70
- 'Median_GSS': grouped_mk_score.max().values
71
- })
61
+ high_GSS_Gene_annotation_pair = pd.DataFrame(
62
+ {
63
+ "Gene": max_annotations.index,
64
+ "Annotation": max_annotations.values,
65
+ "Median_GSS": grouped_mk_score.max().values,
66
+ }
67
+ )
72
68
 
73
69
  # Filter based on median GSS score
74
- high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[high_GSS_Gene_annotation_pair['Median_GSS'] >= 1.0]
75
- high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(corr, left_on='Gene', right_index=True)
70
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[
71
+ high_GSS_Gene_annotation_pair["Median_GSS"] >= 1.0
72
+ ]
73
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(
74
+ corr, left_on="Gene", right_index=True
75
+ )
76
76
 
77
77
  # Prepare the final gene diagnostic info dataframe
78
- gene_diagnostic_info_cols = ['Gene', 'Annotation', 'Median_GSS', 'PCC']
79
- gene_diagnostic_info = high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols].drop_duplicates().dropna(
80
- subset=['Gene'])
81
- gene_diagnostic_info.sort_values('PCC', ascending=False, inplace=True)
78
+ gene_diagnostic_info_cols = ["Gene", "Annotation", "Median_GSS", "PCC"]
79
+ gene_diagnostic_info = (
80
+ high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols]
81
+ .drop_duplicates()
82
+ .dropna(subset=["Gene"])
83
+ )
84
+ gene_diagnostic_info.sort_values("PCC", ascending=False, inplace=True)
82
85
 
83
86
  # Save gene diagnostic info to a file
84
87
  gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
85
88
  gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
86
- logger.info(f'Gene diagnostic information saved to {gene_diagnostic_info_save_path}.')
89
+ logger.info(f"Gene diagnostic information saved to {gene_diagnostic_info_save_path}.")
87
90
 
88
91
  # TODO: A new script is needed to save the gene diagnostic info to adata.var and trait_ldsc_result to adata.obs when running multiple traits
89
92
  # # Save to adata.var with the trait_name prefix
@@ -101,114 +104,180 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
101
104
  return gene_diagnostic_info.reset_index()
102
105
 
103
106
 
104
- def load_gwas_data(config:DiagnosisConfig):
107
+ def load_gwas_data(config: DiagnosisConfig):
105
108
  """Load and process GWAS data."""
106
- logger.info('Loading and processing GWAS data...')
107
- gwas_data = pd.read_csv(config.sumstats_file, compression='gzip', sep='\t')
109
+ logger.info("Loading and processing GWAS data...")
110
+ gwas_data = pd.read_csv(config.sumstats_file, compression="gzip", sep="\t")
108
111
  return convert_z_to_p(gwas_data)
109
112
 
110
113
 
111
- def load_snp_gene_pairs(config:DiagnosisConfig):
114
+ def load_snp_gene_pairs(config: DiagnosisConfig):
112
115
  """Load SNP-gene pairs from multiple chromosomes."""
113
116
  ldscore_save_dir = Path(config.ldscore_save_dir)
114
- return pd.concat([
115
- pd.read_feather(ldscore_save_dir / f'SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather')
116
- for chrom in range(1, 23)
117
- ])
117
+ return pd.concat(
118
+ [
119
+ pd.read_feather(ldscore_save_dir / f"SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather")
120
+ for chrom in range(1, 23)
121
+ ]
122
+ )
118
123
 
119
124
 
120
125
  def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
121
126
  """Filter the SNPs based on significance levels."""
122
- pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort['P'] < 1e-5
127
+ pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort["P"] < 1e-5
123
128
  pass_suggestive_line_number = pass_suggestive_line_mask.sum()
124
129
 
125
130
  if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
126
131
  snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
127
- logger.info(f'To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted.')
132
+ logger.info(
133
+ f"To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted."
134
+ )
128
135
  else:
129
136
  snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
130
137
  logger.info(
131
- f'To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted.')
138
+ f"To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted."
139
+ )
132
140
 
133
141
  return snps2plot
134
142
 
135
143
 
136
144
  def generate_manhattan_plot(config: DiagnosisConfig):
137
145
  """Generate Manhattan plot."""
138
- report_save_dir = config.get_report_dir(config.trait_name)
146
+ # report_save_dir = config.get_report_dir(config.trait_name)
139
147
  gwas_data = load_gwas_data(config)
140
148
  snp_gene_pair = load_snp_gene_pairs(config)
141
- gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on='SNP', how='inner').rename(columns={'gene_name': 'GENE'})
149
+ gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on="SNP", how="inner").rename(
150
+ columns={"gene_name": "GENE"}
151
+ )
142
152
  gene_diagnostic_info = load_gene_diagnostic_info(config)
143
- gwas_data_with_gene_annotation = gwas_data_with_gene.merge(gene_diagnostic_info, left_on='GENE', right_on='Gene',
144
- how='left')
153
+ gwas_data_with_gene_annotation = gwas_data_with_gene.merge(
154
+ gene_diagnostic_info, left_on="GENE", right_on="Gene", how="left"
155
+ )
145
156
 
146
157
  gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
147
- ~gwas_data_with_gene_annotation['Annotation'].isna()]
148
- gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values('P')
158
+ ~gwas_data_with_gene_annotation["Annotation"].isna()
159
+ ]
160
+ gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values("P")
149
161
 
150
162
  snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
151
163
  gwas_data_to_plot = gwas_data_with_gene_annotation[
152
- gwas_data_with_gene_annotation['SNP'].isin(snps2plot)].reset_index(drop=True)
153
- gwas_data_to_plot['Annotation_text'] = 'PCC: ' + gwas_data_to_plot['PCC'].round(2).astype(
154
- str) + '<br>' + 'Annotation: ' + gwas_data_to_plot['Annotation'].astype(str)
164
+ gwas_data_with_gene_annotation["SNP"].isin(snps2plot)
165
+ ].reset_index(drop=True)
166
+ gwas_data_to_plot["Annotation_text"] = (
167
+ "PCC: "
168
+ + gwas_data_to_plot["PCC"].round(2).astype(str)
169
+ + "<br>"
170
+ + "Annotation: "
171
+ + gwas_data_to_plot["Annotation"].astype(str)
172
+ )
155
173
 
156
174
  fig = ManhattanPlot(
157
175
  dataframe=gwas_data_to_plot,
158
- title='gsMap Diagnosis Manhattan Plot',
176
+ title="gsMap Diagnosis Manhattan Plot",
159
177
  point_size=3,
160
- highlight_gene_list=config.selected_genes or gene_diagnostic_info.Gene.iloc[:config.top_corr_genes].tolist(),
178
+ highlight_gene_list=config.selected_genes
179
+ or gene_diagnostic_info.Gene.iloc[: config.top_corr_genes].tolist(),
161
180
  suggestiveline_value=-np.log10(1e-5),
162
- annotation='Annotation_text',
181
+ annotation="Annotation_text",
163
182
  )
164
183
 
165
184
  save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
166
185
  fig.write_html(save_manhattan_plot_path)
167
- logger.info(f'Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.')
186
+ logger.info(f"Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.")
168
187
 
169
188
 
170
189
  def generate_GSS_distribution(config: DiagnosisConfig):
171
190
  """Generate GSS distribution plots."""
172
191
  # logger.info('Loading ST data...')
173
192
  # adata = sc.read_h5ad(config.hdf5_with_latent_path)
174
- mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM').T
193
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM").T
175
194
 
176
- plot_genes = config.selected_genes or load_gene_diagnostic_info(config).Gene.iloc[:config.top_corr_genes].tolist()
195
+ plot_genes = (
196
+ config.selected_genes
197
+ or load_gene_diagnostic_info(config).Gene.iloc[: config.top_corr_genes].tolist()
198
+ )
177
199
  if config.selected_genes is not None:
178
- logger.info(f'Generating GSS & Expression distribution plot for selected genes: {plot_genes}...')
200
+ logger.info(
201
+ f"Generating GSS & Expression distribution plot for selected genes: {plot_genes}..."
202
+ )
179
203
  else:
180
- logger.info(f'Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes...')
204
+ logger.info(
205
+ f"Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes..."
206
+ )
181
207
 
182
208
  if config.customize_fig:
183
- pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
209
+ pixel_width, pixel_height, point_size = (
210
+ config.fig_width,
211
+ config.fig_height,
212
+ config.point_size,
213
+ )
184
214
  else:
185
- (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
215
+ (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(
216
+ adata.obsm["spatial"]
217
+ )
186
218
  sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
187
219
 
188
220
  # save plot gene list
189
- config.get_GSS_plot_select_gene_file(config.trait_name).write_text('\n'.join(plot_genes))
221
+ config.get_GSS_plot_select_gene_file(config.trait_name).write_text("\n".join(plot_genes))
190
222
 
191
223
  for selected_gene in plot_genes:
192
- expression_series = pd.Series(adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index,name='Expression')
193
- threshold = np.quantile(expression_series,0.9999)
224
+ expression_series = pd.Series(
225
+ adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index, name="Expression"
226
+ )
227
+ threshold = np.quantile(expression_series, 0.9999)
194
228
  expression_series[expression_series > threshold] = threshold
195
- generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width,
196
- pixel_height, sub_fig_save_dir, config.sample_name, config.annotation)
197
-
198
-
199
- def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width, pixel_height,
200
- sub_fig_save_dir, sample_name, annotation):
229
+ generate_and_save_plots(
230
+ adata,
231
+ mk_score,
232
+ expression_series,
233
+ selected_gene,
234
+ point_size,
235
+ pixel_width,
236
+ pixel_height,
237
+ sub_fig_save_dir,
238
+ config.sample_name,
239
+ config.annotation,
240
+ )
241
+
242
+
243
+ def generate_and_save_plots(
244
+ adata,
245
+ mk_score,
246
+ expression_series,
247
+ selected_gene,
248
+ point_size,
249
+ pixel_width,
250
+ pixel_height,
251
+ sub_fig_save_dir,
252
+ sample_name,
253
+ annotation,
254
+ ):
201
255
  """Generate and save the plots."""
202
256
  select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
203
- sub_fig_1 = draw_scatter(select_gene_expression_with_space_coord, title=f'{selected_gene} (Expression)',
204
- annotation='annotation', color_by='Expression', point_size=point_size, width=pixel_width,
205
- height=pixel_height)
206
- save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, 'Expression')
257
+ sub_fig_1 = draw_scatter(
258
+ select_gene_expression_with_space_coord,
259
+ title=f"{selected_gene} (Expression)",
260
+ annotation="annotation",
261
+ color_by="Expression",
262
+ point_size=point_size,
263
+ width=pixel_width,
264
+ height=pixel_height,
265
+ )
266
+ save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, "Expression")
207
267
 
208
- select_gene_GSS_with_space_coord = load_st_coord(adata, mk_score[selected_gene].rename('GSS'), annotation)
209
- sub_fig_2 = draw_scatter(select_gene_GSS_with_space_coord, title=f'{selected_gene} (GSS)', annotation='annotation',
210
- color_by='GSS', point_size=point_size, width=pixel_width, height=pixel_height)
211
- save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, 'GSS')
268
+ select_gene_GSS_with_space_coord = load_st_coord(
269
+ adata, mk_score[selected_gene].rename("GSS"), annotation
270
+ )
271
+ sub_fig_2 = draw_scatter(
272
+ select_gene_GSS_with_space_coord,
273
+ title=f"{selected_gene} (GSS)",
274
+ annotation="annotation",
275
+ color_by="GSS",
276
+ point_size=point_size,
277
+ width=pixel_width,
278
+ height=pixel_height,
279
+ )
280
+ save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, "GSS")
212
281
 
213
282
  # combined_fig = make_subplots(rows=1, cols=2,
214
283
  # subplot_titles=(f'{selected_gene} (Expression)', f'{selected_gene} (GSS)'))
@@ -218,57 +287,66 @@ def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, p
218
287
  # combined_fig.add_trace(trace, row=1, col=2)
219
288
  #
220
289
 
290
+
221
291
  def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
222
292
  """Save the plot to HTML and PNG."""
223
- save_sub_fig_path = sub_fig_save_dir / f'{sample_name}_{selected_gene}_{plot_type}_Distribution.html'
293
+ save_sub_fig_path = (
294
+ sub_fig_save_dir / f"{sample_name}_{selected_gene}_{plot_type}_Distribution.html"
295
+ )
224
296
  # sub_fig.write_html(str(save_sub_fig_path))
225
297
  sub_fig.update_layout(showlegend=False)
226
- sub_fig.write_image(str(save_sub_fig_path).replace('.html', '.png'))
298
+ sub_fig.write_image(str(save_sub_fig_path).replace(".html", ".png"))
227
299
 
228
300
 
229
301
  def generate_gsMap_plot(config: DiagnosisConfig):
230
302
  """Generate gsMap plot."""
231
- logger.info('Creating gsMap plot...')
303
+ logger.info("Creating gsMap plot...")
232
304
 
233
305
  trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
234
306
  space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
235
307
 
236
308
  if config.customize_fig:
237
- pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
309
+ pixel_width, pixel_height, point_size = (
310
+ config.fig_width,
311
+ config.fig_height,
312
+ config.point_size,
313
+ )
238
314
  else:
239
- (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
240
- fig = draw_scatter(space_coord_concat,
241
- title=f'{config.trait_name} (gsMap)',
242
- point_size=point_size,
243
- width=pixel_width,
244
- height=pixel_height,
245
- annotation=config.annotation
246
- )
315
+ (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(
316
+ adata.obsm["spatial"]
317
+ )
318
+ fig = draw_scatter(
319
+ space_coord_concat,
320
+ title=f"{config.trait_name} (gsMap)",
321
+ point_size=point_size,
322
+ width=pixel_width,
323
+ height=pixel_height,
324
+ annotation=config.annotation,
325
+ )
247
326
 
248
327
  output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
249
328
  output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
250
- output_file_png = output_file_html.with_suffix('.png')
251
- output_file_csv = output_file_html.with_suffix('.csv')
329
+ output_file_png = output_file_html.with_suffix(".png")
330
+ output_file_csv = output_file_html.with_suffix(".csv")
252
331
 
253
332
  fig.write_html(output_file_html)
254
333
  fig.write_image(output_file_png)
255
334
  space_coord_concat.to_csv(output_file_csv)
256
335
 
257
- logger.info(f'gsMap plot created and saved in {output_dir}.')
336
+ logger.info(f"gsMap plot created and saved in {output_dir}.")
258
337
 
259
338
 
260
339
  def run_Diagnosis(config: DiagnosisConfig):
261
340
  """Main function to run the diagnostic plot generation."""
262
341
  global adata
263
342
  adata = sc.read_h5ad(config.hdf5_with_latent_path)
264
- if 'log1p' not in adata.uns.keys() and adata.X.max() > 14:
343
+ if "log1p" not in adata.uns.keys() and adata.X.max() > 14:
265
344
  sc.pp.normalize_total(adata, target_sum=1e4)
266
345
  sc.pp.log1p(adata)
267
346
 
268
- if config.plot_type in ['manhattan', 'all']:
347
+ if config.plot_type in ["gsMap", "all"]:
348
+ generate_gsMap_plot(config)
349
+ if config.plot_type in ["manhattan", "all"]:
269
350
  generate_manhattan_plot(config)
270
- if config.plot_type in ['GSS', 'all']:
351
+ if config.plot_type in ["GSS", "all"]:
271
352
  generate_GSS_distribution(config)
272
- if config.plot_type in ['gsMap', 'all']:
273
- generate_gsMap_plot(config)
274
-