gsMap 1.62__py3-none-any.whl → 1.64__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN_VAE/adjacency_matrix.py +1 -1
- gsMap/GNN_VAE/model.py +5 -5
- gsMap/GNN_VAE/train.py +1 -1
- gsMap/__init__.py +1 -1
- gsMap/cauchy_combination_test.py +14 -36
- gsMap/config.py +473 -404
- gsMap/diagnosis.py +273 -0
- gsMap/find_latent_representation.py +22 -86
- gsMap/format_sumstats.py +79 -82
- gsMap/generate_ldscore.py +145 -78
- gsMap/latent_to_gene.py +65 -104
- gsMap/main.py +1 -9
- gsMap/report.py +160 -0
- gsMap/run_all_mode.py +195 -0
- gsMap/spatial_ldsc_multiple_sumstats.py +188 -113
- gsMap/templates/report_template.html +198 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/{generate_r2_matrix.py → utils/generate_r2_matrix.py} +2 -10
- gsMap/{make_annotations.py → utils/make_annotations.py} +1 -43
- gsMap/utils/manhattan_plot.py +639 -0
- gsMap/{regression_read.py → utils/regression_read.py} +1 -1
- gsMap/visualize.py +100 -55
- {gsmap-1.62.dist-info → gsmap-1.64.dist-info}/METADATA +21 -46
- gsmap-1.64.dist-info/RECORD +30 -0
- gsmap-1.62.dist-info/RECORD +0 -24
- /gsMap/{jackknife.py → utils/jackknife.py} +0 -0
- {gsmap-1.62.dist-info → gsmap-1.64.dist-info}/LICENSE +0 -0
- {gsmap-1.62.dist-info → gsmap-1.64.dist-info}/WHEEL +0 -0
- {gsmap-1.62.dist-info → gsmap-1.64.dist-info}/entry_points.txt +0 -0
gsMap/diagnosis.py
ADDED
@@ -0,0 +1,273 @@
|
|
1
|
+
import logging
|
2
|
+
import warnings
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import scanpy as sc
|
8
|
+
from scipy.stats import norm
|
9
|
+
|
10
|
+
from gsMap.config import DiagnosisConfig
|
11
|
+
from gsMap.utils.manhattan_plot import ManhattanPlot
|
12
|
+
from gsMap.visualize import draw_scatter, load_st_coord, estimate_point_size_for_plot
|
13
|
+
|
14
|
+
|
15
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
def convert_z_to_p(gwas_data):
|
20
|
+
"""Convert Z-scores to P-values."""
|
21
|
+
gwas_data['P'] = norm.sf(abs(gwas_data['Z'])) * 2
|
22
|
+
min_p_value = 1e-300
|
23
|
+
gwas_data['P'] = gwas_data['P'].clip(lower=min_p_value)
|
24
|
+
return gwas_data
|
25
|
+
|
26
|
+
|
27
|
+
def load_ldsc(ldsc_input_file):
|
28
|
+
"""Load LDSC data and calculate logp."""
|
29
|
+
ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
|
30
|
+
ldsc['spot'] = ldsc['spot'].astype(str).replace('\.0', '', regex=True)
|
31
|
+
ldsc.set_index('spot', inplace=True)
|
32
|
+
ldsc['logp'] = -np.log10(ldsc['p'])
|
33
|
+
return ldsc
|
34
|
+
|
35
|
+
|
36
|
+
def load_gene_diagnostic_info(config:DiagnosisConfig):
|
37
|
+
"""Load or compute gene diagnostic info."""
|
38
|
+
gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
|
39
|
+
if gene_diagnostic_info_save_path.exists():
|
40
|
+
logger.info(f'Loading gene diagnostic information from {gene_diagnostic_info_save_path}...')
|
41
|
+
return pd.read_csv(gene_diagnostic_info_save_path)
|
42
|
+
else:
|
43
|
+
logger.info('Gene diagnostic information not found. Calculating gene diagnostic information...')
|
44
|
+
return compute_gene_diagnostic_info(config)
|
45
|
+
|
46
|
+
|
47
|
+
def compute_gene_diagnostic_info(config: DiagnosisConfig):
|
48
|
+
"""Calculate gene diagnostic info and save it to adata."""
|
49
|
+
logger.info('Loading ST data and LDSC results...')
|
50
|
+
# adata = sc.read_h5ad(config.hdf5_with_latent_path, backed='r')
|
51
|
+
mk_score = pd.read_feather(config.mkscore_feather_path)
|
52
|
+
mk_score.set_index('HUMAN_GENE_SYM', inplace=True)
|
53
|
+
mk_score = mk_score.T
|
54
|
+
trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
|
55
|
+
|
56
|
+
# Align marker scores with trait LDSC results
|
57
|
+
mk_score = mk_score.loc[trait_ldsc_result.index]
|
58
|
+
mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
|
59
|
+
|
60
|
+
logger.info('Calculating correlation between gene marker scores and trait logp-values...')
|
61
|
+
corr = mk_score.corrwith(trait_ldsc_result['logp'])
|
62
|
+
corr.name = 'PCC'
|
63
|
+
|
64
|
+
grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
|
65
|
+
max_annotations = grouped_mk_score.idxmax()
|
66
|
+
|
67
|
+
high_GSS_Gene_annotation_pair = pd.DataFrame({
|
68
|
+
'Gene': max_annotations.index,
|
69
|
+
'Annotation': max_annotations.values,
|
70
|
+
'Median_GSS': grouped_mk_score.max().values
|
71
|
+
})
|
72
|
+
|
73
|
+
# Filter based on median GSS score
|
74
|
+
high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[high_GSS_Gene_annotation_pair['Median_GSS'] >= 1.0]
|
75
|
+
high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(corr, left_on='Gene', right_index=True)
|
76
|
+
|
77
|
+
# Prepare the final gene diagnostic info dataframe
|
78
|
+
gene_diagnostic_info_cols = ['Gene', 'Annotation', 'Median_GSS', 'PCC']
|
79
|
+
gene_diagnostic_info = high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols].drop_duplicates().dropna(
|
80
|
+
subset=['Gene'])
|
81
|
+
gene_diagnostic_info.sort_values('PCC', ascending=False, inplace=True)
|
82
|
+
|
83
|
+
# Save gene diagnostic info to a file
|
84
|
+
gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
|
85
|
+
gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
|
86
|
+
logger.info(f'Gene diagnostic information saved to {gene_diagnostic_info_save_path}.')
|
87
|
+
|
88
|
+
# Save to adata.var with the trait_name prefix
|
89
|
+
logger.info('Saving gene diagnostic info to adata.var...')
|
90
|
+
gene_diagnostic_info.set_index('Gene', inplace=True) # Use 'Gene' as the index to align with adata.var
|
91
|
+
adata.var[f'{config.trait_name}_Annotation'] = gene_diagnostic_info['Annotation']
|
92
|
+
adata.var[f'{config.trait_name}_Median_GSS'] = gene_diagnostic_info['Median_GSS']
|
93
|
+
adata.var[f'{config.trait_name}_PCC'] = gene_diagnostic_info['PCC']
|
94
|
+
|
95
|
+
# Save trait_ldsc_result to adata.obs
|
96
|
+
logger.info(f'Saving trait LDSC results to adata.obs as gsMap_{config.trait_name}_p_value...')
|
97
|
+
adata.obs[f'gsMap_{config.trait_name}_p_value'] = trait_ldsc_result['p']
|
98
|
+
adata.write(config.hdf5_with_latent_path, )
|
99
|
+
|
100
|
+
return gene_diagnostic_info.reset_index()
|
101
|
+
|
102
|
+
|
103
|
+
def load_gwas_data(config:DiagnosisConfig):
|
104
|
+
"""Load and process GWAS data."""
|
105
|
+
logger.info('Loading and processing GWAS data...')
|
106
|
+
gwas_data = pd.read_csv(config.sumstats_file, compression='gzip', sep='\t')
|
107
|
+
return convert_z_to_p(gwas_data)
|
108
|
+
|
109
|
+
|
110
|
+
def load_snp_gene_pairs(config:DiagnosisConfig):
|
111
|
+
"""Load SNP-gene pairs from multiple chromosomes."""
|
112
|
+
ldscore_save_dir = Path(config.ldscore_save_dir)
|
113
|
+
return pd.concat([
|
114
|
+
pd.read_feather(ldscore_save_dir / f'SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather')
|
115
|
+
for chrom in range(1, 23)
|
116
|
+
])
|
117
|
+
|
118
|
+
|
119
|
+
def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
|
120
|
+
"""Filter the SNPs based on significance levels."""
|
121
|
+
pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort['P'] < 1e-5
|
122
|
+
pass_suggestive_line_number = pass_suggestive_line_mask.sum()
|
123
|
+
|
124
|
+
if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
|
125
|
+
snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
|
126
|
+
logger.info(f'To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted.')
|
127
|
+
else:
|
128
|
+
snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
|
129
|
+
logger.info(
|
130
|
+
f'To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted.')
|
131
|
+
|
132
|
+
return snps2plot
|
133
|
+
|
134
|
+
|
135
|
+
def generate_manhattan_plot(config: DiagnosisConfig):
|
136
|
+
"""Generate Manhattan plot."""
|
137
|
+
report_save_dir = config.get_report_dir(config.trait_name)
|
138
|
+
gwas_data = load_gwas_data(config)
|
139
|
+
snp_gene_pair = load_snp_gene_pairs(config)
|
140
|
+
gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on='SNP', how='inner').rename(columns={'gene_name': 'GENE'})
|
141
|
+
gene_diagnostic_info = load_gene_diagnostic_info(config)
|
142
|
+
gwas_data_with_gene_annotation = gwas_data_with_gene.merge(gene_diagnostic_info, left_on='GENE', right_on='Gene',
|
143
|
+
how='left')
|
144
|
+
|
145
|
+
gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
|
146
|
+
~gwas_data_with_gene_annotation['Annotation'].isna()]
|
147
|
+
gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values('P')
|
148
|
+
|
149
|
+
snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
|
150
|
+
gwas_data_to_plot = gwas_data_with_gene_annotation[
|
151
|
+
gwas_data_with_gene_annotation['SNP'].isin(snps2plot)].reset_index(drop=True)
|
152
|
+
gwas_data_to_plot['Annotation_text'] = 'PCC: ' + gwas_data_to_plot['PCC'].round(2).astype(
|
153
|
+
str) + '<br>' + 'Annotation: ' + gwas_data_to_plot['Annotation'].astype(str)
|
154
|
+
|
155
|
+
fig = ManhattanPlot(
|
156
|
+
dataframe=gwas_data_to_plot,
|
157
|
+
title='gsMap Diagnosis Manhattan Plot',
|
158
|
+
point_size=3,
|
159
|
+
highlight_gene_list=config.selected_genes or gene_diagnostic_info.Gene.iloc[:config.top_corr_genes].tolist(),
|
160
|
+
suggestiveline_value=-np.log10(1e-5),
|
161
|
+
annotation='Annotation_text',
|
162
|
+
)
|
163
|
+
|
164
|
+
save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
|
165
|
+
fig.write_html(save_manhattan_plot_path)
|
166
|
+
logger.info(f'Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.')
|
167
|
+
|
168
|
+
|
169
|
+
def generate_GSS_distribution(config: DiagnosisConfig):
|
170
|
+
"""Generate GSS distribution plots."""
|
171
|
+
# logger.info('Loading ST data...')
|
172
|
+
# adata = sc.read_h5ad(config.hdf5_with_latent_path)
|
173
|
+
mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM').T
|
174
|
+
|
175
|
+
plot_genes = config.selected_genes or load_gene_diagnostic_info(config).Gene.iloc[:config.top_corr_genes].tolist()
|
176
|
+
if config.selected_genes is not None:
|
177
|
+
logger.info(f'Generating GSS & Expression distribution plot for selected genes: {plot_genes}...')
|
178
|
+
else:
|
179
|
+
logger.info(f'Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes...')
|
180
|
+
|
181
|
+
if config.customize_fig:
|
182
|
+
pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
|
183
|
+
else:
|
184
|
+
(pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
|
185
|
+
sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
|
186
|
+
|
187
|
+
# save plot gene list
|
188
|
+
config.get_GSS_plot_select_gene_file(config.trait_name).write_text('\n'.join(plot_genes))
|
189
|
+
|
190
|
+
for selected_gene in plot_genes:
|
191
|
+
expression_series = pd.Series(adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index,name='Expression')
|
192
|
+
threshold = np.quantile(expression_series,0.9999)
|
193
|
+
expression_series[expression_series > threshold] = threshold
|
194
|
+
generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width,
|
195
|
+
pixel_height, sub_fig_save_dir, config.sample_name, config.annotation)
|
196
|
+
|
197
|
+
|
198
|
+
def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width, pixel_height,
|
199
|
+
sub_fig_save_dir, sample_name, annotation):
|
200
|
+
"""Generate and save the plots."""
|
201
|
+
select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
|
202
|
+
sub_fig_1 = draw_scatter(select_gene_expression_with_space_coord, title=f'{selected_gene} (Expression)',
|
203
|
+
annotation='annotation', color_by='Expression', point_size=point_size, width=pixel_width,
|
204
|
+
height=pixel_height)
|
205
|
+
save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, 'Expression')
|
206
|
+
|
207
|
+
select_gene_GSS_with_space_coord = load_st_coord(adata, mk_score[selected_gene].rename('GSS'), annotation)
|
208
|
+
sub_fig_2 = draw_scatter(select_gene_GSS_with_space_coord, title=f'{selected_gene} (GSS)', annotation='annotation',
|
209
|
+
color_by='GSS', point_size=point_size, width=pixel_width, height=pixel_height)
|
210
|
+
save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, 'GSS')
|
211
|
+
|
212
|
+
# combined_fig = make_subplots(rows=1, cols=2,
|
213
|
+
# subplot_titles=(f'{selected_gene} (Expression)', f'{selected_gene} (GSS)'))
|
214
|
+
# for trace in sub_fig_1.data:
|
215
|
+
# combined_fig.add_trace(trace, row=1, col=1)
|
216
|
+
# for trace in sub_fig_2.data:
|
217
|
+
# combined_fig.add_trace(trace, row=1, col=2)
|
218
|
+
#
|
219
|
+
|
220
|
+
def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
|
221
|
+
"""Save the plot to HTML and PNG."""
|
222
|
+
save_sub_fig_path = sub_fig_save_dir / f'{sample_name}_{selected_gene}_{plot_type}_Distribution.html'
|
223
|
+
# sub_fig.write_html(str(save_sub_fig_path))
|
224
|
+
sub_fig.update_layout(showlegend=False)
|
225
|
+
sub_fig.write_image(str(save_sub_fig_path).replace('.html', '.png'))
|
226
|
+
|
227
|
+
|
228
|
+
def generate_gsMap_plot(config: DiagnosisConfig):
|
229
|
+
"""Generate gsMap plot."""
|
230
|
+
logger.info('Creating gsMap plot...')
|
231
|
+
|
232
|
+
trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
|
233
|
+
space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
|
234
|
+
|
235
|
+
if config.customize_fig:
|
236
|
+
pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
|
237
|
+
else:
|
238
|
+
(pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
|
239
|
+
fig = draw_scatter(space_coord_concat,
|
240
|
+
title=f'{config.trait_name} (gsMap)',
|
241
|
+
point_size=point_size,
|
242
|
+
width=pixel_width,
|
243
|
+
height=pixel_height,
|
244
|
+
annotation=config.annotation
|
245
|
+
)
|
246
|
+
|
247
|
+
output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
|
248
|
+
output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
|
249
|
+
output_file_png = output_file_html.with_suffix('.png')
|
250
|
+
output_file_csv = output_file_html.with_suffix('.csv')
|
251
|
+
|
252
|
+
fig.write_html(output_file_html)
|
253
|
+
fig.write_image(output_file_png)
|
254
|
+
space_coord_concat.to_csv(output_file_csv)
|
255
|
+
|
256
|
+
logger.info(f'gsMap plot created and saved in {output_dir}.')
|
257
|
+
|
258
|
+
|
259
|
+
def run_Diagnosis(config: DiagnosisConfig):
|
260
|
+
"""Main function to run the diagnostic plot generation."""
|
261
|
+
global adata
|
262
|
+
adata = sc.read_h5ad(config.hdf5_with_latent_path)
|
263
|
+
if 'log1p' not in adata.uns.keys() and adata.X.max() > 14:
|
264
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
265
|
+
sc.pp.log1p(adata)
|
266
|
+
|
267
|
+
if config.plot_type in ['manhattan', 'all']:
|
268
|
+
generate_manhattan_plot(config)
|
269
|
+
if config.plot_type in ['GSS', 'all']:
|
270
|
+
generate_GSS_distribution(config)
|
271
|
+
if config.plot_type in ['gsMap', 'all']:
|
272
|
+
generate_gsMap_plot(config)
|
273
|
+
|
@@ -1,9 +1,5 @@
|
|
1
|
-
import argparse
|
2
1
|
import logging
|
3
|
-
import pprint
|
4
2
|
import random
|
5
|
-
import time
|
6
|
-
from pathlib import Path
|
7
3
|
|
8
4
|
import numpy as np
|
9
5
|
import pandas as pd
|
@@ -13,17 +9,9 @@ from sklearn import preprocessing
|
|
13
9
|
|
14
10
|
from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
|
15
11
|
from gsMap.GNN_VAE.train import Model_Train
|
16
|
-
from gsMap.config import
|
17
|
-
|
18
|
-
# seed all
|
12
|
+
from gsMap.config import FindLatentRepresentationsConfig
|
19
13
|
|
20
14
|
logger = logging.getLogger(__name__)
|
21
|
-
logger.setLevel(logging.DEBUG)
|
22
|
-
handler = logging.StreamHandler()
|
23
|
-
handler.setFormatter(logging.Formatter(
|
24
|
-
'[{asctime}] {levelname:8s} {filename} {message}', style='{'))
|
25
|
-
logger.addHandler(handler)
|
26
|
-
|
27
15
|
|
28
16
|
def set_seed(seed_value):
|
29
17
|
"""
|
@@ -33,30 +21,30 @@ def set_seed(seed_value):
|
|
33
21
|
np.random.seed(seed_value) # Set the seed for NumPy
|
34
22
|
random.seed(seed_value) # Set the seed for Python random module
|
35
23
|
if torch.cuda.is_available():
|
36
|
-
|
24
|
+
logger.info('Running use GPU')
|
37
25
|
torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
|
38
26
|
torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
|
39
27
|
else:
|
40
|
-
|
28
|
+
logger.info('Running use CPU')
|
41
29
|
|
42
|
-
set_seed(2024)
|
43
30
|
|
44
31
|
# The class for finding latent representations
|
45
32
|
class Latent_Representation_Finder:
|
46
33
|
|
47
|
-
def __init__(self, adata,
|
34
|
+
def __init__(self, adata, args:FindLatentRepresentationsConfig):
|
48
35
|
self.adata = adata.copy()
|
49
|
-
self.Params =
|
36
|
+
self.Params = args
|
50
37
|
|
51
38
|
# Standard process
|
52
|
-
if self.Params.
|
53
|
-
self.adata.X = self.adata.layers[self.Params.
|
39
|
+
if self.Params.data_layer == 'count' or self.Params.data_layer == 'counts':
|
40
|
+
self.adata.X = self.adata.layers[self.Params.data_layer]
|
54
41
|
sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
|
55
42
|
sc.pp.normalize_total(self.adata, target_sum=1e4)
|
56
43
|
sc.pp.log1p(self.adata)
|
57
44
|
sc.pp.scale(self.adata)
|
58
45
|
else:
|
59
|
-
self.
|
46
|
+
if self.Params.data_layer != 'X':
|
47
|
+
self.adata.X = self.adata.layers[self.Params.data_layer]
|
60
48
|
sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
|
61
49
|
|
62
50
|
def Run_GNN_VAE(self, label, verbose='whole ST data'):
|
@@ -66,7 +54,7 @@ class Latent_Representation_Finder:
|
|
66
54
|
|
67
55
|
# Process the feature matrix
|
68
56
|
node_X = self.adata[:, self.adata.var.highly_variable].X
|
69
|
-
|
57
|
+
logger.info(f'The shape of feature matrix is {node_X.shape}.')
|
70
58
|
if self.Params.input_pca:
|
71
59
|
node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
|
72
60
|
|
@@ -75,7 +63,7 @@ class Latent_Representation_Finder:
|
|
75
63
|
self.Params.feat_cell = node_X.shape[1]
|
76
64
|
|
77
65
|
# Run GNN-VAE
|
78
|
-
|
66
|
+
logger.info(f'------Finding latent representations for {verbose}...')
|
79
67
|
gvae = Model_Train(node_X, graph_dict, self.Params, label)
|
80
68
|
gvae.run_train()
|
81
69
|
|
@@ -87,20 +75,21 @@ class Latent_Representation_Finder:
|
|
87
75
|
|
88
76
|
|
89
77
|
def run_find_latent_representation(args:FindLatentRepresentationsConfig):
|
78
|
+
set_seed(2024)
|
90
79
|
num_features = args.feat_cell
|
91
|
-
args.
|
92
|
-
args.output_dir.mkdir(parents=True, exist_ok=True,mode=0o755)
|
80
|
+
args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
|
93
81
|
# Load the ST data
|
94
|
-
|
82
|
+
logger.info(f'------Loading ST data of {args.sample_name}...')
|
95
83
|
adata = sc.read_h5ad(f'{args.input_hdf5_path}')
|
96
84
|
adata.var_names_make_unique()
|
97
|
-
|
85
|
+
adata.X = adata.layers[args.data_layer] if args.data_layer in adata.layers.keys() else adata.X
|
86
|
+
logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
|
98
87
|
# Load the cell type annotation
|
99
88
|
if not args.annotation is None:
|
100
89
|
# remove cells without enough annotations
|
101
90
|
adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
|
102
91
|
num = adata.obs[args.annotation].value_counts()
|
103
|
-
adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())
|
92
|
+
adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())]
|
104
93
|
|
105
94
|
le = preprocessing.LabelEncoder()
|
106
95
|
le.fit(adata.obs[args.annotation])
|
@@ -113,7 +102,7 @@ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
|
|
113
102
|
latent_GVAE = latent_rep.Run_GNN_VAE(label)
|
114
103
|
latent_PCA = latent_rep.Run_PCA()
|
115
104
|
# Add latent representations to the spe data
|
116
|
-
|
105
|
+
logger.info(f'------Adding latent representations...')
|
117
106
|
adata.obsm["latent_GVAE"] = latent_GVAE
|
118
107
|
adata.obsm["latent_PCA"] = latent_PCA
|
119
108
|
# Run umap based on latent representations
|
@@ -124,13 +113,13 @@ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
|
|
124
113
|
|
125
114
|
# Find the latent representations hierarchically (optionally)
|
126
115
|
if not args.annotation is None and args.hierarchically:
|
127
|
-
|
116
|
+
logger.info(f'------Finding latent representations hierarchically...')
|
128
117
|
PCA_all = pd.DataFrame()
|
129
118
|
GVAE_all = pd.DataFrame()
|
130
119
|
|
131
120
|
for ct in adata.obs[args.annotation].unique():
|
132
121
|
adata_part = adata[adata.obs[args.annotation] == ct, :]
|
133
|
-
|
122
|
+
logger.info(adata_part.shape)
|
134
123
|
|
135
124
|
# Find latent representations for the selected ct
|
136
125
|
latent_rep = Latent_Representation_Finder(adata_part, args)
|
@@ -151,59 +140,6 @@ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
|
|
151
140
|
|
152
141
|
adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
|
153
142
|
adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
|
154
|
-
|
155
|
-
adata.write(args.
|
156
|
-
|
157
|
-
|
158
|
-
if __name__ == '__main__':
|
159
|
-
parser = argparse.ArgumentParser(description="This script is designed to find latent representations in spatial transcriptomics data using a Graph Neural Network Variational Autoencoder (GNN-VAE). It processes input data, constructs a neighboring graph, and runs GNN-VAE to output latent representations.")
|
160
|
-
add_find_latent_representations_args(parser)
|
161
|
-
TEST=True
|
162
|
-
if TEST:
|
163
|
-
test_dir = '/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021'
|
164
|
-
name = 'Cortex_151507'
|
143
|
+
logger.info(f'------Saving ST data...')
|
144
|
+
adata.write(args.hdf5_with_latent_path)
|
165
145
|
|
166
|
-
|
167
|
-
args = parser.parse_args(
|
168
|
-
[
|
169
|
-
'--input_hdf5_path','/storage/yangjianLab/songliyang/SpatialData/Data/Brain/Human/Nature_Neuroscience_2021/processed/h5ad/Cortex_151507.h5ad',
|
170
|
-
'--output_hdf5_path',f'{test_dir}/{name}/hdf5/{name}_add_latent.h5ad',
|
171
|
-
'--sample_name', name,
|
172
|
-
'--annotation','layer_guess',
|
173
|
-
'--type','count',
|
174
|
-
]
|
175
|
-
|
176
|
-
)
|
177
|
-
|
178
|
-
else:
|
179
|
-
args = parser.parse_args()
|
180
|
-
config=FindLatentRepresentationsConfig(**{'annotation': 'SubClass',
|
181
|
-
'convergence_threshold': 0.0001,
|
182
|
-
'epochs': 300,
|
183
|
-
'feat_cell': 3000,
|
184
|
-
'feat_hidden1': 256,
|
185
|
-
'feat_hidden2': 128,
|
186
|
-
'gcn_decay': 0.01,
|
187
|
-
'gcn_hidden1': 64,
|
188
|
-
'gcn_hidden2': 30,
|
189
|
-
'gcn_lr': 0.001,
|
190
|
-
'hierarchically': False,
|
191
|
-
'input_hdf5_path': '/storage/yangjianLab/songliyang/SpatialData/Data/Brain/macaque/Cell/processed/h5ad/T862_macaque3.h5ad',
|
192
|
-
'label_w': 1.0,
|
193
|
-
'n_comps': 300,
|
194
|
-
'n_neighbors': 11,
|
195
|
-
'nheads': 3,
|
196
|
-
'output_hdf5_path': 'T862_macaque3/find_latent_representations/T862_macaque3_add_latent.h5ad',
|
197
|
-
'p_drop': 0.1,
|
198
|
-
'rec_w': 1.0,
|
199
|
-
'sample_name': 'T862_macaque3',
|
200
|
-
'type': 'SCT',
|
201
|
-
'var': False,
|
202
|
-
'weighted_adj': False})
|
203
|
-
# config=FindLatentRepresentationsConfig(**vars(args))
|
204
|
-
start_time = time.time()
|
205
|
-
logger.info(f'Find latent representations for {config.sample_name}...')
|
206
|
-
pprint.pprint(config)
|
207
|
-
run_find_latent_representation(config)
|
208
|
-
end_time = time.time()
|
209
|
-
logger.info(f'Find latent representations for {config.sample_name} finished. Time spent: {(end_time - start_time) / 60:.2f} min.')
|