gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/adjacency_matrix.py +25 -27
- gsMap/GNN/model.py +9 -7
- gsMap/GNN/train.py +8 -11
- gsMap/__init__.py +3 -3
- gsMap/__main__.py +3 -2
- gsMap/cauchy_combination_test.py +75 -72
- gsMap/config.py +822 -316
- gsMap/create_slice_mean.py +154 -0
- gsMap/diagnosis.py +179 -101
- gsMap/find_latent_representation.py +28 -26
- gsMap/format_sumstats.py +233 -201
- gsMap/generate_ldscore.py +353 -209
- gsMap/latent_to_gene.py +92 -60
- gsMap/main.py +23 -14
- gsMap/report.py +39 -25
- gsMap/run_all_mode.py +86 -46
- gsMap/setup.py +1 -1
- gsMap/spatial_ldsc_multiple_sumstats.py +154 -80
- gsMap/utils/generate_r2_matrix.py +173 -140
- gsMap/utils/jackknife.py +84 -80
- gsMap/utils/manhattan_plot.py +180 -207
- gsMap/utils/regression_read.py +105 -122
- gsMap/visualize.py +82 -64
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/METADATA +21 -6
- gsmap-1.72.3.dist-info/RECORD +31 -0
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/WHEEL +1 -1
- gsMap/utils/make_annotations.py +0 -518
- gsmap-1.71.2.dist-info/RECORD +0 -31
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/LICENSE +0 -0
- {gsmap-1.71.2.dist-info → gsmap-1.72.3.dist-info}/entry_points.txt +0 -0
@@ -14,9 +14,9 @@ from tqdm.contrib.concurrent import thread_map
|
|
14
14
|
|
15
15
|
import gsMap.utils.jackknife as jk
|
16
16
|
from gsMap.config import SpatialLDSCConfig
|
17
|
-
from gsMap.utils.regression_read import _read_sumstats, _read_w_ld
|
17
|
+
from gsMap.utils.regression_read import _read_ref_ld_v2, _read_sumstats, _read_w_ld
|
18
18
|
|
19
|
-
logger = logging.getLogger(
|
19
|
+
logger = logging.getLogger("gsMap.spatial_ldsc")
|
20
20
|
|
21
21
|
|
22
22
|
def _coef_new(jknife, Nbar):
|
@@ -39,14 +39,16 @@ def filter_sumstats_by_chisq(sumstats, chisq_max):
|
|
39
39
|
before_len = len(sumstats)
|
40
40
|
if chisq_max is None:
|
41
41
|
chisq_max = max(0.001 * sumstats.N.max(), 80)
|
42
|
-
logger.info(f
|
43
|
-
sumstats[
|
42
|
+
logger.info(f"No chi^2 threshold provided, using {chisq_max} as default")
|
43
|
+
sumstats["chisq"] = sumstats.Z**2
|
44
44
|
sumstats = sumstats[sumstats.chisq < chisq_max]
|
45
45
|
after_len = len(sumstats)
|
46
46
|
if after_len < before_len:
|
47
|
-
logger.info(
|
47
|
+
logger.info(
|
48
|
+
f"Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)"
|
49
|
+
)
|
48
50
|
else:
|
49
|
-
logger.info(f
|
51
|
+
logger.info(f"No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)")
|
50
52
|
return sumstats
|
51
53
|
|
52
54
|
|
@@ -73,17 +75,37 @@ def weights(ld, w_ld, N, M, hsq, intercept=1):
|
|
73
75
|
def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
|
74
76
|
"""Optimized function to calculate initial weights."""
|
75
77
|
tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
|
76
|
-
initial_w = weights(
|
78
|
+
initial_w = weights(
|
79
|
+
x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept
|
80
|
+
)
|
77
81
|
initial_w = np.sqrt(initial_w)
|
78
82
|
return initial_w
|
79
83
|
|
80
84
|
|
81
|
-
def jackknife_for_processmap(
|
85
|
+
def jackknife_for_processmap(
|
86
|
+
spot_id,
|
87
|
+
spatial_annotation,
|
88
|
+
ref_ld_baseline_column_sum,
|
89
|
+
sumstats,
|
90
|
+
baseline_annotation,
|
91
|
+
w_ld_common_snp,
|
92
|
+
Nbar,
|
93
|
+
n_blocks,
|
94
|
+
):
|
82
95
|
"""Perform jackknife resampling for a given spot."""
|
83
96
|
spot_spatial_annotation = spatial_annotation[:, spot_id]
|
84
97
|
spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
|
85
|
-
initial_w =
|
86
|
-
|
98
|
+
initial_w = (
|
99
|
+
get_weight_optimized(
|
100
|
+
sumstats,
|
101
|
+
x_tot_precomputed=spot_x_tot_precomputed,
|
102
|
+
M_tot=10000,
|
103
|
+
w_ld=w_ld_common_snp,
|
104
|
+
intercept=1,
|
105
|
+
)
|
106
|
+
.astype(np.float32)
|
107
|
+
.reshape((-1, 1))
|
108
|
+
)
|
87
109
|
initial_w_scaled = initial_w / np.sum(initial_w)
|
88
110
|
baseline_annotation_spot = baseline_annotation * initial_w_scaled
|
89
111
|
spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
|
@@ -93,68 +115,93 @@ def jackknife_for_processmap(spot_id, spatial_annotation, ref_ld_baseline_column
|
|
93
115
|
try:
|
94
116
|
jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
|
95
117
|
except np.linalg.LinAlgError as e:
|
96
|
-
logger.warning(f
|
118
|
+
logger.warning(f"LinAlgError: {e}")
|
97
119
|
return np.nan, np.nan
|
98
120
|
return _coef_new(jknife, Nbar)
|
99
121
|
|
100
122
|
|
101
|
-
def _preprocess_sumstats(
|
123
|
+
def _preprocess_sumstats(
|
124
|
+
trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None
|
125
|
+
):
|
102
126
|
"""Preprocess summary statistics."""
|
103
127
|
sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
|
104
|
-
sumstats.set_index(
|
128
|
+
sumstats.set_index("SNP", inplace=True)
|
105
129
|
sumstats = sumstats.astype(np.float32)
|
106
130
|
sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
|
107
131
|
common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
|
108
132
|
if len(common_snp) < 200000:
|
109
|
-
logger.warning(
|
133
|
+
logger.warning(
|
134
|
+
f"WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad."
|
135
|
+
)
|
110
136
|
sumstats = sumstats.loc[common_snp]
|
111
|
-
sumstats[
|
137
|
+
sumstats["common_index_pos"] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(
|
138
|
+
sumstats.index
|
139
|
+
)
|
112
140
|
return sumstats
|
113
141
|
|
114
142
|
|
115
|
-
def _get_sumstats_with_common_snp_from_sumstats_dict(
|
143
|
+
def _get_sumstats_with_common_snp_from_sumstats_dict(
|
144
|
+
sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None
|
145
|
+
):
|
116
146
|
"""Get summary statistics with common SNPs among all traits."""
|
117
|
-
logger.info(
|
147
|
+
logger.info("Validating sumstats files...")
|
118
148
|
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
119
149
|
if not os.path.exists(sumstat_file_path):
|
120
|
-
raise FileNotFoundError(f
|
150
|
+
raise FileNotFoundError(f"{sumstat_file_path} not found")
|
121
151
|
sumstats_cleaned_dict = {}
|
122
152
|
for trait_name, sumstat_file_path in sumstats_config_dict.items():
|
123
|
-
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(
|
153
|
+
sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(
|
154
|
+
trait_name, sumstat_file_path, baseline_and_w_ld_common_snp, chisq_max
|
155
|
+
)
|
124
156
|
common_snp_among_all_sumstats = None
|
125
157
|
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
126
158
|
if common_snp_among_all_sumstats is None:
|
127
159
|
common_snp_among_all_sumstats = sumstats.index
|
128
160
|
else:
|
129
|
-
common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(
|
161
|
+
common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(
|
162
|
+
sumstats.index
|
163
|
+
)
|
130
164
|
for trait_name, sumstats in sumstats_cleaned_dict.items():
|
131
165
|
sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
|
132
|
-
logger.info(f
|
166
|
+
logger.info(f"Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}")
|
133
167
|
return sumstats_cleaned_dict, common_snp_among_all_sumstats
|
134
168
|
|
135
169
|
|
136
170
|
class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
|
137
171
|
"""Class to handle pre-calculated SNP-Gene weight matrix for quick mode."""
|
172
|
+
|
138
173
|
def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
|
139
174
|
self.config = config
|
140
|
-
mk_score = pd.read_feather(config.mkscore_feather_path).set_index(
|
175
|
+
mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM")
|
141
176
|
mk_score_genes = mk_score.index
|
142
177
|
snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
|
143
178
|
common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
|
144
|
-
common_snps = snp_gene_weight_adata.obs.index
|
145
|
-
self.snp_gene_weight_matrix = snp_gene_weight_adata[
|
179
|
+
# common_snps = snp_gene_weight_adata.obs.index
|
180
|
+
self.snp_gene_weight_matrix = snp_gene_weight_adata[
|
181
|
+
common_snp_among_all_sumstats_pos, common_genes.to_list()
|
182
|
+
].X
|
146
183
|
self.mk_score_common = mk_score.loc[common_genes]
|
147
|
-
self.chunk_starts = list(
|
184
|
+
self.chunk_starts = list(
|
185
|
+
range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode)
|
186
|
+
)
|
148
187
|
|
149
188
|
def fetch_ldscore_by_chunk(self, chunk_index):
|
150
189
|
"""Fetch LD score by chunk."""
|
151
190
|
chunk_start = self.chunk_starts[chunk_index]
|
152
|
-
mk_score_chunk = self.mk_score_common.iloc[
|
153
|
-
|
154
|
-
|
191
|
+
mk_score_chunk = self.mk_score_common.iloc[
|
192
|
+
:, chunk_start : chunk_start + self.config.spots_per_chunk_quick_mode
|
193
|
+
]
|
194
|
+
ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
|
195
|
+
mk_score_chunk, drop_dummy_na=False
|
196
|
+
)
|
197
|
+
spots_name = self.mk_score_common.columns[
|
198
|
+
chunk_start : chunk_start + self.config.spots_per_chunk_quick_mode
|
199
|
+
]
|
155
200
|
return ldscore_chunk, spots_name
|
156
201
|
|
157
|
-
def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
|
202
|
+
def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
|
203
|
+
self, mk_score_chunk, drop_dummy_na=True
|
204
|
+
):
|
158
205
|
"""Calculate LD score using SNP-Gene weight matrix by chunk."""
|
159
206
|
if drop_dummy_na:
|
160
207
|
ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
|
@@ -166,7 +213,7 @@ class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
|
|
166
213
|
def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config):
|
167
214
|
"""Load LD score chunk from feather format."""
|
168
215
|
sample_name = config.sample_name
|
169
|
-
ld_file_spatial = f
|
216
|
+
ld_file_spatial = f"{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}."
|
170
217
|
ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
|
171
218
|
ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
|
172
219
|
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
@@ -176,38 +223,42 @@ def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_p
|
|
176
223
|
|
177
224
|
def run_spatial_ldsc(config: SpatialLDSCConfig):
|
178
225
|
"""Run spatial LDSC analysis."""
|
179
|
-
logger.info(f
|
226
|
+
logger.info(f"------Running Spatial LDSC for {config.sample_name}...")
|
180
227
|
n_blocks = config.n_blocks
|
181
228
|
sample_name = config.sample_name
|
182
229
|
|
183
230
|
# Load regression weights
|
184
231
|
w_ld = _read_w_ld(config.w_file)
|
185
|
-
|
186
|
-
w_ld.set_index('SNP', inplace=True)
|
232
|
+
w_ld.set_index("SNP", inplace=True)
|
187
233
|
|
188
|
-
ld_file_baseline = f
|
234
|
+
ld_file_baseline = f"{config.ldscore_save_dir}/baseline/baseline."
|
189
235
|
ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
|
190
236
|
baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
|
191
|
-
baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
|
192
237
|
|
193
|
-
sumstats_cleaned_dict, common_snp_among_all_sumstats =
|
194
|
-
|
195
|
-
|
238
|
+
sumstats_cleaned_dict, common_snp_among_all_sumstats = (
|
239
|
+
_get_sumstats_with_common_snp_from_sumstats_dict(
|
240
|
+
config.sumstats_config_dict, baseline_and_w_ld_common_snp, chisq_max=config.chisq_max
|
241
|
+
)
|
242
|
+
)
|
243
|
+
common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(
|
244
|
+
common_snp_among_all_sumstats
|
245
|
+
)
|
196
246
|
|
197
247
|
if not pd.Series(common_snp_among_all_sumstats_pos).is_monotonic_increasing:
|
198
|
-
raise ValueError(
|
248
|
+
raise ValueError("common_snp_among_all_sumstats_pos is not monotonic increasing")
|
199
249
|
|
200
250
|
if len(common_snp_among_all_sumstats) < 200000:
|
201
251
|
logger.warning(
|
202
|
-
f
|
252
|
+
f"!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files."
|
253
|
+
)
|
203
254
|
|
204
255
|
ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
|
205
256
|
w_ld = w_ld.loc[common_snp_among_all_sumstats]
|
206
257
|
|
207
258
|
# Load additional baseline annotations if needed
|
208
259
|
if config.use_additional_baseline_annotation:
|
209
|
-
logger.info(
|
210
|
-
ld_file_baseline_additional = f
|
260
|
+
logger.info("Using additional baseline annotations")
|
261
|
+
ld_file_baseline_additional = f"{config.ldscore_save_dir}/additional_baseline/baseline."
|
211
262
|
ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
|
212
263
|
ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
|
213
264
|
ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
|
@@ -215,10 +266,12 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
|
|
215
266
|
|
216
267
|
# Initialize s_ldsc once if quick_mode
|
217
268
|
s_ldsc = None
|
218
|
-
if config.ldscore_save_format ==
|
219
|
-
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(
|
269
|
+
if config.ldscore_save_format == "quick_mode":
|
270
|
+
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(
|
271
|
+
config, common_snp_among_all_sumstats_pos
|
272
|
+
)
|
220
273
|
total_chunk_number_found = len(s_ldsc.chunk_starts)
|
221
|
-
logger.info(f
|
274
|
+
logger.info(f"Split data into {total_chunk_number_found} chunks")
|
222
275
|
else:
|
223
276
|
total_chunk_number_found = determine_total_chunks(config)
|
224
277
|
|
@@ -227,12 +280,12 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
|
|
227
280
|
|
228
281
|
# Load zarr file if needed
|
229
282
|
zarr_file, spots_name = None, None
|
230
|
-
if config.ldscore_save_format ==
|
231
|
-
zarr_path = Path(config.ldscore_save_dir) / f
|
283
|
+
if config.ldscore_save_format == "zarr":
|
284
|
+
zarr_path = Path(config.ldscore_save_dir) / f"{config.sample_name}.ldscore.zarr"
|
232
285
|
if not zarr_path.exists():
|
233
|
-
raise FileNotFoundError(f
|
286
|
+
raise FileNotFoundError(f"{zarr_path} not found, which is required for zarr format")
|
234
287
|
zarr_file = zarr.open(str(zarr_path))
|
235
|
-
spots_name = zarr_file.attrs[
|
288
|
+
spots_name = zarr_file.attrs["spot_names"]
|
236
289
|
|
237
290
|
output_dict = defaultdict(list)
|
238
291
|
for chunk_index in range(start_chunk, end_chunk + 1):
|
@@ -242,7 +295,7 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
|
|
242
295
|
config,
|
243
296
|
zarr_file,
|
244
297
|
spots_name,
|
245
|
-
s_ldsc # Pass s_ldsc to the function
|
298
|
+
s_ldsc, # Pass s_ldsc to the function
|
246
299
|
)
|
247
300
|
ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
|
248
301
|
|
@@ -251,7 +304,9 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
|
|
251
304
|
baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
|
252
305
|
w_ld_common_snp = w_ld.astype(np.float32, copy=False)
|
253
306
|
|
254
|
-
baseline_annotation =
|
307
|
+
baseline_annotation = (
|
308
|
+
baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
|
309
|
+
)
|
255
310
|
baseline_annotation = append_intercept(baseline_annotation)
|
256
311
|
|
257
312
|
Nbar = sumstats.N.mean()
|
@@ -265,7 +320,7 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
|
|
265
320
|
baseline_annotation=baseline_annotation,
|
266
321
|
w_ld_common_snp=w_ld_common_snp,
|
267
322
|
Nbar=Nbar,
|
268
|
-
n_blocks=n_blocks
|
323
|
+
n_blocks=n_blocks,
|
269
324
|
)
|
270
325
|
|
271
326
|
out_chunk = thread_map(
|
@@ -273,35 +328,39 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
|
|
273
328
|
range(chunk_size),
|
274
329
|
max_workers=config.num_processes,
|
275
330
|
chunksize=10,
|
276
|
-
desc=f
|
331
|
+
desc=f"Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}",
|
277
332
|
)
|
278
333
|
|
279
|
-
out_chunk = pd.DataFrame.from_records(
|
334
|
+
out_chunk = pd.DataFrame.from_records(
|
335
|
+
out_chunk, columns=["beta", "se"], index=spatial_annotation_cnames
|
336
|
+
)
|
280
337
|
nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
|
281
338
|
if len(nan_spots) > 0:
|
282
|
-
logger.info(
|
339
|
+
logger.info(
|
340
|
+
f"Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed."
|
341
|
+
)
|
283
342
|
out_chunk = out_chunk.dropna()
|
284
|
-
out_chunk[
|
285
|
-
out_chunk[
|
343
|
+
out_chunk["z"] = out_chunk.beta / out_chunk.se
|
344
|
+
out_chunk["p"] = norm.sf(out_chunk["z"])
|
286
345
|
output_dict[trait_name].append(out_chunk)
|
287
346
|
|
288
347
|
del spatial_annotation, baseline_annotation, w_ld_common_snp
|
289
348
|
gc.collect()
|
290
349
|
|
291
350
|
save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk)
|
292
|
-
logger.info(f
|
351
|
+
logger.info(f"------Spatial LDSC for {sample_name} finished!")
|
293
352
|
|
294
353
|
|
295
354
|
def determine_total_chunks(config):
|
296
355
|
"""Determine total number of chunks based on the ldscore save format."""
|
297
|
-
if config.ldscore_save_format ==
|
356
|
+
if config.ldscore_save_format == "quick_mode":
|
298
357
|
s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, [])
|
299
358
|
total_chunk_number_found = len(s_ldsc.chunk_starts)
|
300
|
-
logger.info(f
|
359
|
+
logger.info(f"Split data into {total_chunk_number_found} chunks")
|
301
360
|
else:
|
302
361
|
all_file = os.listdir(config.ldscore_save_dir)
|
303
|
-
total_chunk_number_found = sum(
|
304
|
-
logger.info(f
|
362
|
+
total_chunk_number_found = sum("chunk" in name for name in all_file)
|
363
|
+
logger.info(f"Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}")
|
305
364
|
return total_chunk_number_found
|
306
365
|
|
307
366
|
|
@@ -309,36 +368,49 @@ def determine_chunk_range(config, total_chunk_number_found):
|
|
309
368
|
"""Determine the range of chunks to process."""
|
310
369
|
if config.all_chunk is None:
|
311
370
|
if config.chunk_range is not None:
|
312
|
-
if not (1 <= config.chunk_range[0] <= total_chunk_number_found) or not (
|
313
|
-
|
371
|
+
if not (1 <= config.chunk_range[0] <= total_chunk_number_found) or not (
|
372
|
+
1 <= config.chunk_range[1] <= total_chunk_number_found
|
373
|
+
):
|
374
|
+
raise ValueError("Chunk range out of bound. It should be in [1, all_chunk]")
|
314
375
|
start_chunk, end_chunk = config.chunk_range
|
315
|
-
logger.info(
|
376
|
+
logger.info(
|
377
|
+
f"Chunk range provided, using chunked files from {start_chunk} to {end_chunk}"
|
378
|
+
)
|
316
379
|
else:
|
317
380
|
start_chunk, end_chunk = 1, total_chunk_number_found
|
318
381
|
else:
|
319
382
|
all_chunk = config.all_chunk
|
320
|
-
logger.info(f
|
383
|
+
logger.info(f"Using {all_chunk} chunked files by provided argument")
|
321
384
|
start_chunk, end_chunk = 1, all_chunk
|
322
385
|
return start_chunk, end_chunk
|
323
386
|
|
324
387
|
|
325
|
-
def load_ldscore_chunk(
|
388
|
+
def load_ldscore_chunk(
|
389
|
+
chunk_index,
|
390
|
+
common_snp_among_all_sumstats_pos,
|
391
|
+
config,
|
392
|
+
zarr_file=None,
|
393
|
+
spots_name=None,
|
394
|
+
s_ldsc=None,
|
395
|
+
):
|
326
396
|
"""Load LD score chunk based on save format."""
|
327
|
-
if config.ldscore_save_format ==
|
328
|
-
return load_ldscore_chunk_from_feather(
|
329
|
-
|
397
|
+
if config.ldscore_save_format == "feather":
|
398
|
+
return load_ldscore_chunk_from_feather(
|
399
|
+
chunk_index, common_snp_among_all_sumstats_pos, config
|
400
|
+
)
|
401
|
+
elif config.ldscore_save_format == "zarr":
|
330
402
|
ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
|
331
403
|
start_spot = (chunk_index - 1) * zarr_file.chunks[1]
|
332
404
|
ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
|
333
|
-
spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
|
405
|
+
spatial_annotation_cnames = spots_name[start_spot : start_spot + zarr_file.chunks[1]]
|
334
406
|
return ref_ld_spatial, spatial_annotation_cnames
|
335
|
-
elif config.ldscore_save_format ==
|
407
|
+
elif config.ldscore_save_format == "quick_mode":
|
336
408
|
# Use the pre-initialized s_ldsc
|
337
409
|
if s_ldsc is None:
|
338
410
|
raise ValueError("s_ldsc must be provided in quick_mode")
|
339
411
|
return s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
|
340
412
|
else:
|
341
|
-
raise ValueError(f
|
413
|
+
raise ValueError(f"Invalid ld score save format: {config.ldscore_save_format}")
|
342
414
|
|
343
415
|
|
344
416
|
def save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk):
|
@@ -348,13 +420,15 @@ def save_results(output_dict, config, running_chunk_number, start_chunk, end_chu
|
|
348
420
|
out_all = pd.concat(out_chunk_list, axis=0)
|
349
421
|
sample_name = config.sample_name
|
350
422
|
if running_chunk_number == end_chunk - start_chunk + 1:
|
351
|
-
out_file_name = out_dir / f
|
423
|
+
out_file_name = out_dir / f"{sample_name}_{trait_name}.csv.gz"
|
352
424
|
else:
|
353
|
-
out_file_name =
|
354
|
-
|
355
|
-
|
425
|
+
out_file_name = (
|
426
|
+
out_dir / f"{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz"
|
427
|
+
)
|
428
|
+
out_all["spot"] = out_all.index
|
429
|
+
out_all = out_all[["spot", "beta", "se", "z", "p"]]
|
356
430
|
|
357
431
|
# clip the p-values
|
358
|
-
out_all[
|
359
|
-
out_all.to_csv(out_file_name, compression=
|
360
|
-
logger.info(f
|
432
|
+
out_all["p"] = out_all["p"].clip(1e-300, 1)
|
433
|
+
out_all.to_csv(out_file_name, compression="gzip", index=False)
|
434
|
+
logger.info(f"Output saved to {out_file_name} for {trait_name}")
|