gsMap 1.71.1__py3-none-any.whl → 1.71.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/__init__.py +0 -0
- gsMap/GNN/adjacency_matrix.py +75 -75
- gsMap/GNN/model.py +90 -90
- gsMap/GNN/train.py +0 -0
- gsMap/__init__.py +5 -5
- gsMap/__main__.py +2 -2
- gsMap/cauchy_combination_test.py +141 -141
- gsMap/config.py +806 -805
- gsMap/diagnosis.py +274 -273
- gsMap/find_latent_representation.py +139 -133
- gsMap/format_sumstats.py +407 -407
- gsMap/generate_ldscore.py +618 -618
- gsMap/latent_to_gene.py +252 -234
- gsMap/main.py +31 -31
- gsMap/report.py +160 -160
- gsMap/run_all_mode.py +194 -194
- gsMap/setup.py +0 -0
- gsMap/spatial_ldsc_multiple_sumstats.py +360 -380
- gsMap/templates/report_template.html +198 -198
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +735 -735
- gsMap/utils/jackknife.py +514 -514
- gsMap/utils/make_annotations.py +518 -518
- gsMap/utils/manhattan_plot.py +639 -639
- gsMap/utils/regression_read.py +294 -294
- gsMap/visualize.py +198 -198
- {gsmap-1.71.1.dist-info → gsmap-1.71.2.dist-info}/LICENSE +21 -21
- {gsmap-1.71.1.dist-info → gsmap-1.71.2.dist-info}/METADATA +3 -3
- gsmap-1.71.2.dist-info/RECORD +31 -0
- {gsmap-1.71.1.dist-info → gsmap-1.71.2.dist-info}/WHEEL +1 -1
- gsmap-1.71.1.dist-info/RECORD +0 -31
- {gsmap-1.71.1.dist-info → gsmap-1.71.2.dist-info}/entry_points.txt +0 -0
@@ -1,380 +1,360 @@
|
|
1
|
-
import gc
|
2
|
-
import logging
|
3
|
-
import os
|
4
|
-
from collections import defaultdict
|
5
|
-
from
|
6
|
-
|
7
|
-
|
8
|
-
import
|
9
|
-
import
|
10
|
-
import
|
11
|
-
|
12
|
-
from
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
from gsMap.
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
def _coef_new(jknife):
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
intercept
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
sumstats =
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
sumstats
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
#
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
#
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
f'chunk
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
out_all =
|
359
|
-
|
360
|
-
|
361
|
-
else:
|
362
|
-
out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
|
363
|
-
out_all['spot'] = out_all.index
|
364
|
-
out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
|
365
|
-
out_all.to_csv(out_file_name, compression='gzip', index=False)
|
366
|
-
|
367
|
-
logger.info(f'Output saved to {out_file_name} for {trait_name}')
|
368
|
-
logger.info(f'------Spatial LDSC for {sample_name} finished!')
|
369
|
-
|
370
|
-
|
371
|
-
def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config, ):
|
372
|
-
# Load the spatial annotations for this chunk
|
373
|
-
sample_name = config.sample_name
|
374
|
-
ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
|
375
|
-
ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
|
376
|
-
ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
|
377
|
-
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
378
|
-
|
379
|
-
spatial_annotation_cnames = ref_ld_spatial.columns
|
380
|
-
return ref_ld_spatial.values, spatial_annotation_cnames
|
1
|
+
import gc
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from collections import defaultdict
|
5
|
+
from functools import partial
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import anndata as ad
|
9
|
+
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
import zarr
|
12
|
+
from scipy.stats import norm
|
13
|
+
from tqdm.contrib.concurrent import thread_map
|
14
|
+
|
15
|
+
import gsMap.utils.jackknife as jk
|
16
|
+
from gsMap.config import SpatialLDSCConfig
|
17
|
+
from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
|
18
|
+
|
19
|
+
logger = logging.getLogger('gsMap.spatial_ldsc')
|
20
|
+
|
21
|
+
|
22
|
+
def _coef_new(jknife, Nbar):
|
23
|
+
"""Calculate coefficients adjusted by Nbar."""
|
24
|
+
est_ = jknife.jknife_est[0, 0] / Nbar
|
25
|
+
se_ = jknife.jknife_se[0, 0] / Nbar
|
26
|
+
return est_, se_
|
27
|
+
|
28
|
+
|
29
|
+
def append_intercept(x):
|
30
|
+
"""Append an intercept term to the design matrix."""
|
31
|
+
n_row = x.shape[0]
|
32
|
+
intercept = np.ones((n_row, 1))
|
33
|
+
x_new = np.concatenate((x, intercept), axis=1)
|
34
|
+
return x_new
|
35
|
+
|
36
|
+
|
37
|
+
def filter_sumstats_by_chisq(sumstats, chisq_max):
|
38
|
+
"""Filter summary statistics based on chi-squared threshold."""
|
39
|
+
before_len = len(sumstats)
|
40
|
+
if chisq_max is None:
|
41
|
+
chisq_max = max(0.001 * sumstats.N.max(), 80)
|
42
|
+
logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
|
43
|
+
sumstats['chisq'] = sumstats.Z ** 2
|
44
|
+
sumstats = sumstats[sumstats.chisq < chisq_max]
|
45
|
+
after_len = len(sumstats)
|
46
|
+
if after_len < before_len:
|
47
|
+
logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
|
48
|
+
else:
|
49
|
+
logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
|
50
|
+
return sumstats
|
51
|
+
|
52
|
+
|
53
|
+
def aggregate(y, x, N, M, intercept=1):
|
54
|
+
"""Aggregate function used in weight calculation."""
|
55
|
+
num = M * (np.mean(y) - intercept)
|
56
|
+
denom = np.mean(np.multiply(x, N))
|
57
|
+
return num / denom
|
58
|
+
|
59
|
+
|
60
|
+
def weights(ld, w_ld, N, M, hsq, intercept=1):
|
61
|
+
"""Calculate weights for regression."""
|
62
|
+
M = float(M)
|
63
|
+
hsq = np.clip(hsq, 0.0, 1.0)
|
64
|
+
ld = np.maximum(ld, 1.0)
|
65
|
+
w_ld = np.maximum(w_ld, 1.0)
|
66
|
+
c = hsq * N / M
|
67
|
+
het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
|
68
|
+
oc_w = 1.0 / w_ld
|
69
|
+
w = np.multiply(het_w, oc_w)
|
70
|
+
return w
|
71
|
+
|
72
|
+
|
73
|
+
def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
|
74
|
+
"""Optimized function to calculate initial weights."""
|
75
|
+
tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
|
76
|
+
initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
|
77
|
+
initial_w = np.sqrt(initial_w)
|
78
|
+
return initial_w
|
79
|
+
|
80
|
+
|
81
|
+
def jackknife_for_processmap(spot_id, spatial_annotation, ref_ld_baseline_column_sum, sumstats, baseline_annotation, w_ld_common_snp, Nbar, n_blocks):
|
82
|
+
"""Perform jackknife resampling for a given spot."""
|
83
|
+
spot_spatial_annotation = spatial_annotation[:, spot_id]
|
84
|
+
spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
|
85
|
+
initial_w = get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
|
86
|
+
M_tot=10000, w_ld=w_ld_common_snp, intercept=1).astype(np.float32).reshape((-1, 1))
|
87
|
+
initial_w_scaled = initial_w / np.sum(initial_w)
|
88
|
+
baseline_annotation_spot = baseline_annotation * initial_w_scaled
|
89
|
+
spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
|
90
|
+
CHISQ = sumstats.chisq.values.reshape((-1, 1))
|
91
|
+
y = CHISQ * initial_w_scaled
|
92
|
+
x_focal = np.concatenate((spatial_annotation_spot, baseline_annotation_spot), axis=1)
|
93
|
+
try:
|
94
|
+
jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
|
95
|
+
except np.linalg.LinAlgError as e:
|
96
|
+
logger.warning(f'LinAlgError: {e}')
|
97
|
+
return np.nan, np.nan
|
98
|
+
return _coef_new(jknife, Nbar)
|
99
|
+
|
100
|
+
|
101
|
+
def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
|
102
|
+
"""Preprocess summary statistics."""
|
103
|
+
sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
|
104
|
+
sumstats.set_index('SNP', inplace=True)
|
105
|
+
sumstats = sumstats.astype(np.float32)
|
106
|
+
sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
|
107
|
+
common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
|
108
|
+
if len(common_snp) < 200000:
|
109
|
+
logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
|
110
|
+
sumstats = sumstats.loc[common_snp]
|
111
|
+
sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
|
112
|
+
return sumstats
|
113
|
+
|
114
|
+
|
115
|
+
def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
|
116
|
+
"""Get summary statistics with common SNPs among all traits."""
|
117
|
+
logger.info('Validating sumstats files...')
|
118
|
+
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
119
|
+
if not os.path.exists(sumstat_file_path):
|
120
|
+
raise FileNotFoundError(f'{sumstat_file_path} not found')
|
121
|
+
sumstats_cleaned_dict = {}
|
122
|
+
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
123
|
+
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp, chisq_max)
|
124
|
+
common_snp_among_all_sumstats = None
|
125
|
+
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
126
|
+
if common_snp_among_all_sumstats is None:
|
127
|
+
common_snp_among_all_sumstats = sumstats.index
|
128
|
+
else:
|
129
|
+
common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
|
130
|
+
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
131
|
+
sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
|
132
|
+
logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
|
133
|
+
return sumstats_cleaned_dict, common_snp_among_all_sumstats
|
134
|
+
|
135
|
+
|
136
|
+
class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
|
137
|
+
"""Class to handle pre-calculated SNP-Gene weight matrix for quick mode."""
|
138
|
+
def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
|
139
|
+
self.config = config
|
140
|
+
mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
|
141
|
+
mk_score_genes = mk_score.index
|
142
|
+
snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
|
143
|
+
common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
|
144
|
+
common_snps = snp_gene_weight_adata.obs.index
|
145
|
+
self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
|
146
|
+
self.mk_score_common = mk_score.loc[common_genes]
|
147
|
+
self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
|
148
|
+
|
149
|
+
def fetch_ldscore_by_chunk(self, chunk_index):
|
150
|
+
"""Fetch LD score by chunk."""
|
151
|
+
chunk_start = self.chunk_starts[chunk_index]
|
152
|
+
mk_score_chunk = self.mk_score_common.iloc[:, chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
|
153
|
+
ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(mk_score_chunk, drop_dummy_na=False)
|
154
|
+
spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
|
155
|
+
return ldscore_chunk, spots_name
|
156
|
+
|
157
|
+
def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self, mk_score_chunk, drop_dummy_na=True):
|
158
|
+
"""Calculate LD score using SNP-Gene weight matrix by chunk."""
|
159
|
+
if drop_dummy_na:
|
160
|
+
ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
|
161
|
+
else:
|
162
|
+
ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
|
163
|
+
return ldscore_chr_chunk
|
164
|
+
|
165
|
+
|
166
|
+
def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config):
|
167
|
+
"""Load LD score chunk from feather format."""
|
168
|
+
sample_name = config.sample_name
|
169
|
+
ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
|
170
|
+
ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
|
171
|
+
ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
|
172
|
+
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
173
|
+
spatial_annotation_cnames = ref_ld_spatial.columns
|
174
|
+
return ref_ld_spatial.values, spatial_annotation_cnames
|
175
|
+
|
176
|
+
|
177
|
+
def run_spatial_ldsc(config: SpatialLDSCConfig):
|
178
|
+
"""Run spatial LDSC analysis."""
|
179
|
+
logger.info(f'------Running Spatial LDSC for {config.sample_name}...')
|
180
|
+
n_blocks = config.n_blocks
|
181
|
+
sample_name = config.sample_name
|
182
|
+
|
183
|
+
# Load regression weights
|
184
|
+
w_ld = _read_w_ld(config.w_file)
|
185
|
+
w_ld_cname = w_ld.columns[1]
|
186
|
+
w_ld.set_index('SNP', inplace=True)
|
187
|
+
|
188
|
+
ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
|
189
|
+
ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
|
190
|
+
baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
|
191
|
+
baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
|
192
|
+
|
193
|
+
sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
|
194
|
+
config.sumstats_config_dict, baseline_and_w_ld_common_snp, chisq_max=config.chisq_max)
|
195
|
+
common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
|
196
|
+
|
197
|
+
if not pd.Series(common_snp_among_all_sumstats_pos).is_monotonic_increasing:
|
198
|
+
raise ValueError('common_snp_among_all_sumstats_pos is not monotonic increasing')
|
199
|
+
|
200
|
+
if len(common_snp_among_all_sumstats) < 200000:
|
201
|
+
logger.warning(
|
202
|
+
f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
|
203
|
+
|
204
|
+
ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
|
205
|
+
w_ld = w_ld.loc[common_snp_among_all_sumstats]
|
206
|
+
|
207
|
+
# Load additional baseline annotations if needed
|
208
|
+
if config.use_additional_baseline_annotation:
|
209
|
+
logger.info('Using additional baseline annotations')
|
210
|
+
ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
|
211
|
+
ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
|
212
|
+
ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
|
213
|
+
ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
|
214
|
+
del ref_ld_baseline_additional
|
215
|
+
|
216
|
+
# Initialize s_ldsc once if quick_mode
|
217
|
+
s_ldsc = None
|
218
|
+
if config.ldscore_save_format == 'quick_mode':
|
219
|
+
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
|
220
|
+
total_chunk_number_found = len(s_ldsc.chunk_starts)
|
221
|
+
logger.info(f'Split data into {total_chunk_number_found} chunks')
|
222
|
+
else:
|
223
|
+
total_chunk_number_found = determine_total_chunks(config)
|
224
|
+
|
225
|
+
start_chunk, end_chunk = determine_chunk_range(config, total_chunk_number_found)
|
226
|
+
running_chunk_number = end_chunk - start_chunk + 1
|
227
|
+
|
228
|
+
# Load zarr file if needed
|
229
|
+
zarr_file, spots_name = None, None
|
230
|
+
if config.ldscore_save_format == 'zarr':
|
231
|
+
zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
|
232
|
+
if not zarr_path.exists():
|
233
|
+
raise FileNotFoundError(f'{zarr_path} not found, which is required for zarr format')
|
234
|
+
zarr_file = zarr.open(str(zarr_path))
|
235
|
+
spots_name = zarr_file.attrs['spot_names']
|
236
|
+
|
237
|
+
output_dict = defaultdict(list)
|
238
|
+
for chunk_index in range(start_chunk, end_chunk + 1):
|
239
|
+
ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk(
|
240
|
+
chunk_index,
|
241
|
+
common_snp_among_all_sumstats_pos,
|
242
|
+
config,
|
243
|
+
zarr_file,
|
244
|
+
spots_name,
|
245
|
+
s_ldsc # Pass s_ldsc to the function
|
246
|
+
)
|
247
|
+
ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
|
248
|
+
|
249
|
+
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
250
|
+
spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
|
251
|
+
baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
|
252
|
+
w_ld_common_snp = w_ld.astype(np.float32, copy=False)
|
253
|
+
|
254
|
+
baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
|
255
|
+
baseline_annotation = append_intercept(baseline_annotation)
|
256
|
+
|
257
|
+
Nbar = sumstats.N.mean()
|
258
|
+
chunk_size = spatial_annotation.shape[1]
|
259
|
+
|
260
|
+
jackknife_func = partial(
|
261
|
+
jackknife_for_processmap,
|
262
|
+
spatial_annotation=spatial_annotation,
|
263
|
+
ref_ld_baseline_column_sum=ref_ld_baseline_column_sum,
|
264
|
+
sumstats=sumstats,
|
265
|
+
baseline_annotation=baseline_annotation,
|
266
|
+
w_ld_common_snp=w_ld_common_snp,
|
267
|
+
Nbar=Nbar,
|
268
|
+
n_blocks=n_blocks
|
269
|
+
)
|
270
|
+
|
271
|
+
out_chunk = thread_map(
|
272
|
+
jackknife_func,
|
273
|
+
range(chunk_size),
|
274
|
+
max_workers=config.num_processes,
|
275
|
+
chunksize=10,
|
276
|
+
desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
|
277
|
+
)
|
278
|
+
|
279
|
+
out_chunk = pd.DataFrame.from_records(out_chunk, columns=['beta', 'se'], index=spatial_annotation_cnames)
|
280
|
+
nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
|
281
|
+
if len(nan_spots) > 0:
|
282
|
+
logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
|
283
|
+
out_chunk = out_chunk.dropna()
|
284
|
+
out_chunk['z'] = out_chunk.beta / out_chunk.se
|
285
|
+
out_chunk['p'] = norm.sf(out_chunk['z'])
|
286
|
+
output_dict[trait_name].append(out_chunk)
|
287
|
+
|
288
|
+
del spatial_annotation, baseline_annotation, w_ld_common_snp
|
289
|
+
gc.collect()
|
290
|
+
|
291
|
+
save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk)
|
292
|
+
logger.info(f'------Spatial LDSC for {sample_name} finished!')
|
293
|
+
|
294
|
+
|
295
|
+
def determine_total_chunks(config):
|
296
|
+
"""Determine total number of chunks based on the ldscore save format."""
|
297
|
+
if config.ldscore_save_format == 'quick_mode':
|
298
|
+
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, [])
|
299
|
+
total_chunk_number_found = len(s_ldsc.chunk_starts)
|
300
|
+
logger.info(f'Split data into {total_chunk_number_found} chunks')
|
301
|
+
else:
|
302
|
+
all_file = os.listdir(config.ldscore_save_dir)
|
303
|
+
total_chunk_number_found = sum('chunk' in name for name in all_file)
|
304
|
+
logger.info(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
|
305
|
+
return total_chunk_number_found
|
306
|
+
|
307
|
+
|
308
|
+
def determine_chunk_range(config, total_chunk_number_found):
|
309
|
+
"""Determine the range of chunks to process."""
|
310
|
+
if config.all_chunk is None:
|
311
|
+
if config.chunk_range is not None:
|
312
|
+
if not (1 <= config.chunk_range[0] <= total_chunk_number_found) or not (1 <= config.chunk_range[1] <= total_chunk_number_found):
|
313
|
+
raise ValueError('Chunk range out of bound. It should be in [1, all_chunk]')
|
314
|
+
start_chunk, end_chunk = config.chunk_range
|
315
|
+
logger.info(f'Chunk range provided, using chunked files from {start_chunk} to {end_chunk}')
|
316
|
+
else:
|
317
|
+
start_chunk, end_chunk = 1, total_chunk_number_found
|
318
|
+
else:
|
319
|
+
all_chunk = config.all_chunk
|
320
|
+
logger.info(f'Using {all_chunk} chunked files by provided argument')
|
321
|
+
start_chunk, end_chunk = 1, all_chunk
|
322
|
+
return start_chunk, end_chunk
|
323
|
+
|
324
|
+
|
325
|
+
def load_ldscore_chunk(chunk_index, common_snp_among_all_sumstats_pos, config, zarr_file=None, spots_name=None, s_ldsc=None):
|
326
|
+
"""Load LD score chunk based on save format."""
|
327
|
+
if config.ldscore_save_format == 'feather':
|
328
|
+
return load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config)
|
329
|
+
elif config.ldscore_save_format == 'zarr':
|
330
|
+
ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
|
331
|
+
start_spot = (chunk_index - 1) * zarr_file.chunks[1]
|
332
|
+
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
333
|
+
spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
|
334
|
+
return ref_ld_spatial, spatial_annotation_cnames
|
335
|
+
elif config.ldscore_save_format == 'quick_mode':
|
336
|
+
# Use the pre-initialized s_ldsc
|
337
|
+
if s_ldsc is None:
|
338
|
+
raise ValueError("s_ldsc must be provided in quick_mode")
|
339
|
+
return s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
|
340
|
+
else:
|
341
|
+
raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
|
342
|
+
|
343
|
+
|
344
|
+
def save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk):
|
345
|
+
"""Save the results to the specified directory."""
|
346
|
+
out_dir = config.ldsc_save_dir
|
347
|
+
for trait_name, out_chunk_list in output_dict.items():
|
348
|
+
out_all = pd.concat(out_chunk_list, axis=0)
|
349
|
+
sample_name = config.sample_name
|
350
|
+
if running_chunk_number == end_chunk - start_chunk + 1:
|
351
|
+
out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
|
352
|
+
else:
|
353
|
+
out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
|
354
|
+
out_all['spot'] = out_all.index
|
355
|
+
out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
|
356
|
+
|
357
|
+
# clip the p-values
|
358
|
+
out_all['p'] = out_all['p'].clip(1e-300, 1)
|
359
|
+
out_all.to_csv(out_file_name, compression='gzip', index=False)
|
360
|
+
logger.info(f'Output saved to {out_file_name} for {trait_name}')
|