gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1018 @@
1
+ """
2
+ Module for generating LD scores for each spot in spatial transcriptomics data.
3
+
4
+ This module is responsible for assigning gene specificity scores to SNPs
5
+ and computing stratified LD scores that will be used for spatial LDSC analysis.
6
+ """
7
+
8
+ import gc
9
+ import logging
10
+ import warnings
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pyranges as pr
16
+ from scipy.sparse import csr_matrix
17
+ from tqdm import trange
18
+
19
+ from gsMap.config import GenerateLDScoreConfig
20
+ from gsMap.utils.generate_r2_matrix import PlinkBEDFile
21
+
22
+ # Configure warning behavior more precisely
23
+ warnings.filterwarnings("ignore", category=FutureWarning, module="pandas")
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def load_gtf(
28
+ gtf_file: str, mk_score: pd.DataFrame, window_size: int
29
+ ) -> tuple[pr.PyRanges, pd.DataFrame]:
30
+ """
31
+ Load and process the gene annotation file (GTF).
32
+
33
+ Parameters
34
+ ----------
35
+ gtf_file : str
36
+ Path to the GTF file
37
+ mk_score : pd.DataFrame
38
+ DataFrame containing marker scores
39
+ window_size : int
40
+ Window size around gene bodies in base pairs
41
+
42
+ Returns
43
+ -------
44
+ tuple
45
+ A tuple containing (gtf_pr, mk_score) where:
46
+ - gtf_pr is a PyRanges object with gene coordinates
47
+ - mk_score is the filtered marker score DataFrame
48
+ """
49
+ logger.info("Loading GTF data from %s", gtf_file)
50
+
51
+ # Load GTF file
52
+ gtf = pr.read_gtf(gtf_file, as_df=True)
53
+
54
+ # Filter for gene features
55
+ gtf = gtf[gtf["Feature"] == "gene"]
56
+
57
+ # Find common genes between GTF and marker scores
58
+ # common_gene = np.intersect1d(mk_score.index, gtf.gene_name)
59
+ common_gene = list(set(mk_score.index) & set(gtf.gene_name))
60
+ logger.info(f"Found {len(common_gene)} common genes between GTF and marker scores")
61
+
62
+ # Filter GTF and marker scores to common genes
63
+ gtf = gtf[gtf.gene_name.isin(common_gene)]
64
+ mk_score = mk_score[mk_score.index.isin(common_gene)]
65
+
66
+ # Remove duplicated gene entries
67
+ gtf = gtf.drop_duplicates(subset="gene_name", keep="first")
68
+
69
+ # Process the GTF (open window around gene coordinates)
70
+ gtf_bed = gtf[["Chromosome", "Start", "End", "gene_name", "Strand"]].copy()
71
+ gtf_bed["Chromosome"] = gtf_bed["Chromosome"].apply(
72
+ lambda x: f"chr{x}" if not str(x).startswith("chr") else x
73
+ )
74
+ gtf_bed.loc[:, "TSS"] = gtf_bed["Start"]
75
+ gtf_bed.loc[:, "TED"] = gtf_bed["End"]
76
+
77
+ # Create windows around genes
78
+ gtf_bed.loc[:, "Start"] = gtf_bed["TSS"] - window_size
79
+ gtf_bed.loc[:, "End"] = gtf_bed["TED"] + window_size
80
+ gtf_bed.loc[gtf_bed["Start"] < 0, "Start"] = 0
81
+
82
+ # Handle genes on negative strand (swap TSS and TED)
83
+ tss_neg = gtf_bed.loc[gtf_bed["Strand"] == "-", "TSS"]
84
+ ted_neg = gtf_bed.loc[gtf_bed["Strand"] == "-", "TED"]
85
+ gtf_bed.loc[gtf_bed["Strand"] == "-", "TSS"] = ted_neg
86
+ gtf_bed.loc[gtf_bed["Strand"] == "-", "TED"] = tss_neg
87
+ gtf_bed = gtf_bed.drop("Strand", axis=1)
88
+
89
+ # Convert to PyRanges
90
+ gtf_pr = pr.PyRanges(gtf_bed)
91
+
92
+ return gtf_pr, mk_score
93
+
94
+
95
+ def load_marker_score(mk_score_file: str) -> pd.DataFrame:
96
+ """
97
+ Load marker scores from a feather file.
98
+
99
+ Parameters
100
+ ----------
101
+ mk_score_file : str
102
+ Path to the marker score feather file
103
+
104
+ Returns
105
+ -------
106
+ pd.DataFrame
107
+ DataFrame with marker scores indexed by gene names
108
+ """
109
+ mk_score = pd.read_feather(mk_score_file).set_index("HUMAN_GENE_SYM").rename_axis("gene_name")
110
+ mk_score = mk_score.astype(np.float32, copy=False)
111
+ return mk_score
112
+
113
+
114
+ def overlaps_gtf_bim(gtf_pr: pr.PyRanges, bim_pr: pr.PyRanges) -> pd.DataFrame:
115
+ """
116
+ Find overlaps between GTF and BIM data, and select nearest gene for each SNP.
117
+
118
+ Parameters
119
+ ----------
120
+ gtf_pr : pr.PyRanges
121
+ PyRanges object with gene coordinates
122
+ bim_pr : pr.PyRanges
123
+ PyRanges object with SNP coordinates
124
+
125
+ Returns
126
+ -------
127
+ pd.DataFrame
128
+ DataFrame with SNP-gene pairs where each SNP is matched to its closest gene
129
+ """
130
+ # Join the PyRanges objects to find overlaps
131
+ overlaps = gtf_pr.join(bim_pr)
132
+ overlaps = overlaps.df
133
+
134
+ # Calculate distance to TSS
135
+ overlaps["Distance"] = np.abs(overlaps["Start_b"] - overlaps["TSS"])
136
+
137
+ # For each SNP, select the closest gene
138
+ nearest_genes = overlaps.loc[overlaps.groupby("SNP").Distance.idxmin()]
139
+
140
+ return nearest_genes
141
+
142
+
143
+ class LDScoreCalculator:
144
+ """
145
+ Class for calculating LD scores from gene specificity scores.
146
+ """
147
+
148
+ def __init__(self, config: GenerateLDScoreConfig):
149
+ """Initialize LDScoreCalculator."""
150
+ self.config = config
151
+ self.validate_config()
152
+
153
+ # Load marker scores
154
+ self.mk_score = load_marker_score(config.mkscore_feather_path)
155
+
156
+ # Load GTF and get common markers
157
+ self.gtf_pr, self.mk_score_common = load_gtf(
158
+ config.gtf_annotation_file, self.mk_score, window_size=config.gene_window_size
159
+ )
160
+
161
+ # Initialize enhancer data if provided
162
+ self.enhancer_pr = self._initialize_enhancer() if config.enhancer_annotation_file else None
163
+
164
+ def validate_config(self):
165
+ """Validate configuration parameters."""
166
+ if not Path(self.config.mkscore_feather_path).exists():
167
+ raise FileNotFoundError(
168
+ f"Marker score file not found: {self.config.mkscore_feather_path}"
169
+ )
170
+
171
+ if not Path(self.config.gtf_annotation_file).exists():
172
+ raise FileNotFoundError(
173
+ f"GTF annotation file not found: {self.config.gtf_annotation_file}"
174
+ )
175
+
176
+ if (
177
+ self.config.enhancer_annotation_file
178
+ and not Path(self.config.enhancer_annotation_file).exists()
179
+ ):
180
+ raise FileNotFoundError(
181
+ f"Enhancer annotation file not found: {self.config.enhancer_annotation_file}"
182
+ )
183
+
184
+ def _initialize_enhancer(self) -> pr.PyRanges:
185
+ """
186
+ Initialize enhancer data.
187
+
188
+ Returns
189
+ -------
190
+ pr.PyRanges
191
+ PyRanges object with enhancer data
192
+ """
193
+ # Load enhancer data
194
+ enhancer_df = pr.read_bed(self.config.enhancer_annotation_file, as_df=True)
195
+ enhancer_df.set_index("Name", inplace=True)
196
+ enhancer_df.index.name = "gene_name"
197
+
198
+ # Keep common genes and add marker score information
199
+ avg_mkscore = pd.DataFrame(self.mk_score_common.mean(axis=1), columns=["avg_mkscore"])
200
+ enhancer_df = enhancer_df.join(
201
+ avg_mkscore,
202
+ how="inner",
203
+ on="gene_name",
204
+ )
205
+
206
+ # Add TSS information
207
+ enhancer_df["TSS"] = self.gtf_pr.df.set_index("gene_name").reindex(enhancer_df.index)[
208
+ "TSS"
209
+ ]
210
+
211
+ # Convert to PyRanges
212
+ return pr.PyRanges(enhancer_df.reset_index())
213
+
214
+ def process_chromosome(self, chrom: int):
215
+ """
216
+ Process a single chromosome to calculate LD scores.
217
+
218
+ Parameters
219
+ ----------
220
+ chrom : int
221
+ Chromosome number
222
+ """
223
+ logger.info(f"Processing chromosome {chrom}")
224
+
225
+ # Initialize PlinkBEDFile once for this chromosome
226
+ plink_bed = PlinkBEDFile(f"{self.config.bfile_root}.{chrom}")
227
+
228
+ # Get SNPs passing MAF filter using built-in method
229
+ self.snp_pass_maf = plink_bed.get_snps_by_maf(0.05)
230
+
231
+ # Get SNP-gene dummy pairs
232
+ self.snp_gene_pair_dummy = self._get_snp_gene_dummy(chrom, plink_bed)
233
+
234
+ # Apply SNP filter if provided
235
+ self._apply_snp_filter(chrom)
236
+
237
+ # Process additional baseline annotations if provided
238
+ if self.config.additional_baseline_annotation:
239
+ self._process_additional_baseline(chrom, plink_bed)
240
+ else:
241
+ # Calculate SNP-gene weight matrix using built-in methods
242
+ ld_scores = plink_bed.get_ldscore(
243
+ annot_matrix=self.snp_gene_pair_dummy.values,
244
+ ld_wind=self.config.ld_wind,
245
+ ld_unit=self.config.ld_unit,
246
+ )
247
+
248
+ self.snp_gene_weight_matrix = pd.DataFrame(
249
+ ld_scores,
250
+ index=self.snp_gene_pair_dummy.index,
251
+ columns=self.snp_gene_pair_dummy.columns,
252
+ )
253
+
254
+ # Apply SNP filter if needed
255
+ if self.keep_snp_mask is not None:
256
+ self.snp_gene_weight_matrix = self.snp_gene_weight_matrix[self.keep_snp_mask]
257
+
258
+ # Generate w_ld file if keep_snp_root is provided
259
+ if self.config.keep_snp_root:
260
+ self._generate_w_ld(chrom, plink_bed)
261
+
262
+ # Save pre-calculated SNP-gene weight matrix if requested
263
+ self._save_snp_gene_weight_matrix_if_needed(chrom)
264
+
265
+ # Convert to sparse matrix for memory efficiency
266
+ self.snp_gene_weight_matrix = csr_matrix(self.snp_gene_weight_matrix)
267
+ logger.info(f"SNP-gene weight matrix shape: {self.snp_gene_weight_matrix.shape}")
268
+
269
+ # Calculate baseline LD scores
270
+ logger.info(f"Calculating baseline LD scores for chr{chrom}")
271
+ self._calculate_baseline_ldscores(chrom, plink_bed)
272
+
273
+ # Calculate LD scores for annotation
274
+ logger.info(f"Calculating annotation LD scores for chr{chrom}")
275
+ self._calculate_annotation_ldscores(chrom, plink_bed)
276
+
277
+ # Clear memory
278
+ self._clear_memory()
279
+
280
+ def _generate_w_ld(self, chrom: int, plink_bed):
281
+ """
282
+ Generate w_ld file for the chromosome using filtered SNPs.
283
+
284
+ Parameters
285
+ ----------
286
+ chrom : int
287
+ Chromosome number
288
+ plink_bed : PlinkBEDFile
289
+ Initialized PlinkBEDFile object
290
+ """
291
+ if not self.config.keep_snp_root:
292
+ logger.info(
293
+ f"Skipping w_ld generation for chr{chrom} as keep_snp_root is not provided"
294
+ )
295
+ return
296
+
297
+ logger.info(f"Generating w_ld for chr{chrom}")
298
+
299
+ # Get the indices of SNPs to keep based on the keep_snp
300
+ keep_snps_indices = plink_bed.bim_df[
301
+ plink_bed.bim_df.SNP.isin(self.snp_name)
302
+ ].index.tolist()
303
+
304
+ # Create a simple unit annotation (all ones) for the filtered SNPs
305
+ unit_annotation = np.ones((len(keep_snps_indices), 1), dtype="float32")
306
+
307
+ # Calculate LD scores
308
+ w_ld_scores = plink_bed.get_ldscore(
309
+ annot_matrix=unit_annotation,
310
+ ld_wind=self.config.ld_wind,
311
+ ld_unit=self.config.ld_unit,
312
+ keep_snps_index=keep_snps_indices,
313
+ )
314
+
315
+ # Create the w_ld DataFrame
316
+ bim_subset = plink_bed.bim_df.loc[keep_snps_indices]
317
+ w_ld_df = pd.DataFrame(
318
+ {
319
+ "SNP": bim_subset.SNP,
320
+ "L2": w_ld_scores.flatten(),
321
+ "CHR": bim_subset.CHR,
322
+ "BP": bim_subset.BP,
323
+ "CM": bim_subset.CM,
324
+ }
325
+ )
326
+
327
+ # Reorder columns
328
+ w_ld_df = w_ld_df[["CHR", "SNP", "BP", "CM", "L2"]]
329
+
330
+ # Save to file
331
+ w_ld_dir = Path(self.config.ldscore_save_dir) / "w_ld"
332
+ w_ld_dir.mkdir(parents=True, exist_ok=True)
333
+ w_ld_file = w_ld_dir / f"weights.{chrom}.l2.ldscore.gz"
334
+ w_ld_df.to_csv(w_ld_file, sep="\t", index=False, compression="gzip")
335
+
336
+ logger.info(f"Saved w_ld for chr{chrom} to {w_ld_file}")
337
+
338
+ def _apply_snp_filter(self, chrom: int):
339
+ """
340
+ Apply SNP filter based on keep_snp_root.
341
+
342
+ Parameters
343
+ ----------
344
+ chrom : int
345
+ Chromosome number
346
+ """
347
+ if self.config.keep_snp_root is not None:
348
+ keep_snp_file = f"{self.config.keep_snp_root}.{chrom}.snp"
349
+ keep_snp = pd.read_csv(keep_snp_file, header=None)[0].to_list()
350
+ self.keep_snp_mask = self.snp_gene_pair_dummy.index.isin(keep_snp)
351
+ self.snp_name = self.snp_gene_pair_dummy.index[self.keep_snp_mask].to_list()
352
+ logger.info(f"Kept {len(self.snp_name)} SNPs after filtering with {keep_snp_file}")
353
+ logger.info("These filtered SNPs will be used to calculate w_ld")
354
+ else:
355
+ self.keep_snp_mask = None
356
+ self.snp_name = self.snp_gene_pair_dummy.index.to_list()
357
+ logger.info(f"Using all {len(self.snp_name)} SNPs (no filter applied)")
358
+ logger.warning("No keep_snp_root provided, all SNPs will be used to calculate w_ld.")
359
+
360
+ def _process_additional_baseline(self, chrom: int, plink_bed):
361
+ """
362
+ Process additional baseline annotations.
363
+
364
+ Parameters
365
+ ----------
366
+ chrom : int
367
+ Chromosome number
368
+ plink_bed : PlinkBEDFile
369
+ Initialized PlinkBEDFile object
370
+ """
371
+ # Load additional baseline annotations
372
+ additional_baseline_path = Path(self.config.additional_baseline_annotation)
373
+ annot_file_path = additional_baseline_path / f"baseline.{chrom}.annot.gz"
374
+
375
+ # Verify file existence
376
+ if not annot_file_path.exists():
377
+ raise FileNotFoundError(
378
+ f"Additional baseline annotation file not found: {annot_file_path}"
379
+ )
380
+
381
+ # Load annotations
382
+ additional_baseline_df = pd.read_csv(annot_file_path, sep="\t")
383
+ additional_baseline_df.set_index("SNP", inplace=True)
384
+
385
+ # Drop unnecessary columns
386
+ for col in ["CHR", "BP", "CM"]:
387
+ if col in additional_baseline_df.columns:
388
+ additional_baseline_df.drop(col, axis=1, inplace=True)
389
+
390
+ # Check for SNPs not in the additional baseline
391
+ missing_snps = ~self.snp_gene_pair_dummy.index.isin(additional_baseline_df.index)
392
+ missing_count = missing_snps.sum()
393
+
394
+ if missing_count > 0:
395
+ logger.warning(
396
+ f"{missing_count} SNPs not found in additional baseline annotations. "
397
+ "Setting their values to 0."
398
+ )
399
+ additional_baseline_df = additional_baseline_df.reindex(
400
+ self.snp_gene_pair_dummy.index, fill_value=0
401
+ )
402
+
403
+ # Combine annotations into a single matrix
404
+ combined_annotations = pd.concat(
405
+ [self.snp_gene_pair_dummy, additional_baseline_df], axis=1
406
+ )
407
+
408
+ # Calculate LD scores
409
+ ld_scores = plink_bed.get_ldscore(
410
+ annot_matrix=combined_annotations.values.astype(np.float32, copy=False),
411
+ ld_wind=self.config.ld_wind,
412
+ ld_unit=self.config.ld_unit,
413
+ )
414
+
415
+ # Split results
416
+ # total_cols = combined_annotations.shape[1]
417
+ gene_cols = self.snp_gene_pair_dummy.shape[1]
418
+ # baseline_cols = additional_baseline_df.shape[1]
419
+
420
+ # Create DataFrames with proper indices and columns
421
+ self.snp_gene_weight_matrix = pd.DataFrame(
422
+ ld_scores[:, :gene_cols],
423
+ index=combined_annotations.index,
424
+ columns=self.snp_gene_pair_dummy.columns,
425
+ )
426
+
427
+ additional_ldscore = pd.DataFrame(
428
+ ld_scores[:, gene_cols:],
429
+ index=combined_annotations.index,
430
+ columns=additional_baseline_df.columns,
431
+ )
432
+
433
+ # Filter by keep_snp_mask if specified
434
+ if self.keep_snp_mask is not None:
435
+ additional_ldscore = additional_ldscore[self.keep_snp_mask]
436
+ self.snp_gene_weight_matrix = self.snp_gene_weight_matrix[self.keep_snp_mask]
437
+
438
+ # Save additional baseline LD scores
439
+ ld_score_file = f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.ldscore.feather"
440
+ m_file_path = f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.M"
441
+ m_5_file_path = (
442
+ f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.M_5_50"
443
+ )
444
+ Path(m_file_path).parent.mkdir(parents=True, exist_ok=True)
445
+
446
+ # Save LD scores
447
+ self._save_ldscore_to_feather(
448
+ additional_ldscore.values,
449
+ column_names=additional_ldscore.columns,
450
+ save_file_name=ld_score_file,
451
+ )
452
+
453
+ # Calculate and save M values
454
+ m_chr_chunk = additional_baseline_df.values.sum(axis=0, keepdims=True)
455
+ m_5_chr_chunk = additional_baseline_df.loc[self.snp_pass_maf].values.sum(
456
+ axis=0, keepdims=True
457
+ )
458
+
459
+ # Save M statistics
460
+ np.savetxt(m_file_path, m_chr_chunk, delimiter="\t")
461
+ np.savetxt(m_5_file_path, m_5_chr_chunk, delimiter="\t")
462
+
463
+ def _save_snp_gene_weight_matrix_if_needed(self, chrom: int):
464
+ """
465
+ Save pre-calculated SNP-gene weight matrix if requested.
466
+
467
+ Parameters
468
+ ----------
469
+ chrom : int
470
+ Chromosome number
471
+ """
472
+ if self.config.save_pre_calculate_snp_gene_weight_matrix:
473
+ save_dir = Path(self.config.ldscore_save_dir) / "snp_gene_weight_matrix"
474
+ save_dir.mkdir(parents=True, exist_ok=True)
475
+
476
+ logger.info(f"Saving SNP-gene weight matrix for chr{chrom}")
477
+
478
+ save_path = save_dir / f"{chrom}.snp_gene_weight_matrix.feather"
479
+ self.snp_gene_weight_matrix.reset_index().to_feather(save_path)
480
+
481
+ def _calculate_baseline_ldscores(self, chrom: int, plink_bed):
482
+ """
483
+ Calculate and save baseline LD scores.
484
+
485
+ Parameters
486
+ ----------
487
+ chrom : int
488
+ Chromosome number
489
+ plink_bed : PlinkBEDFile
490
+ Initialized PlinkBEDFile object
491
+ """
492
+ # Create baseline scores
493
+ baseline_mk_score = np.ones((self.snp_gene_pair_dummy.shape[1], 2))
494
+ baseline_mk_score[-1, 0] = 0 # all_gene column
495
+
496
+ baseline_df = pd.DataFrame(
497
+ baseline_mk_score, index=self.snp_gene_pair_dummy.columns, columns=["all_gene", "base"]
498
+ )
499
+
500
+ # Define file paths
501
+ ld_score_file = (
502
+ f"{self.config.ldscore_save_dir}/baseline/baseline.{chrom}.l2.ldscore.feather"
503
+ )
504
+ m_file = f"{self.config.ldscore_save_dir}/baseline/baseline.{chrom}.l2.M"
505
+ m_5_file = f"{self.config.ldscore_save_dir}/baseline/baseline.{chrom}.l2.M_5_50"
506
+
507
+ # Calculate LD scores
508
+ ldscore_chunk = self._calculate_ldscore_from_weights(
509
+ baseline_df, plink_bed, drop_dummy_na=False
510
+ )
511
+
512
+ # Save LD scores and M values
513
+ self._save_ldscore_to_feather(
514
+ ldscore_chunk,
515
+ column_names=baseline_df.columns,
516
+ save_file_name=ld_score_file,
517
+ )
518
+
519
+ self._calculate_and_save_m_values(
520
+ baseline_df,
521
+ m_file,
522
+ m_5_file,
523
+ drop_dummy_na=False,
524
+ )
525
+
526
+ # If keep_snp_root is not provided, use the first column of baseline ldscore as w_ld
527
+ if not self.config.keep_snp_root:
528
+ self._save_baseline_as_w_ld(chrom, ldscore_chunk, plink_bed)
529
+
530
+ def _save_baseline_as_w_ld(self, chrom: int, ldscore_chunk: np.ndarray, plink_bed):
531
+ """
532
+ Save the first column of baseline ldscore as w_ld.
533
+
534
+ Parameters
535
+ ----------
536
+ chrom : int
537
+ Chromosome number
538
+ ldscore_chunk : np.ndarray
539
+ Array with baseline LD scores
540
+ plink_bed : PlinkBEDFile
541
+ Initialized PlinkBEDFile object
542
+ """
543
+ logger.info(f"Using first column of baseline ldscore as w_ld for chr{chrom}")
544
+
545
+ # Create w_ld directory
546
+ w_ld_dir = Path(self.config.ldscore_save_dir) / "w_ld"
547
+ w_ld_dir.mkdir(parents=True, exist_ok=True)
548
+
549
+ # Define file path
550
+ w_ld_file = w_ld_dir / f"weights.{chrom}.l2.ldscore.gz"
551
+
552
+ # Extract the first column
553
+ w_ld_values = ldscore_chunk[:, 0]
554
+
555
+ # Create a DataFrame with SNP information from the BIM file
556
+ snp_indices = (
557
+ plink_bed.kept_snps
558
+ if hasattr(plink_bed, "kept_snps")
559
+ else np.arange(len(self.snp_name))
560
+ )
561
+ bim_subset = plink_bed.bim_df.iloc[snp_indices]
562
+
563
+ w_ld_df = pd.DataFrame(
564
+ {
565
+ "SNP": self.snp_name,
566
+ "L2": w_ld_values,
567
+ "CHR": bim_subset.CHR.values[: len(self.snp_name)], # Ensure length matches
568
+ "BP": bim_subset.BP.values[: len(self.snp_name)],
569
+ "CM": bim_subset.CM.values[: len(self.snp_name)],
570
+ }
571
+ )
572
+
573
+ # Reorder columns
574
+ w_ld_df = w_ld_df[["CHR", "SNP", "BP", "CM", "L2"]]
575
+
576
+ w_ld_df.to_csv(w_ld_file, sep="\t", index=False, compression="gzip")
577
+
578
+ logger.info(f"Saved w_ld for chr{chrom} to {w_ld_file}")
579
+
580
+ def _calculate_annotation_ldscores(self, chrom: int, plink_bed):
581
+ """
582
+ Calculate and save LD scores for spatial annotations.
583
+
584
+ Parameters
585
+ ----------
586
+ chrom : int
587
+ Chromosome number
588
+ plink_bed : PlinkBEDFile
589
+ Initialized PlinkBEDFile object
590
+ """
591
+ # Get marker scores for gene columns (excluding dummy NA column)
592
+ mk_scores = self.mk_score_common.loc[self.snp_gene_pair_dummy.columns[:-1]]
593
+
594
+ # Process in chunks
595
+ chunk_index = 1
596
+ for i in trange(
597
+ 0,
598
+ mk_scores.shape[1],
599
+ self.config.spots_per_chunk,
600
+ desc=f"Calculating LD scores for chr{chrom}",
601
+ ):
602
+ # Get marker scores for current chunk
603
+ mk_score_chunk = mk_scores.iloc[:, i : i + self.config.spots_per_chunk]
604
+
605
+ # Define file paths
606
+ sample_name = self.config.sample_name
607
+ ld_score_file = f"{self.config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.ldscore.feather"
608
+ m_file = f"{self.config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.M"
609
+ m_5_file = f"{self.config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.M_5_50"
610
+
611
+ # Calculate LD scores
612
+ ldscore_chunk = self._calculate_ldscore_from_weights(mk_score_chunk, plink_bed)
613
+
614
+ # Save LD scores based on format
615
+ if self.config.ldscore_save_format == "feather":
616
+ self._save_ldscore_to_feather(
617
+ ldscore_chunk,
618
+ column_names=mk_score_chunk.columns,
619
+ save_file_name=ld_score_file,
620
+ )
621
+ else:
622
+ raise ValueError(f"Invalid ldscore_save_format: {self.config.ldscore_save_format}")
623
+
624
+ # Calculate and save M values
625
+ self._calculate_and_save_m_values(
626
+ mk_score_chunk,
627
+ m_file,
628
+ m_5_file,
629
+ drop_dummy_na=True,
630
+ )
631
+
632
+ chunk_index += 1
633
+
634
+ # Clear memory
635
+ del ldscore_chunk
636
+ gc.collect()
637
+
638
+ def _calculate_ldscore_from_weights(
639
+ self, marker_scores: pd.DataFrame, plink_bed, drop_dummy_na: bool = True
640
+ ) -> np.ndarray:
641
+ """
642
+ Calculate LD scores using SNP-gene weight matrix.
643
+
644
+ Parameters
645
+ ----------
646
+ marker_scores : pd.DataFrame
647
+ DataFrame with marker scores
648
+ plink_bed : PlinkBEDFile
649
+ Initialized PlinkBEDFile object
650
+ drop_dummy_na : bool, optional
651
+ Whether to drop the dummy NA column, by default True
652
+
653
+ Returns
654
+ -------
655
+ np.ndarray
656
+ Array with calculated LD scores
657
+ """
658
+ weight_matrix = self.snp_gene_weight_matrix
659
+
660
+ if drop_dummy_na:
661
+ # Use all columns except the last one (dummy NA)
662
+ ldscore = weight_matrix[:, :-1] @ marker_scores
663
+ else:
664
+ ldscore = weight_matrix @ marker_scores
665
+
666
+ return ldscore
667
+
668
+ def _save_ldscore_to_feather(
669
+ self, ldscore_data: np.ndarray, column_names: list[str], save_file_name: str
670
+ ):
671
+ """
672
+ Save LD scores to a feather file.
673
+
674
+ Parameters
675
+ ----------
676
+ ldscore_data : np.ndarray
677
+ Array with LD scores
678
+ column_names : list
679
+ List of column names
680
+ save_file_name : str
681
+ Path to save the feather file
682
+ """
683
+ # Create directory if needed
684
+ save_dir = Path(save_file_name).parent
685
+ save_dir.mkdir(parents=True, exist_ok=True)
686
+
687
+ # Convert to float16 for storage efficiency
688
+ ldscore_data = ldscore_data.astype(np.float16, copy=False)
689
+
690
+ # Handle numerical overflow
691
+ ldscore_data[np.isinf(ldscore_data)] = np.finfo(np.float16).max
692
+
693
+ # Create DataFrame and save
694
+ df = pd.DataFrame(
695
+ ldscore_data,
696
+ index=self.snp_name,
697
+ columns=column_names,
698
+ )
699
+ df.index.name = "SNP"
700
+ df.reset_index().to_feather(save_file_name)
701
+
702
+ def _calculate_and_save_m_values(
703
+ self,
704
+ marker_scores: pd.DataFrame,
705
+ m_file_path: str,
706
+ m_5_file_path: str,
707
+ drop_dummy_na: bool = True,
708
+ ):
709
+ """
710
+ Calculate and save M statistics.
711
+
712
+ Parameters
713
+ ----------
714
+ marker_scores : pd.DataFrame
715
+ DataFrame with marker scores
716
+ m_file_path : str
717
+ Path to save M values
718
+ m_5_file_path : str
719
+ Path to save M_5_50 values
720
+ drop_dummy_na : bool, optional
721
+ Whether to drop the dummy NA column, by default True
722
+ """
723
+ # Create directory if needed
724
+ save_dir = Path(m_file_path).parent
725
+ save_dir.mkdir(parents=True, exist_ok=True)
726
+
727
+ # Get sum of SNP-gene pairs
728
+ snp_gene_sum = self.snp_gene_pair_dummy.values.sum(axis=0, keepdims=True)
729
+ snp_gene_sum_maf = self.snp_gene_pair_dummy.loc[self.snp_pass_maf].values.sum(
730
+ axis=0, keepdims=True
731
+ )
732
+
733
+ # Drop dummy NA column if requested
734
+ if drop_dummy_na:
735
+ snp_gene_sum = snp_gene_sum[:, :-1]
736
+ snp_gene_sum_maf = snp_gene_sum_maf[:, :-1]
737
+
738
+ # Calculate M values
739
+ m_values = snp_gene_sum @ marker_scores
740
+ m_5_values = snp_gene_sum_maf @ marker_scores
741
+
742
+ # Save M values
743
+ np.savetxt(m_file_path, m_values, delimiter="\t")
744
+ np.savetxt(m_5_file_path, m_5_values, delimiter="\t")
745
+
746
+ def _get_snp_gene_dummy(self, chrom: int, plink_bed) -> pd.DataFrame:
747
+ """
748
+ Get dummy matrix for SNP-gene pairs.
749
+
750
+ Parameters
751
+ ----------
752
+ chrom : int
753
+ Chromosome number
754
+ plink_bed : PlinkBEDFile
755
+
756
+ Returns
757
+ -------
758
+ pd.DataFrame
759
+ DataFrame with dummy variables for SNP-gene pairs
760
+ """
761
+ logger.info(f"Creating SNP-gene mappings for chromosome {chrom}")
762
+
763
+ # Load BIM file
764
+ bim = plink_bed.bim_df
765
+ bim_pr = plink_bed.convert_bim_to_pyrange(bim)
766
+
767
+ # Determine mapping strategy
768
+ if self.config.gene_window_enhancer_priority in ["gene_window_first", "enhancer_first"]:
769
+ # Use both gene window and enhancer
770
+ snp_gene_pair = self._combine_gtf_and_enhancer_mappings(bim, bim_pr)
771
+
772
+ elif self.config.gene_window_enhancer_priority is None:
773
+ # Use only gene window
774
+ snp_gene_pair = self._get_snp_gene_pair_from_gtf(bim, bim_pr)
775
+
776
+ elif self.config.gene_window_enhancer_priority == "enhancer_only":
777
+ # Use only enhancer
778
+ snp_gene_pair = self._get_snp_gene_pair_from_enhancer(bim, bim_pr)
779
+
780
+ else:
781
+ raise ValueError(
782
+ f"Invalid gene_window_enhancer_priority: {self.config.gene_window_enhancer_priority}"
783
+ )
784
+
785
+ # Save SNP-gene pair mapping
786
+ self._save_snp_gene_pair_mapping(snp_gene_pair, chrom)
787
+
788
+ # Create dummy variables
789
+ snp_gene_dummy = pd.get_dummies(snp_gene_pair["gene_name"], dummy_na=True)
790
+
791
+ return snp_gene_dummy
792
+
793
+ def _combine_gtf_and_enhancer_mappings(
794
+ self, bim: pd.DataFrame, bim_pr: pr.PyRanges
795
+ ) -> pd.DataFrame:
796
+ """
797
+ Combine gene window and enhancer mappings.
798
+
799
+ Parameters
800
+ ----------
801
+ bim : pd.DataFrame
802
+ BIM DataFrame
803
+ bim_pr : pr.PyRanges
804
+ BIM PyRanges object
805
+
806
+ Returns
807
+ -------
808
+ pd.DataFrame
809
+ Combined SNP-gene pair mapping
810
+ """
811
+ # Get mappings from both sources
812
+ gtf_mapping = self._get_snp_gene_pair_from_gtf(bim, bim_pr)
813
+ enhancer_mapping = self._get_snp_gene_pair_from_enhancer(bim, bim_pr)
814
+
815
+ # Find SNPs with missing mappings in each source
816
+ mask_of_nan_gtf = gtf_mapping.gene_name.isna()
817
+ mask_of_nan_enhancer = enhancer_mapping.gene_name.isna()
818
+
819
+ # Combine based on priority
820
+ if self.config.gene_window_enhancer_priority == "gene_window_first":
821
+ # Use gene window mappings first, fill missing with enhancer mappings
822
+ combined_mapping = gtf_mapping.copy()
823
+ combined_mapping.loc[mask_of_nan_gtf, "gene_name"] = enhancer_mapping.loc[
824
+ mask_of_nan_gtf, "gene_name"
825
+ ]
826
+ logger.info(
827
+ f"Filled {mask_of_nan_gtf.sum()} SNPs with no GTF mapping using enhancer mappings"
828
+ )
829
+
830
+ elif self.config.gene_window_enhancer_priority == "enhancer_first":
831
+ # Use enhancer mappings first, fill missing with gene window mappings
832
+ combined_mapping = enhancer_mapping.copy()
833
+ combined_mapping.loc[mask_of_nan_enhancer, "gene_name"] = gtf_mapping.loc[
834
+ mask_of_nan_enhancer, "gene_name"
835
+ ]
836
+ logger.info(
837
+ f"Filled {mask_of_nan_enhancer.sum()} SNPs with no enhancer mapping using GTF mappings"
838
+ )
839
+
840
+ else:
841
+ raise ValueError(
842
+ f"Invalid gene_window_enhancer_priority for combining: {self.config.gene_window_enhancer_priority}"
843
+ )
844
+
845
+ return combined_mapping
846
+
847
+ def _get_snp_gene_pair_from_gtf(self, bim: pd.DataFrame, bim_pr: pr.PyRanges) -> pd.DataFrame:
848
+ """
849
+ Get SNP-gene pairs based on GTF annotations.
850
+
851
+ Parameters
852
+ ----------
853
+ bim : pd.DataFrame
854
+ BIM DataFrame
855
+ bim_pr : pr.PyRanges
856
+ BIM PyRanges object
857
+
858
+ Returns
859
+ -------
860
+ pd.DataFrame
861
+ SNP-gene pairs based on GTF
862
+ """
863
+ logger.info(
864
+ "Getting SNP-gene pairs from GTF. SNPs in multiple genes will be assigned to the nearest gene (by TSS)"
865
+ )
866
+
867
+ # Find overlaps between SNPs and gene windows
868
+ overlaps = overlaps_gtf_bim(self.gtf_pr, bim_pr)
869
+
870
+ # Get SNP information
871
+ annot = bim[["CHR", "BP", "SNP", "CM"]]
872
+
873
+ # Create SNP-gene pairs DataFrame
874
+ snp_gene_pair = (
875
+ overlaps[["SNP", "gene_name"]]
876
+ .set_index("SNP")
877
+ .join(annot.set_index("SNP"), how="right")
878
+ )
879
+
880
+ logger.info(f"Found {overlaps.shape[0]} SNP-gene pairs from GTF")
881
+
882
+ return snp_gene_pair
883
+
884
+ def _get_snp_gene_pair_from_enhancer(
885
+ self, bim: pd.DataFrame, bim_pr: pr.PyRanges
886
+ ) -> pd.DataFrame:
887
+ """
888
+ Get SNP-gene pairs based on enhancer annotations.
889
+
890
+ Parameters
891
+ ----------
892
+ bim : pd.DataFrame
893
+ BIM DataFrame
894
+ bim_pr : pr.PyRanges
895
+ BIM PyRanges object
896
+
897
+ Returns
898
+ -------
899
+ pd.DataFrame
900
+ SNP-gene pairs based on enhancer
901
+ """
902
+ if self.enhancer_pr is None:
903
+ raise ValueError("Enhancer annotation file is required but not provided")
904
+
905
+ # Find overlaps between SNPs and enhancers
906
+ overlaps = self.enhancer_pr.join(bim_pr).df
907
+
908
+ # Get SNP information
909
+ annot = bim[["CHR", "BP", "SNP", "CM"]]
910
+
911
+ if self.config.snp_multiple_enhancer_strategy == "max_mkscore":
912
+ logger.info(
913
+ "SNPs in multiple enhancers will be assigned to the gene with highest marker score"
914
+ )
915
+ overlaps = overlaps.loc[overlaps.groupby("SNP").avg_mkscore.idxmax()]
916
+
917
+ elif self.config.snp_multiple_enhancer_strategy == "nearest_TSS":
918
+ logger.info("SNPs in multiple enhancers will be assigned to the gene with nearest TSS")
919
+ overlaps["Distance"] = np.abs(overlaps["Start_b"] - overlaps["TSS"])
920
+ overlaps = overlaps.loc[overlaps.groupby("SNP").Distance.idxmin()]
921
+
922
+ # Create SNP-gene pairs DataFrame
923
+ snp_gene_pair = (
924
+ overlaps[["SNP", "gene_name"]]
925
+ .set_index("SNP")
926
+ .join(annot.set_index("SNP"), how="right")
927
+ )
928
+
929
+ logger.info(f"Found {overlaps.shape[0]} SNP-gene pairs from enhancers")
930
+
931
+ return snp_gene_pair
932
+
933
+ def _save_snp_gene_pair_mapping(self, snp_gene_pair: pd.DataFrame, chrom: int):
934
+ """
935
+ Save SNP-gene pair mapping to a feather file.
936
+
937
+ Parameters
938
+ ----------
939
+ snp_gene_pair : pd.DataFrame
940
+ SNP-gene pair mapping
941
+ chrom : int
942
+ Chromosome number
943
+ """
944
+ save_path = (
945
+ Path(self.config.ldscore_save_dir) / f"SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather"
946
+ )
947
+ save_path.parent.mkdir(parents=True, exist_ok=True)
948
+ snp_gene_pair.reset_index().to_feather(save_path)
949
+
950
+ def _clear_memory(self):
951
+ """Clear memory to prevent leaks."""
952
+ gc.collect()
953
+
954
+
955
+ def run_generate_ldscore(config: GenerateLDScoreConfig):
956
+ """
957
+ Main function to run the LD score generation.
958
+
959
+ Parameters
960
+ ----------
961
+ config : GenerateLDScoreConfig
962
+ Configuration object
963
+ """
964
+ # Create output directory
965
+ Path(config.ldscore_save_dir).mkdir(parents=True, exist_ok=True)
966
+
967
+ if config.ldscore_save_format == "quick_mode":
968
+ logger.info(
969
+ "Running in quick_mode. Skip the process of generating ldscore. Using the pre-calculated ldscore."
970
+ )
971
+ ldscore_save_dir = Path(config.ldscore_save_dir)
972
+
973
+ # Set up symbolic links
974
+ baseline_dir = ldscore_save_dir / "baseline"
975
+ baseline_dir.parent.mkdir(parents=True, exist_ok=True)
976
+ if not baseline_dir.exists():
977
+ baseline_dir.symlink_to(config.baseline_annotation_dir, target_is_directory=True)
978
+
979
+ snp_gene_pair_dir = ldscore_save_dir / "SNP_gene_pair"
980
+ snp_gene_pair_dir.parent.mkdir(parents=True, exist_ok=True)
981
+ if not snp_gene_pair_dir.exists():
982
+ snp_gene_pair_dir.symlink_to(config.SNP_gene_pair_dir, target_is_directory=True)
983
+
984
+ # Create a done file to mark completion
985
+ done_file = ldscore_save_dir / f"{config.sample_name}_generate_ldscore.done"
986
+ done_file.touch()
987
+
988
+ return
989
+
990
+ # Initialize calculator
991
+ calculator = LDScoreCalculator(config)
992
+
993
+ # Process chromosomes
994
+ if config.chrom == "all":
995
+ # Process all chromosomes
996
+ for chrom in range(1, 23):
997
+ try:
998
+ calculator.process_chromosome(chrom)
999
+ except Exception as e:
1000
+ logger.error(f"Error processing chromosome {chrom}: {e}")
1001
+ raise
1002
+ else:
1003
+ # Process one chromosome
1004
+ try:
1005
+ chrom = int(config.chrom)
1006
+ except ValueError:
1007
+ logger.error(f"Invalid chromosome: {config.chrom}")
1008
+ raise ValueError(
1009
+ f"Invalid chromosome: {config.chrom}. Must be an integer between 1-22 or 'all'"
1010
+ ) from None
1011
+ else:
1012
+ calculator.process_chromosome(chrom)
1013
+
1014
+ # Create a done file to mark completion
1015
+ done_file = Path(config.ldscore_save_dir) / f"{config.sample_name}_generate_ldscore.done"
1016
+ done_file.touch()
1017
+
1018
+ logger.info(f"LD score generation completed for {config.sample_name}")