gsMap 1.70__py3-none-any.whl → 1.71.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/__init__.py +0 -0
- gsMap/GNN/adjacency_matrix.py +75 -75
- gsMap/GNN/model.py +90 -89
- gsMap/GNN/train.py +0 -0
- gsMap/__init__.py +5 -5
- gsMap/__main__.py +2 -2
- gsMap/cauchy_combination_test.py +141 -141
- gsMap/config.py +805 -805
- gsMap/diagnosis.py +273 -273
- gsMap/find_latent_representation.py +133 -133
- gsMap/format_sumstats.py +407 -407
- gsMap/generate_ldscore.py +618 -618
- gsMap/latent_to_gene.py +234 -234
- gsMap/main.py +31 -31
- gsMap/report.py +160 -160
- gsMap/run_all_mode.py +194 -194
- gsMap/setup.py +0 -0
- gsMap/spatial_ldsc_multiple_sumstats.py +380 -380
- gsMap/templates/report_template.html +198 -198
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +735 -735
- gsMap/utils/jackknife.py +514 -514
- gsMap/utils/make_annotations.py +518 -518
- gsMap/utils/manhattan_plot.py +639 -639
- gsMap/utils/regression_read.py +294 -294
- gsMap/visualize.py +198 -198
- {gsmap-1.70.dist-info → gsmap-1.71.1.dist-info}/LICENSE +21 -21
- {gsmap-1.70.dist-info → gsmap-1.71.1.dist-info}/METADATA +2 -2
- gsmap-1.71.1.dist-info/RECORD +31 -0
- gsmap-1.70.dist-info/RECORD +0 -31
- {gsmap-1.70.dist-info → gsmap-1.71.1.dist-info}/WHEEL +0 -0
- {gsmap-1.70.dist-info → gsmap-1.71.1.dist-info}/entry_points.txt +0 -0
@@ -1,380 +1,380 @@
|
|
1
|
-
import gc
|
2
|
-
import logging
|
3
|
-
import os
|
4
|
-
from collections import defaultdict
|
5
|
-
from pathlib import Path
|
6
|
-
|
7
|
-
import anndata as ad
|
8
|
-
import numpy as np
|
9
|
-
import pandas as pd
|
10
|
-
import zarr
|
11
|
-
from scipy.stats import norm
|
12
|
-
from tqdm.contrib.concurrent import thread_map
|
13
|
-
|
14
|
-
import gsMap.utils.jackknife as jk
|
15
|
-
from gsMap.config import SpatialLDSCConfig
|
16
|
-
from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
|
17
|
-
|
18
|
-
logger = logging.getLogger('gsMap.spatial_ldsc')
|
19
|
-
|
20
|
-
|
21
|
-
# %%
|
22
|
-
def _coef_new(jknife):
|
23
|
-
est_ = jknife.jknife_est[0, 0] / Nbar
|
24
|
-
se_ = jknife.jknife_se[0, 0] / Nbar
|
25
|
-
return est_, se_
|
26
|
-
|
27
|
-
|
28
|
-
def append_intercept(x):
|
29
|
-
n_row = x.shape[0]
|
30
|
-
intercept = np.ones((n_row, 1))
|
31
|
-
x_new = np.concatenate((x, intercept), axis=1)
|
32
|
-
return x_new
|
33
|
-
|
34
|
-
|
35
|
-
def filter_sumstats_by_chisq(sumstats, chisq_max):
|
36
|
-
before_len = len(sumstats)
|
37
|
-
if chisq_max is None:
|
38
|
-
chisq_max = max(0.001 * sumstats.N.max(), 80)
|
39
|
-
logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
|
40
|
-
sumstats['chisq'] = sumstats.Z ** 2
|
41
|
-
sumstats = sumstats[sumstats.chisq < chisq_max]
|
42
|
-
after_len = len(sumstats)
|
43
|
-
if after_len < before_len:
|
44
|
-
logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
|
45
|
-
else:
|
46
|
-
logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
|
47
|
-
return sumstats
|
48
|
-
|
49
|
-
|
50
|
-
def aggregate(y, x, N, M, intercept=1):
|
51
|
-
num = M * (np.mean(y) - intercept)
|
52
|
-
denom = np.mean(np.multiply(x, N))
|
53
|
-
return num / denom
|
54
|
-
|
55
|
-
|
56
|
-
def weights(ld, w_ld, N, M, hsq, intercept=1):
|
57
|
-
M = float(M)
|
58
|
-
hsq = np.clip(hsq, 0.0, 1.0)
|
59
|
-
ld = np.maximum(ld, 1.0)
|
60
|
-
w_ld = np.maximum(w_ld, 1.0)
|
61
|
-
c = hsq * N / M
|
62
|
-
het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
|
63
|
-
oc_w = 1.0 / w_ld
|
64
|
-
w = np.multiply(het_w, oc_w)
|
65
|
-
return w
|
66
|
-
|
67
|
-
|
68
|
-
def jackknife_for_processmap(spot_id):
|
69
|
-
# calculate the initial weight for each spot
|
70
|
-
spot_spatial_annotation = spatial_annotation[:, spot_id]
|
71
|
-
spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
|
72
|
-
initial_w = (
|
73
|
-
get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
|
74
|
-
M_tot=10000, w_ld=w_ld_common_snp, intercept=1)
|
75
|
-
.astype(np.float32)
|
76
|
-
.reshape((-1, 1)))
|
77
|
-
|
78
|
-
# apply the weight to baseline annotation, spatial annotation and CHISQ
|
79
|
-
initial_w_scaled = initial_w / np.sum(initial_w)
|
80
|
-
baseline_annotation_spot = baseline_annotation * initial_w_scaled
|
81
|
-
spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
|
82
|
-
CHISQ = sumstats.chisq.values.reshape((-1, 1))
|
83
|
-
y = CHISQ * initial_w_scaled
|
84
|
-
|
85
|
-
# run the jackknife
|
86
|
-
x_focal = np.concatenate((spatial_annotation_spot,
|
87
|
-
baseline_annotation_spot), axis=1)
|
88
|
-
try:
|
89
|
-
jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
|
90
|
-
# LinAlgError
|
91
|
-
except np.linalg.LinAlgError as e:
|
92
|
-
logger.warning(f'LinAlgError: {e}')
|
93
|
-
return np.nan, np.nan
|
94
|
-
return _coef_new(jknife)
|
95
|
-
|
96
|
-
|
97
|
-
# Updated function
|
98
|
-
def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
|
99
|
-
tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
|
100
|
-
initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
|
101
|
-
initial_w = np.sqrt(initial_w)
|
102
|
-
return initial_w
|
103
|
-
|
104
|
-
|
105
|
-
def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
|
106
|
-
# Load the gwas summary statistics
|
107
|
-
sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
|
108
|
-
sumstats.set_index('SNP', inplace=True)
|
109
|
-
sumstats = sumstats.astype(np.float32)
|
110
|
-
sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
|
111
|
-
|
112
|
-
# NB: The intersection order is essential for keeping the same order of SNPs by its BP location
|
113
|
-
common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
|
114
|
-
if len(common_snp) < 200000:
|
115
|
-
logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
|
116
|
-
|
117
|
-
sumstats = sumstats.loc[common_snp]
|
118
|
-
|
119
|
-
# get the common index position of baseline_and_w_ld_common_snp for quick access
|
120
|
-
sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
|
121
|
-
return sumstats
|
122
|
-
|
123
|
-
|
124
|
-
def _get_sumstats_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
|
125
|
-
chisq_max=None):
|
126
|
-
# first validate if all sumstats file exists
|
127
|
-
logger.info('Validating sumstats files...')
|
128
|
-
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
129
|
-
if not os.path.exists(sumstat_file_path):
|
130
|
-
raise FileNotFoundError(f'{sumstat_file_path} not found')
|
131
|
-
# then load all sumstats
|
132
|
-
sumstats_cleaned_dict = {}
|
133
|
-
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
134
|
-
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
|
135
|
-
baseline_and_w_ld_common_snp, chisq_max)
|
136
|
-
logger.info('cleaned sumstats loaded')
|
137
|
-
return sumstats_cleaned_dict
|
138
|
-
|
139
|
-
|
140
|
-
class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
|
141
|
-
def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
|
142
|
-
self.config = config
|
143
|
-
mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
|
144
|
-
mk_score_genes = mk_score.index
|
145
|
-
|
146
|
-
snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
|
147
|
-
common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
|
148
|
-
common_snps = snp_gene_weight_adata.obs.index
|
149
|
-
# self.snp_gene_weight_adata = snp_gene_weight_adata[common_snp_among_all_sumstats:, common_genes.to_list()]
|
150
|
-
self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
|
151
|
-
self.mk_score_common = mk_score.loc[common_genes]
|
152
|
-
|
153
|
-
# calculate the chunk number
|
154
|
-
self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
|
155
|
-
|
156
|
-
def fetch_ldscore_by_chunk(self, chunk_index):
|
157
|
-
chunk_start = self.chunk_starts[chunk_index]
|
158
|
-
mk_score_chunk = self.mk_score_common.iloc[:,
|
159
|
-
chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
|
160
|
-
ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
|
161
|
-
mk_score_chunk,
|
162
|
-
drop_dummy_na=False,
|
163
|
-
)
|
164
|
-
|
165
|
-
spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
|
166
|
-
return ldscore_chunk, spots_name
|
167
|
-
|
168
|
-
def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self,
|
169
|
-
mk_score_chunk,
|
170
|
-
drop_dummy_na=True,
|
171
|
-
):
|
172
|
-
|
173
|
-
if drop_dummy_na:
|
174
|
-
ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
|
175
|
-
else:
|
176
|
-
ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
|
177
|
-
|
178
|
-
return ldscore_chr_chunk
|
179
|
-
|
180
|
-
|
181
|
-
def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
|
182
|
-
chisq_max=None):
|
183
|
-
# first validate if all sumstats file exists
|
184
|
-
logger.info('Validating sumstats files...')
|
185
|
-
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
186
|
-
if not os.path.exists(sumstat_file_path):
|
187
|
-
raise FileNotFoundError(f'{sumstat_file_path} not found')
|
188
|
-
# then load all sumstats
|
189
|
-
sumstats_cleaned_dict = {}
|
190
|
-
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
191
|
-
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
|
192
|
-
baseline_and_w_ld_common_snp, chisq_max)
|
193
|
-
# get the common snps among all sumstats
|
194
|
-
common_snp_among_all_sumstats = None
|
195
|
-
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
196
|
-
if common_snp_among_all_sumstats is None:
|
197
|
-
common_snp_among_all_sumstats = sumstats.index
|
198
|
-
else:
|
199
|
-
common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
|
200
|
-
|
201
|
-
# filter the common snps among all sumstats
|
202
|
-
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
203
|
-
sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
|
204
|
-
|
205
|
-
logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
|
206
|
-
return sumstats_cleaned_dict, common_snp_among_all_sumstats
|
207
|
-
|
208
|
-
|
209
|
-
def run_spatial_ldsc(config: SpatialLDSCConfig):
|
210
|
-
global spatial_annotation, baseline_annotation, n_blocks, Nbar, sumstats, ref_ld_baseline_column_sum, w_ld_common_snp
|
211
|
-
# config
|
212
|
-
n_blocks = config.n_blocks
|
213
|
-
sample_name = config.sample_name
|
214
|
-
|
215
|
-
print(f'------Running Spatial LDSC for {sample_name}...')
|
216
|
-
# Load the regression weights
|
217
|
-
w_ld = _read_w_ld(config.w_file)
|
218
|
-
w_ld_cname = w_ld.columns[1]
|
219
|
-
w_ld.set_index('SNP', inplace=True)
|
220
|
-
|
221
|
-
ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
|
222
|
-
|
223
|
-
ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
|
224
|
-
# n_annot_baseline = len(ref_ld_baseline.columns)
|
225
|
-
# M_annot_baseline = _read_M_v2(ld_file_baseline, n_annot_baseline, config.not_M_5_50)
|
226
|
-
|
227
|
-
# common snp between baseline and w_ld
|
228
|
-
baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
|
229
|
-
baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
|
230
|
-
|
231
|
-
# Clean the sumstats
|
232
|
-
sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
|
233
|
-
config.sumstats_config_dict, baseline_and_w_ld_common_snp,
|
234
|
-
chisq_max=config.chisq_max)
|
235
|
-
common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
|
236
|
-
|
237
|
-
# insure the order is monotonic
|
238
|
-
assert pd.Series(
|
239
|
-
common_snp_among_all_sumstats_pos).is_monotonic_increasing, 'common_snp_among_all_sumstats_pos is not monotonic increasing'
|
240
|
-
|
241
|
-
if len(common_snp_among_all_sumstats) < 200000:
|
242
|
-
logger.warning(
|
243
|
-
f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
|
244
|
-
|
245
|
-
ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
|
246
|
-
w_ld = w_ld.loc[common_snp_among_all_sumstats]
|
247
|
-
|
248
|
-
# load additional baseline annotations
|
249
|
-
if config.use_additional_baseline_annotation:
|
250
|
-
print('Using additional baseline annotations')
|
251
|
-
ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
|
252
|
-
ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
|
253
|
-
n_annot_baseline_additional = len(ref_ld_baseline_additional.columns)
|
254
|
-
logger.info(f'{len(ref_ld_baseline_additional.columns)} additional baseline annotations loaded')
|
255
|
-
# M_annot_baseline_additional = _read_M_v2(ld_file_baseline_additional, n_annot_baseline_additional,
|
256
|
-
# config.not_M_5_50)
|
257
|
-
ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
|
258
|
-
ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
|
259
|
-
del ref_ld_baseline_additional
|
260
|
-
|
261
|
-
# Detect available chunk files
|
262
|
-
if config.ldscore_save_format == 'quick_mode':
|
263
|
-
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
|
264
|
-
total_chunk_number_found = len(s_ldsc.chunk_starts)
|
265
|
-
print(f'Split data into {total_chunk_number_found} chunks')
|
266
|
-
else:
|
267
|
-
all_file = os.listdir(config.ldscore_save_dir)
|
268
|
-
total_chunk_number_found = sum('chunk' in name for name in all_file)
|
269
|
-
print(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
|
270
|
-
|
271
|
-
if config.all_chunk is None:
|
272
|
-
if config.chunk_range is not None:
|
273
|
-
assert config.chunk_range[0] >= 1 and config.chunk_range[
|
274
|
-
1] <= total_chunk_number_found, 'Chunk range out of bound. It should be in [1, all_chunk]'
|
275
|
-
print(
|
276
|
-
f'chunk range provided, using chunked files from {config.chunk_range[0]} to {config.chunk_range[1]}')
|
277
|
-
start_chunk, end_chunk = config.chunk_range
|
278
|
-
else:
|
279
|
-
start_chunk, end_chunk = 1, total_chunk_number_found
|
280
|
-
else:
|
281
|
-
all_chunk = config.all_chunk
|
282
|
-
print(f'using {all_chunk} chunked files by provided argument')
|
283
|
-
print(f'\t')
|
284
|
-
print(f'Input {all_chunk} chunked files')
|
285
|
-
start_chunk, end_chunk = 1, all_chunk
|
286
|
-
|
287
|
-
running_chunk_number = end_chunk - start_chunk + 1
|
288
|
-
|
289
|
-
# Process each chunk
|
290
|
-
output_dict = defaultdict(list)
|
291
|
-
zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
|
292
|
-
if config.ldscore_save_format == 'zarr':
|
293
|
-
assert zarr_path.exists(), f'{zarr_path} not found, which is required for zarr format'
|
294
|
-
zarr_file = zarr.open(str(zarr_path))
|
295
|
-
spots_name = zarr_file.attrs['spot_names']
|
296
|
-
|
297
|
-
for chunk_index in range(start_chunk, end_chunk + 1):
|
298
|
-
if config.ldscore_save_format == 'feather':
|
299
|
-
ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk_from_feather(chunk_index,
|
300
|
-
common_snp_among_all_sumstats_pos,
|
301
|
-
config,
|
302
|
-
)
|
303
|
-
elif config.ldscore_save_format == 'zarr':
|
304
|
-
ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
|
305
|
-
start_spot = (chunk_index - 1) * zarr_file.chunks[1]
|
306
|
-
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
307
|
-
spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
|
308
|
-
elif config.ldscore_save_format == 'quick_mode':
|
309
|
-
ref_ld_spatial, spatial_annotation_cnames = s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
|
310
|
-
else:
|
311
|
-
raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
|
312
|
-
|
313
|
-
# get the x_tot_precomputed matrix by adding baseline and spatial annotation
|
314
|
-
ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
|
315
|
-
# x_tot_precomputed = ref_ld_spatial + ref_ld_baseline_column_sum
|
316
|
-
|
317
|
-
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
318
|
-
|
319
|
-
spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
|
320
|
-
baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
|
321
|
-
w_ld_common_snp = w_ld.astype(np.float32, copy=False)
|
322
|
-
|
323
|
-
# weight the baseline annotation by N
|
324
|
-
baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
|
325
|
-
# append intercept
|
326
|
-
baseline_annotation = append_intercept(baseline_annotation)
|
327
|
-
|
328
|
-
# Run the jackknife
|
329
|
-
Nbar = sumstats.N.mean()
|
330
|
-
chunk_size = spatial_annotation.shape[1]
|
331
|
-
out_chunk = thread_map(jackknife_for_processmap, range(chunk_size),
|
332
|
-
max_workers=config.num_processes,
|
333
|
-
chunksize=10,
|
334
|
-
desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
|
335
|
-
)
|
336
|
-
|
337
|
-
# cache the results
|
338
|
-
out_chunk = pd.DataFrame.from_records(out_chunk,
|
339
|
-
columns=['beta', 'se', ],
|
340
|
-
index=spatial_annotation_cnames)
|
341
|
-
# get the spots with nan
|
342
|
-
nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
|
343
|
-
if len(nan_spots) > 0:
|
344
|
-
logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
|
345
|
-
# drop the nan
|
346
|
-
out_chunk = out_chunk.dropna()
|
347
|
-
|
348
|
-
out_chunk['z'] = out_chunk.beta / out_chunk.se
|
349
|
-
out_chunk['p'] = norm.sf(out_chunk['z'])
|
350
|
-
output_dict[trait_name].append(out_chunk)
|
351
|
-
|
352
|
-
del ref_ld_spatial, spatial_annotation, baseline_annotation, w_ld_common_snp
|
353
|
-
gc.collect()
|
354
|
-
|
355
|
-
# Save the results
|
356
|
-
out_dir = config.ldsc_save_dir
|
357
|
-
for trait_name, out_chunk_list in output_dict.items():
|
358
|
-
out_all = pd.concat(out_chunk_list, axis=0)
|
359
|
-
if running_chunk_number == total_chunk_number_found:
|
360
|
-
out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
|
361
|
-
else:
|
362
|
-
out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
|
363
|
-
out_all['spot'] = out_all.index
|
364
|
-
out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
|
365
|
-
out_all.to_csv(out_file_name, compression='gzip', index=False)
|
366
|
-
|
367
|
-
logger.info(f'Output saved to {out_file_name} for {trait_name}')
|
368
|
-
logger.info(f'------Spatial LDSC for {sample_name} finished!')
|
369
|
-
|
370
|
-
|
371
|
-
def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config, ):
|
372
|
-
# Load the spatial annotations for this chunk
|
373
|
-
sample_name = config.sample_name
|
374
|
-
ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
|
375
|
-
ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
|
376
|
-
ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
|
377
|
-
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
378
|
-
|
379
|
-
spatial_annotation_cnames = ref_ld_spatial.columns
|
380
|
-
return ref_ld_spatial.values, spatial_annotation_cnames
|
1
|
+
import gc
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from collections import defaultdict
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import anndata as ad
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
import zarr
|
11
|
+
from scipy.stats import norm
|
12
|
+
from tqdm.contrib.concurrent import thread_map
|
13
|
+
|
14
|
+
import gsMap.utils.jackknife as jk
|
15
|
+
from gsMap.config import SpatialLDSCConfig
|
16
|
+
from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
|
17
|
+
|
18
|
+
logger = logging.getLogger('gsMap.spatial_ldsc')
|
19
|
+
|
20
|
+
|
21
|
+
# %%
|
22
|
+
def _coef_new(jknife):
|
23
|
+
est_ = jknife.jknife_est[0, 0] / Nbar
|
24
|
+
se_ = jknife.jknife_se[0, 0] / Nbar
|
25
|
+
return est_, se_
|
26
|
+
|
27
|
+
|
28
|
+
def append_intercept(x):
|
29
|
+
n_row = x.shape[0]
|
30
|
+
intercept = np.ones((n_row, 1))
|
31
|
+
x_new = np.concatenate((x, intercept), axis=1)
|
32
|
+
return x_new
|
33
|
+
|
34
|
+
|
35
|
+
def filter_sumstats_by_chisq(sumstats, chisq_max):
|
36
|
+
before_len = len(sumstats)
|
37
|
+
if chisq_max is None:
|
38
|
+
chisq_max = max(0.001 * sumstats.N.max(), 80)
|
39
|
+
logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
|
40
|
+
sumstats['chisq'] = sumstats.Z ** 2
|
41
|
+
sumstats = sumstats[sumstats.chisq < chisq_max]
|
42
|
+
after_len = len(sumstats)
|
43
|
+
if after_len < before_len:
|
44
|
+
logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
|
45
|
+
else:
|
46
|
+
logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
|
47
|
+
return sumstats
|
48
|
+
|
49
|
+
|
50
|
+
def aggregate(y, x, N, M, intercept=1):
|
51
|
+
num = M * (np.mean(y) - intercept)
|
52
|
+
denom = np.mean(np.multiply(x, N))
|
53
|
+
return num / denom
|
54
|
+
|
55
|
+
|
56
|
+
def weights(ld, w_ld, N, M, hsq, intercept=1):
|
57
|
+
M = float(M)
|
58
|
+
hsq = np.clip(hsq, 0.0, 1.0)
|
59
|
+
ld = np.maximum(ld, 1.0)
|
60
|
+
w_ld = np.maximum(w_ld, 1.0)
|
61
|
+
c = hsq * N / M
|
62
|
+
het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
|
63
|
+
oc_w = 1.0 / w_ld
|
64
|
+
w = np.multiply(het_w, oc_w)
|
65
|
+
return w
|
66
|
+
|
67
|
+
|
68
|
+
def jackknife_for_processmap(spot_id):
|
69
|
+
# calculate the initial weight for each spot
|
70
|
+
spot_spatial_annotation = spatial_annotation[:, spot_id]
|
71
|
+
spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
|
72
|
+
initial_w = (
|
73
|
+
get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
|
74
|
+
M_tot=10000, w_ld=w_ld_common_snp, intercept=1)
|
75
|
+
.astype(np.float32)
|
76
|
+
.reshape((-1, 1)))
|
77
|
+
|
78
|
+
# apply the weight to baseline annotation, spatial annotation and CHISQ
|
79
|
+
initial_w_scaled = initial_w / np.sum(initial_w)
|
80
|
+
baseline_annotation_spot = baseline_annotation * initial_w_scaled
|
81
|
+
spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
|
82
|
+
CHISQ = sumstats.chisq.values.reshape((-1, 1))
|
83
|
+
y = CHISQ * initial_w_scaled
|
84
|
+
|
85
|
+
# run the jackknife
|
86
|
+
x_focal = np.concatenate((spatial_annotation_spot,
|
87
|
+
baseline_annotation_spot), axis=1)
|
88
|
+
try:
|
89
|
+
jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
|
90
|
+
# LinAlgError
|
91
|
+
except np.linalg.LinAlgError as e:
|
92
|
+
logger.warning(f'LinAlgError: {e}')
|
93
|
+
return np.nan, np.nan
|
94
|
+
return _coef_new(jknife)
|
95
|
+
|
96
|
+
|
97
|
+
# Updated function
|
98
|
+
def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
|
99
|
+
tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
|
100
|
+
initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
|
101
|
+
initial_w = np.sqrt(initial_w)
|
102
|
+
return initial_w
|
103
|
+
|
104
|
+
|
105
|
+
def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
|
106
|
+
# Load the gwas summary statistics
|
107
|
+
sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
|
108
|
+
sumstats.set_index('SNP', inplace=True)
|
109
|
+
sumstats = sumstats.astype(np.float32)
|
110
|
+
sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
|
111
|
+
|
112
|
+
# NB: The intersection order is essential for keeping the same order of SNPs by its BP location
|
113
|
+
common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
|
114
|
+
if len(common_snp) < 200000:
|
115
|
+
logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
|
116
|
+
|
117
|
+
sumstats = sumstats.loc[common_snp]
|
118
|
+
|
119
|
+
# get the common index position of baseline_and_w_ld_common_snp for quick access
|
120
|
+
sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
|
121
|
+
return sumstats
|
122
|
+
|
123
|
+
|
124
|
+
def _get_sumstats_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
|
125
|
+
chisq_max=None):
|
126
|
+
# first validate if all sumstats file exists
|
127
|
+
logger.info('Validating sumstats files...')
|
128
|
+
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
129
|
+
if not os.path.exists(sumstat_file_path):
|
130
|
+
raise FileNotFoundError(f'{sumstat_file_path} not found')
|
131
|
+
# then load all sumstats
|
132
|
+
sumstats_cleaned_dict = {}
|
133
|
+
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
134
|
+
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
|
135
|
+
baseline_and_w_ld_common_snp, chisq_max)
|
136
|
+
logger.info('cleaned sumstats loaded')
|
137
|
+
return sumstats_cleaned_dict
|
138
|
+
|
139
|
+
|
140
|
+
class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
|
141
|
+
def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
|
142
|
+
self.config = config
|
143
|
+
mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
|
144
|
+
mk_score_genes = mk_score.index
|
145
|
+
|
146
|
+
snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
|
147
|
+
common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
|
148
|
+
common_snps = snp_gene_weight_adata.obs.index
|
149
|
+
# self.snp_gene_weight_adata = snp_gene_weight_adata[common_snp_among_all_sumstats:, common_genes.to_list()]
|
150
|
+
self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
|
151
|
+
self.mk_score_common = mk_score.loc[common_genes]
|
152
|
+
|
153
|
+
# calculate the chunk number
|
154
|
+
self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
|
155
|
+
|
156
|
+
def fetch_ldscore_by_chunk(self, chunk_index):
|
157
|
+
chunk_start = self.chunk_starts[chunk_index]
|
158
|
+
mk_score_chunk = self.mk_score_common.iloc[:,
|
159
|
+
chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
|
160
|
+
ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
|
161
|
+
mk_score_chunk,
|
162
|
+
drop_dummy_na=False,
|
163
|
+
)
|
164
|
+
|
165
|
+
spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
|
166
|
+
return ldscore_chunk, spots_name
|
167
|
+
|
168
|
+
def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self,
|
169
|
+
mk_score_chunk,
|
170
|
+
drop_dummy_na=True,
|
171
|
+
):
|
172
|
+
|
173
|
+
if drop_dummy_na:
|
174
|
+
ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
|
175
|
+
else:
|
176
|
+
ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
|
177
|
+
|
178
|
+
return ldscore_chr_chunk
|
179
|
+
|
180
|
+
|
181
|
+
def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
|
182
|
+
chisq_max=None):
|
183
|
+
# first validate if all sumstats file exists
|
184
|
+
logger.info('Validating sumstats files...')
|
185
|
+
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
186
|
+
if not os.path.exists(sumstat_file_path):
|
187
|
+
raise FileNotFoundError(f'{sumstat_file_path} not found')
|
188
|
+
# then load all sumstats
|
189
|
+
sumstats_cleaned_dict = {}
|
190
|
+
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
191
|
+
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
|
192
|
+
baseline_and_w_ld_common_snp, chisq_max)
|
193
|
+
# get the common snps among all sumstats
|
194
|
+
common_snp_among_all_sumstats = None
|
195
|
+
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
196
|
+
if common_snp_among_all_sumstats is None:
|
197
|
+
common_snp_among_all_sumstats = sumstats.index
|
198
|
+
else:
|
199
|
+
common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
|
200
|
+
|
201
|
+
# filter the common snps among all sumstats
|
202
|
+
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
203
|
+
sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
|
204
|
+
|
205
|
+
logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
|
206
|
+
return sumstats_cleaned_dict, common_snp_among_all_sumstats
|
207
|
+
|
208
|
+
|
209
|
+
def run_spatial_ldsc(config: SpatialLDSCConfig):
|
210
|
+
global spatial_annotation, baseline_annotation, n_blocks, Nbar, sumstats, ref_ld_baseline_column_sum, w_ld_common_snp
|
211
|
+
# config
|
212
|
+
n_blocks = config.n_blocks
|
213
|
+
sample_name = config.sample_name
|
214
|
+
|
215
|
+
print(f'------Running Spatial LDSC for {sample_name}...')
|
216
|
+
# Load the regression weights
|
217
|
+
w_ld = _read_w_ld(config.w_file)
|
218
|
+
w_ld_cname = w_ld.columns[1]
|
219
|
+
w_ld.set_index('SNP', inplace=True)
|
220
|
+
|
221
|
+
ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
|
222
|
+
|
223
|
+
ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
|
224
|
+
# n_annot_baseline = len(ref_ld_baseline.columns)
|
225
|
+
# M_annot_baseline = _read_M_v2(ld_file_baseline, n_annot_baseline, config.not_M_5_50)
|
226
|
+
|
227
|
+
# common snp between baseline and w_ld
|
228
|
+
baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
|
229
|
+
baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
|
230
|
+
|
231
|
+
# Clean the sumstats
|
232
|
+
sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
|
233
|
+
config.sumstats_config_dict, baseline_and_w_ld_common_snp,
|
234
|
+
chisq_max=config.chisq_max)
|
235
|
+
common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
|
236
|
+
|
237
|
+
# insure the order is monotonic
|
238
|
+
assert pd.Series(
|
239
|
+
common_snp_among_all_sumstats_pos).is_monotonic_increasing, 'common_snp_among_all_sumstats_pos is not monotonic increasing'
|
240
|
+
|
241
|
+
if len(common_snp_among_all_sumstats) < 200000:
|
242
|
+
logger.warning(
|
243
|
+
f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
|
244
|
+
|
245
|
+
ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
|
246
|
+
w_ld = w_ld.loc[common_snp_among_all_sumstats]
|
247
|
+
|
248
|
+
# load additional baseline annotations
|
249
|
+
if config.use_additional_baseline_annotation:
|
250
|
+
print('Using additional baseline annotations')
|
251
|
+
ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
|
252
|
+
ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
|
253
|
+
n_annot_baseline_additional = len(ref_ld_baseline_additional.columns)
|
254
|
+
logger.info(f'{len(ref_ld_baseline_additional.columns)} additional baseline annotations loaded')
|
255
|
+
# M_annot_baseline_additional = _read_M_v2(ld_file_baseline_additional, n_annot_baseline_additional,
|
256
|
+
# config.not_M_5_50)
|
257
|
+
ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
|
258
|
+
ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
|
259
|
+
del ref_ld_baseline_additional
|
260
|
+
|
261
|
+
# Detect available chunk files
|
262
|
+
if config.ldscore_save_format == 'quick_mode':
|
263
|
+
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
|
264
|
+
total_chunk_number_found = len(s_ldsc.chunk_starts)
|
265
|
+
print(f'Split data into {total_chunk_number_found} chunks')
|
266
|
+
else:
|
267
|
+
all_file = os.listdir(config.ldscore_save_dir)
|
268
|
+
total_chunk_number_found = sum('chunk' in name for name in all_file)
|
269
|
+
print(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
|
270
|
+
|
271
|
+
if config.all_chunk is None:
|
272
|
+
if config.chunk_range is not None:
|
273
|
+
assert config.chunk_range[0] >= 1 and config.chunk_range[
|
274
|
+
1] <= total_chunk_number_found, 'Chunk range out of bound. It should be in [1, all_chunk]'
|
275
|
+
print(
|
276
|
+
f'chunk range provided, using chunked files from {config.chunk_range[0]} to {config.chunk_range[1]}')
|
277
|
+
start_chunk, end_chunk = config.chunk_range
|
278
|
+
else:
|
279
|
+
start_chunk, end_chunk = 1, total_chunk_number_found
|
280
|
+
else:
|
281
|
+
all_chunk = config.all_chunk
|
282
|
+
print(f'using {all_chunk} chunked files by provided argument')
|
283
|
+
print(f'\t')
|
284
|
+
print(f'Input {all_chunk} chunked files')
|
285
|
+
start_chunk, end_chunk = 1, all_chunk
|
286
|
+
|
287
|
+
running_chunk_number = end_chunk - start_chunk + 1
|
288
|
+
|
289
|
+
# Process each chunk
|
290
|
+
output_dict = defaultdict(list)
|
291
|
+
zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
|
292
|
+
if config.ldscore_save_format == 'zarr':
|
293
|
+
assert zarr_path.exists(), f'{zarr_path} not found, which is required for zarr format'
|
294
|
+
zarr_file = zarr.open(str(zarr_path))
|
295
|
+
spots_name = zarr_file.attrs['spot_names']
|
296
|
+
|
297
|
+
for chunk_index in range(start_chunk, end_chunk + 1):
|
298
|
+
if config.ldscore_save_format == 'feather':
|
299
|
+
ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk_from_feather(chunk_index,
|
300
|
+
common_snp_among_all_sumstats_pos,
|
301
|
+
config,
|
302
|
+
)
|
303
|
+
elif config.ldscore_save_format == 'zarr':
|
304
|
+
ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
|
305
|
+
start_spot = (chunk_index - 1) * zarr_file.chunks[1]
|
306
|
+
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
307
|
+
spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
|
308
|
+
elif config.ldscore_save_format == 'quick_mode':
|
309
|
+
ref_ld_spatial, spatial_annotation_cnames = s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
|
310
|
+
else:
|
311
|
+
raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
|
312
|
+
|
313
|
+
# get the x_tot_precomputed matrix by adding baseline and spatial annotation
|
314
|
+
ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
|
315
|
+
# x_tot_precomputed = ref_ld_spatial + ref_ld_baseline_column_sum
|
316
|
+
|
317
|
+
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
318
|
+
|
319
|
+
spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
|
320
|
+
baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
|
321
|
+
w_ld_common_snp = w_ld.astype(np.float32, copy=False)
|
322
|
+
|
323
|
+
# weight the baseline annotation by N
|
324
|
+
baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
|
325
|
+
# append intercept
|
326
|
+
baseline_annotation = append_intercept(baseline_annotation)
|
327
|
+
|
328
|
+
# Run the jackknife
|
329
|
+
Nbar = sumstats.N.mean()
|
330
|
+
chunk_size = spatial_annotation.shape[1]
|
331
|
+
out_chunk = thread_map(jackknife_for_processmap, range(chunk_size),
|
332
|
+
max_workers=config.num_processes,
|
333
|
+
chunksize=10,
|
334
|
+
desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
|
335
|
+
)
|
336
|
+
|
337
|
+
# cache the results
|
338
|
+
out_chunk = pd.DataFrame.from_records(out_chunk,
|
339
|
+
columns=['beta', 'se', ],
|
340
|
+
index=spatial_annotation_cnames)
|
341
|
+
# get the spots with nan
|
342
|
+
nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
|
343
|
+
if len(nan_spots) > 0:
|
344
|
+
logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
|
345
|
+
# drop the nan
|
346
|
+
out_chunk = out_chunk.dropna()
|
347
|
+
|
348
|
+
out_chunk['z'] = out_chunk.beta / out_chunk.se
|
349
|
+
out_chunk['p'] = norm.sf(out_chunk['z'])
|
350
|
+
output_dict[trait_name].append(out_chunk)
|
351
|
+
|
352
|
+
del ref_ld_spatial, spatial_annotation, baseline_annotation, w_ld_common_snp
|
353
|
+
gc.collect()
|
354
|
+
|
355
|
+
# Save the results
|
356
|
+
out_dir = config.ldsc_save_dir
|
357
|
+
for trait_name, out_chunk_list in output_dict.items():
|
358
|
+
out_all = pd.concat(out_chunk_list, axis=0)
|
359
|
+
if running_chunk_number == total_chunk_number_found:
|
360
|
+
out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
|
361
|
+
else:
|
362
|
+
out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
|
363
|
+
out_all['spot'] = out_all.index
|
364
|
+
out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
|
365
|
+
out_all.to_csv(out_file_name, compression='gzip', index=False)
|
366
|
+
|
367
|
+
logger.info(f'Output saved to {out_file_name} for {trait_name}')
|
368
|
+
logger.info(f'------Spatial LDSC for {sample_name} finished!')
|
369
|
+
|
370
|
+
|
371
|
+
def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config, ):
|
372
|
+
# Load the spatial annotations for this chunk
|
373
|
+
sample_name = config.sample_name
|
374
|
+
ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
|
375
|
+
ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
|
376
|
+
ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
|
377
|
+
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
378
|
+
|
379
|
+
spatial_annotation_cnames = ref_ld_spatial.columns
|
380
|
+
return ref_ld_spatial.values, spatial_annotation_cnames
|