gsMap 1.71.1__py3-none-any.whl → 1.71.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,380 +1,360 @@
1
- import gc
2
- import logging
3
- import os
4
- from collections import defaultdict
5
- from pathlib import Path
6
-
7
- import anndata as ad
8
- import numpy as np
9
- import pandas as pd
10
- import zarr
11
- from scipy.stats import norm
12
- from tqdm.contrib.concurrent import thread_map
13
-
14
- import gsMap.utils.jackknife as jk
15
- from gsMap.config import SpatialLDSCConfig
16
- from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
17
-
18
- logger = logging.getLogger('gsMap.spatial_ldsc')
19
-
20
-
21
- # %%
22
- def _coef_new(jknife):
23
- est_ = jknife.jknife_est[0, 0] / Nbar
24
- se_ = jknife.jknife_se[0, 0] / Nbar
25
- return est_, se_
26
-
27
-
28
- def append_intercept(x):
29
- n_row = x.shape[0]
30
- intercept = np.ones((n_row, 1))
31
- x_new = np.concatenate((x, intercept), axis=1)
32
- return x_new
33
-
34
-
35
- def filter_sumstats_by_chisq(sumstats, chisq_max):
36
- before_len = len(sumstats)
37
- if chisq_max is None:
38
- chisq_max = max(0.001 * sumstats.N.max(), 80)
39
- logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
40
- sumstats['chisq'] = sumstats.Z ** 2
41
- sumstats = sumstats[sumstats.chisq < chisq_max]
42
- after_len = len(sumstats)
43
- if after_len < before_len:
44
- logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
45
- else:
46
- logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
47
- return sumstats
48
-
49
-
50
- def aggregate(y, x, N, M, intercept=1):
51
- num = M * (np.mean(y) - intercept)
52
- denom = np.mean(np.multiply(x, N))
53
- return num / denom
54
-
55
-
56
- def weights(ld, w_ld, N, M, hsq, intercept=1):
57
- M = float(M)
58
- hsq = np.clip(hsq, 0.0, 1.0)
59
- ld = np.maximum(ld, 1.0)
60
- w_ld = np.maximum(w_ld, 1.0)
61
- c = hsq * N / M
62
- het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
63
- oc_w = 1.0 / w_ld
64
- w = np.multiply(het_w, oc_w)
65
- return w
66
-
67
-
68
- def jackknife_for_processmap(spot_id):
69
- # calculate the initial weight for each spot
70
- spot_spatial_annotation = spatial_annotation[:, spot_id]
71
- spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
72
- initial_w = (
73
- get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
74
- M_tot=10000, w_ld=w_ld_common_snp, intercept=1)
75
- .astype(np.float32)
76
- .reshape((-1, 1)))
77
-
78
- # apply the weight to baseline annotation, spatial annotation and CHISQ
79
- initial_w_scaled = initial_w / np.sum(initial_w)
80
- baseline_annotation_spot = baseline_annotation * initial_w_scaled
81
- spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
82
- CHISQ = sumstats.chisq.values.reshape((-1, 1))
83
- y = CHISQ * initial_w_scaled
84
-
85
- # run the jackknife
86
- x_focal = np.concatenate((spatial_annotation_spot,
87
- baseline_annotation_spot), axis=1)
88
- try:
89
- jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
90
- # LinAlgError
91
- except np.linalg.LinAlgError as e:
92
- logger.warning(f'LinAlgError: {e}')
93
- return np.nan, np.nan
94
- return _coef_new(jknife)
95
-
96
-
97
- # Updated function
98
- def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
99
- tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
100
- initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
101
- initial_w = np.sqrt(initial_w)
102
- return initial_w
103
-
104
-
105
- def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
106
- # Load the gwas summary statistics
107
- sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
108
- sumstats.set_index('SNP', inplace=True)
109
- sumstats = sumstats.astype(np.float32)
110
- sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
111
-
112
- # NB: The intersection order is essential for keeping the same order of SNPs by its BP location
113
- common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
114
- if len(common_snp) < 200000:
115
- logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
116
-
117
- sumstats = sumstats.loc[common_snp]
118
-
119
- # get the common index position of baseline_and_w_ld_common_snp for quick access
120
- sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
121
- return sumstats
122
-
123
-
124
- def _get_sumstats_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
125
- chisq_max=None):
126
- # first validate if all sumstats file exists
127
- logger.info('Validating sumstats files...')
128
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
129
- if not os.path.exists(sumstat_file_path):
130
- raise FileNotFoundError(f'{sumstat_file_path} not found')
131
- # then load all sumstats
132
- sumstats_cleaned_dict = {}
133
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
134
- sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
135
- baseline_and_w_ld_common_snp, chisq_max)
136
- logger.info('cleaned sumstats loaded')
137
- return sumstats_cleaned_dict
138
-
139
-
140
- class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
141
- def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
142
- self.config = config
143
- mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
144
- mk_score_genes = mk_score.index
145
-
146
- snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
147
- common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
148
- common_snps = snp_gene_weight_adata.obs.index
149
- # self.snp_gene_weight_adata = snp_gene_weight_adata[common_snp_among_all_sumstats:, common_genes.to_list()]
150
- self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
151
- self.mk_score_common = mk_score.loc[common_genes]
152
-
153
- # calculate the chunk number
154
- self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
155
-
156
- def fetch_ldscore_by_chunk(self, chunk_index):
157
- chunk_start = self.chunk_starts[chunk_index]
158
- mk_score_chunk = self.mk_score_common.iloc[:,
159
- chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
160
- ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
161
- mk_score_chunk,
162
- drop_dummy_na=False,
163
- )
164
-
165
- spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
166
- return ldscore_chunk, spots_name
167
-
168
- def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self,
169
- mk_score_chunk,
170
- drop_dummy_na=True,
171
- ):
172
-
173
- if drop_dummy_na:
174
- ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
175
- else:
176
- ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
177
-
178
- return ldscore_chr_chunk
179
-
180
-
181
- def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index,
182
- chisq_max=None):
183
- # first validate if all sumstats file exists
184
- logger.info('Validating sumstats files...')
185
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
186
- if not os.path.exists(sumstat_file_path):
187
- raise FileNotFoundError(f'{sumstat_file_path} not found')
188
- # then load all sumstats
189
- sumstats_cleaned_dict = {}
190
- for trait_name, sumstat_file_path in sumstats_config_dict.items():
191
- sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path,
192
- baseline_and_w_ld_common_snp, chisq_max)
193
- # get the common snps among all sumstats
194
- common_snp_among_all_sumstats = None
195
- for trait_name, sumstats in sumstats_cleaned_dict.items():
196
- if common_snp_among_all_sumstats is None:
197
- common_snp_among_all_sumstats = sumstats.index
198
- else:
199
- common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
200
-
201
- # filter the common snps among all sumstats
202
- for trait_name, sumstats in sumstats_cleaned_dict.items():
203
- sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
204
-
205
- logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
206
- return sumstats_cleaned_dict, common_snp_among_all_sumstats
207
-
208
-
209
- def run_spatial_ldsc(config: SpatialLDSCConfig):
210
- global spatial_annotation, baseline_annotation, n_blocks, Nbar, sumstats, ref_ld_baseline_column_sum, w_ld_common_snp
211
- # config
212
- n_blocks = config.n_blocks
213
- sample_name = config.sample_name
214
-
215
- print(f'------Running Spatial LDSC for {sample_name}...')
216
- # Load the regression weights
217
- w_ld = _read_w_ld(config.w_file)
218
- w_ld_cname = w_ld.columns[1]
219
- w_ld.set_index('SNP', inplace=True)
220
-
221
- ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
222
-
223
- ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
224
- # n_annot_baseline = len(ref_ld_baseline.columns)
225
- # M_annot_baseline = _read_M_v2(ld_file_baseline, n_annot_baseline, config.not_M_5_50)
226
-
227
- # common snp between baseline and w_ld
228
- baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
229
- baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
230
-
231
- # Clean the sumstats
232
- sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
233
- config.sumstats_config_dict, baseline_and_w_ld_common_snp,
234
- chisq_max=config.chisq_max)
235
- common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
236
-
237
- # insure the order is monotonic
238
- assert pd.Series(
239
- common_snp_among_all_sumstats_pos).is_monotonic_increasing, 'common_snp_among_all_sumstats_pos is not monotonic increasing'
240
-
241
- if len(common_snp_among_all_sumstats) < 200000:
242
- logger.warning(
243
- f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
244
-
245
- ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
246
- w_ld = w_ld.loc[common_snp_among_all_sumstats]
247
-
248
- # load additional baseline annotations
249
- if config.use_additional_baseline_annotation:
250
- print('Using additional baseline annotations')
251
- ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
252
- ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
253
- n_annot_baseline_additional = len(ref_ld_baseline_additional.columns)
254
- logger.info(f'{len(ref_ld_baseline_additional.columns)} additional baseline annotations loaded')
255
- # M_annot_baseline_additional = _read_M_v2(ld_file_baseline_additional, n_annot_baseline_additional,
256
- # config.not_M_5_50)
257
- ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
258
- ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
259
- del ref_ld_baseline_additional
260
-
261
- # Detect available chunk files
262
- if config.ldscore_save_format == 'quick_mode':
263
- s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
264
- total_chunk_number_found = len(s_ldsc.chunk_starts)
265
- print(f'Split data into {total_chunk_number_found} chunks')
266
- else:
267
- all_file = os.listdir(config.ldscore_save_dir)
268
- total_chunk_number_found = sum('chunk' in name for name in all_file)
269
- print(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
270
-
271
- if config.all_chunk is None:
272
- if config.chunk_range is not None:
273
- assert config.chunk_range[0] >= 1 and config.chunk_range[
274
- 1] <= total_chunk_number_found, 'Chunk range out of bound. It should be in [1, all_chunk]'
275
- print(
276
- f'chunk range provided, using chunked files from {config.chunk_range[0]} to {config.chunk_range[1]}')
277
- start_chunk, end_chunk = config.chunk_range
278
- else:
279
- start_chunk, end_chunk = 1, total_chunk_number_found
280
- else:
281
- all_chunk = config.all_chunk
282
- print(f'using {all_chunk} chunked files by provided argument')
283
- print(f'\t')
284
- print(f'Input {all_chunk} chunked files')
285
- start_chunk, end_chunk = 1, all_chunk
286
-
287
- running_chunk_number = end_chunk - start_chunk + 1
288
-
289
- # Process each chunk
290
- output_dict = defaultdict(list)
291
- zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
292
- if config.ldscore_save_format == 'zarr':
293
- assert zarr_path.exists(), f'{zarr_path} not found, which is required for zarr format'
294
- zarr_file = zarr.open(str(zarr_path))
295
- spots_name = zarr_file.attrs['spot_names']
296
-
297
- for chunk_index in range(start_chunk, end_chunk + 1):
298
- if config.ldscore_save_format == 'feather':
299
- ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk_from_feather(chunk_index,
300
- common_snp_among_all_sumstats_pos,
301
- config,
302
- )
303
- elif config.ldscore_save_format == 'zarr':
304
- ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
305
- start_spot = (chunk_index - 1) * zarr_file.chunks[1]
306
- ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
307
- spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
308
- elif config.ldscore_save_format == 'quick_mode':
309
- ref_ld_spatial, spatial_annotation_cnames = s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
310
- else:
311
- raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
312
-
313
- # get the x_tot_precomputed matrix by adding baseline and spatial annotation
314
- ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
315
- # x_tot_precomputed = ref_ld_spatial + ref_ld_baseline_column_sum
316
-
317
- for trait_name, sumstats in sumstats_cleaned_dict.items():
318
-
319
- spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
320
- baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
321
- w_ld_common_snp = w_ld.astype(np.float32, copy=False)
322
-
323
- # weight the baseline annotation by N
324
- baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
325
- # append intercept
326
- baseline_annotation = append_intercept(baseline_annotation)
327
-
328
- # Run the jackknife
329
- Nbar = sumstats.N.mean()
330
- chunk_size = spatial_annotation.shape[1]
331
- out_chunk = thread_map(jackknife_for_processmap, range(chunk_size),
332
- max_workers=config.num_processes,
333
- chunksize=10,
334
- desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
335
- )
336
-
337
- # cache the results
338
- out_chunk = pd.DataFrame.from_records(out_chunk,
339
- columns=['beta', 'se', ],
340
- index=spatial_annotation_cnames)
341
- # get the spots with nan
342
- nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
343
- if len(nan_spots) > 0:
344
- logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
345
- # drop the nan
346
- out_chunk = out_chunk.dropna()
347
-
348
- out_chunk['z'] = out_chunk.beta / out_chunk.se
349
- out_chunk['p'] = norm.sf(out_chunk['z'])
350
- output_dict[trait_name].append(out_chunk)
351
-
352
- del ref_ld_spatial, spatial_annotation, baseline_annotation, w_ld_common_snp
353
- gc.collect()
354
-
355
- # Save the results
356
- out_dir = config.ldsc_save_dir
357
- for trait_name, out_chunk_list in output_dict.items():
358
- out_all = pd.concat(out_chunk_list, axis=0)
359
- if running_chunk_number == total_chunk_number_found:
360
- out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
361
- else:
362
- out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
363
- out_all['spot'] = out_all.index
364
- out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
365
- out_all.to_csv(out_file_name, compression='gzip', index=False)
366
-
367
- logger.info(f'Output saved to {out_file_name} for {trait_name}')
368
- logger.info(f'------Spatial LDSC for {sample_name} finished!')
369
-
370
-
371
- def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config, ):
372
- # Load the spatial annotations for this chunk
373
- sample_name = config.sample_name
374
- ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
375
- ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
376
- ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
377
- ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
378
-
379
- spatial_annotation_cnames = ref_ld_spatial.columns
380
- return ref_ld_spatial.values, spatial_annotation_cnames
1
+ import gc
2
+ import logging
3
+ import os
4
+ from collections import defaultdict
5
+ from functools import partial
6
+ from pathlib import Path
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import pandas as pd
11
+ import zarr
12
+ from scipy.stats import norm
13
+ from tqdm.contrib.concurrent import thread_map
14
+
15
+ import gsMap.utils.jackknife as jk
16
+ from gsMap.config import SpatialLDSCConfig
17
+ from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
18
+
19
+ logger = logging.getLogger('gsMap.spatial_ldsc')
20
+
21
+
22
+ def _coef_new(jknife, Nbar):
23
+ """Calculate coefficients adjusted by Nbar."""
24
+ est_ = jknife.jknife_est[0, 0] / Nbar
25
+ se_ = jknife.jknife_se[0, 0] / Nbar
26
+ return est_, se_
27
+
28
+
29
+ def append_intercept(x):
30
+ """Append an intercept term to the design matrix."""
31
+ n_row = x.shape[0]
32
+ intercept = np.ones((n_row, 1))
33
+ x_new = np.concatenate((x, intercept), axis=1)
34
+ return x_new
35
+
36
+
37
+ def filter_sumstats_by_chisq(sumstats, chisq_max):
38
+ """Filter summary statistics based on chi-squared threshold."""
39
+ before_len = len(sumstats)
40
+ if chisq_max is None:
41
+ chisq_max = max(0.001 * sumstats.N.max(), 80)
42
+ logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
43
+ sumstats['chisq'] = sumstats.Z ** 2
44
+ sumstats = sumstats[sumstats.chisq < chisq_max]
45
+ after_len = len(sumstats)
46
+ if after_len < before_len:
47
+ logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
48
+ else:
49
+ logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
50
+ return sumstats
51
+
52
+
53
+ def aggregate(y, x, N, M, intercept=1):
54
+ """Aggregate function used in weight calculation."""
55
+ num = M * (np.mean(y) - intercept)
56
+ denom = np.mean(np.multiply(x, N))
57
+ return num / denom
58
+
59
+
60
+ def weights(ld, w_ld, N, M, hsq, intercept=1):
61
+ """Calculate weights for regression."""
62
+ M = float(M)
63
+ hsq = np.clip(hsq, 0.0, 1.0)
64
+ ld = np.maximum(ld, 1.0)
65
+ w_ld = np.maximum(w_ld, 1.0)
66
+ c = hsq * N / M
67
+ het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
68
+ oc_w = 1.0 / w_ld
69
+ w = np.multiply(het_w, oc_w)
70
+ return w
71
+
72
+
73
+ def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
74
+ """Optimized function to calculate initial weights."""
75
+ tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
76
+ initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
77
+ initial_w = np.sqrt(initial_w)
78
+ return initial_w
79
+
80
+
81
+ def jackknife_for_processmap(spot_id, spatial_annotation, ref_ld_baseline_column_sum, sumstats, baseline_annotation, w_ld_common_snp, Nbar, n_blocks):
82
+ """Perform jackknife resampling for a given spot."""
83
+ spot_spatial_annotation = spatial_annotation[:, spot_id]
84
+ spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
85
+ initial_w = get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
86
+ M_tot=10000, w_ld=w_ld_common_snp, intercept=1).astype(np.float32).reshape((-1, 1))
87
+ initial_w_scaled = initial_w / np.sum(initial_w)
88
+ baseline_annotation_spot = baseline_annotation * initial_w_scaled
89
+ spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
90
+ CHISQ = sumstats.chisq.values.reshape((-1, 1))
91
+ y = CHISQ * initial_w_scaled
92
+ x_focal = np.concatenate((spatial_annotation_spot, baseline_annotation_spot), axis=1)
93
+ try:
94
+ jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
95
+ except np.linalg.LinAlgError as e:
96
+ logger.warning(f'LinAlgError: {e}')
97
+ return np.nan, np.nan
98
+ return _coef_new(jknife, Nbar)
99
+
100
+
101
+ def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
102
+ """Preprocess summary statistics."""
103
+ sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
104
+ sumstats.set_index('SNP', inplace=True)
105
+ sumstats = sumstats.astype(np.float32)
106
+ sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
107
+ common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
108
+ if len(common_snp) < 200000:
109
+ logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
110
+ sumstats = sumstats.loc[common_snp]
111
+ sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
112
+ return sumstats
113
+
114
+
115
+ def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
116
+ """Get summary statistics with common SNPs among all traits."""
117
+ logger.info('Validating sumstats files...')
118
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
119
+ if not os.path.exists(sumstat_file_path):
120
+ raise FileNotFoundError(f'{sumstat_file_path} not found')
121
+ sumstats_cleaned_dict = {}
122
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
123
+ sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp, chisq_max)
124
+ common_snp_among_all_sumstats = None
125
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
126
+ if common_snp_among_all_sumstats is None:
127
+ common_snp_among_all_sumstats = sumstats.index
128
+ else:
129
+ common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
130
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
131
+ sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
132
+ logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
133
+ return sumstats_cleaned_dict, common_snp_among_all_sumstats
134
+
135
+
136
+ class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
137
+ """Class to handle pre-calculated SNP-Gene weight matrix for quick mode."""
138
+ def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
139
+ self.config = config
140
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
141
+ mk_score_genes = mk_score.index
142
+ snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
143
+ common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
144
+ common_snps = snp_gene_weight_adata.obs.index
145
+ self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
146
+ self.mk_score_common = mk_score.loc[common_genes]
147
+ self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
148
+
149
+ def fetch_ldscore_by_chunk(self, chunk_index):
150
+ """Fetch LD score by chunk."""
151
+ chunk_start = self.chunk_starts[chunk_index]
152
+ mk_score_chunk = self.mk_score_common.iloc[:, chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
153
+ ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(mk_score_chunk, drop_dummy_na=False)
154
+ spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
155
+ return ldscore_chunk, spots_name
156
+
157
+ def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self, mk_score_chunk, drop_dummy_na=True):
158
+ """Calculate LD score using SNP-Gene weight matrix by chunk."""
159
+ if drop_dummy_na:
160
+ ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
161
+ else:
162
+ ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
163
+ return ldscore_chr_chunk
164
+
165
+
166
+ def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config):
167
+ """Load LD score chunk from feather format."""
168
+ sample_name = config.sample_name
169
+ ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
170
+ ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
171
+ ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
172
+ ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
173
+ spatial_annotation_cnames = ref_ld_spatial.columns
174
+ return ref_ld_spatial.values, spatial_annotation_cnames
175
+
176
+
177
+ def run_spatial_ldsc(config: SpatialLDSCConfig):
178
+ """Run spatial LDSC analysis."""
179
+ logger.info(f'------Running Spatial LDSC for {config.sample_name}...')
180
+ n_blocks = config.n_blocks
181
+ sample_name = config.sample_name
182
+
183
+ # Load regression weights
184
+ w_ld = _read_w_ld(config.w_file)
185
+ w_ld_cname = w_ld.columns[1]
186
+ w_ld.set_index('SNP', inplace=True)
187
+
188
+ ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
189
+ ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
190
+ baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
191
+ baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
192
+
193
+ sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
194
+ config.sumstats_config_dict, baseline_and_w_ld_common_snp, chisq_max=config.chisq_max)
195
+ common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
196
+
197
+ if not pd.Series(common_snp_among_all_sumstats_pos).is_monotonic_increasing:
198
+ raise ValueError('common_snp_among_all_sumstats_pos is not monotonic increasing')
199
+
200
+ if len(common_snp_among_all_sumstats) < 200000:
201
+ logger.warning(
202
+ f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
203
+
204
+ ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
205
+ w_ld = w_ld.loc[common_snp_among_all_sumstats]
206
+
207
+ # Load additional baseline annotations if needed
208
+ if config.use_additional_baseline_annotation:
209
+ logger.info('Using additional baseline annotations')
210
+ ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
211
+ ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
212
+ ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
213
+ ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
214
+ del ref_ld_baseline_additional
215
+
216
+ # Initialize s_ldsc once if quick_mode
217
+ s_ldsc = None
218
+ if config.ldscore_save_format == 'quick_mode':
219
+ s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
220
+ total_chunk_number_found = len(s_ldsc.chunk_starts)
221
+ logger.info(f'Split data into {total_chunk_number_found} chunks')
222
+ else:
223
+ total_chunk_number_found = determine_total_chunks(config)
224
+
225
+ start_chunk, end_chunk = determine_chunk_range(config, total_chunk_number_found)
226
+ running_chunk_number = end_chunk - start_chunk + 1
227
+
228
+ # Load zarr file if needed
229
+ zarr_file, spots_name = None, None
230
+ if config.ldscore_save_format == 'zarr':
231
+ zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
232
+ if not zarr_path.exists():
233
+ raise FileNotFoundError(f'{zarr_path} not found, which is required for zarr format')
234
+ zarr_file = zarr.open(str(zarr_path))
235
+ spots_name = zarr_file.attrs['spot_names']
236
+
237
+ output_dict = defaultdict(list)
238
+ for chunk_index in range(start_chunk, end_chunk + 1):
239
+ ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk(
240
+ chunk_index,
241
+ common_snp_among_all_sumstats_pos,
242
+ config,
243
+ zarr_file,
244
+ spots_name,
245
+ s_ldsc # Pass s_ldsc to the function
246
+ )
247
+ ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
248
+
249
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
250
+ spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
251
+ baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
252
+ w_ld_common_snp = w_ld.astype(np.float32, copy=False)
253
+
254
+ baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
255
+ baseline_annotation = append_intercept(baseline_annotation)
256
+
257
+ Nbar = sumstats.N.mean()
258
+ chunk_size = spatial_annotation.shape[1]
259
+
260
+ jackknife_func = partial(
261
+ jackknife_for_processmap,
262
+ spatial_annotation=spatial_annotation,
263
+ ref_ld_baseline_column_sum=ref_ld_baseline_column_sum,
264
+ sumstats=sumstats,
265
+ baseline_annotation=baseline_annotation,
266
+ w_ld_common_snp=w_ld_common_snp,
267
+ Nbar=Nbar,
268
+ n_blocks=n_blocks
269
+ )
270
+
271
+ out_chunk = thread_map(
272
+ jackknife_func,
273
+ range(chunk_size),
274
+ max_workers=config.num_processes,
275
+ chunksize=10,
276
+ desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
277
+ )
278
+
279
+ out_chunk = pd.DataFrame.from_records(out_chunk, columns=['beta', 'se'], index=spatial_annotation_cnames)
280
+ nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
281
+ if len(nan_spots) > 0:
282
+ logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
283
+ out_chunk = out_chunk.dropna()
284
+ out_chunk['z'] = out_chunk.beta / out_chunk.se
285
+ out_chunk['p'] = norm.sf(out_chunk['z'])
286
+ output_dict[trait_name].append(out_chunk)
287
+
288
+ del spatial_annotation, baseline_annotation, w_ld_common_snp
289
+ gc.collect()
290
+
291
+ save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk)
292
+ logger.info(f'------Spatial LDSC for {sample_name} finished!')
293
+
294
+
295
+ def determine_total_chunks(config):
296
+ """Determine total number of chunks based on the ldscore save format."""
297
+ if config.ldscore_save_format == 'quick_mode':
298
+ s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, [])
299
+ total_chunk_number_found = len(s_ldsc.chunk_starts)
300
+ logger.info(f'Split data into {total_chunk_number_found} chunks')
301
+ else:
302
+ all_file = os.listdir(config.ldscore_save_dir)
303
+ total_chunk_number_found = sum('chunk' in name for name in all_file)
304
+ logger.info(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
305
+ return total_chunk_number_found
306
+
307
+
308
+ def determine_chunk_range(config, total_chunk_number_found):
309
+ """Determine the range of chunks to process."""
310
+ if config.all_chunk is None:
311
+ if config.chunk_range is not None:
312
+ if not (1 <= config.chunk_range[0] <= total_chunk_number_found) or not (1 <= config.chunk_range[1] <= total_chunk_number_found):
313
+ raise ValueError('Chunk range out of bound. It should be in [1, all_chunk]')
314
+ start_chunk, end_chunk = config.chunk_range
315
+ logger.info(f'Chunk range provided, using chunked files from {start_chunk} to {end_chunk}')
316
+ else:
317
+ start_chunk, end_chunk = 1, total_chunk_number_found
318
+ else:
319
+ all_chunk = config.all_chunk
320
+ logger.info(f'Using {all_chunk} chunked files by provided argument')
321
+ start_chunk, end_chunk = 1, all_chunk
322
+ return start_chunk, end_chunk
323
+
324
+
325
+ def load_ldscore_chunk(chunk_index, common_snp_among_all_sumstats_pos, config, zarr_file=None, spots_name=None, s_ldsc=None):
326
+ """Load LD score chunk based on save format."""
327
+ if config.ldscore_save_format == 'feather':
328
+ return load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config)
329
+ elif config.ldscore_save_format == 'zarr':
330
+ ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
331
+ start_spot = (chunk_index - 1) * zarr_file.chunks[1]
332
+ ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
333
+ spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
334
+ return ref_ld_spatial, spatial_annotation_cnames
335
+ elif config.ldscore_save_format == 'quick_mode':
336
+ # Use the pre-initialized s_ldsc
337
+ if s_ldsc is None:
338
+ raise ValueError("s_ldsc must be provided in quick_mode")
339
+ return s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
340
+ else:
341
+ raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
342
+
343
+
344
+ def save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk):
345
+ """Save the results to the specified directory."""
346
+ out_dir = config.ldsc_save_dir
347
+ for trait_name, out_chunk_list in output_dict.items():
348
+ out_all = pd.concat(out_chunk_list, axis=0)
349
+ sample_name = config.sample_name
350
+ if running_chunk_number == end_chunk - start_chunk + 1:
351
+ out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
352
+ else:
353
+ out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
354
+ out_all['spot'] = out_all.index
355
+ out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
356
+
357
+ # clip the p-values
358
+ out_all['p'] = out_all['p'].clip(1e-300, 1)
359
+ out_all.to_csv(out_file_name, compression='gzip', index=False)
360
+ logger.info(f'Output saved to {out_file_name} for {trait_name}')