gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/adjacency_matrix.py +25 -27
- gsMap/GNN/model.py +9 -7
- gsMap/GNN/train.py +8 -11
- gsMap/__init__.py +3 -3
- gsMap/__main__.py +3 -2
- gsMap/cauchy_combination_test.py +75 -72
- gsMap/config.py +822 -316
- gsMap/create_slice_mean.py +154 -0
- gsMap/diagnosis.py +179 -101
- gsMap/find_latent_representation.py +28 -26
- gsMap/format_sumstats.py +233 -201
- gsMap/generate_ldscore.py +353 -209
- gsMap/latent_to_gene.py +92 -60
- gsMap/main.py +23 -14
- gsMap/report.py +39 -25
- gsMap/run_all_mode.py +86 -46
- gsMap/setup.py +1 -1
- gsMap/spatial_ldsc_multiple_sumstats.py +154 -80
- gsMap/utils/generate_r2_matrix.py +173 -140
- gsMap/utils/jackknife.py +84 -80
- gsMap/utils/manhattan_plot.py +180 -207
- gsMap/utils/regression_read.py +105 -122
- gsMap/visualize.py +82 -64
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/METADATA +21 -6
- gsmap-1.72.3.dist-info/RECORD +31 -0
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/WHEEL +1 -1
- gsMap/utils/make_annotations.py +0 -518
- gsmap-1.71.2.dist-info/RECORD +0 -31
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/LICENSE +0 -0
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,154 @@
|
|
1
|
+
import logging
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import anndata
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import scanpy as sc
|
8
|
+
import zarr
|
9
|
+
from scipy.stats import rankdata
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
from gsMap.config import CreateSliceMeanConfig
|
13
|
+
|
14
|
+
# %% Helper functions
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def get_common_genes(h5ad_files, config: CreateSliceMeanConfig):
|
19
|
+
"""
|
20
|
+
Get common genes from a list of h5ad files.
|
21
|
+
"""
|
22
|
+
common_genes = None
|
23
|
+
for file in tqdm(h5ad_files, desc="Finding common genes"):
|
24
|
+
adata = sc.read_h5ad(file)
|
25
|
+
adata.var_names_make_unique()
|
26
|
+
if common_genes is None:
|
27
|
+
common_genes = adata.var_names
|
28
|
+
else:
|
29
|
+
common_genes = common_genes.intersection(adata.var_names)
|
30
|
+
# sort
|
31
|
+
|
32
|
+
if config.species is not None:
|
33
|
+
homologs = pd.read_csv(config.homolog_file, sep="\t")
|
34
|
+
if homologs.shape[1] < 2:
|
35
|
+
raise ValueError(
|
36
|
+
"Homologs file must have at least two columns: one for the species and one for the human gene symbol."
|
37
|
+
)
|
38
|
+
homologs.columns = [config.species, "HUMAN_GENE_SYM"]
|
39
|
+
homologs.set_index(config.species, inplace=True)
|
40
|
+
common_genes = np.intersect1d(common_genes, homologs.index)
|
41
|
+
|
42
|
+
common_genes = sorted(common_genes)
|
43
|
+
return common_genes
|
44
|
+
|
45
|
+
|
46
|
+
def calculate_one_slice_mean(
|
47
|
+
sample_name, file_path: Path, common_genes, zarr_group_path, data_layer
|
48
|
+
):
|
49
|
+
"""
|
50
|
+
Calculate the geometric mean (using log trick) of gene expressions for a single slice and store it in a Zarr group.
|
51
|
+
"""
|
52
|
+
# file_name = file_path.name
|
53
|
+
gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
|
54
|
+
adata = anndata.read_h5ad(file_path)
|
55
|
+
|
56
|
+
if data_layer in adata.layers.keys():
|
57
|
+
adata.X = adata.layers[data_layer]
|
58
|
+
elif data_layer == "X":
|
59
|
+
pass
|
60
|
+
else:
|
61
|
+
raise ValueError(f"Data layer {data_layer} not found in {file_path}")
|
62
|
+
|
63
|
+
adata = adata[:, common_genes].copy()
|
64
|
+
n_cells = adata.shape[0]
|
65
|
+
log_ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
|
66
|
+
# Compute log of ranks to avoid overflow when computing geometric mean
|
67
|
+
for i in tqdm(range(n_cells), desc=f"Computing log ranks for {sample_name}"):
|
68
|
+
data = adata.X[i, :].toarray().flatten()
|
69
|
+
ranks = rankdata(data, method="average")
|
70
|
+
log_ranks[i, :] = np.log(ranks) # Adding small value to avoid log(0)
|
71
|
+
|
72
|
+
# Calculate geometric mean via log trick: exp(mean(log(values)))
|
73
|
+
gmean = (np.exp(np.mean(log_ranks, axis=0))).reshape(-1, 1)
|
74
|
+
|
75
|
+
# Calculate the expression fractio
|
76
|
+
adata_X_bool = adata.X.astype(bool)
|
77
|
+
frac = (np.asarray(adata_X_bool.sum(axis=0)).flatten()).reshape(-1, 1)
|
78
|
+
|
79
|
+
# Save to zarr group
|
80
|
+
gmean_frac = np.concatenate([gmean, frac], axis=1)
|
81
|
+
s1_zarr = gmean_zarr_group.array(sample_name, data=gmean_frac, chunks=None, dtype="f4")
|
82
|
+
s1_zarr.attrs["spot_number"] = adata.shape[0]
|
83
|
+
|
84
|
+
|
85
|
+
def merge_zarr_means(zarr_group_path, output_file, common_genes):
|
86
|
+
"""
|
87
|
+
Merge all Zarr arrays into a weighted geometric mean and save to a Parquet file.
|
88
|
+
Instead of calculating the mean, it sums the logs and applies the exponential.
|
89
|
+
"""
|
90
|
+
gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
|
91
|
+
log_sum = None
|
92
|
+
frac_sum = None
|
93
|
+
total_spot_number = 0
|
94
|
+
for key in tqdm(gmean_zarr_group.array_keys(), desc="Merging Zarr arrays"):
|
95
|
+
s1 = gmean_zarr_group[key]
|
96
|
+
s1_array_gmean = s1[:][:, 0]
|
97
|
+
s1_array_frac = s1[:][:, 1]
|
98
|
+
n = s1.attrs["spot_number"]
|
99
|
+
|
100
|
+
if log_sum is None:
|
101
|
+
log_sum = np.log(s1_array_gmean) * n
|
102
|
+
frac_sum = s1_array_frac
|
103
|
+
else:
|
104
|
+
log_sum += np.log(s1_array_gmean) * n
|
105
|
+
frac_sum += s1_array_frac
|
106
|
+
|
107
|
+
total_spot_number += n
|
108
|
+
|
109
|
+
# Apply the geometric mean via exponentiation of the averaged logs
|
110
|
+
final_mean = np.exp(log_sum / total_spot_number)
|
111
|
+
final_frac = frac_sum / total_spot_number
|
112
|
+
|
113
|
+
# Save the final mean to a Parquet file
|
114
|
+
gene_names = common_genes
|
115
|
+
final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_mean, "frac": final_frac})
|
116
|
+
final_df.set_index("gene", inplace=True)
|
117
|
+
final_df.to_parquet(output_file)
|
118
|
+
return final_df
|
119
|
+
|
120
|
+
|
121
|
+
def run_create_slice_mean(config: CreateSliceMeanConfig):
|
122
|
+
"""
|
123
|
+
Main entrypoint to create slice means.
|
124
|
+
Now works with a config that can accept either:
|
125
|
+
1. An h5ad_yaml file.
|
126
|
+
2. Direct lists of sample names and h5ad files.
|
127
|
+
"""
|
128
|
+
h5ad_files = list(config.h5ad_dict.values())
|
129
|
+
|
130
|
+
# Step 2: Get common genes from the h5ad files
|
131
|
+
common_genes = get_common_genes(h5ad_files, config)
|
132
|
+
logger.info(f"Found {len(common_genes)} common genes across all files.")
|
133
|
+
|
134
|
+
# Step 3: Initialize the Zarr group
|
135
|
+
zarr_group_path = config.slice_mean_output_file.with_suffix(".zarr")
|
136
|
+
|
137
|
+
for sample_name, h5ad_file in config.h5ad_dict.items():
|
138
|
+
# Step 4: Process each file to calculate the slice means
|
139
|
+
if zarr_group_path.exists():
|
140
|
+
zarr_group = zarr.open(zarr_group_path.as_posix(), mode="r")
|
141
|
+
# Check if the slice mean for this file already exists
|
142
|
+
if sample_name in zarr_group.array_keys():
|
143
|
+
logger.info(f"Skipping {sample_name}, already processed.")
|
144
|
+
continue
|
145
|
+
|
146
|
+
calculate_one_slice_mean(
|
147
|
+
sample_name, h5ad_file, common_genes, zarr_group_path, config.data_layer
|
148
|
+
)
|
149
|
+
|
150
|
+
output_file = config.slice_mean_output_file
|
151
|
+
final_df = merge_zarr_means(zarr_group_path, output_file, common_genes)
|
152
|
+
|
153
|
+
logger.info(f"Final slice mean and expression fraction saved to {output_file}")
|
154
|
+
return final_df
|
gsMap/diagnosis.py
CHANGED
@@ -9,8 +9,7 @@ from scipy.stats import norm
|
|
9
9
|
|
10
10
|
from gsMap.config import DiagnosisConfig
|
11
11
|
from gsMap.utils.manhattan_plot import ManhattanPlot
|
12
|
-
from gsMap.visualize import draw_scatter,
|
13
|
-
|
12
|
+
from gsMap.visualize import draw_scatter, estimate_point_size_for_plot, load_ldsc, load_st_coord
|
14
13
|
|
15
14
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
16
15
|
logger = logging.getLogger(__name__)
|
@@ -18,38 +17,33 @@ logger = logging.getLogger(__name__)
|
|
18
17
|
|
19
18
|
def convert_z_to_p(gwas_data):
|
20
19
|
"""Convert Z-scores to P-values."""
|
21
|
-
gwas_data[
|
20
|
+
gwas_data["P"] = norm.sf(abs(gwas_data["Z"])) * 2
|
22
21
|
min_p_value = 1e-300
|
23
|
-
gwas_data[
|
22
|
+
gwas_data["P"] = gwas_data["P"].clip(lower=min_p_value)
|
24
23
|
return gwas_data
|
25
24
|
|
26
25
|
|
27
|
-
def
|
28
|
-
"""Load LDSC data and calculate logp."""
|
29
|
-
ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
|
30
|
-
ldsc['spot'] = ldsc['spot'].astype(str).replace('\.0', '', regex=True)
|
31
|
-
ldsc.set_index('spot', inplace=True)
|
32
|
-
ldsc['logp'] = -np.log10(ldsc['p'])
|
33
|
-
return ldsc
|
34
|
-
|
35
|
-
|
36
|
-
def load_gene_diagnostic_info(config:DiagnosisConfig):
|
26
|
+
def load_gene_diagnostic_info(config: DiagnosisConfig):
|
37
27
|
"""Load or compute gene diagnostic info."""
|
38
28
|
gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
|
39
29
|
if gene_diagnostic_info_save_path.exists():
|
40
|
-
logger.info(
|
30
|
+
logger.info(
|
31
|
+
f"Loading gene diagnostic information from {gene_diagnostic_info_save_path}..."
|
32
|
+
)
|
41
33
|
return pd.read_csv(gene_diagnostic_info_save_path)
|
42
34
|
else:
|
43
|
-
logger.info(
|
35
|
+
logger.info(
|
36
|
+
"Gene diagnostic information not found. Calculating gene diagnostic information..."
|
37
|
+
)
|
44
38
|
return compute_gene_diagnostic_info(config)
|
45
39
|
|
46
40
|
|
47
41
|
def compute_gene_diagnostic_info(config: DiagnosisConfig):
|
48
42
|
"""Calculate gene diagnostic info and save it to adata."""
|
49
|
-
logger.info(
|
43
|
+
logger.info("Loading ST data and LDSC results...")
|
50
44
|
# adata = sc.read_h5ad(config.hdf5_with_latent_path, backed='r')
|
51
45
|
mk_score = pd.read_feather(config.mkscore_feather_path)
|
52
|
-
mk_score.set_index(
|
46
|
+
mk_score.set_index("HUMAN_GENE_SYM", inplace=True)
|
53
47
|
mk_score = mk_score.T
|
54
48
|
trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
|
55
49
|
|
@@ -57,33 +51,42 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
|
|
57
51
|
mk_score = mk_score.loc[trait_ldsc_result.index]
|
58
52
|
mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
|
59
53
|
|
60
|
-
logger.info(
|
61
|
-
corr = mk_score.corrwith(trait_ldsc_result[
|
62
|
-
corr.name =
|
54
|
+
logger.info("Calculating correlation between gene marker scores and trait logp-values...")
|
55
|
+
corr = mk_score.corrwith(trait_ldsc_result["logp"])
|
56
|
+
corr.name = "PCC"
|
63
57
|
|
64
58
|
grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
|
65
59
|
max_annotations = grouped_mk_score.idxmax()
|
66
60
|
|
67
|
-
high_GSS_Gene_annotation_pair = pd.DataFrame(
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
61
|
+
high_GSS_Gene_annotation_pair = pd.DataFrame(
|
62
|
+
{
|
63
|
+
"Gene": max_annotations.index,
|
64
|
+
"Annotation": max_annotations.values,
|
65
|
+
"Median_GSS": grouped_mk_score.max().values,
|
66
|
+
}
|
67
|
+
)
|
72
68
|
|
73
69
|
# Filter based on median GSS score
|
74
|
-
high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[
|
75
|
-
|
70
|
+
high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[
|
71
|
+
high_GSS_Gene_annotation_pair["Median_GSS"] >= 1.0
|
72
|
+
]
|
73
|
+
high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(
|
74
|
+
corr, left_on="Gene", right_index=True
|
75
|
+
)
|
76
76
|
|
77
77
|
# Prepare the final gene diagnostic info dataframe
|
78
|
-
gene_diagnostic_info_cols = [
|
79
|
-
gene_diagnostic_info =
|
80
|
-
|
81
|
-
|
78
|
+
gene_diagnostic_info_cols = ["Gene", "Annotation", "Median_GSS", "PCC"]
|
79
|
+
gene_diagnostic_info = (
|
80
|
+
high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols]
|
81
|
+
.drop_duplicates()
|
82
|
+
.dropna(subset=["Gene"])
|
83
|
+
)
|
84
|
+
gene_diagnostic_info.sort_values("PCC", ascending=False, inplace=True)
|
82
85
|
|
83
86
|
# Save gene diagnostic info to a file
|
84
87
|
gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
|
85
88
|
gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
|
86
|
-
logger.info(f
|
89
|
+
logger.info(f"Gene diagnostic information saved to {gene_diagnostic_info_save_path}.")
|
87
90
|
|
88
91
|
# TODO: A new script is needed to save the gene diagnostic info to adata.var and trait_ldsc_result to adata.obs when running multiple traits
|
89
92
|
# # Save to adata.var with the trait_name prefix
|
@@ -101,114 +104,180 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
|
|
101
104
|
return gene_diagnostic_info.reset_index()
|
102
105
|
|
103
106
|
|
104
|
-
def load_gwas_data(config:DiagnosisConfig):
|
107
|
+
def load_gwas_data(config: DiagnosisConfig):
|
105
108
|
"""Load and process GWAS data."""
|
106
|
-
logger.info(
|
107
|
-
gwas_data = pd.read_csv(config.sumstats_file, compression=
|
109
|
+
logger.info("Loading and processing GWAS data...")
|
110
|
+
gwas_data = pd.read_csv(config.sumstats_file, compression="gzip", sep="\t")
|
108
111
|
return convert_z_to_p(gwas_data)
|
109
112
|
|
110
113
|
|
111
|
-
def load_snp_gene_pairs(config:DiagnosisConfig):
|
114
|
+
def load_snp_gene_pairs(config: DiagnosisConfig):
|
112
115
|
"""Load SNP-gene pairs from multiple chromosomes."""
|
113
116
|
ldscore_save_dir = Path(config.ldscore_save_dir)
|
114
|
-
return pd.concat(
|
115
|
-
|
116
|
-
|
117
|
-
|
117
|
+
return pd.concat(
|
118
|
+
[
|
119
|
+
pd.read_feather(ldscore_save_dir / f"SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather")
|
120
|
+
for chrom in range(1, 23)
|
121
|
+
]
|
122
|
+
)
|
118
123
|
|
119
124
|
|
120
125
|
def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
|
121
126
|
"""Filter the SNPs based on significance levels."""
|
122
|
-
pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort[
|
127
|
+
pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort["P"] < 1e-5
|
123
128
|
pass_suggestive_line_number = pass_suggestive_line_mask.sum()
|
124
129
|
|
125
130
|
if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
|
126
131
|
snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
|
127
|
-
logger.info(
|
132
|
+
logger.info(
|
133
|
+
f"To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted."
|
134
|
+
)
|
128
135
|
else:
|
129
136
|
snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
|
130
137
|
logger.info(
|
131
|
-
f
|
138
|
+
f"To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted."
|
139
|
+
)
|
132
140
|
|
133
141
|
return snps2plot
|
134
142
|
|
135
143
|
|
136
144
|
def generate_manhattan_plot(config: DiagnosisConfig):
|
137
145
|
"""Generate Manhattan plot."""
|
138
|
-
report_save_dir = config.get_report_dir(config.trait_name)
|
146
|
+
# report_save_dir = config.get_report_dir(config.trait_name)
|
139
147
|
gwas_data = load_gwas_data(config)
|
140
148
|
snp_gene_pair = load_snp_gene_pairs(config)
|
141
|
-
gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on=
|
149
|
+
gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on="SNP", how="inner").rename(
|
150
|
+
columns={"gene_name": "GENE"}
|
151
|
+
)
|
142
152
|
gene_diagnostic_info = load_gene_diagnostic_info(config)
|
143
|
-
gwas_data_with_gene_annotation = gwas_data_with_gene.merge(
|
144
|
-
|
153
|
+
gwas_data_with_gene_annotation = gwas_data_with_gene.merge(
|
154
|
+
gene_diagnostic_info, left_on="GENE", right_on="Gene", how="left"
|
155
|
+
)
|
145
156
|
|
146
157
|
gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
|
147
|
-
~gwas_data_with_gene_annotation[
|
148
|
-
|
158
|
+
~gwas_data_with_gene_annotation["Annotation"].isna()
|
159
|
+
]
|
160
|
+
gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values("P")
|
149
161
|
|
150
162
|
snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
|
151
163
|
gwas_data_to_plot = gwas_data_with_gene_annotation[
|
152
|
-
gwas_data_with_gene_annotation[
|
153
|
-
|
154
|
-
|
164
|
+
gwas_data_with_gene_annotation["SNP"].isin(snps2plot)
|
165
|
+
].reset_index(drop=True)
|
166
|
+
gwas_data_to_plot["Annotation_text"] = (
|
167
|
+
"PCC: "
|
168
|
+
+ gwas_data_to_plot["PCC"].round(2).astype(str)
|
169
|
+
+ "<br>"
|
170
|
+
+ "Annotation: "
|
171
|
+
+ gwas_data_to_plot["Annotation"].astype(str)
|
172
|
+
)
|
155
173
|
|
156
174
|
fig = ManhattanPlot(
|
157
175
|
dataframe=gwas_data_to_plot,
|
158
|
-
title=
|
176
|
+
title="gsMap Diagnosis Manhattan Plot",
|
159
177
|
point_size=3,
|
160
|
-
highlight_gene_list=config.selected_genes
|
178
|
+
highlight_gene_list=config.selected_genes
|
179
|
+
or gene_diagnostic_info.Gene.iloc[: config.top_corr_genes].tolist(),
|
161
180
|
suggestiveline_value=-np.log10(1e-5),
|
162
|
-
annotation=
|
181
|
+
annotation="Annotation_text",
|
163
182
|
)
|
164
183
|
|
165
184
|
save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
|
166
185
|
fig.write_html(save_manhattan_plot_path)
|
167
|
-
logger.info(f
|
186
|
+
logger.info(f"Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.")
|
168
187
|
|
169
188
|
|
170
189
|
def generate_GSS_distribution(config: DiagnosisConfig):
|
171
190
|
"""Generate GSS distribution plots."""
|
172
191
|
# logger.info('Loading ST data...')
|
173
192
|
# adata = sc.read_h5ad(config.hdf5_with_latent_path)
|
174
|
-
mk_score = pd.read_feather(config.mkscore_feather_path).set_index(
|
193
|
+
mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM").T
|
175
194
|
|
176
|
-
plot_genes =
|
195
|
+
plot_genes = (
|
196
|
+
config.selected_genes
|
197
|
+
or load_gene_diagnostic_info(config).Gene.iloc[: config.top_corr_genes].tolist()
|
198
|
+
)
|
177
199
|
if config.selected_genes is not None:
|
178
|
-
logger.info(
|
200
|
+
logger.info(
|
201
|
+
f"Generating GSS & Expression distribution plot for selected genes: {plot_genes}..."
|
202
|
+
)
|
179
203
|
else:
|
180
|
-
logger.info(
|
204
|
+
logger.info(
|
205
|
+
f"Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes..."
|
206
|
+
)
|
181
207
|
|
182
208
|
if config.customize_fig:
|
183
|
-
pixel_width, pixel_height, point_size =
|
209
|
+
pixel_width, pixel_height, point_size = (
|
210
|
+
config.fig_width,
|
211
|
+
config.fig_height,
|
212
|
+
config.point_size,
|
213
|
+
)
|
184
214
|
else:
|
185
|
-
(pixel_width, pixel_height), point_size = estimate_point_size_for_plot(
|
215
|
+
(pixel_width, pixel_height), point_size = estimate_point_size_for_plot(
|
216
|
+
adata.obsm["spatial"]
|
217
|
+
)
|
186
218
|
sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
|
187
219
|
|
188
220
|
# save plot gene list
|
189
|
-
config.get_GSS_plot_select_gene_file(config.trait_name).write_text(
|
221
|
+
config.get_GSS_plot_select_gene_file(config.trait_name).write_text("\n".join(plot_genes))
|
190
222
|
|
191
223
|
for selected_gene in plot_genes:
|
192
|
-
expression_series = pd.Series(
|
193
|
-
|
224
|
+
expression_series = pd.Series(
|
225
|
+
adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index, name="Expression"
|
226
|
+
)
|
227
|
+
threshold = np.quantile(expression_series, 0.9999)
|
194
228
|
expression_series[expression_series > threshold] = threshold
|
195
|
-
generate_and_save_plots(
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
229
|
+
generate_and_save_plots(
|
230
|
+
adata,
|
231
|
+
mk_score,
|
232
|
+
expression_series,
|
233
|
+
selected_gene,
|
234
|
+
point_size,
|
235
|
+
pixel_width,
|
236
|
+
pixel_height,
|
237
|
+
sub_fig_save_dir,
|
238
|
+
config.sample_name,
|
239
|
+
config.annotation,
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def generate_and_save_plots(
|
244
|
+
adata,
|
245
|
+
mk_score,
|
246
|
+
expression_series,
|
247
|
+
selected_gene,
|
248
|
+
point_size,
|
249
|
+
pixel_width,
|
250
|
+
pixel_height,
|
251
|
+
sub_fig_save_dir,
|
252
|
+
sample_name,
|
253
|
+
annotation,
|
254
|
+
):
|
201
255
|
"""Generate and save the plots."""
|
202
256
|
select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
|
203
|
-
sub_fig_1 = draw_scatter(
|
204
|
-
|
205
|
-
|
206
|
-
|
257
|
+
sub_fig_1 = draw_scatter(
|
258
|
+
select_gene_expression_with_space_coord,
|
259
|
+
title=f"{selected_gene} (Expression)",
|
260
|
+
annotation="annotation",
|
261
|
+
color_by="Expression",
|
262
|
+
point_size=point_size,
|
263
|
+
width=pixel_width,
|
264
|
+
height=pixel_height,
|
265
|
+
)
|
266
|
+
save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, "Expression")
|
207
267
|
|
208
|
-
select_gene_GSS_with_space_coord = load_st_coord(
|
209
|
-
|
210
|
-
|
211
|
-
|
268
|
+
select_gene_GSS_with_space_coord = load_st_coord(
|
269
|
+
adata, mk_score[selected_gene].rename("GSS"), annotation
|
270
|
+
)
|
271
|
+
sub_fig_2 = draw_scatter(
|
272
|
+
select_gene_GSS_with_space_coord,
|
273
|
+
title=f"{selected_gene} (GSS)",
|
274
|
+
annotation="annotation",
|
275
|
+
color_by="GSS",
|
276
|
+
point_size=point_size,
|
277
|
+
width=pixel_width,
|
278
|
+
height=pixel_height,
|
279
|
+
)
|
280
|
+
save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, "GSS")
|
212
281
|
|
213
282
|
# combined_fig = make_subplots(rows=1, cols=2,
|
214
283
|
# subplot_titles=(f'{selected_gene} (Expression)', f'{selected_gene} (GSS)'))
|
@@ -218,57 +287,66 @@ def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, p
|
|
218
287
|
# combined_fig.add_trace(trace, row=1, col=2)
|
219
288
|
#
|
220
289
|
|
290
|
+
|
221
291
|
def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
|
222
292
|
"""Save the plot to HTML and PNG."""
|
223
|
-
save_sub_fig_path =
|
293
|
+
save_sub_fig_path = (
|
294
|
+
sub_fig_save_dir / f"{sample_name}_{selected_gene}_{plot_type}_Distribution.html"
|
295
|
+
)
|
224
296
|
# sub_fig.write_html(str(save_sub_fig_path))
|
225
297
|
sub_fig.update_layout(showlegend=False)
|
226
|
-
sub_fig.write_image(str(save_sub_fig_path).replace(
|
298
|
+
sub_fig.write_image(str(save_sub_fig_path).replace(".html", ".png"))
|
227
299
|
|
228
300
|
|
229
301
|
def generate_gsMap_plot(config: DiagnosisConfig):
|
230
302
|
"""Generate gsMap plot."""
|
231
|
-
logger.info(
|
303
|
+
logger.info("Creating gsMap plot...")
|
232
304
|
|
233
305
|
trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
|
234
306
|
space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
|
235
307
|
|
236
308
|
if config.customize_fig:
|
237
|
-
pixel_width, pixel_height, point_size =
|
309
|
+
pixel_width, pixel_height, point_size = (
|
310
|
+
config.fig_width,
|
311
|
+
config.fig_height,
|
312
|
+
config.point_size,
|
313
|
+
)
|
238
314
|
else:
|
239
|
-
(pixel_width, pixel_height), point_size = estimate_point_size_for_plot(
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
315
|
+
(pixel_width, pixel_height), point_size = estimate_point_size_for_plot(
|
316
|
+
adata.obsm["spatial"]
|
317
|
+
)
|
318
|
+
fig = draw_scatter(
|
319
|
+
space_coord_concat,
|
320
|
+
title=f"{config.trait_name} (gsMap)",
|
321
|
+
point_size=point_size,
|
322
|
+
width=pixel_width,
|
323
|
+
height=pixel_height,
|
324
|
+
annotation=config.annotation,
|
325
|
+
)
|
247
326
|
|
248
327
|
output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
|
249
328
|
output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
|
250
|
-
output_file_png = output_file_html.with_suffix(
|
251
|
-
output_file_csv = output_file_html.with_suffix(
|
329
|
+
output_file_png = output_file_html.with_suffix(".png")
|
330
|
+
output_file_csv = output_file_html.with_suffix(".csv")
|
252
331
|
|
253
332
|
fig.write_html(output_file_html)
|
254
333
|
fig.write_image(output_file_png)
|
255
334
|
space_coord_concat.to_csv(output_file_csv)
|
256
335
|
|
257
|
-
logger.info(f
|
336
|
+
logger.info(f"gsMap plot created and saved in {output_dir}.")
|
258
337
|
|
259
338
|
|
260
339
|
def run_Diagnosis(config: DiagnosisConfig):
|
261
340
|
"""Main function to run the diagnostic plot generation."""
|
262
341
|
global adata
|
263
342
|
adata = sc.read_h5ad(config.hdf5_with_latent_path)
|
264
|
-
if
|
343
|
+
if "log1p" not in adata.uns.keys() and adata.X.max() > 14:
|
265
344
|
sc.pp.normalize_total(adata, target_sum=1e4)
|
266
345
|
sc.pp.log1p(adata)
|
267
346
|
|
268
|
-
if config.plot_type in [
|
347
|
+
if config.plot_type in ["gsMap", "all"]:
|
348
|
+
generate_gsMap_plot(config)
|
349
|
+
if config.plot_type in ["manhattan", "all"]:
|
269
350
|
generate_manhattan_plot(config)
|
270
|
-
if config.plot_type in [
|
351
|
+
if config.plot_type in ["GSS", "all"]:
|
271
352
|
generate_GSS_distribution(config)
|
272
|
-
if config.plot_type in ['gsMap', 'all']:
|
273
|
-
generate_gsMap_plot(config)
|
274
|
-
|