gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,9 +14,9 @@ from tqdm.contrib.concurrent import thread_map
14
14
 
15
15
  import gsMap.utils.jackknife as jk
16
16
  from gsMap.config import SpatialLDSCConfig
17
- from gsMap.utils.regression_read import _read_sumstats, _read_w_ld, _read_ref_ld_v2
17
+ from gsMap.utils.regression_read import _read_ref_ld_v2, _read_sumstats, _read_w_ld
18
18
 
19
- logger = logging.getLogger('gsMap.spatial_ldsc')
19
+ logger = logging.getLogger("gsMap.spatial_ldsc")
20
20
 
21
21
 
22
22
  def _coef_new(jknife, Nbar):
@@ -39,14 +39,16 @@ def filter_sumstats_by_chisq(sumstats, chisq_max):
39
39
  before_len = len(sumstats)
40
40
  if chisq_max is None:
41
41
  chisq_max = max(0.001 * sumstats.N.max(), 80)
42
- logger.info(f'No chi^2 threshold provided, using {chisq_max} as default')
43
- sumstats['chisq'] = sumstats.Z ** 2
42
+ logger.info(f"No chi^2 threshold provided, using {chisq_max} as default")
43
+ sumstats["chisq"] = sumstats.Z**2
44
44
  sumstats = sumstats[sumstats.chisq < chisq_max]
45
45
  after_len = len(sumstats)
46
46
  if after_len < before_len:
47
- logger.info(f'Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)')
47
+ logger.info(
48
+ f"Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)"
49
+ )
48
50
  else:
49
- logger.info(f'No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)')
51
+ logger.info(f"No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)")
50
52
  return sumstats
51
53
 
52
54
 
@@ -73,17 +75,37 @@ def weights(ld, w_ld, N, M, hsq, intercept=1):
73
75
  def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
74
76
  """Optimized function to calculate initial weights."""
75
77
  tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
76
- initial_w = weights(x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept)
78
+ initial_w = weights(
79
+ x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept
80
+ )
77
81
  initial_w = np.sqrt(initial_w)
78
82
  return initial_w
79
83
 
80
84
 
81
- def jackknife_for_processmap(spot_id, spatial_annotation, ref_ld_baseline_column_sum, sumstats, baseline_annotation, w_ld_common_snp, Nbar, n_blocks):
85
+ def jackknife_for_processmap(
86
+ spot_id,
87
+ spatial_annotation,
88
+ ref_ld_baseline_column_sum,
89
+ sumstats,
90
+ baseline_annotation,
91
+ w_ld_common_snp,
92
+ Nbar,
93
+ n_blocks,
94
+ ):
82
95
  """Perform jackknife resampling for a given spot."""
83
96
  spot_spatial_annotation = spatial_annotation[:, spot_id]
84
97
  spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
85
- initial_w = get_weight_optimized(sumstats, x_tot_precomputed=spot_x_tot_precomputed,
86
- M_tot=10000, w_ld=w_ld_common_snp, intercept=1).astype(np.float32).reshape((-1, 1))
98
+ initial_w = (
99
+ get_weight_optimized(
100
+ sumstats,
101
+ x_tot_precomputed=spot_x_tot_precomputed,
102
+ M_tot=10000,
103
+ w_ld=w_ld_common_snp,
104
+ intercept=1,
105
+ )
106
+ .astype(np.float32)
107
+ .reshape((-1, 1))
108
+ )
87
109
  initial_w_scaled = initial_w / np.sum(initial_w)
88
110
  baseline_annotation_spot = baseline_annotation * initial_w_scaled
89
111
  spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
@@ -93,68 +115,93 @@ def jackknife_for_processmap(spot_id, spatial_annotation, ref_ld_baseline_column
93
115
  try:
94
116
  jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
95
117
  except np.linalg.LinAlgError as e:
96
- logger.warning(f'LinAlgError: {e}')
118
+ logger.warning(f"LinAlgError: {e}")
97
119
  return np.nan, np.nan
98
120
  return _coef_new(jknife, Nbar)
99
121
 
100
122
 
101
- def _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
123
+ def _preprocess_sumstats(
124
+ trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None
125
+ ):
102
126
  """Preprocess summary statistics."""
103
127
  sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
104
- sumstats.set_index('SNP', inplace=True)
128
+ sumstats.set_index("SNP", inplace=True)
105
129
  sumstats = sumstats.astype(np.float32)
106
130
  sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
107
131
  common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
108
132
  if len(common_snp) < 200000:
109
- logger.warning(f'WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad.')
133
+ logger.warning(
134
+ f"WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad."
135
+ )
110
136
  sumstats = sumstats.loc[common_snp]
111
- sumstats['common_index_pos'] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(sumstats.index)
137
+ sumstats["common_index_pos"] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(
138
+ sumstats.index
139
+ )
112
140
  return sumstats
113
141
 
114
142
 
115
- def _get_sumstats_with_common_snp_from_sumstats_dict(sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None):
143
+ def _get_sumstats_with_common_snp_from_sumstats_dict(
144
+ sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None
145
+ ):
116
146
  """Get summary statistics with common SNPs among all traits."""
117
- logger.info('Validating sumstats files...')
147
+ logger.info("Validating sumstats files...")
118
148
  for trait_name, sumstat_file_path in sumstats_config_dict.items():
119
149
  if not os.path.exists(sumstat_file_path):
120
- raise FileNotFoundError(f'{sumstat_file_path} not found')
150
+ raise FileNotFoundError(f"{sumstat_file_path} not found")
121
151
  sumstats_cleaned_dict = {}
122
152
  for trait_name, sumstat_file_path in sumstats_config_dict.items():
123
- sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(trait_name, sumstat_file_path, baseline_and_w_ld_common_snp, chisq_max)
153
+ sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(
154
+ trait_name, sumstat_file_path, baseline_and_w_ld_common_snp, chisq_max
155
+ )
124
156
  common_snp_among_all_sumstats = None
125
157
  for trait_name, sumstats in sumstats_cleaned_dict.items():
126
158
  if common_snp_among_all_sumstats is None:
127
159
  common_snp_among_all_sumstats = sumstats.index
128
160
  else:
129
- common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(sumstats.index)
161
+ common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(
162
+ sumstats.index
163
+ )
130
164
  for trait_name, sumstats in sumstats_cleaned_dict.items():
131
165
  sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
132
- logger.info(f'Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}')
166
+ logger.info(f"Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}")
133
167
  return sumstats_cleaned_dict, common_snp_among_all_sumstats
134
168
 
135
169
 
136
170
  class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
137
171
  """Class to handle pre-calculated SNP-Gene weight matrix for quick mode."""
172
+
138
173
  def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
139
174
  self.config = config
140
- mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM')
175
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM")
141
176
  mk_score_genes = mk_score.index
142
177
  snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
143
178
  common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
144
- common_snps = snp_gene_weight_adata.obs.index
145
- self.snp_gene_weight_matrix = snp_gene_weight_adata[common_snp_among_all_sumstats_pos, common_genes.to_list()].X
179
+ # common_snps = snp_gene_weight_adata.obs.index
180
+ self.snp_gene_weight_matrix = snp_gene_weight_adata[
181
+ common_snp_among_all_sumstats_pos, common_genes.to_list()
182
+ ].X
146
183
  self.mk_score_common = mk_score.loc[common_genes]
147
- self.chunk_starts = list(range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode))
184
+ self.chunk_starts = list(
185
+ range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode)
186
+ )
148
187
 
149
188
  def fetch_ldscore_by_chunk(self, chunk_index):
150
189
  """Fetch LD score by chunk."""
151
190
  chunk_start = self.chunk_starts[chunk_index]
152
- mk_score_chunk = self.mk_score_common.iloc[:, chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
153
- ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(mk_score_chunk, drop_dummy_na=False)
154
- spots_name = self.mk_score_common.columns[chunk_start:chunk_start + self.config.spots_per_chunk_quick_mode]
191
+ mk_score_chunk = self.mk_score_common.iloc[
192
+ :, chunk_start : chunk_start + self.config.spots_per_chunk_quick_mode
193
+ ]
194
+ ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
195
+ mk_score_chunk, drop_dummy_na=False
196
+ )
197
+ spots_name = self.mk_score_common.columns[
198
+ chunk_start : chunk_start + self.config.spots_per_chunk_quick_mode
199
+ ]
155
200
  return ldscore_chunk, spots_name
156
201
 
157
- def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(self, mk_score_chunk, drop_dummy_na=True):
202
+ def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
203
+ self, mk_score_chunk, drop_dummy_na=True
204
+ ):
158
205
  """Calculate LD score using SNP-Gene weight matrix by chunk."""
159
206
  if drop_dummy_na:
160
207
  ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
@@ -166,7 +213,7 @@ class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
166
213
  def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config):
167
214
  """Load LD score chunk from feather format."""
168
215
  sample_name = config.sample_name
169
- ld_file_spatial = f'{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.'
216
+ ld_file_spatial = f"{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}."
170
217
  ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
171
218
  ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
172
219
  ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
@@ -176,38 +223,42 @@ def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_p
176
223
 
177
224
  def run_spatial_ldsc(config: SpatialLDSCConfig):
178
225
  """Run spatial LDSC analysis."""
179
- logger.info(f'------Running Spatial LDSC for {config.sample_name}...')
226
+ logger.info(f"------Running Spatial LDSC for {config.sample_name}...")
180
227
  n_blocks = config.n_blocks
181
228
  sample_name = config.sample_name
182
229
 
183
230
  # Load regression weights
184
231
  w_ld = _read_w_ld(config.w_file)
185
- w_ld_cname = w_ld.columns[1]
186
- w_ld.set_index('SNP', inplace=True)
232
+ w_ld.set_index("SNP", inplace=True)
187
233
 
188
- ld_file_baseline = f'{config.ldscore_save_dir}/baseline/baseline.'
234
+ ld_file_baseline = f"{config.ldscore_save_dir}/baseline/baseline."
189
235
  ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
190
236
  baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
191
- baseline_and_w_ld_common_snp_pos = pd.Index(ref_ld_baseline.index).get_indexer(baseline_and_w_ld_common_snp)
192
237
 
193
- sumstats_cleaned_dict, common_snp_among_all_sumstats = _get_sumstats_with_common_snp_from_sumstats_dict(
194
- config.sumstats_config_dict, baseline_and_w_ld_common_snp, chisq_max=config.chisq_max)
195
- common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(common_snp_among_all_sumstats)
238
+ sumstats_cleaned_dict, common_snp_among_all_sumstats = (
239
+ _get_sumstats_with_common_snp_from_sumstats_dict(
240
+ config.sumstats_config_dict, baseline_and_w_ld_common_snp, chisq_max=config.chisq_max
241
+ )
242
+ )
243
+ common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(
244
+ common_snp_among_all_sumstats
245
+ )
196
246
 
197
247
  if not pd.Series(common_snp_among_all_sumstats_pos).is_monotonic_increasing:
198
- raise ValueError('common_snp_among_all_sumstats_pos is not monotonic increasing')
248
+ raise ValueError("common_snp_among_all_sumstats_pos is not monotonic increasing")
199
249
 
200
250
  if len(common_snp_among_all_sumstats) < 200000:
201
251
  logger.warning(
202
- f'!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files.')
252
+ f"!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files."
253
+ )
203
254
 
204
255
  ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
205
256
  w_ld = w_ld.loc[common_snp_among_all_sumstats]
206
257
 
207
258
  # Load additional baseline annotations if needed
208
259
  if config.use_additional_baseline_annotation:
209
- logger.info('Using additional baseline annotations')
210
- ld_file_baseline_additional = f'{config.ldscore_save_dir}/additional_baseline/baseline.'
260
+ logger.info("Using additional baseline annotations")
261
+ ld_file_baseline_additional = f"{config.ldscore_save_dir}/additional_baseline/baseline."
211
262
  ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
212
263
  ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
213
264
  ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
@@ -215,10 +266,12 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
215
266
 
216
267
  # Initialize s_ldsc once if quick_mode
217
268
  s_ldsc = None
218
- if config.ldscore_save_format == 'quick_mode':
219
- s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, common_snp_among_all_sumstats_pos)
269
+ if config.ldscore_save_format == "quick_mode":
270
+ s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(
271
+ config, common_snp_among_all_sumstats_pos
272
+ )
220
273
  total_chunk_number_found = len(s_ldsc.chunk_starts)
221
- logger.info(f'Split data into {total_chunk_number_found} chunks')
274
+ logger.info(f"Split data into {total_chunk_number_found} chunks")
222
275
  else:
223
276
  total_chunk_number_found = determine_total_chunks(config)
224
277
 
@@ -227,12 +280,12 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
227
280
 
228
281
  # Load zarr file if needed
229
282
  zarr_file, spots_name = None, None
230
- if config.ldscore_save_format == 'zarr':
231
- zarr_path = Path(config.ldscore_save_dir) / f'{config.sample_name}.ldscore.zarr'
283
+ if config.ldscore_save_format == "zarr":
284
+ zarr_path = Path(config.ldscore_save_dir) / f"{config.sample_name}.ldscore.zarr"
232
285
  if not zarr_path.exists():
233
- raise FileNotFoundError(f'{zarr_path} not found, which is required for zarr format')
286
+ raise FileNotFoundError(f"{zarr_path} not found, which is required for zarr format")
234
287
  zarr_file = zarr.open(str(zarr_path))
235
- spots_name = zarr_file.attrs['spot_names']
288
+ spots_name = zarr_file.attrs["spot_names"]
236
289
 
237
290
  output_dict = defaultdict(list)
238
291
  for chunk_index in range(start_chunk, end_chunk + 1):
@@ -242,7 +295,7 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
242
295
  config,
243
296
  zarr_file,
244
297
  spots_name,
245
- s_ldsc # Pass s_ldsc to the function
298
+ s_ldsc, # Pass s_ldsc to the function
246
299
  )
247
300
  ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
248
301
 
@@ -251,7 +304,9 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
251
304
  baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
252
305
  w_ld_common_snp = w_ld.astype(np.float32, copy=False)
253
306
 
254
- baseline_annotation = baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
307
+ baseline_annotation = (
308
+ baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
309
+ )
255
310
  baseline_annotation = append_intercept(baseline_annotation)
256
311
 
257
312
  Nbar = sumstats.N.mean()
@@ -265,7 +320,7 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
265
320
  baseline_annotation=baseline_annotation,
266
321
  w_ld_common_snp=w_ld_common_snp,
267
322
  Nbar=Nbar,
268
- n_blocks=n_blocks
323
+ n_blocks=n_blocks,
269
324
  )
270
325
 
271
326
  out_chunk = thread_map(
@@ -273,35 +328,39 @@ def run_spatial_ldsc(config: SpatialLDSCConfig):
273
328
  range(chunk_size),
274
329
  max_workers=config.num_processes,
275
330
  chunksize=10,
276
- desc=f'Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}',
331
+ desc=f"Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}",
277
332
  )
278
333
 
279
- out_chunk = pd.DataFrame.from_records(out_chunk, columns=['beta', 'se'], index=spatial_annotation_cnames)
334
+ out_chunk = pd.DataFrame.from_records(
335
+ out_chunk, columns=["beta", "se"], index=spatial_annotation_cnames
336
+ )
280
337
  nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
281
338
  if len(nan_spots) > 0:
282
- logger.info(f'Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed.')
339
+ logger.info(
340
+ f"Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed."
341
+ )
283
342
  out_chunk = out_chunk.dropna()
284
- out_chunk['z'] = out_chunk.beta / out_chunk.se
285
- out_chunk['p'] = norm.sf(out_chunk['z'])
343
+ out_chunk["z"] = out_chunk.beta / out_chunk.se
344
+ out_chunk["p"] = norm.sf(out_chunk["z"])
286
345
  output_dict[trait_name].append(out_chunk)
287
346
 
288
347
  del spatial_annotation, baseline_annotation, w_ld_common_snp
289
348
  gc.collect()
290
349
 
291
350
  save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk)
292
- logger.info(f'------Spatial LDSC for {sample_name} finished!')
351
+ logger.info(f"------Spatial LDSC for {sample_name} finished!")
293
352
 
294
353
 
295
354
  def determine_total_chunks(config):
296
355
  """Determine total number of chunks based on the ldscore save format."""
297
- if config.ldscore_save_format == 'quick_mode':
356
+ if config.ldscore_save_format == "quick_mode":
298
357
  s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, [])
299
358
  total_chunk_number_found = len(s_ldsc.chunk_starts)
300
- logger.info(f'Split data into {total_chunk_number_found} chunks')
359
+ logger.info(f"Split data into {total_chunk_number_found} chunks")
301
360
  else:
302
361
  all_file = os.listdir(config.ldscore_save_dir)
303
- total_chunk_number_found = sum('chunk' in name for name in all_file)
304
- logger.info(f'Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}')
362
+ total_chunk_number_found = sum("chunk" in name for name in all_file)
363
+ logger.info(f"Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}")
305
364
  return total_chunk_number_found
306
365
 
307
366
 
@@ -309,36 +368,49 @@ def determine_chunk_range(config, total_chunk_number_found):
309
368
  """Determine the range of chunks to process."""
310
369
  if config.all_chunk is None:
311
370
  if config.chunk_range is not None:
312
- if not (1 <= config.chunk_range[0] <= total_chunk_number_found) or not (1 <= config.chunk_range[1] <= total_chunk_number_found):
313
- raise ValueError('Chunk range out of bound. It should be in [1, all_chunk]')
371
+ if not (1 <= config.chunk_range[0] <= total_chunk_number_found) or not (
372
+ 1 <= config.chunk_range[1] <= total_chunk_number_found
373
+ ):
374
+ raise ValueError("Chunk range out of bound. It should be in [1, all_chunk]")
314
375
  start_chunk, end_chunk = config.chunk_range
315
- logger.info(f'Chunk range provided, using chunked files from {start_chunk} to {end_chunk}')
376
+ logger.info(
377
+ f"Chunk range provided, using chunked files from {start_chunk} to {end_chunk}"
378
+ )
316
379
  else:
317
380
  start_chunk, end_chunk = 1, total_chunk_number_found
318
381
  else:
319
382
  all_chunk = config.all_chunk
320
- logger.info(f'Using {all_chunk} chunked files by provided argument')
383
+ logger.info(f"Using {all_chunk} chunked files by provided argument")
321
384
  start_chunk, end_chunk = 1, all_chunk
322
385
  return start_chunk, end_chunk
323
386
 
324
387
 
325
- def load_ldscore_chunk(chunk_index, common_snp_among_all_sumstats_pos, config, zarr_file=None, spots_name=None, s_ldsc=None):
388
+ def load_ldscore_chunk(
389
+ chunk_index,
390
+ common_snp_among_all_sumstats_pos,
391
+ config,
392
+ zarr_file=None,
393
+ spots_name=None,
394
+ s_ldsc=None,
395
+ ):
326
396
  """Load LD score chunk based on save format."""
327
- if config.ldscore_save_format == 'feather':
328
- return load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config)
329
- elif config.ldscore_save_format == 'zarr':
397
+ if config.ldscore_save_format == "feather":
398
+ return load_ldscore_chunk_from_feather(
399
+ chunk_index, common_snp_among_all_sumstats_pos, config
400
+ )
401
+ elif config.ldscore_save_format == "zarr":
330
402
  ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
331
403
  start_spot = (chunk_index - 1) * zarr_file.chunks[1]
332
404
  ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
333
- spatial_annotation_cnames = spots_name[start_spot:start_spot + zarr_file.chunks[1]]
405
+ spatial_annotation_cnames = spots_name[start_spot : start_spot + zarr_file.chunks[1]]
334
406
  return ref_ld_spatial, spatial_annotation_cnames
335
- elif config.ldscore_save_format == 'quick_mode':
407
+ elif config.ldscore_save_format == "quick_mode":
336
408
  # Use the pre-initialized s_ldsc
337
409
  if s_ldsc is None:
338
410
  raise ValueError("s_ldsc must be provided in quick_mode")
339
411
  return s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
340
412
  else:
341
- raise ValueError(f'Invalid ld score save format: {config.ldscore_save_format}')
413
+ raise ValueError(f"Invalid ld score save format: {config.ldscore_save_format}")
342
414
 
343
415
 
344
416
  def save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk):
@@ -348,13 +420,15 @@ def save_results(output_dict, config, running_chunk_number, start_chunk, end_chu
348
420
  out_all = pd.concat(out_chunk_list, axis=0)
349
421
  sample_name = config.sample_name
350
422
  if running_chunk_number == end_chunk - start_chunk + 1:
351
- out_file_name = out_dir / f'{sample_name}_{trait_name}.csv.gz'
423
+ out_file_name = out_dir / f"{sample_name}_{trait_name}.csv.gz"
352
424
  else:
353
- out_file_name = out_dir / f'{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz'
354
- out_all['spot'] = out_all.index
355
- out_all = out_all[['spot', 'beta', 'se', 'z', 'p']]
425
+ out_file_name = (
426
+ out_dir / f"{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz"
427
+ )
428
+ out_all["spot"] = out_all.index
429
+ out_all = out_all[["spot", "beta", "se", "z", "p"]]
356
430
 
357
431
  # clip the p-values
358
- out_all['p'] = out_all['p'].clip(1e-300, 1)
359
- out_all.to_csv(out_file_name, compression='gzip', index=False)
360
- logger.info(f'Output saved to {out_file_name} for {trait_name}')
432
+ out_all["p"] = out_all["p"].clip(1e-300, 1)
433
+ out_all.to_csv(out_file_name, compression="gzip", index=False)
434
+ logger.info(f"Output saved to {out_file_name} for {trait_name}")