gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,439 @@
1
+ import gc
2
+ import logging
3
+ import os
4
+ from collections import defaultdict
5
+ from functools import partial
6
+ from pathlib import Path
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import pandas as pd
11
+ import zarr
12
+ from scipy.stats import norm
13
+ from tqdm.contrib.concurrent import thread_map
14
+
15
+ import gsMap.utils.jackknife as jk
16
+ from gsMap.config import SpatialLDSCConfig
17
+ from gsMap.utils.regression_read import _read_ref_ld_v2, _read_sumstats, _read_w_ld
18
+
19
+ logger = logging.getLogger("gsMap.spatial_ldsc")
20
+
21
+
22
+ def _coef_new(jknife, Nbar):
23
+ """Calculate coefficients adjusted by Nbar."""
24
+ est_ = jknife.jknife_est[0, 0] / Nbar
25
+ se_ = jknife.jknife_se[0, 0] / Nbar
26
+ return est_, se_
27
+
28
+
29
+ def append_intercept(x):
30
+ """Append an intercept term to the design matrix."""
31
+ n_row = x.shape[0]
32
+ intercept = np.ones((n_row, 1))
33
+ x_new = np.concatenate((x, intercept), axis=1)
34
+ return x_new
35
+
36
+
37
+ def filter_sumstats_by_chisq(sumstats, chisq_max):
38
+ """Filter summary statistics based on chi-squared threshold."""
39
+ before_len = len(sumstats)
40
+ if chisq_max is None:
41
+ chisq_max = max(0.001 * sumstats.N.max(), 80)
42
+ logger.info(f"No chi^2 threshold provided, using {chisq_max} as default")
43
+ sumstats["chisq"] = sumstats.Z**2
44
+ sumstats = sumstats[sumstats.chisq < chisq_max]
45
+ after_len = len(sumstats)
46
+ if after_len < before_len:
47
+ logger.info(
48
+ f"Removed {before_len - after_len} SNPs with chi^2 > {chisq_max} ({after_len} SNPs remain)"
49
+ )
50
+ else:
51
+ logger.info(f"No SNPs removed with chi^2 > {chisq_max} ({after_len} SNPs remain)")
52
+ return sumstats
53
+
54
+
55
+ def aggregate(y, x, N, M, intercept=1):
56
+ """Aggregate function used in weight calculation."""
57
+ num = M * (np.mean(y) - intercept)
58
+ denom = np.mean(np.multiply(x, N))
59
+ return num / denom
60
+
61
+
62
+ def weights(ld, w_ld, N, M, hsq, intercept=1):
63
+ """Calculate weights for regression."""
64
+ M = float(M)
65
+ hsq = np.clip(hsq, 0.0, 1.0)
66
+ ld = np.maximum(ld, 1.0)
67
+ w_ld = np.maximum(w_ld, 1.0)
68
+ c = hsq * N / M
69
+ het_w = 1.0 / (2 * np.square(intercept + np.multiply(c, ld)))
70
+ oc_w = 1.0 / w_ld
71
+ w = np.multiply(het_w, oc_w)
72
+ return w
73
+
74
+
75
+ def get_weight_optimized(sumstats, x_tot_precomputed, M_tot, w_ld, intercept=1):
76
+ """Optimized function to calculate initial weights."""
77
+ tot_agg = aggregate(sumstats.chisq, x_tot_precomputed, sumstats.N, M_tot, intercept)
78
+ initial_w = weights(
79
+ x_tot_precomputed, w_ld.LD_weights.values, sumstats.N.values, M_tot, tot_agg, intercept
80
+ )
81
+ initial_w = np.sqrt(initial_w)
82
+ return initial_w
83
+
84
+
85
+ def jackknife_for_processmap(
86
+ spot_id,
87
+ spatial_annotation,
88
+ ref_ld_baseline_column_sum,
89
+ sumstats,
90
+ baseline_annotation,
91
+ w_ld_common_snp,
92
+ Nbar,
93
+ n_blocks,
94
+ ):
95
+ """Perform jackknife resampling for a given spot."""
96
+ spot_spatial_annotation = spatial_annotation[:, spot_id]
97
+ spot_x_tot_precomputed = spot_spatial_annotation + ref_ld_baseline_column_sum
98
+ initial_w = (
99
+ get_weight_optimized(
100
+ sumstats,
101
+ x_tot_precomputed=spot_x_tot_precomputed,
102
+ M_tot=10000,
103
+ w_ld=w_ld_common_snp,
104
+ intercept=1,
105
+ )
106
+ .astype(np.float32)
107
+ .reshape((-1, 1))
108
+ )
109
+ initial_w_scaled = initial_w / np.sum(initial_w)
110
+ baseline_annotation_spot = baseline_annotation * initial_w_scaled
111
+ spatial_annotation_spot = spot_spatial_annotation.reshape((-1, 1)) * initial_w_scaled
112
+ CHISQ = sumstats.chisq.values.reshape((-1, 1))
113
+ y = CHISQ * initial_w_scaled
114
+ x_focal = np.concatenate((spatial_annotation_spot, baseline_annotation_spot), axis=1)
115
+ try:
116
+ jknife = jk.LstsqJackknifeFast(x_focal, y, n_blocks)
117
+ except np.linalg.LinAlgError as e:
118
+ logger.warning(f"LinAlgError: {e}")
119
+ return np.nan, np.nan
120
+ return _coef_new(jknife, Nbar)
121
+
122
+
123
+ def _preprocess_sumstats(
124
+ trait_name, sumstat_file_path, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None
125
+ ):
126
+ """Preprocess summary statistics."""
127
+ sumstats = _read_sumstats(fh=sumstat_file_path, alleles=False, dropna=False)
128
+ sumstats.set_index("SNP", inplace=True)
129
+ sumstats = sumstats.astype(np.float32)
130
+ sumstats = filter_sumstats_by_chisq(sumstats, chisq_max)
131
+ common_snp = baseline_and_w_ld_common_snp.intersection(sumstats.index)
132
+ if len(common_snp) < 200000:
133
+ logger.warning(
134
+ f"WARNING: number of SNPs less than 200k; for {trait_name} this is almost always bad."
135
+ )
136
+ sumstats = sumstats.loc[common_snp]
137
+ sumstats["common_index_pos"] = pd.Index(baseline_and_w_ld_common_snp).get_indexer(
138
+ sumstats.index
139
+ )
140
+ return sumstats
141
+
142
+
143
+ def _get_sumstats_with_common_snp_from_sumstats_dict(
144
+ sumstats_config_dict: dict, baseline_and_w_ld_common_snp: pd.Index, chisq_max=None
145
+ ):
146
+ """Get summary statistics with common SNPs among all traits."""
147
+ logger.info("Validating sumstats files...")
148
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
149
+ if not os.path.exists(sumstat_file_path):
150
+ raise FileNotFoundError(f"{sumstat_file_path} not found")
151
+ sumstats_cleaned_dict = {}
152
+ for trait_name, sumstat_file_path in sumstats_config_dict.items():
153
+ sumstats_cleaned_dict[trait_name] = _preprocess_sumstats(
154
+ trait_name, sumstat_file_path, baseline_and_w_ld_common_snp, chisq_max
155
+ )
156
+ common_snp_among_all_sumstats = None
157
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
158
+ if common_snp_among_all_sumstats is None:
159
+ common_snp_among_all_sumstats = sumstats.index
160
+ else:
161
+ common_snp_among_all_sumstats = common_snp_among_all_sumstats.intersection(
162
+ sumstats.index
163
+ )
164
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
165
+ sumstats_cleaned_dict[trait_name] = sumstats.loc[common_snp_among_all_sumstats]
166
+ logger.info(f"Common SNPs among all sumstats: {len(common_snp_among_all_sumstats)}")
167
+ return sumstats_cleaned_dict, common_snp_among_all_sumstats
168
+
169
+
170
+ class S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix:
171
+ """Class to handle pre-calculated SNP-Gene weight matrix for quick mode."""
172
+
173
+ def __init__(self, config: SpatialLDSCConfig, common_snp_among_all_sumstats_pos):
174
+ self.config = config
175
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index("HUMAN_GENE_SYM")
176
+ mk_score_genes = mk_score.index
177
+ snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
178
+ common_genes = mk_score_genes.intersection(snp_gene_weight_adata.var.index)
179
+ # common_snps = snp_gene_weight_adata.obs.index
180
+ self.snp_gene_weight_matrix = snp_gene_weight_adata[
181
+ common_snp_among_all_sumstats_pos, common_genes.to_list()
182
+ ].X
183
+ self.mk_score_common = mk_score.loc[common_genes]
184
+ self.chunk_starts = list(
185
+ range(0, self.mk_score_common.shape[1], self.config.spots_per_chunk_quick_mode)
186
+ )
187
+
188
+ def fetch_ldscore_by_chunk(self, chunk_index):
189
+ """Fetch LD score by chunk."""
190
+ chunk_start = self.chunk_starts[chunk_index]
191
+ mk_score_chunk = self.mk_score_common.iloc[
192
+ :, chunk_start : chunk_start + self.config.spots_per_chunk_quick_mode
193
+ ]
194
+ ldscore_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
195
+ mk_score_chunk, drop_dummy_na=False
196
+ )
197
+ spots_name = self.mk_score_common.columns[
198
+ chunk_start : chunk_start + self.config.spots_per_chunk_quick_mode
199
+ ]
200
+ return ldscore_chunk, spots_name
201
+
202
+ def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
203
+ self, mk_score_chunk, drop_dummy_na=True
204
+ ):
205
+ """Calculate LD score using SNP-Gene weight matrix by chunk."""
206
+ if drop_dummy_na:
207
+ ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
208
+ else:
209
+ ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
210
+ return ldscore_chr_chunk
211
+
212
+
213
+ def load_ldscore_chunk_from_feather(chunk_index, common_snp_among_all_sumstats_pos, config):
214
+ """Load LD score chunk from feather format."""
215
+ sample_name = config.sample_name
216
+ ld_file_spatial = f"{config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}."
217
+ ref_ld_spatial = _read_ref_ld_v2(ld_file_spatial)
218
+ ref_ld_spatial = ref_ld_spatial.iloc[common_snp_among_all_sumstats_pos]
219
+ ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
220
+ spatial_annotation_cnames = ref_ld_spatial.columns
221
+ return ref_ld_spatial.values, spatial_annotation_cnames
222
+
223
+
224
+ def run_spatial_ldsc(config: SpatialLDSCConfig):
225
+ """Run spatial LDSC analysis."""
226
+ logger.info(f"------Running Spatial LDSC for {config.sample_name}...")
227
+ n_blocks = config.n_blocks
228
+ sample_name = config.sample_name
229
+
230
+ # Load regression weights
231
+ w_ld = _read_w_ld(config.w_file)
232
+ w_ld.set_index("SNP", inplace=True)
233
+
234
+ ld_file_baseline = f"{config.ldscore_save_dir}/baseline/baseline."
235
+ ref_ld_baseline = _read_ref_ld_v2(ld_file_baseline)
236
+ baseline_and_w_ld_common_snp = ref_ld_baseline.index.intersection(w_ld.index)
237
+
238
+ sumstats_cleaned_dict, common_snp_among_all_sumstats = (
239
+ _get_sumstats_with_common_snp_from_sumstats_dict(
240
+ config.sumstats_config_dict, baseline_and_w_ld_common_snp, chisq_max=config.chisq_max
241
+ )
242
+ )
243
+ common_snp_among_all_sumstats_pos = ref_ld_baseline.index.get_indexer(
244
+ common_snp_among_all_sumstats
245
+ )
246
+
247
+ if not pd.Series(common_snp_among_all_sumstats_pos).is_monotonic_increasing:
248
+ raise ValueError("common_snp_among_all_sumstats_pos is not monotonic increasing")
249
+
250
+ if len(common_snp_among_all_sumstats) < 200000:
251
+ logger.warning(
252
+ f"!!!!! WARNING: number of SNPs less than 200k; for {sample_name} this is almost always bad. Please check the sumstats files."
253
+ )
254
+
255
+ ref_ld_baseline = ref_ld_baseline.loc[common_snp_among_all_sumstats]
256
+ w_ld = w_ld.loc[common_snp_among_all_sumstats]
257
+
258
+ # Load additional baseline annotations if needed
259
+ if config.use_additional_baseline_annotation:
260
+ logger.info("Using additional baseline annotations")
261
+ ld_file_baseline_additional = f"{config.ldscore_save_dir}/additional_baseline/baseline."
262
+ ref_ld_baseline_additional = _read_ref_ld_v2(ld_file_baseline_additional)
263
+ ref_ld_baseline_additional = ref_ld_baseline_additional.loc[common_snp_among_all_sumstats]
264
+ ref_ld_baseline = pd.concat([ref_ld_baseline, ref_ld_baseline_additional], axis=1)
265
+ del ref_ld_baseline_additional
266
+
267
+ # Initialize s_ldsc once if quick_mode
268
+ s_ldsc = None
269
+ if config.ldscore_save_format == "quick_mode":
270
+ s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(
271
+ config, common_snp_among_all_sumstats_pos
272
+ )
273
+ total_chunk_number_found = len(s_ldsc.chunk_starts)
274
+ logger.info(f"Split data into {total_chunk_number_found} chunks")
275
+ else:
276
+ total_chunk_number_found = determine_total_chunks(config)
277
+
278
+ start_chunk, end_chunk = determine_cell_indices_range(config, total_chunk_number_found)
279
+ running_chunk_number = end_chunk - start_chunk + 1
280
+
281
+ # Load zarr file if needed
282
+ zarr_file, spots_name = None, None
283
+ if config.ldscore_save_format == "zarr":
284
+ zarr_path = Path(config.ldscore_save_dir) / f"{config.sample_name}.ldscore.zarr"
285
+ if not zarr_path.exists():
286
+ raise FileNotFoundError(f"{zarr_path} not found, which is required for zarr format")
287
+ zarr_file = zarr.open(str(zarr_path))
288
+ spots_name = zarr_file.attrs["spot_names"]
289
+
290
+ output_dict = defaultdict(list)
291
+ for chunk_index in range(start_chunk, end_chunk + 1):
292
+ ref_ld_spatial, spatial_annotation_cnames = load_ldscore_chunk(
293
+ chunk_index,
294
+ common_snp_among_all_sumstats_pos,
295
+ config,
296
+ zarr_file,
297
+ spots_name,
298
+ s_ldsc, # Pass s_ldsc to the function
299
+ )
300
+ ref_ld_baseline_column_sum = ref_ld_baseline.sum(axis=1).values
301
+
302
+ for trait_name, sumstats in sumstats_cleaned_dict.items():
303
+ spatial_annotation = ref_ld_spatial.astype(np.float32, copy=False)
304
+ baseline_annotation = ref_ld_baseline.copy().astype(np.float32, copy=False)
305
+ w_ld_common_snp = w_ld.astype(np.float32, copy=False)
306
+
307
+ baseline_annotation = (
308
+ baseline_annotation * sumstats.N.values.reshape((-1, 1)) / sumstats.N.mean()
309
+ )
310
+ baseline_annotation = append_intercept(baseline_annotation)
311
+
312
+ Nbar = sumstats.N.mean()
313
+ chunk_size = spatial_annotation.shape[1]
314
+
315
+ jackknife_func = partial(
316
+ jackknife_for_processmap,
317
+ spatial_annotation=spatial_annotation,
318
+ ref_ld_baseline_column_sum=ref_ld_baseline_column_sum,
319
+ sumstats=sumstats,
320
+ baseline_annotation=baseline_annotation,
321
+ w_ld_common_snp=w_ld_common_snp,
322
+ Nbar=Nbar,
323
+ n_blocks=n_blocks,
324
+ )
325
+
326
+ out_chunk = thread_map(
327
+ jackknife_func,
328
+ range(chunk_size),
329
+ max_workers=config.num_processes,
330
+ chunksize=10,
331
+ desc=f"Chunk-{chunk_index}/Total-chunk-{running_chunk_number} for {trait_name}",
332
+ )
333
+
334
+ out_chunk = pd.DataFrame.from_records(
335
+ out_chunk, columns=["beta", "se"], index=spatial_annotation_cnames
336
+ )
337
+ nan_spots = out_chunk[out_chunk.isna().any(axis=1)].index
338
+ if len(nan_spots) > 0:
339
+ logger.info(
340
+ f"Nan spots: {nan_spots} in chunk-{chunk_index} for {trait_name}. They are removed."
341
+ )
342
+ out_chunk = out_chunk.dropna()
343
+ out_chunk["z"] = out_chunk.beta / out_chunk.se
344
+ out_chunk["p"] = norm.sf(out_chunk["z"])
345
+ output_dict[trait_name].append(out_chunk)
346
+
347
+ del spatial_annotation, baseline_annotation, w_ld_common_snp
348
+ gc.collect()
349
+
350
+ save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk)
351
+ logger.info(f"------Spatial LDSC for {sample_name} finished!")
352
+
353
+
354
+ def determine_total_chunks(config):
355
+ """Determine total number of chunks based on the ldscore save format."""
356
+ if config.ldscore_save_format == "quick_mode":
357
+ s_ldsc = S_LDSC_Boost_with_pre_calculate_SNP_Gene_weight_matrix(config, [])
358
+ total_chunk_number_found = len(s_ldsc.chunk_starts)
359
+ logger.info(f"Split data into {total_chunk_number_found} chunks")
360
+ else:
361
+ all_file = os.listdir(config.ldscore_save_dir)
362
+ total_chunk_number_found = sum("chunk" in name for name in all_file)
363
+ logger.info(f"Find {total_chunk_number_found} chunked files in {config.ldscore_save_dir}")
364
+ return total_chunk_number_found
365
+
366
+
367
+ def determine_cell_indices_range(config, total_chunk_number_found):
368
+ """Determine the range of cell indices (chunks) to process."""
369
+ if config.cell_indices_range is not None:
370
+ # Convert cell indices to chunk indices (both 0-based)
371
+ start_cell, end_cell = config.cell_indices_range # 0-based [start, end)
372
+ chunk_size = config.spots_per_chunk_quick_mode if hasattr(config, 'spots_per_chunk_quick_mode') else 1000
373
+
374
+ # Calculate which chunks contain these cells
375
+ start_chunk = start_cell // chunk_size # 0-based chunk index
376
+ end_chunk = (end_cell - 1) // chunk_size if end_cell > 0 else 0 # 0-based, inclusive
377
+
378
+ # Validate chunk indices
379
+ if start_chunk >= total_chunk_number_found:
380
+ raise ValueError(f"cell_indices_range start ({start_cell}) maps to chunk {start_chunk} which is >= total chunks ({total_chunk_number_found})")
381
+ if end_chunk >= total_chunk_number_found:
382
+ logger.warning(f"cell_indices_range end ({end_cell}) maps to chunk {end_chunk} which is >= total chunks ({total_chunk_number_found}). Capping to last chunk.")
383
+ end_chunk = total_chunk_number_found - 1
384
+
385
+ logger.info(f"Cell range [{start_cell}, {end_cell}) maps to chunks {start_chunk}-{end_chunk} (0-based)")
386
+ # Convert to 1-based for compatibility with existing code that expects 1-based chunks
387
+ return start_chunk + 1, end_chunk + 1
388
+ else:
389
+ # Process all chunks (return 1-based for compatibility)
390
+ return 1, total_chunk_number_found
391
+
392
+
393
+ def load_ldscore_chunk(
394
+ chunk_index,
395
+ common_snp_among_all_sumstats_pos,
396
+ config,
397
+ zarr_file=None,
398
+ spots_name=None,
399
+ s_ldsc=None,
400
+ ):
401
+ """Load LD score chunk based on save format."""
402
+ if config.ldscore_save_format == "feather":
403
+ return load_ldscore_chunk_from_feather(
404
+ chunk_index, common_snp_among_all_sumstats_pos, config
405
+ )
406
+ elif config.ldscore_save_format == "zarr":
407
+ ref_ld_spatial = zarr_file.blocks[:, chunk_index - 1][common_snp_among_all_sumstats_pos]
408
+ start_spot = (chunk_index - 1) * zarr_file.chunks[1]
409
+ ref_ld_spatial = ref_ld_spatial.astype(np.float32, copy=False)
410
+ spatial_annotation_cnames = spots_name[start_spot : start_spot + zarr_file.chunks[1]]
411
+ return ref_ld_spatial, spatial_annotation_cnames
412
+ elif config.ldscore_save_format == "quick_mode":
413
+ # Use the pre-initialized s_ldsc
414
+ if s_ldsc is None:
415
+ raise ValueError("s_ldsc must be provided in quick_mode")
416
+ return s_ldsc.fetch_ldscore_by_chunk(chunk_index - 1)
417
+ else:
418
+ raise ValueError(f"Invalid ld score save format: {config.ldscore_save_format}")
419
+
420
+
421
+ def save_results(output_dict, config, running_chunk_number, start_chunk, end_chunk):
422
+ """Save the results to the specified directory."""
423
+ out_dir = config.ldsc_save_dir
424
+ for trait_name, out_chunk_list in output_dict.items():
425
+ out_all = pd.concat(out_chunk_list, axis=0)
426
+ sample_name = config.sample_name
427
+ if running_chunk_number == determine_total_chunks(config):
428
+ out_file_name = out_dir / f"{sample_name}_{trait_name}.csv.gz"
429
+ else:
430
+ out_file_name = (
431
+ out_dir / f"{sample_name}_{trait_name}_chunk{start_chunk}-{end_chunk}.csv.gz"
432
+ )
433
+ out_all["spot"] = out_all.index
434
+ out_all = out_all[["spot", "beta", "se", "z", "p"]]
435
+
436
+ # clip the p-values
437
+ out_all["p"] = out_all["p"].clip(1e-300, 1)
438
+ out_all.to_csv(out_file_name, compression="gzip", index=False)
439
+ logger.info(f"Output saved to {out_file_name} for {trait_name}")
File without changes