gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,615 @@
1
+ """
2
+ Chromosome-wise pipeline for LD score calculation using NumPy/Scipy.
3
+
4
+ This module orchestrates the complete workflow:
5
+ 1. Loop through chromosomes
6
+ 2. Construct batches
7
+ 3. Load genotypes once per chromosome
8
+ 4. Process each batch to compute LD weights
9
+ 5. Save results as AnnData
10
+ """
11
+
12
+ import logging
13
+ import sys
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+
17
+ import anndata as ad
18
+ import numpy as np
19
+ import pandas as pd
20
+ import pyranges as pr
21
+ import scipy.sparse
22
+ from rich.progress import (
23
+ BarColumn,
24
+ Progress,
25
+ SpinnerColumn,
26
+ TaskProgressColumn,
27
+ TextColumn,
28
+ TimeElapsedColumn,
29
+ TimeRemainingColumn,
30
+ )
31
+
32
+ from gsMap.config import LDScoreConfig
33
+
34
+ from .batch_construction import construct_batches
35
+ from .compute import (
36
+ compute_batch_weights_sparse,
37
+ compute_ld_scores,
38
+ )
39
+ from .io import PlinkBEDReader
40
+ from .mapping import create_snp_feature_map
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ @dataclass
46
+ class ChromosomeResult:
47
+ """
48
+ Results for a single chromosome.
49
+
50
+ Attributes
51
+ ----------
52
+ hm3_snp_names : List[str]
53
+ Names of HM3 SNPs processed
54
+ weights : Union[scipy.sparse.csr_matrix, np.ndarray]
55
+ Weight matrix for features (sparse or dense), shape (n_hm3_snps, n_features)
56
+ feature_names : List[str]
57
+ Names of mapped features
58
+ """
59
+ hm3_snp_names: list[str]
60
+ hm3_snp_chr: list[int]
61
+ hm3_snp_bp: list[int]
62
+ weights: scipy.sparse.csr_matrix | np.ndarray
63
+ feature_names: list[str]
64
+ mapping_df: pd.DataFrame | None = None
65
+ ld_scores: np.ndarray | None = None
66
+
67
+
68
+ class LDScorePipeline:
69
+ """
70
+ Pipeline for computing LD scores across chromosomes.
71
+ """
72
+
73
+ def __init__(self, config: LDScoreConfig):
74
+ self.config = config
75
+ self.hm3_dir = Path(config.hm3_snp_path)
76
+
77
+ logger.info("=" * 80)
78
+ logger.info("LD Score Pipeline Configuration (NumPy/Scipy)")
79
+ logger.info("=" * 80)
80
+ logger.info(f"PLINK template: {config.bfile_root}")
81
+ logger.info(f"HM3 directory: {config.hm3_snp_path}")
82
+ logger.info(f"Batch size (HM3): {config.batch_size_hm3}")
83
+ logger.info(f"LD window: {config.ld_wind} {config.ld_unit}")
84
+ logger.info(f"MAF filter: {config.maf_min}")
85
+ logger.info(f"Chromosomes: {config.chromosomes}")
86
+ logger.info(f"Output Directory: {config.output_dir}")
87
+ logger.info(f"Output Filename: {config.output_filename}")
88
+ logger.info("=" * 80)
89
+
90
+ def run(self):
91
+ """
92
+ Main entry point. Dispatches to specific pipeline mode based on configuration.
93
+ """
94
+ if self.config.annot_file:
95
+ logger.info(f"Mode: Annotation File ({self.config.annot_file})")
96
+ self._run_with_annotation()
97
+ elif self.config.mapping_file:
98
+ logger.info(f"Mode: SNP-Feature Mapping ({self.config.mapping_type})")
99
+ self._run_with_mapping()
100
+ else:
101
+ raise ValueError("Invalid Configuration: Neither 'annot_file' nor 'mapping_file' specified.")
102
+
103
+ def _run_with_mapping(self):
104
+ """Run the pipeline using SNP-Feature mapping (BED/Dictionary)."""
105
+ # 1. Load Mapping Data
106
+ mapping_data = self._load_mapping_data()
107
+
108
+ results = {}
109
+ for chrom in self.config.chromosomes:
110
+ result = self._process_chromosome_from_mapping(
111
+ chrom,
112
+ mapping_data=mapping_data
113
+ )
114
+ if result is not None:
115
+ results[str(chrom)] = result
116
+
117
+ if not results:
118
+ logger.warning("No results generated. Skipping save.")
119
+ return
120
+
121
+ self._save_aggregated_results(results, self.config.output_dir, self.config.output_filename)
122
+
123
+ def _run_with_annotation(self):
124
+ """Run the pipeline using external annotation matrices."""
125
+ results = {}
126
+ for chrom in self.config.chromosomes:
127
+ result = self._process_chromosome_from_annotation(chrom)
128
+ if result:
129
+ results[str(chrom)] = result
130
+
131
+ if not results:
132
+ logger.warning("No results generated.")
133
+ return
134
+
135
+ self._save_aggregated_results(results, self.config.output_dir, self.config.output_filename)
136
+
137
+ def _load_mapping_data(self) -> pd.DataFrame | dict[str, str]:
138
+ """Helper to load mapping file based on config."""
139
+ logger.info(f"Loading mapping data from: {self.config.mapping_file}")
140
+
141
+ if self.config.mapping_type == 'bed':
142
+ # Use pyranges to read standard BED file
143
+ try:
144
+ bed_pr = pr.read_bed(self.config.mapping_file)
145
+ bed_df = bed_pr.df
146
+
147
+ # Convert pyranges BED format to expected format
148
+ # Standard BED columns: Chromosome, Start, End, Name, Score, Strand
149
+ # Expected format: Feature, Chromosome, Start, End, [Score], [Strand]
150
+
151
+ if 'Name' not in bed_df.columns:
152
+ logger.error("BED file must contain a 'Name' column (4th column in BED6 format)")
153
+ logger.error("Required format: standard BED6 (chr, start, end, name, score, strand)")
154
+ logger.error("Note: BED file should NOT have a header line")
155
+ sys.exit(1)
156
+
157
+ # Rename 'Name' to 'Feature' for internal use
158
+ bed_df = bed_df.rename(columns={'Name': 'Feature'})
159
+
160
+ # Ensure required columns exist
161
+ required_cols = ['Chromosome', 'Start', 'End', 'Feature']
162
+ if not all(col in bed_df.columns for col in required_cols):
163
+ logger.error("BED file missing required columns after parsing")
164
+ logger.error("Required format: standard BED6 (chr, start, end, name, score, strand)")
165
+ logger.error("Note: BED file should NOT have a header line")
166
+ sys.exit(1)
167
+
168
+ logger.info(f"Successfully loaded BED file: {len(bed_df)} features")
169
+ logger.info(f"Columns: {list(bed_df.columns)}")
170
+ return bed_df
171
+
172
+ except Exception as e:
173
+ logger.error(f"Failed to read BED file: {e}")
174
+ logger.error("Required format: standard BED6 (chr, start, end, name, score, strand)")
175
+ logger.error("Note: BED file should NOT have a header line")
176
+ sys.exit(1)
177
+
178
+ elif self.config.mapping_type == 'dict':
179
+ df_map = pd.read_csv(self.config.mapping_file, sep=None, engine='python')
180
+
181
+ if 'SNP' in df_map.columns and 'Feature' in df_map.columns:
182
+ return dict(zip(df_map['SNP'], df_map['Feature'], strict=False))
183
+ elif len(df_map.columns) >= 2:
184
+ logger.info("Assuming first column is SNP and second is Feature.")
185
+ return dict(zip(df_map.iloc[:, 0], df_map.iloc[:, 1], strict=False))
186
+ else:
187
+ raise ValueError(f"Dictionary mapping file {self.config.mapping_file} must have at least 2 columns.")
188
+ else:
189
+ raise ValueError(f"Unsupported mapping_type: {self.config.mapping_type}")
190
+
191
+ def _load_hm3_snps(self, chromosome: int) -> list[str]:
192
+ """Load HM3 SNP list for a chromosome."""
193
+ possible_paths = [
194
+ self.hm3_dir / f"hm.{chromosome}.snp",
195
+ self.hm3_dir / f"hm3_snps.chr{chromosome}.txt",
196
+ self.hm3_dir / f"hapmap3_snps.chr{chromosome}.txt",
197
+ self.hm3_dir / f"chr{chromosome}.snplist",
198
+ self.hm3_dir / f"w_hm3.snplist.chr{chromosome}",
199
+ ]
200
+
201
+ for path in possible_paths:
202
+ if path.exists():
203
+ try:
204
+ snps = pd.read_csv(path, header=None, names=["SNP"])["SNP"].tolist()
205
+ logger.info(f"Loaded {len(snps)} HM3 SNPs from {path}")
206
+ return snps
207
+ except:
208
+ continue
209
+
210
+ logger.warning(f"No HM3 SNP file found for chromosome {chromosome}")
211
+ return []
212
+
213
+ def _load_plink_reader(self, chromosome: int) -> PlinkBEDReader | None:
214
+ """Helper to initialize the PLINK reader for a chromosome."""
215
+ bfile_prefix = self.config.bfile_root.format(chr=chromosome)
216
+ logger.info(f"Loading PLINK data from: {bfile_prefix}")
217
+
218
+ try:
219
+ reader = PlinkBEDReader(
220
+ bfile_prefix,
221
+ maf_min=self.config.maf_min,
222
+ preload=True,
223
+ )
224
+ return reader
225
+ except FileNotFoundError as e:
226
+ logger.error(f"PLINK files not found for chromosome {chromosome}: {e}")
227
+ return None
228
+
229
+ def _process_chromosome_from_mapping(
230
+ self,
231
+ chromosome: int,
232
+ mapping_data: pd.DataFrame | dict[str, str],
233
+ ) -> ChromosomeResult | None:
234
+ """Process a single chromosome using mapping rules (BED/Dict)."""
235
+ logger.info("=" * 80)
236
+ logger.info(f"Processing Chromosome {chromosome}")
237
+ logger.info("=" * 80)
238
+
239
+ target_hm3_snps = self._load_hm3_snps(chromosome)
240
+ if not target_hm3_snps:
241
+ return None
242
+
243
+ reader = self._load_plink_reader(chromosome)
244
+ if not reader:
245
+ return None
246
+
247
+ logger.info(f"Creating SNP-feature mapping for chromosome {chromosome}...")
248
+ # create_snp_feature_map now returns (matrix, names, df)
249
+ mapping_matrix, feature_names, mapping_df = create_snp_feature_map(
250
+ bim_df=reader.bim,
251
+ mapping_type=self.config.mapping_type,
252
+ mapping_data=mapping_data,
253
+ feature_window_size=self.config.feature_window_size,
254
+ strategy=self.config.strategy,
255
+ )
256
+
257
+ # Save Curated Mapping if it exists (BED type)
258
+ if mapping_df is not None:
259
+ logger.debug(f" Captured SNP-Feature mapping ({len(mapping_df)} rows) for chromosome {chromosome}")
260
+
261
+ # Ensure feature names align with matrix width (handle Unmapped bin)
262
+ if mapping_matrix.shape[1] == len(feature_names) + 1:
263
+ feature_names_full = feature_names + ["Unmapped"]
264
+ else:
265
+ feature_names_full = feature_names
266
+
267
+ logger.info(f" Features: {len(feature_names_full)} | Matrix nnz: {mapping_matrix.nnz}")
268
+
269
+ if self.config.calculate_w_ld:
270
+ self._compute_w_ld(chromosome, target_hm3_snps)
271
+
272
+ return self._compute_chromosome_weights(
273
+ chromosome=chromosome,
274
+ reader=reader,
275
+ target_hm3_snps=target_hm3_snps,
276
+ mapping_matrix=mapping_matrix,
277
+ feature_names=feature_names_full,
278
+ mapping_df=mapping_df,
279
+ output_format="sparse"
280
+ )
281
+
282
+ def _process_chromosome_from_annotation(
283
+ self,
284
+ chromosome: int,
285
+ ) -> ChromosomeResult | None:
286
+ """Process a single chromosome using external annotation files."""
287
+ logger.info("=" * 80)
288
+ logger.info(f"Processing Chromosome {chromosome} (Annotation Mode)")
289
+ logger.info("=" * 80)
290
+
291
+ target_hm3_snps = self._load_hm3_snps(chromosome)
292
+ if not target_hm3_snps:
293
+ return None
294
+
295
+ reader = self._load_plink_reader(chromosome)
296
+ if not reader:
297
+ return None
298
+
299
+ # Load Annotation File
300
+ try:
301
+ annot_file = self.config.annot_file.format(chr=chromosome)
302
+ except KeyError:
303
+ annot_file = self.config.annot_file.format(chromosome=chromosome)
304
+
305
+ if not Path(annot_file).exists():
306
+ logger.error(f"Annotation file not found: {annot_file}")
307
+ return None
308
+
309
+ logger.info(f"Loading annotation from: {annot_file}")
310
+ try:
311
+ df_annot = pd.read_csv(annot_file, sep=r"\s+")
312
+ except Exception as e:
313
+ logger.error(f"Failed to read annotation file: {e}")
314
+ return None
315
+
316
+ # Drop metadata columns
317
+ [c for c in ["CHR", "BP", "CM", "SNP"] if c in df_annot.columns]
318
+
319
+ if "SNP" in df_annot.columns:
320
+ # Align to filtered BIM
321
+ df_annot = df_annot.set_index("SNP")
322
+ df_annot = df_annot.reindex(reader.bim["SNP"], fill_value=0)
323
+ else:
324
+ logger.warning("Annotation file missing 'SNP' column. Thin annotation format detected, assuming strict alignment.")
325
+
326
+ if len(df_annot) == reader.m_original:
327
+ # Filter df_annot according to filtered SNPs using snp_ids_original
328
+ logger.info(f"Filtering annotation from {len(df_annot)} to {reader.m} SNPs based on MAF/QC filters")
329
+ df_annot.index = reader.snp_ids_original
330
+ df_annot = df_annot.reindex(reader.bim["SNP"], fill_value=0)
331
+ else:
332
+ logger.error(f"Annotation rows mismatch. Missing 'SNP' column preventing alignment. For the thin format, annotation rows ({len(df_annot)}) must match SNP count ({reader.m_original}) in the plink panel.")
333
+ return None
334
+
335
+ # Drop non-feature columns
336
+ feature_df = df_annot.drop(columns=[c for c in ["CHR", "BP", "CM"] if c in df_annot.columns])
337
+ feature_names = feature_df.columns.tolist()
338
+
339
+ # QC
340
+ for feature in feature_names:
341
+ if np.count_nonzero(feature_df[feature].values) / len(feature_df) < 0.0001:
342
+ logger.warning(f"Feature '{feature}' has < 0.01% non-zero entries.")
343
+
344
+ logger.info("Converting annotation to sparse matrix...")
345
+ annot_matrix = scipy.sparse.csr_matrix(feature_df.values, dtype=np.float32)
346
+
347
+ if self.config.calculate_w_ld:
348
+ self._compute_w_ld(chromosome, target_hm3_snps)
349
+
350
+ return self._compute_chromosome_weights(
351
+ chromosome=chromosome,
352
+ reader=reader,
353
+ target_hm3_snps=target_hm3_snps,
354
+ mapping_matrix=annot_matrix,
355
+ feature_names=feature_names,
356
+ mapping_df=None,
357
+ output_format="dense"
358
+ )
359
+
360
+ def _compute_chromosome_weights(
361
+ self,
362
+ chromosome: int,
363
+ reader: PlinkBEDReader,
364
+ target_hm3_snps: list[str],
365
+ mapping_matrix: scipy.sparse.csr_matrix | np.ndarray,
366
+ feature_names: list[str],
367
+ mapping_df: pd.DataFrame | None = None,
368
+ output_format: str = "sparse"
369
+ ) -> ChromosomeResult | None:
370
+ """
371
+ Core logic: Batches -> Load Genotypes -> Slice Mapping -> Compute Weights.
372
+ """
373
+ logger.info("Constructing batches...")
374
+ batch_infos = construct_batches(
375
+ bim_df=reader.bim,
376
+ hm3_snp_names=target_hm3_snps,
377
+ batch_size_hm3=self.config.batch_size_hm3,
378
+ ld_wind=self.config.ld_wind,
379
+ ld_unit=self.config.ld_unit,
380
+ )
381
+
382
+ if len(batch_infos) == 0:
383
+ logger.warning(f"No batches created for chromosome {chromosome}")
384
+ return None
385
+
386
+ total_snps = sum(len(b.hm3_indices) for b in batch_infos)
387
+ logger.info(f"Actual HM3 SNPs found in BIM: {total_snps}")
388
+
389
+ batch_weight_data = []
390
+ all_hm3_snp_names = []
391
+ all_hm3_snp_chr = []
392
+ all_hm3_snp_bp = []
393
+
394
+ with Progress(
395
+ SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
396
+ BarColumn(), TaskProgressColumn(), TimeElapsedColumn(), TimeRemainingColumn()
397
+ ) as progress:
398
+ task = progress.add_task(f"[cyan]Chr {chromosome}", total=total_snps)
399
+
400
+ for i, batch_info in enumerate(batch_infos):
401
+ # Genotypes
402
+ X_hm3 = reader.genotypes[:, batch_info.hm3_indices]
403
+
404
+ ref_indices = np.arange(batch_info.ref_start_idx, batch_info.ref_end_idx)
405
+ ref_indices = ref_indices[ref_indices < reader.m]
406
+ X_ref_block = reader.genotypes[:, ref_indices]
407
+
408
+ # Metadata
409
+ batch_bim = reader.bim.iloc[batch_info.hm3_indices]
410
+ batch_snp_names = batch_bim["SNP"].tolist()
411
+ all_hm3_snp_names.extend(batch_snp_names)
412
+ all_hm3_snp_chr.extend(batch_bim["CHR"].tolist())
413
+ all_hm3_snp_bp.extend(batch_bim["BP"].tolist())
414
+
415
+ # Slice Mapping
416
+ block_mapping = mapping_matrix[ref_indices, :]
417
+
418
+ # Compute
419
+ weights = compute_batch_weights_sparse(X_hm3, X_ref_block, block_mapping)
420
+
421
+ batch_weight_data.append({
422
+ 'weights': weights,
423
+ 'hm3_start_idx': len(all_hm3_snp_names) - len(batch_snp_names),
424
+ 'n_hm3': len(batch_snp_names)
425
+ })
426
+ progress.update(task, advance=len(batch_snp_names))
427
+
428
+ # Aggregate
429
+ n_hm3_total = len(all_hm3_snp_names)
430
+ n_features = len(feature_names)
431
+
432
+ if output_format == "sparse":
433
+ weights_out = self._create_sparse_matrix_from_batches(batch_weight_data, n_hm3_total, n_features)
434
+ else:
435
+ weights_out = self._create_dense_matrix_from_batches(batch_weight_data, n_hm3_total, n_features)
436
+
437
+ return ChromosomeResult(
438
+ hm3_snp_names=all_hm3_snp_names,
439
+ hm3_snp_chr=all_hm3_snp_chr,
440
+ hm3_snp_bp=all_hm3_snp_bp,
441
+ weights=weights_out,
442
+ feature_names=feature_names,
443
+ mapping_df=mapping_df
444
+ )
445
+
446
+ def _create_sparse_matrix_from_batches(self, batch_data: list[dict], n_rows: int, n_cols: int) -> scipy.sparse.csr_matrix:
447
+ """Helper to stack batch results into sparse CSR."""
448
+ if not batch_data:
449
+ return scipy.sparse.csr_matrix((n_rows, n_cols), dtype=np.float32)
450
+ matrices = [scipy.sparse.csr_matrix(b['weights']) for b in batch_data]
451
+ return scipy.sparse.vstack(matrices, format='csr')
452
+
453
+ def _create_dense_matrix_from_batches(self, batch_data: list[dict], n_rows: int, n_cols: int) -> np.ndarray:
454
+ """Helper to stack batch results into dense numpy array."""
455
+ if not batch_data:
456
+ return np.zeros((n_rows, n_cols), dtype=np.float32)
457
+ return np.vstack([b['weights'] for b in batch_data])
458
+
459
+ def _compute_w_ld(self, chromosome: int, hm3_snps: list[str]):
460
+ """
461
+ Compute 'w_ld' (weighted LD scores) for a chromosome.
462
+ w_ld is typically the LD score of a SNP computed against the set of regression SNPs (HM3 SNPs).
463
+ """
464
+ logger.info(f"Computing w_ld for chromosome {chromosome}...")
465
+
466
+ # 1. Initialize Reader restricted to HM3 SNPs
467
+ bfile_prefix = self.config.bfile_root.format(chr=chromosome)
468
+ try:
469
+ # We filter for HM3 SNPs specifically
470
+ reader = PlinkBEDReader(
471
+ bfile_prefix,
472
+ maf_min=self.config.maf_min,
473
+ keep_snps=hm3_snps,
474
+ preload=True,
475
+ )
476
+ except Exception as e:
477
+ logger.error(f"Failed to load PLINK for w_ld (chk {chromosome}): {e}")
478
+ return
479
+
480
+ # 2. Construct Batches (using the filtered BIM)
481
+ # Note: reader.bim now ONLY contains HM3 SNPs (and those passing MAF)
482
+ # We use all available SNPs in the reader as targets
483
+ available_hm3 = reader.bim["SNP"].tolist()
484
+
485
+ batch_infos = construct_batches(
486
+ bim_df=reader.bim,
487
+ hm3_snp_names=available_hm3,
488
+ batch_size_hm3=self.config.batch_size_hm3,
489
+ ld_wind=self.config.ld_wind,
490
+ ld_unit=self.config.ld_unit,
491
+ )
492
+
493
+ logger.info(f" w_ld: Processing {len(batch_infos)} batches for {len(available_hm3)} SNPs")
494
+
495
+ w_ld_values = []
496
+ w_ld_snps = []
497
+
498
+ with Progress(
499
+ SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
500
+ BarColumn(), TaskProgressColumn(), TimeElapsedColumn(), TimeRemainingColumn()
501
+ ) as progress:
502
+ task = progress.add_task(f"[magenta]w_ld Chr {chromosome}", total=len(available_hm3))
503
+
504
+ for batch in batch_infos:
505
+ # Genotypes (filtered to HM3)
506
+ X_hm3 = reader.genotypes[:, batch.hm3_indices]
507
+
508
+ # Reference block is also from the filtered reader (so it's HM3 only)
509
+ ref_indices = np.arange(batch.ref_start_idx, batch.ref_end_idx)
510
+ # Clip to valid range (should be handled by construct_batches but safe to double check)
511
+ ref_indices = ref_indices[ref_indices < reader.m]
512
+ X_ref = reader.genotypes[:, ref_indices]
513
+
514
+ # Compute LD Scores (L2 sum)
515
+ ld_scores_batch = compute_ld_scores(X_hm3, X_ref)
516
+
517
+ w_ld_values.append(ld_scores_batch)
518
+
519
+ batch_snps = reader.bim.iloc[batch.hm3_indices]["SNP"].tolist()
520
+ w_ld_snps.extend(batch_snps)
521
+
522
+ progress.update(task, advance=len(batch.hm3_indices))
523
+
524
+ # 3. Aggregate and Save
525
+ if w_ld_values:
526
+ w_ld_arr = np.concatenate(w_ld_values)
527
+
528
+ # Match with metadata
529
+ df_w_ld = reader.bim[reader.bim["SNP"].isin(w_ld_snps)].copy()
530
+ # Ensure order
531
+ df_w_ld = df_w_ld.set_index("SNP").reindex(w_ld_snps).reset_index()
532
+ df_w_ld["L2"] = w_ld_arr
533
+
534
+ # Select columns
535
+ out_cols = ["CHR", "SNP", "BP", "CM", "L2"]
536
+ # Ensure columns exist (CM might be missing or 0)
537
+ if "CM" not in df_w_ld.columns:
538
+ df_w_ld["CM"] = 0.0
539
+
540
+ df_w_ld = df_w_ld[out_cols]
541
+
542
+ # Determine output path
543
+ w_ld_base = Path(self.config.w_ld_dir) if self.config.w_ld_dir else Path(self.config.output_dir) / "w_ld"
544
+ w_ld_base.mkdir(parents=True, exist_ok=True)
545
+
546
+ out_file = w_ld_base / f"weights.{chromosome}.l2.ldscore.gz"
547
+ df_w_ld.to_csv(out_file, sep="\t", index=False, compression="gzip")
548
+ logger.info(f" Saved w_ld to: {out_file}")
549
+
550
+ def _save_aggregated_results(
551
+ self,
552
+ results: dict[str, ChromosomeResult],
553
+ output_dir: str,
554
+ output_filename: str
555
+ ):
556
+ """Helper to concatenate and save results."""
557
+ logger.info("\n" + "=" * 80)
558
+ logger.info("Processing Complete. Concatenating and Saving...")
559
+ logger.info("=" * 80)
560
+
561
+ all_snp_names = []
562
+ all_snp_chr = [] # numeric CHR from BIM
563
+ all_snp_bp = [] # numeric BP from BIM
564
+ matrices = []
565
+
566
+ # Check consistency
567
+ first_res = next(iter(results.values()))
568
+ feature_names = first_res.feature_names
569
+
570
+ sorted_chroms = sorted(results.keys(), key=lambda x: int(x) if x.isdigit() else x)
571
+
572
+ for chrom in sorted_chroms:
573
+ res = results[chrom]
574
+ if res.feature_names != feature_names:
575
+ logger.error(f"Feature mismatch in chromosome {chrom}. Skipping.")
576
+ continue
577
+
578
+ all_snp_names.extend(res.hm3_snp_names)
579
+ all_snp_chr.extend(res.hm3_snp_chr)
580
+ all_snp_bp.extend(res.hm3_snp_bp)
581
+ matrices.append(res.weights)
582
+
583
+ # Concatenate
584
+ if scipy.sparse.issparse(matrices[0]):
585
+ X_full = scipy.sparse.vstack(matrices, format='csr')
586
+ else:
587
+ X_full = np.vstack(matrices)
588
+
589
+ # AnnData
590
+ obs = pd.DataFrame({
591
+ 'CHR': all_snp_chr,
592
+ 'BP': all_snp_bp
593
+ }, index=all_snp_names)
594
+ obs.index.name = 'SNP'
595
+
596
+ var = pd.DataFrame(index=feature_names)
597
+ var.index.name = 'Feature'
598
+
599
+ logger.info(f"Creating AnnData object: {X_full.shape[0]} SNPs x {X_full.shape[1]} Features")
600
+ adata = ad.AnnData(X=X_full, obs=obs, var=var)
601
+
602
+ out_path = Path(output_dir)
603
+ out_path.mkdir(parents=True, exist_ok=True)
604
+ out_file = out_path / f"{output_filename}.h5ad"
605
+
606
+ adata.write(out_file)
607
+ logger.info(f"Successfully saved AnnData to {out_file}")
608
+
609
+ # Save Combined Mapping CSV if available
610
+ all_mapping_dfs = [res.mapping_df for res in results.values() if res.mapping_df is not None]
611
+ if all_mapping_dfs:
612
+ combined_mapping_df = pd.concat(all_mapping_dfs, ignore_index=True)
613
+ mapping_out_file = out_path / f"{output_filename}.csv"
614
+ combined_mapping_df.to_csv(mapping_out_file, index=False)
615
+ logger.info(f"Successfully saved combined SNP-Feature mapping to {mapping_out_file}")