gsMap 1.67__py3-none-any.whl → 1.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,380 +1,380 @@
1
- import gc
2
- import logging
3
- import os
4
- from collections import defaultdict
5
- from pathlib import Path
6
-
7
- import anndata as ad
8
- import numpy as np
9
- import pandas as pd
10
- import zarr
11
- from scipy.stats import norm
12
- from tqdm.contrib.concurrent import thread_map
13
-
14
- import gsMap.utils.jackknife as jk
15
- from gsMap.config import SpatialLDSCConfig
16
- from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
17
-
18
- logger = logging.getLogger('gsMap.spatial_ldsc')
19
-
20
-
21
- # %%
22
- def _coef_new(jknife):
23
- est_ = jknife.jknife_est[0, 0] / Nbar
24
- se_ = jknife.jknife_se[0, 0] / Nbar
25
- return est_, se_
26
-
27
-
28
- def append_intercept(x):
29
- n_row = x.shape[0]
30
- intercept = np.ones((n_row, 1))
31
- x_new = np.concatenate((x, intercept), axis=1)
32
- return x_new
33
-
34
-
35
- def filter_sumstats_by_chisq(sumstats, chisq_max):
36
- before_len = len(sumstats)
37
- if chisq_max is None:
38
- chisq_max = max(0.001 * sumstats.N.max(), 80)
39
- logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
40
- sumstats['chisq'] = sumstats.Z ** 2
41
- sumstats = sumstats[sumstats.chisq < chisq_max]
42
- after_len = len(sumstats)
43
- if after_len < before_len:
44
- logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
45
- else:
46
- logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
47
- return sumstats
48
-
49
-
50
- def aggregate(y, x, N, M, intercept=1):
51
- num = M * (np.mean(y) - intercept)
52
- denom = np.mean(np.multiply(x, N))
53
- return num / denom
54
-
55
-
56
- def weights(ld, w_ld, N, M, hsq, intercept=1):
57
- M = float(M)
58
- hsq = np.clip(hsq, 0.0, 1.0)
59
- ld = np.maximum(ld, 1.0)
60
- w_ld = np.maximum(w_ld, 1.0)
61
- c = hsq * N / M
62
- het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
63
- oc_w = 1.0 / w_ld
64
- w = np.multiply(het_w, oc_w)
65
- return w
66
-
67
-
68
- def jackknife_for_processmap(spot_id):
69
- # calculate the initial weight for each spot
70
- spot_spatial_annotation = spatial_annotation[:, spot_id]
71
- spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
72
- initial_w = (
73
- get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
74
- M_tot=10000, w_ld=w_ld_common_snp, intercept=1)
75
- .astype(np.float32)
76
- .reshape((-1, 1)))
77
-
78
- # apply the weight to baseline annotation, spatial annotation and CHISQ
79
- initial_w_scaled = initial_w / np.sum(initial_w)
80
- baseline_annotation_spot = baseline_annotation * initial_w_scaled
81
- spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
82
- CHISQ = sumstats.chisq.values.reshape((-1, 1))
83
- y = CHISQ * initial_w_scaled
84
-
85
- # run the jackknife
86
- x_focal = np.concatenate((spatial_annotation_spot,
87
- baseline_annotation_spot), axis=1)
88
- try:
89
- jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
90
- # LinAlgError
91
- except np.linalg.LinAlgError as e:
92
- logger.warning(f'LinAlgError: {e}')
93
- return np.nan, np.nan
94
- return _coef_new(jknife)
95
-
96
-
97
- # Updated function
98
- def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
99
- tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
100
- initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
101
- initial_w = np.sqrt(initial_w)
102
- return initial_w
103
-
104
-
105
- def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
106
- # Load the gwas summary statistics
107
- sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
108
- sumstats.set_index('SNP', inplace=True)
109
- sumstats = sumstats.astype(np.float32)
110
- sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
111
-
112
- # NB: The intersection order is essential for keeping the same order of SNPs by its BP location
113
- common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
114
- if len(common_snp) < 200000:
115
- logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
116
-
117
- sumstats = sumstats.loc[common_snp]
118
-
119
- # get the common index position of baseline_and_w_ld_common_snp for quick access
120
- sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
121
- return sumstats
122
-
123
-
124
- def _get_sumstats_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
125
- chisq_max=None):
126
- # first validate if all sumstats file exists
127
- logger.info('Validating sumstats files...')
128
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
129
- if not os.path.exists(sumstat_file_path):
130
- raise FileNotFoundError(f'{sumstat_file_path} not found')
131
- # then load all sumstats
132
- sumstats_cleaned_dict = {}
133
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
134
- sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
135
- baseline_and_w_ld_common_snp, chisq_max)
136
- logger.info('cleaned sumstats loaded')
137
- return sumstats_cleaned_dict
138
-
139
-
140
- class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
141
- def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
142
- self.config = config
143
- mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
144
- mk_score_genes = mk_score.index
145
-
146
- snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
147
- common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
148
- common_snps = snp_gene_weight_adata.obs.index
149
- # self.snp_gene_weight_adata = snp_gene_weight_adata[common_snp_among_all_sumstats:, common_genes.to_list()]
150
- self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
151
- self.mk_score_common = mk_score.loc[common_genes]
152
-
153
- # calculate the chunk number
154
- self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
155
-
156
- def fetch_ldscore_by_chunk(self, chunk_index):
157
- chunk_start = self.chunk_starts[chunk_index]
158
- mk_score_chunk = self.mk_score_common.iloc[:,
159
- chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
160
- ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
161
- mk_score_chunk,
162
- drop_dummy_na=False,
163
- )
164
-
165
- spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
166
- return ldscore_chunk, spots_name
167
-
168
- def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self,
169
- mk_score_chunk,
170
- drop_dummy_na=True,
171
- ):
172
-
173
- if drop_dummy_na:
174
- ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
175
- else:
176
- ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
177
-
178
- return ldscore_chr_chunk
179
-
180
-
181
- def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
182
- chisq_max=None):
183
- # first validate if all sumstats file exists
184
- logger.info('Validating sumstats files...')
185
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
186
- if not os.path.exists(sumstat_file_path):
187
- raise FileNotFoundError(f'{sumstat_file_path} not found')
188
- # then load all sumstats
189
- sumstats_cleaned_dict = {}
190
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
191
- sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
192
- baseline_and_w_ld_common_snp, chisq_max)
193
- # get the common snps among all sumstats
194
- common_snp_among_all_sumstats = None
195
- for trait_name, sumstats in sumstats_cleaned_dict.items():
196
- if common_snp_among_all_sumstats is None:
197
- common_snp_among_all_sumstats = sumstats.index
198
- else:
199
- common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
200
-
201
- # filter the common snps among all sumstats
202
- for trait_name, sumstats in sumstats_cleaned_dict.items():
203
- sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
204
-
205
- logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
206
- return sumstats_cleaned_dict, common_snp_among_all_sumstats
207
-
208
-
209
- def run_spatial_ldsc(config: SpatialLDSCConfig):
210
- global spatial_annotation, baseline_annotation, n_blocks, Nbar, sumstats, ref_ld_baseline_column_sum, w_ld_common_snp
211
- # config
212
- n_blocks = config.n_blocks
213
- sample_name = config.sample_name
214
-
215
- print(f'------Running Spatial LDSC for {sample_name}...')
216
- # Load the regression weights
217
- w_ld = _read_w_ld(config.w_file)
218
- w_ld_cname = w_ld.columns[1]
219
- w_ld.set_index('SNP', inplace=True)
220
-
221
- ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
222
-
223
- ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
224
- # n_annot_baseline = len(ref_ld_baseline.columns)
225
- # M_annot_baseline = _read_M_v2(ld_file_baseline, n_annot_baseline, config.not_M_5_50)
226
-
227
- # common snp between baseline and w_ld
228
- baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
229
- baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
230
-
231
- # Clean the sumstats
232
- sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
233
- config.sumstats_config_dict, baseline_and_w_ld_common_snp,
234
- chisq_max=config.chisq_max)
235
- common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
236
-
237
- # insure the order is monotonic
238
- assert pd.Series(
239
- common_snp_among_all_sumstats_pos).is_monotonic_increasing, 'common_snp_among_all_sumstats_pos is not monotonic increasing'
240
-
241
- if len(common_snp_among_all_sumstats) < 200000:
242
- logger.warning(
243
- f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
244
-
245
- ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
246
- w_ld = w_ld.loc[common_snp_among_all_sumstats]
247
-
248
- # load additional baseline annotations
249
- if config.use_additional_baseline_annotation:
250
- print('Using additional baseline annotations')
251
- ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
252
- ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
253
- n_annot_baseline_additional = len(ref_ld_baseline_additional.columns)
254
- logger.info(f'{len(ref_ld_baseline_additional.columns)} additional baseline annotations loaded')
255
- # M_annot_baseline_additional = _read_M_v2(ld_file_baseline_additional, n_annot_baseline_additional,
256
- # config.not_M_5_50)
257
- ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
258
- ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
259
- del ref_ld_baseline_additional
260
-
261
- # Detect available chunk files
262
- if config.ldscore_save_format == 'quick_mode':
263
- s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
264
- total_chunk_number_found = len(s_ldsc.chunk_starts)
265
- print(f'Split data into {total_chunk_number_found} chunks')
266
- else:
267
- all_file = os.listdir(config.ldscore_save_dir)
268
- total_chunk_number_found = sum('chunk' in name for name in all_file)
269
- print(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
270
-
271
- if config.all_chunk is None:
272
- if config.chunk_range is not None:
273
- assert config.chunk_range[0] >= 1 and config.chunk_range[
274
- 1] <= total_chunk_number_found, 'Chunk range out of bound. It should be in [1, all_chunk]'
275
- print(
276
- f'chunk range provided, using chunked files from {config.chunk_range[0]} to {config.chunk_range[1]}')
277
- start_chunk, end_chunk = config.chunk_range
278
- else:
279
- start_chunk, end_chunk = 1, total_chunk_number_found
280
- else:
281
- all_chunk = config.all_chunk
282
- print(f'using {all_chunk} chunked files by provided argument')
283
- print(f'\t')
284
- print(f'Input {all_chunk} chunked files')
285
- start_chunk, end_chunk = 1, all_chunk
286
-
287
- running_chunk_number = end_chunk - start_chunk + 1
288
-
289
- # Process each chunk
290
- output_dict = defaultdict(list)
291
- zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
292
- if config.ldscore_save_format == 'zarr':
293
- assert zarr_path.exists(), f'{zarr_path} not found, which is required for zarr format'
294
- zarr_file = zarr.open(str(zarr_path))
295
- spots_name = zarr_file.attrs['spot_names']
296
-
297
- for chunk_index in range(start_chunk, end_chunk + 1):
298
- if config.ldscore_save_format == 'feather':
299
- ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk_from_feather(chunk_index,
300
- common_snp_among_all_sumstats_pos,
301
- config,
302
- )
303
- elif config.ldscore_save_format == 'zarr':
304
- ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
305
- start_spot = (chunk_index - 1) * zarr_file.chunks[1]
306
- ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
307
- spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
308
- elif config.ldscore_save_format == 'quick_mode':
309
- ref_ld_spatial, spatial_annotation_cnames = s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
310
- else:
311
- raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
312
-
313
- # get the x_tot_precomputed matrix by adding baseline and spatial annotation
314
- ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
315
- # x_tot_precomputed = ref_ld_spatial + ref_ld_baseline_column_sum
316
-
317
- for trait_name, sumstats in sumstats_cleaned_dict.items():
318
-
319
- spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
320
- baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
321
- w_ld_common_snp = w_ld.astype(np.float32, copy=False)
322
-
323
- # weight the baseline annotation by N
324
- baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
325
- # append intercept
326
- baseline_annotation = append_intercept(baseline_annotation)
327
-
328
- # Run the jackknife
329
- Nbar = sumstats.N.mean()
330
- chunk_size = spatial_annotation.shape[1]
331
- out_chunk = thread_map(jackknife_for_processmap, range(chunk_size),
332
- max_workers=config.num_processes,
333
- chunksize=10,
334
- desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
335
- )
336
-
337
- # cache the results
338
- out_chunk = pd.DataFrame.from_records(out_chunk,
339
- columns=['beta', 'se', ],
340
- index=spatial_annotation_cnames)
341
- # get the spots with nan
342
- nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
343
- if len(nan_spots) > 0:
344
- logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
345
- # drop the nan
346
- out_chunk = out_chunk.dropna()
347
-
348
- out_chunk['z'] = out_chunk.beta / out_chunk.se
349
- out_chunk['p'] = norm.sf(out_chunk['z'])
350
- output_dict[trait_name].append(out_chunk)
351
-
352
- del ref_ld_spatial, spatial_annotation, baseline_annotation, w_ld_common_snp
353
- gc.collect()
354
-
355
- # Save the results
356
- out_dir = config.ldsc_save_dir
357
- for trait_name, out_chunk_list in output_dict.items():
358
- out_all = pd.concat(out_chunk_list, axis=0)
359
- if running_chunk_number == total_chunk_number_found:
360
- out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
361
- else:
362
- out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
363
- out_all['spot'] = out_all.index
364
- out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
365
- out_all.to_csv(out_file_name, compression='gzip', index=False)
366
-
367
- logger.info(f'Output saved to {out_file_name} for {trait_name}')
368
- logger.info(f'------Spatial LDSC for {sample_name} finished!')
369
-
370
-
371
- def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config, ):
372
- # Load the spatial annotations for this chunk
373
- sample_name = config.sample_name
374
- ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
375
- ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
376
- ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
377
- ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
378
-
379
- spatial_annotation_cnames = ref_ld_spatial.columns
380
- return ref_ld_spatial.values, spatial_annotation_cnames
1
+ import gc
2
+ import logging
3
+ import os
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+
7
+ import anndata as ad
8
+ import numpy as np
9
+ import pandas as pd
10
+ import zarr
11
+ from scipy.stats import norm
12
+ from tqdm.contrib.concurrent import thread_map
13
+
14
+ import gsMap.utils.jackknife as jk
15
+ from gsMap.config import SpatialLDSCConfig
16
+ from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
17
+
18
+ logger = logging.getLogger('gsMap.spatial_ldsc')
19
+
20
+
21
+ # %%
22
+ def _coef_new(jknife):
23
+ est_ = jknife.jknife_est[0, 0] / Nbar
24
+ se_ = jknife.jknife_se[0, 0] / Nbar
25
+ return est_, se_
26
+
27
+
28
+ def append_intercept(x):
29
+ n_row = x.shape[0]
30
+ intercept = np.ones((n_row, 1))
31
+ x_new = np.concatenate((x, intercept), axis=1)
32
+ return x_new
33
+
34
+
35
+ def filter_sumstats_by_chisq(sumstats, chisq_max):
36
+ before_len = len(sumstats)
37
+ if chisq_max is None:
38
+ chisq_max = max(0.001 * sumstats.N.max(), 80)
39
+ logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
40
+ sumstats['chisq'] = sumstats.Z ** 2
41
+ sumstats = sumstats[sumstats.chisq < chisq_max]
42
+ after_len = len(sumstats)
43
+ if after_len < before_len:
44
+ logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
45
+ else:
46
+ logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
47
+ return sumstats
48
+
49
+
50
+ def aggregate(y, x, N, M, intercept=1):
51
+ num = M * (np.mean(y) - intercept)
52
+ denom = np.mean(np.multiply(x, N))
53
+ return num / denom
54
+
55
+
56
+ def weights(ld, w_ld, N, M, hsq, intercept=1):
57
+ M = float(M)
58
+ hsq = np.clip(hsq, 0.0, 1.0)
59
+ ld = np.maximum(ld, 1.0)
60
+ w_ld = np.maximum(w_ld, 1.0)
61
+ c = hsq * N / M
62
+ het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
63
+ oc_w = 1.0 / w_ld
64
+ w = np.multiply(het_w, oc_w)
65
+ return w
66
+
67
+
68
+ def jackknife_for_processmap(spot_id):
69
+ # calculate the initial weight for each spot
70
+ spot_spatial_annotation = spatial_annotation[:, spot_id]
71
+ spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
72
+ initial_w = (
73
+ get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
74
+ M_tot=10000, w_ld=w_ld_common_snp, intercept=1)
75
+ .astype(np.float32)
76
+ .reshape((-1, 1)))
77
+
78
+ # apply the weight to baseline annotation, spatial annotation and CHISQ
79
+ initial_w_scaled = initial_w / np.sum(initial_w)
80
+ baseline_annotation_spot = baseline_annotation * initial_w_scaled
81
+ spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
82
+ CHISQ = sumstats.chisq.values.reshape((-1, 1))
83
+ y = CHISQ * initial_w_scaled
84
+
85
+ # run the jackknife
86
+ x_focal = np.concatenate((spatial_annotation_spot,
87
+ baseline_annotation_spot), axis=1)
88
+ try:
89
+ jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
90
+ # LinAlgError
91
+ except np.linalg.LinAlgError as e:
92
+ logger.warning(f'LinAlgError: {e}')
93
+ return np.nan, np.nan
94
+ return _coef_new(jknife)
95
+
96
+
97
+ # Updated function
98
+ def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
99
+ tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
100
+ initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
101
+ initial_w = np.sqrt(initial_w)
102
+ return initial_w
103
+
104
+
105
+ def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
106
+ # Load the gwas summary statistics
107
+ sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
108
+ sumstats.set_index('SNP', inplace=True)
109
+ sumstats = sumstats.astype(np.float32)
110
+ sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
111
+
112
+ # NB: The intersection order is essential for keeping the same order of SNPs by its BP location
113
+ common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
114
+ if len(common_snp) < 200000:
115
+ logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
116
+
117
+ sumstats = sumstats.loc[common_snp]
118
+
119
+ # get the common index position of baseline_and_w_ld_common_snp for quick access
120
+ sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
121
+ return sumstats
122
+
123
+
124
+ def _get_sumstats_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
125
+ chisq_max=None):
126
+ # first validate if all sumstats file exists
127
+ logger.info('Validating sumstats files...')
128
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
129
+ if not os.path.exists(sumstat_file_path):
130
+ raise FileNotFoundError(f'{sumstat_file_path} not found')
131
+ # then load all sumstats
132
+ sumstats_cleaned_dict = {}
133
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
134
+ sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
135
+ baseline_and_w_ld_common_snp, chisq_max)
136
+ logger.info('cleaned sumstats loaded')
137
+ return sumstats_cleaned_dict
138
+
139
+
140
+ class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
141
+ def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
142
+ self.config = config
143
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
144
+ mk_score_genes = mk_score.index
145
+
146
+ snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
147
+ common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
148
+ common_snps = snp_gene_weight_adata.obs.index
149
+ # self.snp_gene_weight_adata = snp_gene_weight_adata[common_snp_among_all_sumstats:, common_genes.to_list()]
150
+ self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
151
+ self.mk_score_common = mk_score.loc[common_genes]
152
+
153
+ # calculate the chunk number
154
+ self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
155
+
156
+ def fetch_ldscore_by_chunk(self, chunk_index):
157
+ chunk_start = self.chunk_starts[chunk_index]
158
+ mk_score_chunk = self.mk_score_common.iloc[:,
159
+ chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
160
+ ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
161
+ mk_score_chunk,
162
+ drop_dummy_na=False,
163
+ )
164
+
165
+ spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
166
+ return ldscore_chunk, spots_name
167
+
168
+ def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self,
169
+ mk_score_chunk,
170
+ drop_dummy_na=True,
171
+ ):
172
+
173
+ if drop_dummy_na:
174
+ ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
175
+ else:
176
+ ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
177
+
178
+ return ldscore_chr_chunk
179
+
180
+
181
+ def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
182
+ chisq_max=None):
183
+ # first validate if all sumstats file exists
184
+ logger.info('Validating sumstats files...')
185
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
186
+ if not os.path.exists(sumstat_file_path):
187
+ raise FileNotFoundError(f'{sumstat_file_path} not found')
188
+ # then load all sumstats
189
+ sumstats_cleaned_dict = {}
190
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
191
+ sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
192
+ baseline_and_w_ld_common_snp, chisq_max)
193
+ # get the common snps among all sumstats
194
+ common_snp_among_all_sumstats = None
195
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
196
+ if common_snp_among_all_sumstats is None:
197
+ common_snp_among_all_sumstats = sumstats.index
198
+ else:
199
+ common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
200
+
201
+ # filter the common snps among all sumstats
202
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
203
+ sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
204
+
205
+ logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
206
+ return sumstats_cleaned_dict, common_snp_among_all_sumstats
207
+
208
+
209
+ def run_spatial_ldsc(config: SpatialLDSCConfig):
210
+ global spatial_annotation, baseline_annotation, n_blocks, Nbar, sumstats, ref_ld_baseline_column_sum, w_ld_common_snp
211
+ # config
212
+ n_blocks = config.n_blocks
213
+ sample_name = config.sample_name
214
+
215
+ print(f'------Running Spatial LDSC for {sample_name}...')
216
+ # Load the regression weights
217
+ w_ld = _read_w_ld(config.w_file)
218
+ w_ld_cname = w_ld.columns[1]
219
+ w_ld.set_index('SNP', inplace=True)
220
+
221
+ ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
222
+
223
+ ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
224
+ # n_annot_baseline = len(ref_ld_baseline.columns)
225
+ # M_annot_baseline = _read_M_v2(ld_file_baseline, n_annot_baseline, config.not_M_5_50)
226
+
227
+ # common snp between baseline and w_ld
228
+ baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
229
+ baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
230
+
231
+ # Clean the sumstats
232
+ sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
233
+ config.sumstats_config_dict, baseline_and_w_ld_common_snp,
234
+ chisq_max=config.chisq_max)
235
+ common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
236
+
237
+ # insure the order is monotonic
238
+ assert pd.Series(
239
+ common_snp_among_all_sumstats_pos).is_monotonic_increasing, 'common_snp_among_all_sumstats_pos is not monotonic increasing'
240
+
241
+ if len(common_snp_among_all_sumstats) < 200000:
242
+ logger.warning(
243
+ f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
244
+
245
+ ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
246
+ w_ld = w_ld.loc[common_snp_among_all_sumstats]
247
+
248
+ # load additional baseline annotations
249
+ if config.use_additional_baseline_annotation:
250
+ print('Using additional baseline annotations')
251
+ ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
252
+ ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
253
+ n_annot_baseline_additional = len(ref_ld_baseline_additional.columns)
254
+ logger.info(f'{len(ref_ld_baseline_additional.columns)} additional baseline annotations loaded')
255
+ # M_annot_baseline_additional = _read_M_v2(ld_file_baseline_additional, n_annot_baseline_additional,
256
+ # config.not_M_5_50)
257
+ ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
258
+ ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
259
+ del ref_ld_baseline_additional
260
+
261
+ # Detect available chunk files
262
+ if config.ldscore_save_format == 'quick_mode':
263
+ s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
264
+ total_chunk_number_found = len(s_ldsc.chunk_starts)
265
+ print(f'Split data into {total_chunk_number_found} chunks')
266
+ else:
267
+ all_file = os.listdir(config.ldscore_save_dir)
268
+ total_chunk_number_found = sum('chunk' in name for name in all_file)
269
+ print(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
270
+
271
+ if config.all_chunk is None:
272
+ if config.chunk_range is not None:
273
+ assert config.chunk_range[0] >= 1 and config.chunk_range[
274
+ 1] <= total_chunk_number_found, 'Chunk range out of bound. It should be in [1, all_chunk]'
275
+ print(
276
+ f'chunk range provided, using chunked files from {config.chunk_range[0]} to {config.chunk_range[1]}')
277
+ start_chunk, end_chunk = config.chunk_range
278
+ else:
279
+ start_chunk, end_chunk = 1, total_chunk_number_found
280
+ else:
281
+ all_chunk = config.all_chunk
282
+ print(f'using {all_chunk} chunked files by provided argument')
283
+ print(f'\t')
284
+ print(f'Input {all_chunk} chunked files')
285
+ start_chunk, end_chunk = 1, all_chunk
286
+
287
+ running_chunk_number = end_chunk - start_chunk + 1
288
+
289
+ # Process each chunk
290
+ output_dict = defaultdict(list)
291
+ zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
292
+ if config.ldscore_save_format == 'zarr':
293
+ assert zarr_path.exists(), f'{zarr_path} not found, which is required for zarr format'
294
+ zarr_file = zarr.open(str(zarr_path))
295
+ spots_name = zarr_file.attrs['spot_names']
296
+
297
+ for chunk_index in range(start_chunk, end_chunk + 1):
298
+ if config.ldscore_save_format == 'feather':
299
+ ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk_from_feather(chunk_index,
300
+ common_snp_among_all_sumstats_pos,
301
+ config,
302
+ )
303
+ elif config.ldscore_save_format == 'zarr':
304
+ ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
305
+ start_spot = (chunk_index - 1) * zarr_file.chunks[1]
306
+ ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
307
+ spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
308
+ elif config.ldscore_save_format == 'quick_mode':
309
+ ref_ld_spatial, spatial_annotation_cnames = s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
310
+ else:
311
+ raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
312
+
313
+ # get the x_tot_precomputed matrix by adding baseline and spatial annotation
314
+ ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
315
+ # x_tot_precomputed = ref_ld_spatial + ref_ld_baseline_column_sum
316
+
317
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
318
+
319
+ spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
320
+ baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
321
+ w_ld_common_snp = w_ld.astype(np.float32, copy=False)
322
+
323
+ # weight the baseline annotation by N
324
+ baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
325
+ # append intercept
326
+ baseline_annotation = append_intercept(baseline_annotation)
327
+
328
+ # Run the jackknife
329
+ Nbar = sumstats.N.mean()
330
+ chunk_size = spatial_annotation.shape[1]
331
+ out_chunk = thread_map(jackknife_for_processmap, range(chunk_size),
332
+ max_workers=config.num_processes,
333
+ chunksize=10,
334
+ desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
335
+ )
336
+
337
+ # cache the results
338
+ out_chunk = pd.DataFrame.from_records(out_chunk,
339
+ columns=['beta', 'se', ],
340
+ index=spatial_annotation_cnames)
341
+ # get the spots with nan
342
+ nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
343
+ if len(nan_spots) > 0:
344
+ logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
345
+ # drop the nan
346
+ out_chunk = out_chunk.dropna()
347
+
348
+ out_chunk['z'] = out_chunk.beta / out_chunk.se
349
+ out_chunk['p'] = norm.sf(out_chunk['z'])
350
+ output_dict[trait_name].append(out_chunk)
351
+
352
+ del ref_ld_spatial, spatial_annotation, baseline_annotation, w_ld_common_snp
353
+ gc.collect()
354
+
355
+ # Save the results
356
+ out_dir = config.ldsc_save_dir
357
+ for trait_name, out_chunk_list in output_dict.items():
358
+ out_all = pd.concat(out_chunk_list, axis=0)
359
+ if running_chunk_number == total_chunk_number_found:
360
+ out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
361
+ else:
362
+ out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
363
+ out_all['spot'] = out_all.index
364
+ out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
365
+ out_all.to_csv(out_file_name, compression='gzip', index=False)
366
+
367
+ logger.info(f'Output saved to {out_file_name} for {trait_name}')
368
+ logger.info(f'------Spatial LDSC for {sample_name} finished!')
369
+
370
+
371
+ def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config, ):
372
+ # Load the spatial annotations for this chunk
373
+ sample_name = config.sample_name
374
+ ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
375
+ ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
376
+ ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
377
+ ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
378
+
379
+ spatial_annotation_cnames = ref_ld_spatial.columns
380
+ return ref_ld_spatial.values, spatial_annotation_cnames