gsMap 1.72.3__py3-none-any.whl → 1.73.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsMap/generate_ldscore.py CHANGED
@@ -1,3 +1,11 @@
1
+ """
2
+ Module for generating LD scores for each spot in spatial transcriptomics data.
3
+
4
+ This module is responsible for assigning gene specificity scores to SNPs
5
+ and computing stratified LD scores that will be used for spatial LDSC analysis.
6
+ """
7
+
8
+ import gc
1
9
  import logging
2
10
  import warnings
3
11
  from pathlib import Path
@@ -10,113 +18,205 @@ from scipy.sparse import csr_matrix
10
18
  from tqdm import trange
11
19
 
12
20
  from gsMap.config import GenerateLDScoreConfig
13
- from gsMap.utils.generate_r2_matrix import ID_List_Factory, PlinkBEDFileWithR2Cache, getBlockLefts
21
+ from gsMap.utils.generate_r2_matrix import getBlockLefts, load_bfile
14
22
 
15
- warnings.filterwarnings("ignore", category=FutureWarning)
23
+ # Configure warning behavior more precisely
24
+ warnings.filterwarnings("ignore", category=FutureWarning, module="pandas")
16
25
  logger = logging.getLogger(__name__)
17
26
 
18
27
 
19
- # %%
20
- # load gtf
21
- def load_gtf(gtf_file, mk_score, window_size):
28
+ def load_gtf(
29
+ gtf_file: str, mk_score: pd.DataFrame, window_size: int
30
+ ) -> tuple[pr.PyRanges, pd.DataFrame]:
22
31
  """
23
- Load the gene annotation file (gtf).
32
+ Load and process the gene annotation file (GTF).
33
+
34
+ Parameters
35
+ ----------
36
+ gtf_file : str
37
+ Path to the GTF file
38
+ mk_score : pd.DataFrame
39
+ DataFrame containing marker scores
40
+ window_size : int
41
+ Window size around gene bodies in base pairs
42
+
43
+ Returns
44
+ -------
45
+ tuple
46
+ A tuple containing (gtf_pr, mk_score) where:
47
+ - gtf_pr is a PyRanges object with gene coordinates
48
+ - mk_score is the filtered marker score DataFrame
24
49
  """
25
- print("Loading gtf data")
26
- #
50
+ logger.info("Loading GTF data from %s", gtf_file)
51
+
27
52
  # Load GTF file
28
- gtf = pr.read_gtf(
29
- gtf_file,
30
- )
53
+ gtf = pr.read_gtf(gtf_file)
31
54
  gtf = gtf.df
32
- #
33
- # Select the common genes
55
+
56
+ # Filter for gene features
34
57
  gtf = gtf[gtf["Feature"] == "gene"]
58
+
59
+ # Find common genes between GTF and marker scores
35
60
  common_gene = np.intersect1d(mk_score.index, gtf.gene_name)
36
- #
61
+ logger.info(f"Found {len(common_gene)} common genes between GTF and marker scores")
62
+
63
+ # Filter GTF and marker scores to common genes
37
64
  gtf = gtf[gtf.gene_name.isin(common_gene)]
38
65
  mk_score = mk_score[mk_score.index.isin(common_gene)]
39
- #
40
- # Remove duplicated lines
66
+
67
+ # Remove duplicated gene entries
41
68
  gtf = gtf.drop_duplicates(subset="gene_name", keep="first")
42
- #
43
- # Process the GTF (open 100-KB window: Tss - Ted)
69
+
70
+ # Process the GTF (open window around gene coordinates)
44
71
  gtf_bed = gtf[["Chromosome", "Start", "End", "gene_name", "Strand"]].copy()
45
72
  gtf_bed.loc[:, "TSS"] = gtf_bed["Start"]
46
73
  gtf_bed.loc[:, "TED"] = gtf_bed["End"]
47
74
 
75
+ # Create windows around genes
48
76
  gtf_bed.loc[:, "Start"] = gtf_bed["TSS"] - window_size
49
77
  gtf_bed.loc[:, "End"] = gtf_bed["TED"] + window_size
50
78
  gtf_bed.loc[gtf_bed["Start"] < 0, "Start"] = 0
51
- #
52
- # Correct the negative strand
79
+
80
+ # Handle genes on negative strand (swap TSS and TED)
53
81
  tss_neg = gtf_bed.loc[gtf_bed["Strand"] == "-", "TSS"]
54
82
  ted_neg = gtf_bed.loc[gtf_bed["Strand"] == "-", "TED"]
55
83
  gtf_bed.loc[gtf_bed["Strand"] == "-", "TSS"] = ted_neg
56
84
  gtf_bed.loc[gtf_bed["Strand"] == "-", "TED"] = tss_neg
57
85
  gtf_bed = gtf_bed.drop("Strand", axis=1)
58
- #
59
- # Transform the GTF to PyRanges
86
+
87
+ # Convert to PyRanges
60
88
  gtf_pr = pr.PyRanges(gtf_bed)
89
+
61
90
  return gtf_pr, mk_score
62
91
 
63
92
 
64
- # %%
65
- def load_marker_score(mk_score_file):
93
+ def load_marker_score(mk_score_file: str) -> pd.DataFrame:
66
94
  """
67
- Load marker scores of each cell.
95
+ Load marker scores from a feather file.
96
+
97
+ Parameters
98
+ ----------
99
+ mk_score_file : str
100
+ Path to the marker score feather file
101
+
102
+ Returns
103
+ -------
104
+ pd.DataFrame
105
+ DataFrame with marker scores indexed by gene names
68
106
  """
69
107
  mk_score = pd.read_feather(mk_score_file).set_index("HUMAN_GENE_SYM").rename_axis("gene_name")
70
108
  mk_score = mk_score.astype(np.float32, copy=False)
71
109
  return mk_score
72
110
 
73
111
 
74
- # %%
75
- # load mkscore get common gene
76
- # %%
77
- # load bim
78
- def load_bim(bfile_root, chrom):
112
+ def load_bim(bfile_root: str, chrom: int) -> tuple[pd.DataFrame, pr.PyRanges]:
79
113
  """
80
- Load the bim file.
114
+ Load PLINK BIM file and convert to a PyRanges object.
115
+
116
+ Parameters
117
+ ----------
118
+ bfile_root : str
119
+ Root path for PLINK bfiles
120
+ chrom : int
121
+ Chromosome number
122
+
123
+ Returns
124
+ -------
125
+ tuple
126
+ A tuple containing (bim_df, bim_pr) where:
127
+ - bim_df is a pandas DataFrame with BIM data
128
+ - bim_pr is a PyRanges object with BIM data
81
129
  """
82
- bim = pd.read_csv(f"{bfile_root}.{chrom}.bim", sep="\t", header=None)
130
+ bim_file = f"{bfile_root}.{chrom}.bim"
131
+ logger.debug(f"Loading BIM file: {bim_file}")
132
+
133
+ bim = pd.read_csv(bim_file, sep="\t", header=None)
83
134
  bim.columns = ["CHR", "SNP", "CM", "BP", "A1", "A2"]
84
- #
85
- # Transform bim to PyRanges
135
+
136
+ # Convert to PyRanges
86
137
  bim_pr = bim.copy()
87
138
  bim_pr.columns = ["Chromosome", "SNP", "CM", "Start", "A1", "A2"]
88
139
 
140
+ # Adjust coordinates (BIM is 1-based, PyRanges uses 0-based)
89
141
  bim_pr["End"] = bim_pr["Start"].copy()
90
- bim_pr["Start"] = bim_pr["Start"] - 1 # Due to bim file is 1-based
142
+ bim_pr["Start"] = bim_pr["Start"] - 1
91
143
 
92
144
  bim_pr = pr.PyRanges(bim_pr)
93
145
  bim_pr.Chromosome = f"chr{chrom}"
146
+
94
147
  return bim, bim_pr
95
148
 
96
149
 
97
- # %%
98
- def Overlaps_gtf_bim(gtf_pr, bim_pr):
150
+ def overlaps_gtf_bim(gtf_pr: pr.PyRanges, bim_pr: pr.PyRanges) -> pd.DataFrame:
99
151
  """
100
- Find overlaps between gtf and bim file.
152
+ Find overlaps between GTF and BIM data, and select nearest gene for each SNP.
153
+
154
+ Parameters
155
+ ----------
156
+ gtf_pr : pr.PyRanges
157
+ PyRanges object with gene coordinates
158
+ bim_pr : pr.PyRanges
159
+ PyRanges object with SNP coordinates
160
+
161
+ Returns
162
+ -------
163
+ pd.DataFrame
164
+ DataFrame with SNP-gene pairs where each SNP is matched to its closest gene
101
165
  """
102
- # Select the overlapped regions (SNPs in gene windows)
166
+ # Join the PyRanges objects to find overlaps
103
167
  overlaps = gtf_pr.join(bim_pr)
104
168
  overlaps = overlaps.df
169
+
170
+ # Calculate distance to TSS
105
171
  overlaps["Distance"] = np.abs(overlaps["Start_b"] - overlaps["TSS"])
106
- overlaps_small = overlaps.copy()
107
- overlaps_small = overlaps_small.loc[overlaps_small.groupby("SNP").Distance.idxmin()]
108
- return overlaps_small
172
+
173
+ # For each SNP, select the closest gene
174
+ nearest_genes = overlaps.loc[overlaps.groupby("SNP").Distance.idxmin()]
175
+
176
+ return nearest_genes
109
177
 
110
178
 
111
- # %%
112
- def filter_snps_by_keep_snp(bim_df, keep_snp_file):
113
- # Load the keep_snp file and filter the BIM DataFrame
179
+ def filter_snps_by_keep_snp(bim_df: pd.DataFrame, keep_snp_file: str) -> pd.DataFrame:
180
+ """
181
+ Filter BIM DataFrame to keep only SNPs in a provided list.
182
+
183
+ Parameters
184
+ ----------
185
+ bim_df : pd.DataFrame
186
+ DataFrame with BIM data
187
+ keep_snp_file : str
188
+ Path to a file with SNP IDs to keep
189
+
190
+ Returns
191
+ -------
192
+ pd.DataFrame
193
+ Filtered BIM DataFrame
194
+ """
195
+ # Read SNPs to keep
114
196
  keep_snp = pd.read_csv(keep_snp_file, header=None)[0].to_list()
197
+
198
+ # Filter the BIM DataFrame
115
199
  filtered_bim_df = bim_df[bim_df["SNP"].isin(keep_snp)]
200
+
201
+ logger.info(f"Kept {len(filtered_bim_df)} SNPs out of {len(bim_df)} after filtering")
202
+
116
203
  return filtered_bim_df
117
204
 
118
205
 
119
- def get_snp_counts(config):
206
+ def get_snp_counts(config: GenerateLDScoreConfig) -> dict:
207
+ """
208
+ Count SNPs per chromosome and calculate start positions for zarr arrays.
209
+
210
+ Parameters
211
+ ----------
212
+ config : GenerateLDScoreConfig
213
+ Configuration object
214
+
215
+ Returns
216
+ -------
217
+ dict
218
+ Dictionary with SNP counts and start positions
219
+ """
120
220
  snp_counts = {}
121
221
  total_snp = 0
122
222
 
@@ -134,65 +234,84 @@ def get_snp_counts(config):
134
234
 
135
235
  snp_counts["total"] = total_snp
136
236
 
237
+ # Calculate cumulative SNP counts for zarr array indexing
137
238
  chrom_snp_length_array = np.array([snp_counts[chrom] for chrom in range(1, 23)]).cumsum()
138
-
139
239
  snp_counts["chrom_snp_start_point"] = [0] + chrom_snp_length_array.tolist()
140
240
 
141
241
  return snp_counts
142
242
 
143
243
 
144
- # %%
145
- def get_snp_pass_maf(bfile_root, chrom, maf_min=0.05):
244
+ def get_snp_pass_maf(bfile_root: str, chrom: int, maf_min: float = 0.05) -> list[str]:
146
245
  """
147
- Get the dummy matrix of SNP-gene pairs.
246
+ Get SNPs that pass the minimum minor allele frequency (MAF) threshold.
247
+
248
+ Parameters
249
+ ----------
250
+ bfile_root : str
251
+ Root path for PLINK bfiles
252
+ chrom : int
253
+ Chromosome number
254
+ maf_min : float, optional
255
+ Minimum MAF threshold, by default 0.05
256
+
257
+ Returns
258
+ -------
259
+ list
260
+ List of SNP IDs that pass the MAF threshold
148
261
  """
149
- # Load the bim file
150
- PlinkBIMFile = ID_List_Factory(
151
- ["CHR", "SNP", "CM", "BP", "A1", "A2"], 1, ".bim", usecols=[0, 1, 2, 3, 4, 5]
262
+ array_snps, array_indivs, geno_array = load_bfile(
263
+ bfile_chr_prefix=f"{bfile_root}.{chrom}", mafMin=maf_min
152
264
  )
153
- PlinkFAMFile = ID_List_Factory(["IID"], 0, ".fam", usecols=[1])
154
265
 
155
- bfile = f"{bfile_root}.{chrom}"
156
- snp_file, snp_obj = bfile + ".bim", PlinkBIMFile
157
- array_snps = snp_obj(snp_file)
158
- # m = len(array_snps.IDList)
159
-
160
- # Load fam
161
- ind_file, ind_obj = bfile + ".fam", PlinkFAMFile
162
- array_indivs = ind_obj(ind_file)
266
+ m = len(array_snps.IDList)
163
267
  n = len(array_indivs.IDList)
164
- array_file, array_obj = bfile + ".bed", PlinkBEDFileWithR2Cache
165
- geno_array = array_obj(
166
- array_file, n, array_snps, keep_snps=None, keep_indivs=None, mafMin=None
268
+ logger.info(
269
+ f"Loading genotype data for {m} SNPs and {n} individuals from {bfile_root}.{chrom}"
167
270
  )
168
- ii = geno_array.maf > maf_min
169
- snp_pass_maf = array_snps.IDList[ii]
170
- print(f"After filtering SNPs with MAF < {maf_min}, {len(snp_pass_maf)} SNPs remain.")
171
- return snp_pass_maf.SNP.to_list()
172
271
 
272
+ # Filter SNPs by MAF
273
+ snp_pass_maf = array_snps.IDList.iloc[geno_array.kept_snps]
274
+ logger.info(f"After filtering SNPs with MAF < {maf_min}, {len(snp_pass_maf)} SNPs remain")
173
275
 
174
- def get_ldscore(bfile_root, chrom, annot_matrix, ld_wind, ld_unit="CM"):
175
- PlinkBIMFile = ID_List_Factory(
176
- ["CHR", "SNP", "CM", "BP", "A1", "A2"], 1, ".bim", usecols=[0, 1, 2, 3, 4, 5]
177
- )
178
- PlinkFAMFile = ID_List_Factory(["IID"], 0, ".fam", usecols=[1])
276
+ return snp_pass_maf.SNP.to_list()
179
277
 
180
- bfile = f"{bfile_root}.{chrom}"
181
- snp_file, snp_obj = bfile + ".bim", PlinkBIMFile
182
- array_snps = snp_obj(snp_file)
183
- m = len(array_snps.IDList)
184
- print(f"Read list of {m} SNPs from {snp_file}")
185
278
 
186
- # Load fam
187
- ind_file, ind_obj = bfile + ".fam", PlinkFAMFile
188
- array_indivs = ind_obj(ind_file)
189
- n = len(array_indivs.IDList)
190
- print(f"Read list of {n} individuals from {ind_file}")
191
- array_file, array_obj = bfile + ".bed", PlinkBEDFileWithR2Cache
192
- geno_array = array_obj(
193
- array_file, n, array_snps, keep_snps=None, keep_indivs=None, mafMin=None
279
+ def get_ldscore(
280
+ bfile_root: str,
281
+ chrom: int,
282
+ annot_matrix: np.ndarray,
283
+ ld_wind: float,
284
+ ld_unit: str = "CM",
285
+ keep_snps_index: list[int] = None,
286
+ ) -> pd.DataFrame:
287
+ """
288
+ Calculate LD scores using PLINK data and an annotation matrix.
289
+
290
+ Parameters
291
+ ----------
292
+ bfile_root : str
293
+ Root path for PLINK bfiles
294
+ chrom : int
295
+ Chromosome number
296
+ annot_matrix : np.ndarray
297
+ Annotation matrix
298
+ ld_wind : float
299
+ LD window size
300
+ ld_unit : str, optional
301
+ Unit for the LD window, by default "CM"
302
+ keep_snps_index : list[int], optional
303
+ Indices of SNPs to keep, by default None
304
+
305
+ Returns
306
+ -------
307
+ pd.DataFrame
308
+ DataFrame with calculated LD scores
309
+ """
310
+ array_snps, array_indivs, geno_array = load_bfile(
311
+ bfile_chr_prefix=f"{bfile_root}.{chrom}", keep_snps=keep_snps_index
194
312
  )
195
- # Load the annotations of the baseline
313
+
314
+ # Configure LD window based on specified unit
196
315
  if ld_unit == "SNP":
197
316
  max_dist = ld_wind
198
317
  coords = np.array(range(geno_array.m))
@@ -202,61 +321,141 @@ def get_ldscore(bfile_root, chrom, annot_matrix, ld_wind, ld_unit="CM"):
202
321
  elif ld_unit == "CM":
203
322
  max_dist = ld_wind
204
323
  coords = np.array(array_snps.df["CM"])[geno_array.kept_snps]
324
+ # Check if the CM is all 0
325
+ if np.all(coords == 0):
326
+ logger.warning(
327
+ "All CM values are 0 in the BIM file. Using 1MB window size for LD score calculation."
328
+ )
329
+ max_dist = 1_000_000
330
+ coords = np.array(array_snps.df["BP"])[geno_array.kept_snps]
205
331
  else:
206
- raise ValueError(f"Invalid ld_wind_unit: {ld_unit}")
332
+ raise ValueError(f"Invalid ld_wind_unit: {ld_unit}. Must be one of: SNP, KB, CM")
333
+
334
+ # Calculate blocks for LD computation
207
335
  block_left = getBlockLefts(coords, max_dist)
208
- # Calculate the LD score
209
- lN_df = pd.DataFrame(geno_array.ldScoreVarBlocks(block_left, 100, annot=annot_matrix))
210
- return lN_df
336
+ assert block_left.sum() > 0, "Invalid window size, please check the ld_wind parameter."
337
+
338
+ # Calculate LD scores
339
+ ld_scores = pd.DataFrame(geno_array.ldScoreVarBlocks(block_left, 100, annot=annot_matrix))
340
+
341
+ return ld_scores
211
342
 
212
343
 
213
- # %%
214
344
  def calculate_ldscore_from_annotation(
215
- SNP_annotation_df, chrom, bfile_root, ld_wind=1, ld_unit="CM"
216
- ):
345
+ snp_annotation_df: pd.DataFrame,
346
+ chrom: int,
347
+ bfile_root: str,
348
+ ld_wind: float = 1,
349
+ ld_unit: str = "CM",
350
+ ) -> pd.DataFrame:
217
351
  """
218
- Calculate the SNP-gene weight matrix.
352
+ Calculate LD scores from SNP annotation DataFrame.
353
+
354
+ Parameters
355
+ ----------
356
+ snp_annotation_df : pd.DataFrame
357
+ DataFrame with SNP annotations
358
+ chrom : int
359
+ Chromosome number
360
+ bfile_root : str
361
+ Root path for PLINK bfiles
362
+ ld_wind : float, optional
363
+ LD window size, by default 1
364
+ ld_unit : str, optional
365
+ Unit for the LD window, by default "CM"
366
+
367
+ Returns
368
+ -------
369
+ pd.DataFrame
370
+ DataFrame with calculated LD scores
219
371
  """
220
- # Get the dummy matrix
221
- # Get the SNP-gene weight matrix
372
+ # Calculate LD scores
222
373
  snp_gene_weight_matrix = get_ldscore(
223
- bfile_root, chrom, SNP_annotation_df.values, ld_wind=ld_wind, ld_unit=ld_unit
374
+ bfile_root, chrom, snp_annotation_df.values, ld_wind=ld_wind, ld_unit=ld_unit
224
375
  )
376
+
377
+ # Set proper data types and indices
225
378
  snp_gene_weight_matrix = snp_gene_weight_matrix.astype(np.float32, copy=False)
226
- snp_gene_weight_matrix.index = SNP_annotation_df.index
227
- snp_gene_weight_matrix.columns = SNP_annotation_df.columns
379
+ snp_gene_weight_matrix.index = snp_annotation_df.index
380
+ snp_gene_weight_matrix.columns = snp_annotation_df.columns
381
+
228
382
  return snp_gene_weight_matrix
229
383
 
230
384
 
231
385
  def calculate_ldscore_from_multiple_annotation(
232
- SNP_annotation_df_list, chrom, bfile_root, ld_wind=1, ld_unit="CM"
233
- ):
234
- SNP_annotation_df = pd.concat(SNP_annotation_df_list, axis=1).astype(np.float32, copy=False)
386
+ snp_annotation_df_list: list[pd.DataFrame],
387
+ chrom: int,
388
+ bfile_root: str,
389
+ ld_wind: float = 1,
390
+ ld_unit: str = "CM",
391
+ ) -> list[pd.DataFrame]:
392
+ """
393
+ Calculate LD scores from multiple SNP annotation DataFrames.
394
+
395
+ Parameters
396
+ ----------
397
+ snp_annotation_df_list : list
398
+ List of DataFrames with SNP annotations
399
+ chrom : int
400
+ Chromosome number
401
+ bfile_root : str
402
+ Root path for PLINK bfiles
403
+ ld_wind : float, optional
404
+ LD window size, by default 1
405
+ ld_unit : str, optional
406
+ Unit for the LD window, by default "CM"
407
+
408
+ Returns
409
+ -------
410
+ list
411
+ List of DataFrames with calculated LD scores
412
+ """
413
+ # Combine annotations
414
+ combined_annotations = pd.concat(snp_annotation_df_list, axis=1).astype(np.float32, copy=False)
235
415
 
236
- snp_gene_weight_matrix = get_ldscore(
237
- bfile_root, chrom, SNP_annotation_df.values, ld_wind=ld_wind, ld_unit=ld_unit
416
+ # Calculate LD scores
417
+ combined_ld_scores = get_ldscore(
418
+ bfile_root, chrom, combined_annotations.values, ld_wind=ld_wind, ld_unit=ld_unit
238
419
  )
239
- snp_gene_weight_matrix = snp_gene_weight_matrix.astype(np.float32, copy=False)
240
- snp_gene_weight_matrix.index = SNP_annotation_df.index
241
- snp_gene_weight_matrix.columns = SNP_annotation_df.columns
242
-
243
- # split to each annotation
244
- snp_annotation_len_list = [len(df.columns) for df in SNP_annotation_df_list]
245
- snp_gene_weight_matrix_list = []
246
- start = 0
247
- for snp_annotation_len in snp_annotation_len_list:
248
- snp_gene_weight_matrix_list.append(
249
- snp_gene_weight_matrix.iloc[:, start : start + snp_annotation_len]
250
- )
251
- start += snp_annotation_len
252
- return snp_gene_weight_matrix_list
420
+
421
+ # Apply proper indices and columns
422
+ combined_ld_scores.index = combined_annotations.index
423
+ combined_ld_scores.columns = combined_annotations.columns
424
+
425
+ # Split back into separate DataFrames
426
+ annotation_lengths = [len(df.columns) for df in snp_annotation_df_list]
427
+ result_dataframes = []
428
+ start_col = 0
429
+
430
+ for length in annotation_lengths:
431
+ end_col = start_col + length
432
+ result_dataframes.append(combined_ld_scores.iloc[:, start_col:end_col])
433
+ start_col = end_col
434
+
435
+ return result_dataframes
253
436
 
254
437
 
255
- # %%
256
- class S_LDSC_Boost:
438
+ class LDScoreCalculator:
439
+ """
440
+ Class for calculating LD scores from gene specificity scores.
441
+
442
+ This class handles the assignment of gene specificity scores to SNPs
443
+ and the calculation of LD scores.
444
+ """
445
+
257
446
  def __init__(self, config: GenerateLDScoreConfig):
447
+ """
448
+ Initialize LDScoreCalculator.
449
+
450
+ Parameters
451
+ ----------
452
+ config : GenerateLDScoreConfig
453
+ Configuration object
454
+ """
258
455
  self.config = config
456
+ self.validate_config()
259
457
 
458
+ # Load marker scores
260
459
  self.mk_score = load_marker_score(config.mkscore_feather_path)
261
460
 
262
461
  # Load GTF and get common markers
@@ -264,499 +463,895 @@ class S_LDSC_Boost:
264
463
  config.gtf_annotation_file, self.mk_score, window_size=config.gene_window_size
265
464
  )
266
465
 
267
- # Load enhancer
268
- if config.enhancer_annotation_file is not None:
269
- enhancer_df = pr.read_bed(config.enhancer_annotation_file, as_df=True)
270
- enhancer_df.set_index("Name", inplace=True)
271
- enhancer_df.index.name = "gene_name"
272
-
273
- # keep the common genes and add the enhancer score
274
- avg_mkscore = pd.DataFrame(self.mk_score_common.mean(axis=1), columns=["avg_mkscore"])
275
- enhancer_df = enhancer_df.join(
276
- avg_mkscore,
277
- how="inner",
278
- on="gene_name",
466
+ # Initialize enhancer data if provided
467
+ self.enhancer_pr = self._initialize_enhancer() if config.enhancer_annotation_file else None
468
+
469
+ # Initialize zarr file if needed
470
+ self._initialize_zarr_if_needed()
471
+
472
+ def validate_config(self):
473
+ """Validate configuration parameters."""
474
+ if not Path(self.config.mkscore_feather_path).exists():
475
+ raise FileNotFoundError(
476
+ f"Marker score file not found: {self.config.mkscore_feather_path}"
279
477
  )
280
478
 
281
- # add distance to TSS
282
- enhancer_df["TSS"] = self.gtf_pr.df.set_index("gene_name").reindex(enhancer_df.index)[
283
- "TSS"
284
- ]
479
+ if not Path(self.config.gtf_annotation_file).exists():
480
+ raise FileNotFoundError(
481
+ f"GTF annotation file not found: {self.config.gtf_annotation_file}"
482
+ )
285
483
 
286
- # convert to pyranges
287
- self.enhancer_pr = pr.PyRanges(enhancer_df.reset_index())
484
+ if (
485
+ self.config.enhancer_annotation_file
486
+ and not Path(self.config.enhancer_annotation_file).exists()
487
+ ):
488
+ raise FileNotFoundError(
489
+ f"Enhancer annotation file not found: {self.config.enhancer_annotation_file}"
490
+ )
288
491
 
289
- else:
290
- self.enhancer_pr = None
492
+ def _initialize_enhancer(self) -> pr.PyRanges:
493
+ """
494
+ Initialize enhancer data.
291
495
 
292
- # create tha zarr file
293
- if config.ldscore_save_format == "zarr":
294
- chrom_snp_length_dict = get_snp_counts(config)
496
+ Returns
497
+ -------
498
+ pr.PyRanges
499
+ PyRanges object with enhancer data
500
+ """
501
+ # Load enhancer data
502
+ enhancer_df = pr.read_bed(self.config.enhancer_annotation_file, as_df=True)
503
+ enhancer_df.set_index("Name", inplace=True)
504
+ enhancer_df.index.name = "gene_name"
505
+
506
+ # Keep common genes and add marker score information
507
+ avg_mkscore = pd.DataFrame(self.mk_score_common.mean(axis=1), columns=["avg_mkscore"])
508
+ enhancer_df = enhancer_df.join(
509
+ avg_mkscore,
510
+ how="inner",
511
+ on="gene_name",
512
+ )
513
+
514
+ # Add TSS information
515
+ enhancer_df["TSS"] = self.gtf_pr.df.set_index("gene_name").reindex(enhancer_df.index)[
516
+ "TSS"
517
+ ]
518
+
519
+ # Convert to PyRanges
520
+ return pr.PyRanges(enhancer_df.reset_index())
521
+
522
+ def _initialize_zarr_if_needed(self):
523
+ """Initialize zarr file if zarr format is specified."""
524
+ if self.config.ldscore_save_format == "zarr":
525
+ chrom_snp_length_dict = get_snp_counts(self.config)
295
526
  self.chrom_snp_start_point = chrom_snp_length_dict["chrom_snp_start_point"]
296
527
 
297
- zarr_path = Path(config.ldscore_save_dir) / f"{config.sample_name}.ldscore.zarr"
528
+ zarr_path = (
529
+ Path(self.config.ldscore_save_dir) / f"{self.config.sample_name}.ldscore.zarr"
530
+ )
531
+
298
532
  if not zarr_path.exists():
299
533
  self.zarr_file = zarr.open(
300
534
  zarr_path.as_posix(),
301
535
  mode="a",
302
536
  dtype=np.float16,
303
- chunks=config.zarr_chunk_size,
537
+ chunks=self.config.zarr_chunk_size,
304
538
  shape=(chrom_snp_length_dict["total"], self.mk_score_common.shape[1]),
305
539
  )
306
- zarr_path.mkdir(parents=True, exist_ok=True)
307
- # save spot names
540
+ zarr_path.parent.mkdir(parents=True, exist_ok=True)
541
+
542
+ # Save metadata
308
543
  self.zarr_file.attrs["spot_names"] = self.mk_score_common.columns.to_list()
309
- # save chrom_snp_length_dict
310
544
  self.zarr_file.attrs["chrom_snp_start_point"] = self.chrom_snp_start_point
545
+
311
546
  else:
312
547
  self.zarr_file = zarr.open(zarr_path.as_posix(), mode="a")
313
548
 
314
549
  def process_chromosome(self, chrom: int):
550
+ """
551
+ Process a single chromosome to calculate LD scores.
552
+
553
+ Parameters
554
+ ----------
555
+ chrom : int
556
+ Chromosome number
557
+ """
558
+ logger.info(f"Processing chromosome {chrom}")
559
+
560
+ # Get SNPs passing MAF filter
315
561
  self.snp_pass_maf = get_snp_pass_maf(self.config.bfile_root, chrom, maf_min=0.05)
316
562
 
317
- # Get SNP-Gene dummy pairs
318
- self.snp_gene_pair_dummy = self.get_snp_gene_dummy(
563
+ # Get SNP-gene dummy pairs
564
+ self.snp_gene_pair_dummy = self._get_snp_gene_dummy(chrom)
565
+
566
+ # Apply SNP filter if provided
567
+ self._apply_snp_filter(chrom)
568
+
569
+ # Process additional baseline annotations if provided
570
+ if self.config.additional_baseline_annotation:
571
+ self._process_additional_baseline(chrom)
572
+ else:
573
+ # Calculate SNP-gene weight matrix
574
+ self.snp_gene_weight_matrix = calculate_ldscore_from_annotation(
575
+ self.snp_gene_pair_dummy,
576
+ chrom,
577
+ self.config.bfile_root,
578
+ ld_wind=self.config.ld_wind,
579
+ ld_unit=self.config.ld_unit,
580
+ )
581
+
582
+ # Apply SNP filter if needed
583
+ if self.keep_snp_mask is not None:
584
+ self.snp_gene_weight_matrix = self.snp_gene_weight_matrix[self.keep_snp_mask]
585
+
586
+ # Generate w_ld file if keep_snp_root is provided
587
+ if self.config.keep_snp_root:
588
+ self._generate_w_ld(chrom)
589
+
590
+ # Save pre-calculated SNP-gene weight matrix if requested
591
+ self._save_snp_gene_weight_matrix_if_needed(chrom)
592
+
593
+ # Convert to sparse matrix for memory efficiency
594
+ self.snp_gene_weight_matrix = csr_matrix(self.snp_gene_weight_matrix)
595
+ logger.info(f"SNP-gene weight matrix shape: {self.snp_gene_weight_matrix.shape}")
596
+
597
+ # Calculate baseline LD scores
598
+ logger.info(f"Calculating baseline LD scores for chr{chrom}")
599
+ self._calculate_baseline_ldscores(chrom)
600
+
601
+ # Calculate LD scores for annotation
602
+ logger.info(f"Calculating annotation LD scores for chr{chrom}")
603
+ self._calculate_annotation_ldscores(chrom)
604
+
605
+ # Clear memory
606
+ self._clear_memory()
607
+
608
+ def _generate_w_ld(self, chrom: int):
609
+ """
610
+ Generate w_ld file for the chromosome using filtered SNPs.
611
+
612
+ Parameters
613
+ ----------
614
+ chrom : int
615
+ Chromosome number
616
+ """
617
+ if not self.config.keep_snp_root:
618
+ logger.info(
619
+ f"Skipping w_ld generation for chr{chrom} as keep_snp_root is not provided"
620
+ )
621
+ return
622
+
623
+ logger.info(f"Generating w_ld for chr{chrom}")
624
+
625
+ # Get the indices of SNPs to keep based on the keep_snp_mask
626
+ keep_snps_index = np.nonzero(self.keep_snp_mask)[0]
627
+
628
+ # Create a simple unit annotation (all ones) for the filtered SNPs
629
+ unit_annotation = np.ones((len(keep_snps_index), 1))
630
+
631
+ # Calculate LD scores using the filtered SNPs
632
+ w_ld_scores = get_ldscore(
633
+ self.config.bfile_root,
319
634
  chrom,
635
+ unit_annotation,
636
+ ld_wind=self.config.ld_wind,
637
+ ld_unit=self.config.ld_unit,
638
+ keep_snps_index=keep_snps_index.tolist(),
639
+ )
640
+
641
+ # Load the BIM file to get SNP information
642
+ bim_data = pd.read_csv(
643
+ f"{self.config.bfile_root}.{chrom}.bim",
644
+ sep="\t",
645
+ header=None,
646
+ names=["CHR", "SNP", "CM", "BP", "A1", "A2"],
647
+ )
648
+
649
+ # Get SNP names for the kept indices
650
+ kept_snp_names = bim_data.iloc[keep_snps_index].SNP.tolist()
651
+
652
+ # Create the w_ld DataFrame
653
+ w_ld_df = pd.DataFrame(
654
+ {
655
+ "SNP": kept_snp_names,
656
+ "L2": w_ld_scores.values.flatten(),
657
+ "CHR": bim_data.iloc[keep_snps_index].CHR.values,
658
+ "BP": bim_data.iloc[keep_snps_index].BP.values,
659
+ "CM": bim_data.iloc[keep_snps_index].CM.values,
660
+ }
320
661
  )
321
662
 
663
+ # Reorder columns
664
+ w_ld_df = w_ld_df[["CHR", "SNP", "BP", "CM", "L2"]]
665
+
666
+ # Save to feather format
667
+ w_ld_dir = Path(self.config.ldscore_save_dir) / "w_ld"
668
+ w_ld_dir.mkdir(parents=True, exist_ok=True)
669
+ w_ld_file = w_ld_dir / f"weights.{chrom}.l2.ldscore.gz"
670
+ w_ld_df.to_csv(w_ld_file, sep="\t", index=False, compression="gzip")
671
+
672
+ logger.info(f"Saved w_ld for chr{chrom} to {w_ld_file}")
673
+
674
+ def _apply_snp_filter(self, chrom: int):
675
+ """
676
+ Apply SNP filter based on keep_snp_root.
677
+
678
+ Parameters
679
+ ----------
680
+ chrom : int
681
+ Chromosome number
682
+ """
322
683
  if self.config.keep_snp_root is not None:
323
- keep_snp = pd.read_csv(f"{self.config.keep_snp_root}.{chrom}.snp", header=None)[
324
- 0
325
- ].to_list()
684
+ keep_snp_file = f"{self.config.keep_snp_root}.{chrom}.snp"
685
+ keep_snp = pd.read_csv(keep_snp_file, header=None)[0].to_list()
326
686
  self.keep_snp_mask = self.snp_gene_pair_dummy.index.isin(keep_snp)
327
- # the SNP name of keeped
328
687
  self.snp_name = self.snp_gene_pair_dummy.index[self.keep_snp_mask].to_list()
688
+ logger.info(f"Kept {len(self.snp_name)} SNPs after filtering with {keep_snp_file}")
689
+ logger.info("These filtered SNPs will be used to calculate w_ld")
329
690
  else:
330
691
  self.keep_snp_mask = None
331
692
  self.snp_name = self.snp_gene_pair_dummy.index.to_list()
693
+ logger.info(f"Using all {len(self.snp_name)} SNPs (no filter applied)")
694
+ logger.warning("No keep_snp_root provided, all SNPs will be used to calculate w_ld.")
332
695
 
333
- if self.config.additional_baseline_annotation is not None:
334
- additional_baseline_annotation = Path(self.config.additional_baseline_annotation)
335
- additional_baseline_annotation_file_path = (
336
- additional_baseline_annotation / f"baseline.{chrom}.annot.gz"
337
- )
338
- assert additional_baseline_annotation_file_path.exists(), (
339
- f"additional_baseline_annotation_file_path not exists: {additional_baseline_annotation_file_path}"
340
- )
341
- additional_baseline_annotation_df = pd.read_csv(
342
- additional_baseline_annotation_file_path, sep="\t"
343
- )
344
- additional_baseline_annotation_df.set_index("SNP", inplace=True)
345
-
346
- # drop these columns if exists CHR BP CM]
347
- additional_baseline_annotation_df.drop(
348
- ["CHR", "BP", "CM"], axis=1, inplace=True, errors="ignore"
349
- )
350
-
351
- # reindex, for those SNPs not in additional_baseline_annotation_df, set to 0
352
- num_of_not_exist_snp = (
353
- ~self.snp_gene_pair_dummy.index.isin(additional_baseline_annotation_df.index)
354
- ).sum()
355
- if num_of_not_exist_snp > 0:
356
- logger.warning(
357
- f"{num_of_not_exist_snp} SNPs not in additional_baseline_annotation_df but in the reference panel, so the additional baseline annotation of these SNP will set to 0"
358
- )
359
- additional_baseline_annotation_df = additional_baseline_annotation_df.reindex(
360
- self.snp_gene_pair_dummy.index, fill_value=0
361
- )
362
- else:
363
- additional_baseline_annotation_df = additional_baseline_annotation_df.reindex(
364
- self.snp_gene_pair_dummy.index
365
- )
696
+ def _process_additional_baseline(self, chrom: int):
697
+ """
698
+ Process additional baseline annotations.
366
699
 
367
- # do this for saving the cpu time, only calculate r2 once
368
- self.snp_gene_weight_matrix, additional_baseline_annotation_ldscore = (
369
- calculate_ldscore_from_multiple_annotation(
370
- [self.snp_gene_pair_dummy, additional_baseline_annotation_df],
371
- chrom,
372
- self.config.bfile_root,
373
- ld_wind=self.config.ld_wind,
374
- ld_unit=self.config.ld_unit,
375
- )
700
+ Parameters
701
+ ----------
702
+ chrom : int
703
+ Chromosome number
704
+ """
705
+ # Load additional baseline annotations
706
+ additional_baseline_path = Path(self.config.additional_baseline_annotation)
707
+ annot_file_path = additional_baseline_path / f"baseline.{chrom}.annot.gz"
708
+
709
+ # Verify file existence
710
+ if not annot_file_path.exists():
711
+ raise FileNotFoundError(
712
+ f"Additional baseline annotation file not found: {annot_file_path}"
376
713
  )
377
714
 
378
- additional_baseline_annotation_ldscore = additional_baseline_annotation_ldscore.loc[
379
- self.snp_name
380
- ]
381
- # print(additional_baseline_annotation_ldscore.index.to_list()==self.snp_name)
715
+ # Load annotations
716
+ additional_baseline_df = pd.read_csv(annot_file_path, sep="\t")
717
+ additional_baseline_df.set_index("SNP", inplace=True)
382
718
 
383
- ld_score_file = f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.ldscore.feather"
384
- M_file_path = (
385
- f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.M"
386
- )
387
- M_5_file_path = (
388
- f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.M_5_50"
389
- )
719
+ # Drop unnecessary columns
720
+ for col in ["CHR", "BP", "CM"]:
721
+ if col in additional_baseline_df.columns:
722
+ additional_baseline_df.drop(col, axis=1, inplace=True)
390
723
 
391
- # save additional baseline annotation ldscore
392
- self.save_ldscore_to_feather(
393
- additional_baseline_annotation_ldscore.values,
394
- column_names=additional_baseline_annotation_ldscore.columns,
395
- save_file_name=ld_score_file,
396
- )
724
+ # Check for SNPs not in the additional baseline
725
+ missing_snps = ~self.snp_gene_pair_dummy.index.isin(additional_baseline_df.index)
726
+ missing_count = missing_snps.sum()
397
727
 
398
- # caculate the M and save
399
- save_dir = Path(M_file_path).parent
400
- save_dir.mkdir(parents=True, exist_ok=True)
401
- M_chr_chunk = additional_baseline_annotation_df.values.sum(axis=0, keepdims=True)
402
- M_5_chr_chunk = additional_baseline_annotation_df.loc[self.snp_pass_maf].values.sum(
403
- axis=0, keepdims=True
404
- )
405
- np.savetxt(
406
- M_file_path,
407
- M_chr_chunk,
408
- delimiter="\t",
728
+ if missing_count > 0:
729
+ logger.warning(
730
+ f"{missing_count} SNPs not found in additional baseline annotations. "
731
+ "Setting their values to 0."
409
732
  )
410
- np.savetxt(
411
- M_5_file_path,
412
- M_5_chr_chunk,
413
- delimiter="\t",
733
+ additional_baseline_df = additional_baseline_df.reindex(
734
+ self.snp_gene_pair_dummy.index, fill_value=0
414
735
  )
415
-
416
736
  else:
417
- # Calculate SNP-Gene weight matrix
418
- self.snp_gene_weight_matrix = calculate_ldscore_from_annotation(
419
- self.snp_gene_pair_dummy,
737
+ additional_baseline_df = additional_baseline_df.reindex(self.snp_gene_pair_dummy.index)
738
+
739
+ # Calculate LD scores for both annotation sets together
740
+ self.snp_gene_weight_matrix, additional_ldscore = (
741
+ calculate_ldscore_from_multiple_annotation(
742
+ [self.snp_gene_pair_dummy, additional_baseline_df],
420
743
  chrom,
421
744
  self.config.bfile_root,
422
745
  ld_wind=self.config.ld_wind,
423
746
  ld_unit=self.config.ld_unit,
424
747
  )
425
- # only keep the snp in keep_snp_root
426
- if self.keep_snp_mask is not None:
427
- self.snp_gene_weight_matrix = self.snp_gene_weight_matrix[self.keep_snp_mask]
748
+ )
428
749
 
429
- if self.config.save_pre_calculate_snp_gene_weight_matrix:
430
- snp_gene_weight_matrix_save_dir = (
431
- Path(self.config.ldscore_save_dir) / "snp_gene_weight_matrix"
432
- )
433
- snp_gene_weight_matrix_save_dir.mkdir(parents=True, exist_ok=True)
434
- logger.info(f"Saving snp_gene_weight_matrix for chr{chrom}...")
435
- self.snp_gene_weight_matrix.reset_index().to_feather(
436
- snp_gene_weight_matrix_save_dir / f"{chrom}.snp_gene_weight_matrix.feather"
437
- )
750
+ # Filter additional ldscore
751
+ additional_ldscore = additional_ldscore.loc[self.snp_name]
438
752
 
439
- # convert to sparse
440
- self.snp_gene_weight_matrix = csr_matrix(self.snp_gene_weight_matrix)
441
- logger.info(
442
- f"Compute snp_gene_weight_matrix finished. shape: {self.snp_gene_weight_matrix.shape}"
753
+ # Save additional baseline LD scores
754
+ ld_score_file = f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.ldscore.feather"
755
+ m_file_path = f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.M"
756
+ m_5_file_path = (
757
+ f"{self.config.ldscore_save_dir}/additional_baseline/baseline.{chrom}.l2.M_5_50"
443
758
  )
759
+ Path(m_file_path).parent.mkdir(parents=True, exist_ok=True)
444
760
 
445
- # calculate baseline ld score
446
- logger.info(f"Calculating baseline ld score for chr{chrom}...")
447
- self.calculate_ldscore_for_base_line(
448
- chrom, self.config.sample_name, self.config.ldscore_save_dir
761
+ # Save LD scores
762
+ self._save_ldscore_to_feather(
763
+ additional_ldscore.values,
764
+ column_names=additional_ldscore.columns,
765
+ save_file_name=ld_score_file,
449
766
  )
450
767
 
451
- # calculate ld score for annotation
452
- logger.info(f"Calculating ld score for annotation for chr{chrom}...")
453
- self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chr(
454
- self.mk_score_common.loc[self.snp_gene_pair_dummy.columns[:-1]],
455
- chrom,
456
- self.config.sample_name,
457
- self.config.ldscore_save_dir,
768
+ # Calculate and save M values
769
+ m_chr_chunk = additional_baseline_df.values.sum(axis=0, keepdims=True)
770
+ m_5_chr_chunk = additional_baseline_df.loc[self.snp_pass_maf].values.sum(
771
+ axis=0, keepdims=True
458
772
  )
459
773
 
460
- def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
461
- self,
462
- mk_score_chunk,
463
- drop_dummy_na=True,
464
- ):
465
- if drop_dummy_na:
466
- ldscore_chr_chunk = self.snp_gene_weight_matrix[:, :-1] @ mk_score_chunk
467
- else:
468
- ldscore_chr_chunk = self.snp_gene_weight_matrix @ mk_score_chunk
774
+ # Save M statistics
775
+ np.savetxt(m_file_path, m_chr_chunk, delimiter="\t")
776
+ np.savetxt(m_5_file_path, m_5_chr_chunk, delimiter="\t")
469
777
 
470
- return ldscore_chr_chunk
778
+ def _save_snp_gene_weight_matrix_if_needed(self, chrom: int):
779
+ """
780
+ Save pre-calculated SNP-gene weight matrix if requested.
471
781
 
472
- def save_ldscore_to_feather(self, ldscore_chr_chunk: np.ndarray, column_names, save_file_name):
473
- save_dir = Path(save_file_name).parent
474
- save_dir.mkdir(parents=True, exist_ok=True)
782
+ Parameters
783
+ ----------
784
+ chrom : int
785
+ Chromosome number
786
+ """
787
+ if self.config.save_pre_calculate_snp_gene_weight_matrix:
788
+ save_dir = Path(self.config.ldscore_save_dir) / "snp_gene_weight_matrix"
789
+ save_dir.mkdir(parents=True, exist_ok=True)
475
790
 
476
- ldscore_chr_chunk = ldscore_chr_chunk.astype(np.float16, copy=False)
477
- # avoid overflow of float16, if inf, set to max of float16
478
- ldscore_chr_chunk[np.isinf(ldscore_chr_chunk)] = np.finfo(np.float16).max
479
- # ldscore_chr_chunk = ldscore_chr_chunk if self.config.keep_snp_root is None else ldscore_chr_chunk[
480
- # self.keep_snp_mask]
791
+ logger.info(f"Saving SNP-gene weight matrix for chr{chrom}")
481
792
 
482
- # save for each chunk
483
- df = pd.DataFrame(
484
- ldscore_chr_chunk,
485
- index=self.snp_name,
486
- columns=column_names,
793
+ save_path = save_dir / f"{chrom}.snp_gene_weight_matrix.feather"
794
+ self.snp_gene_weight_matrix.reset_index().to_feather(save_path)
795
+
796
+ def _calculate_baseline_ldscores(self, chrom: int):
797
+ """
798
+ Calculate and save baseline LD scores.
799
+
800
+ Parameters
801
+ ----------
802
+ chrom : int
803
+ Chromosome number
804
+ """
805
+ # Create baseline scores
806
+ baseline_mk_score = np.ones((self.snp_gene_pair_dummy.shape[1], 2))
807
+ baseline_mk_score[-1, 0] = 0 # all_gene column
808
+
809
+ baseline_df = pd.DataFrame(
810
+ baseline_mk_score, index=self.snp_gene_pair_dummy.columns, columns=["all_gene", "base"]
487
811
  )
488
- df.index.name = "SNP"
489
- df.reset_index().to_feather(save_file_name)
490
812
 
491
- def save_ldscore_chunk_to_zarr(
492
- self,
493
- ldscore_chr_chunk: np.ndarray,
494
- chrom: int,
495
- start_col_index,
496
- ):
497
- ldscore_chr_chunk = ldscore_chr_chunk.astype(np.float16, copy=False)
498
- # avoid overflow of float16, if inf, set to max of float16
499
- ldscore_chr_chunk[np.isinf(ldscore_chr_chunk)] = np.finfo(np.float16).max
813
+ # Define file paths
814
+ ld_score_file = (
815
+ f"{self.config.ldscore_save_dir}/baseline/baseline.{chrom}.l2.ldscore.feather"
816
+ )
817
+ m_file = f"{self.config.ldscore_save_dir}/baseline/baseline.{chrom}.l2.M"
818
+ m_5_file = f"{self.config.ldscore_save_dir}/baseline/baseline.{chrom}.l2.M_5_50"
500
819
 
501
- # save for each chunk
502
- chrom_snp_start_point = self.chrom_snp_start_point[chrom - 1]
503
- chrom_snp_end_point = self.chrom_snp_start_point[chrom]
820
+ # Calculate LD scores
821
+ ldscore_chunk = self._calculate_ldscore_from_weights(baseline_df, drop_dummy_na=False)
504
822
 
505
- self.zarr_file[
506
- chrom_snp_start_point:chrom_snp_end_point,
507
- start_col_index : start_col_index + ldscore_chr_chunk.shape[1],
508
- ] = ldscore_chr_chunk
823
+ # Save LD scores and M values
824
+ self._save_ldscore_to_feather(
825
+ ldscore_chunk,
826
+ column_names=baseline_df.columns,
827
+ save_file_name=ld_score_file,
828
+ )
509
829
 
510
- def calculate_M_use_SNP_gene_pair_dummy_by_chunk(
511
- self,
512
- mk_score_chunk,
513
- M_file_path,
514
- M_5_file_path,
515
- drop_dummy_na=True,
516
- ):
830
+ self._calculate_and_save_m_values(
831
+ baseline_df,
832
+ m_file,
833
+ m_5_file,
834
+ drop_dummy_na=False,
835
+ )
836
+
837
+ # If keep_snp_root is not provided, use the first column of baseline ldscore as w_ld
838
+ if not self.config.keep_snp_root:
839
+ self._save_baseline_as_w_ld(chrom, ldscore_chunk)
840
+
841
+ def _save_baseline_as_w_ld(self, chrom: int, ldscore_chunk: np.ndarray):
517
842
  """
518
- Calculate M use SNP_gene_pair_dummy_sumed_along_snp_axis and mk_score_chunk
843
+ Save the first column of baseline ldscore as w_ld.
844
+
845
+ Parameters
846
+ ----------
847
+ chrom : int
848
+ Chromosome number
849
+ ldscore_chunk : np.ndarray
850
+ Array with baseline LD scores
519
851
  """
520
- SNP_gene_pair_dummy_sumed_along_snp_axis = self.snp_gene_pair_dummy.values.sum(
521
- axis=0, keepdims=True
522
- )
523
- SNP_gene_pair_dummy_sumed_along_snp_axis_pass_maf = self.snp_gene_pair_dummy.loc[
524
- self.snp_pass_maf
525
- ].values.sum(axis=0, keepdims=True)
526
- if drop_dummy_na:
527
- SNP_gene_pair_dummy_sumed_along_snp_axis = SNP_gene_pair_dummy_sumed_along_snp_axis[
528
- :, :-1
529
- ]
530
- SNP_gene_pair_dummy_sumed_along_snp_axis_pass_maf = (
531
- SNP_gene_pair_dummy_sumed_along_snp_axis_pass_maf[:, :-1]
532
- )
533
- save_dir = Path(M_file_path).parent
534
- save_dir.mkdir(parents=True, exist_ok=True)
535
- M_chr_chunk = SNP_gene_pair_dummy_sumed_along_snp_axis @ mk_score_chunk
536
- M_5_chr_chunk = SNP_gene_pair_dummy_sumed_along_snp_axis_pass_maf @ mk_score_chunk
537
- np.savetxt(
538
- M_file_path,
539
- M_chr_chunk,
540
- delimiter="\t",
852
+ logger.info(f"Using first column of baseline ldscore as w_ld for chr{chrom}")
853
+
854
+ # Create w_ld directory
855
+ w_ld_dir = Path(self.config.ldscore_save_dir) / "w_ld"
856
+ w_ld_dir.mkdir(parents=True, exist_ok=True)
857
+
858
+ # Define file path
859
+ w_ld_file = w_ld_dir / f"weights.{chrom}.l2.ldscore.gz"
860
+
861
+ # Extract the first column
862
+ w_ld_values = ldscore_chunk[:, 0]
863
+
864
+ # Create a DataFrame
865
+ bim_data = pd.read_csv(
866
+ f"{self.config.bfile_root}.{chrom}.bim",
867
+ sep="\t",
868
+ header=None,
869
+ names=["CHR", "SNP", "CM", "BP", "A1", "A2"],
541
870
  )
542
- np.savetxt(
543
- M_5_file_path,
544
- M_5_chr_chunk,
545
- delimiter="\t",
871
+ w_ld_df = pd.DataFrame(
872
+ {
873
+ "SNP": self.snp_name,
874
+ "L2": w_ld_values,
875
+ }
546
876
  )
547
877
 
548
- def calculate_ldscore_use_SNP_Gene_weight_matrix_by_chr(
549
- self, mk_score_common, chrom, sample_name, save_dir
550
- ):
878
+ # Add CHR, BP, and CM information
879
+ w_ld_df = w_ld_df.merge(bim_data[["SNP", "CHR", "BP", "CM"]], on="SNP", how="left")
880
+
881
+ # Reorder columns
882
+ w_ld_df = w_ld_df[["CHR", "SNP", "BP", "CM", "L2"]]
883
+
884
+ w_ld_df.to_csv(w_ld_file, sep="\t", index=False, compression="gzip")
885
+
886
+ logger.info(f"Saved w_ld for chr{chrom} to {w_ld_file}")
887
+
888
+ def _calculate_annotation_ldscores(self, chrom: int):
551
889
  """
552
- Calculate the LD score using the SNP-gene weight matrix.
553
- :param sample_name:
890
+ Calculate and save LD scores for spatial annotations.
891
+
892
+ Parameters
893
+ ----------
894
+ chrom : int
895
+ Chromosome number
554
896
  """
555
- # Calculate the LD score
897
+ # Get marker scores for gene columns (excluding dummy NA column)
898
+ mk_scores = self.mk_score_common.loc[self.snp_gene_pair_dummy.columns[:-1]]
899
+
900
+ # Process in chunks
556
901
  chunk_index = 1
557
902
  for i in trange(
558
903
  0,
559
- mk_score_common.shape[1],
904
+ mk_scores.shape[1],
560
905
  self.config.spots_per_chunk,
561
- desc=f"Calculating LD score by chunk for chr{chrom}",
906
+ desc=f"Calculating LD scores for chr{chrom}",
562
907
  ):
563
- mk_score_chunk = mk_score_common.iloc[:, i : i + self.config.spots_per_chunk]
908
+ # Get marker scores for current chunk
909
+ mk_score_chunk = mk_scores.iloc[:, i : i + self.config.spots_per_chunk]
564
910
 
565
- ld_score_file = f"{save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.ldscore.feather"
566
- M_file = f"{save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.M"
567
- M_5_file = (
568
- f"{save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.M_5_50"
569
- )
911
+ # Define file paths
912
+ sample_name = self.config.sample_name
913
+ ld_score_file = f"{self.config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.ldscore.feather"
914
+ m_file = f"{self.config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.M"
915
+ m_5_file = f"{self.config.ldscore_save_dir}/{sample_name}_chunk{chunk_index}/{sample_name}.{chrom}.l2.M_5_50"
570
916
 
571
- ldscore_chr_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
572
- mk_score_chunk,
573
- drop_dummy_na=True,
574
- )
917
+ # Calculate LD scores
918
+ ldscore_chunk = self._calculate_ldscore_from_weights(mk_score_chunk)
919
+
920
+ # Save LD scores based on format
575
921
  if self.config.ldscore_save_format == "feather":
576
- self.save_ldscore_to_feather(
577
- ldscore_chr_chunk,
922
+ self._save_ldscore_to_feather(
923
+ ldscore_chunk,
578
924
  column_names=mk_score_chunk.columns,
579
925
  save_file_name=ld_score_file,
580
926
  )
581
927
  elif self.config.ldscore_save_format == "zarr":
582
- self.save_ldscore_chunk_to_zarr(
583
- ldscore_chr_chunk,
928
+ self._save_ldscore_chunk_to_zarr(
929
+ ldscore_chunk,
584
930
  chrom=chrom,
585
931
  start_col_index=i,
586
932
  )
587
933
  else:
588
934
  raise ValueError(f"Invalid ldscore_save_format: {self.config.ldscore_save_format}")
589
935
 
590
- self.calculate_M_use_SNP_gene_pair_dummy_by_chunk(
936
+ # Calculate and save M values
937
+ self._calculate_and_save_m_values(
591
938
  mk_score_chunk,
592
- M_file,
593
- M_5_file,
939
+ m_file,
940
+ m_5_file,
594
941
  drop_dummy_na=True,
595
942
  )
596
943
 
597
944
  chunk_index += 1
598
945
 
599
- def calculate_ldscore_for_base_line(self, chrom, sample_name, save_dir):
600
- # save baseline ld score
601
- baseline_mk_score = np.ones((self.snp_gene_pair_dummy.shape[1], 2))
602
- baseline_mk_score[-1, 0] = 0 # all_gene
603
- baseline_mk_score_df = pd.DataFrame(
604
- baseline_mk_score, index=self.snp_gene_pair_dummy.columns, columns=["all_gene", "base"]
605
- )
606
- ld_score_file = f"{save_dir}/baseline/baseline.{chrom}.l2.ldscore.feather"
607
- M_file = f"{save_dir}/baseline/baseline.{chrom}.l2.M"
608
- M_5_file = f"{save_dir}/baseline/baseline.{chrom}.l2.M_5_50"
946
+ # Clear memory
947
+ del ldscore_chunk
948
+ gc.collect()
609
949
 
610
- ldscore_chr_chunk = self.calculate_ldscore_use_SNP_Gene_weight_matrix_by_chunk(
611
- baseline_mk_score_df,
612
- drop_dummy_na=False,
613
- )
950
+ def _calculate_ldscore_from_weights(
951
+ self, marker_scores: pd.DataFrame, drop_dummy_na: bool = True
952
+ ) -> np.ndarray:
953
+ """
954
+ Calculate LD scores using SNP-gene weight matrix.
955
+
956
+ Parameters
957
+ ----------
958
+ marker_scores : pd.DataFrame
959
+ DataFrame with marker scores
960
+ drop_dummy_na : bool, optional
961
+ Whether to drop the dummy NA column, by default True
962
+
963
+ Returns
964
+ -------
965
+ np.ndarray
966
+ Array with calculated LD scores
967
+ """
968
+ weight_matrix = self.snp_gene_weight_matrix
614
969
 
615
- self.save_ldscore_to_feather(
616
- ldscore_chr_chunk,
617
- column_names=baseline_mk_score_df.columns,
618
- save_file_name=ld_score_file,
619
- )
620
- # save baseline M
621
- self.calculate_M_use_SNP_gene_pair_dummy_by_chunk(
622
- baseline_mk_score_df,
623
- M_file,
624
- M_5_file,
625
- drop_dummy_na=False,
970
+ if drop_dummy_na:
971
+ # Use all columns except the last one (dummy NA)
972
+ ldscore = weight_matrix[:, :-1] @ marker_scores
973
+ else:
974
+ ldscore = weight_matrix @ marker_scores
975
+
976
+ return ldscore
977
+
978
+ def _save_ldscore_to_feather(
979
+ self, ldscore_data: np.ndarray, column_names: list[str], save_file_name: str
980
+ ):
981
+ """
982
+ Save LD scores to a feather file.
983
+
984
+ Parameters
985
+ ----------
986
+ ldscore_data : np.ndarray
987
+ Array with LD scores
988
+ column_names : list
989
+ List of column names
990
+ save_file_name : str
991
+ Path to save the feather file
992
+ """
993
+ # Create directory if needed
994
+ save_dir = Path(save_file_name).parent
995
+ save_dir.mkdir(parents=True, exist_ok=True)
996
+
997
+ # Convert to float16 for storage efficiency
998
+ ldscore_data = ldscore_data.astype(np.float16, copy=False)
999
+
1000
+ # Handle numerical overflow
1001
+ ldscore_data[np.isinf(ldscore_data)] = np.finfo(np.float16).max
1002
+
1003
+ # Create DataFrame and save
1004
+ df = pd.DataFrame(
1005
+ ldscore_data,
1006
+ index=self.snp_name,
1007
+ columns=column_names,
626
1008
  )
1009
+ df.index.name = "SNP"
1010
+ df.reset_index().to_feather(save_file_name)
627
1011
 
628
- def get_snp_gene_dummy(
1012
+ def _save_ldscore_chunk_to_zarr(
1013
+ self, ldscore_data: np.ndarray, chrom: int, start_col_index: int
1014
+ ):
1015
+ """
1016
+ Save LD scores to a zarr array.
1017
+
1018
+ Parameters
1019
+ ----------
1020
+ ldscore_data : np.ndarray
1021
+ Array with LD scores
1022
+ chrom : int
1023
+ Chromosome number
1024
+ start_col_index : int
1025
+ Starting column index in the zarr array
1026
+ """
1027
+ # Convert to float16 for storage efficiency
1028
+ ldscore_data = ldscore_data.astype(np.float16, copy=False)
1029
+
1030
+ # Handle numerical overflow
1031
+ ldscore_data[np.isinf(ldscore_data)] = np.finfo(np.float16).max
1032
+
1033
+ # Get start and end indices for this chromosome
1034
+ chrom_start = self.chrom_snp_start_point[chrom - 1]
1035
+ chrom_end = self.chrom_snp_start_point[chrom]
1036
+
1037
+ # Save to zarr array
1038
+ self.zarr_file[
1039
+ chrom_start:chrom_end,
1040
+ start_col_index : start_col_index + ldscore_data.shape[1],
1041
+ ] = ldscore_data
1042
+
1043
+ def _calculate_and_save_m_values(
629
1044
  self,
630
- chrom,
1045
+ marker_scores: pd.DataFrame,
1046
+ m_file_path: str,
1047
+ m_5_file_path: str,
1048
+ drop_dummy_na: bool = True,
631
1049
  ):
632
1050
  """
633
- Get the dummy matrix of SNP-gene pairs.
1051
+ Calculate and save M statistics.
1052
+
1053
+ Parameters
1054
+ ----------
1055
+ marker_scores : pd.DataFrame
1056
+ DataFrame with marker scores
1057
+ m_file_path : str
1058
+ Path to save M values
1059
+ m_5_file_path : str
1060
+ Path to save M_5_50 values
1061
+ drop_dummy_na : bool, optional
1062
+ Whether to drop the dummy NA column, by default True
1063
+ """
1064
+ # Create directory if needed
1065
+ save_dir = Path(m_file_path).parent
1066
+ save_dir.mkdir(parents=True, exist_ok=True)
1067
+
1068
+ # Get sum of SNP-gene pairs
1069
+ snp_gene_sum = self.snp_gene_pair_dummy.values.sum(axis=0, keepdims=True)
1070
+ snp_gene_sum_maf = self.snp_gene_pair_dummy.loc[self.snp_pass_maf].values.sum(
1071
+ axis=0, keepdims=True
1072
+ )
1073
+
1074
+ # Drop dummy NA column if requested
1075
+ if drop_dummy_na:
1076
+ snp_gene_sum = snp_gene_sum[:, :-1]
1077
+ snp_gene_sum_maf = snp_gene_sum_maf[:, :-1]
1078
+
1079
+ # Calculate M values
1080
+ m_values = snp_gene_sum @ marker_scores
1081
+ m_5_values = snp_gene_sum_maf @ marker_scores
1082
+
1083
+ # Save M values
1084
+ np.savetxt(m_file_path, m_values, delimiter="\t")
1085
+ np.savetxt(m_5_file_path, m_5_values, delimiter="\t")
1086
+
1087
+ def _get_snp_gene_dummy(self, chrom: int) -> pd.DataFrame:
1088
+ """
1089
+ Get dummy matrix for SNP-gene pairs.
1090
+
1091
+ Parameters
1092
+ ----------
1093
+ chrom : int
1094
+ Chromosome number
1095
+
1096
+ Returns
1097
+ -------
1098
+ pd.DataFrame
1099
+ DataFrame with dummy variables for SNP-gene pairs
634
1100
  """
635
- # Load the bim file
636
- print("Loading bim data")
1101
+ logger.info(f"Creating SNP-gene mappings for chromosome {chrom}")
1102
+
1103
+ # Load BIM file
637
1104
  bim, bim_pr = load_bim(self.config.bfile_root, chrom)
638
1105
 
1106
+ # Determine mapping strategy
639
1107
  if self.config.gene_window_enhancer_priority in ["gene_window_first", "enhancer_first"]:
640
- SNP_gene_pair_gtf = self.get_SNP_gene_pair_from_gtf(
641
- bim,
642
- bim_pr,
643
- )
644
- SNP_gene_pair_enhancer = self.get_SNP_gene_pair_from_enhancer(
645
- bim,
646
- bim_pr,
1108
+ # Use both gene window and enhancer
1109
+ snp_gene_pair = self._combine_gtf_and_enhancer_mappings(bim, bim_pr)
1110
+
1111
+ elif self.config.gene_window_enhancer_priority is None:
1112
+ # Use only gene window
1113
+ snp_gene_pair = self._get_snp_gene_pair_from_gtf(bim, bim_pr)
1114
+
1115
+ elif self.config.gene_window_enhancer_priority == "enhancer_only":
1116
+ # Use only enhancer
1117
+ snp_gene_pair = self._get_snp_gene_pair_from_enhancer(bim, bim_pr)
1118
+
1119
+ else:
1120
+ raise ValueError(
1121
+ f"Invalid gene_window_enhancer_priority: {self.config.gene_window_enhancer_priority}"
647
1122
  )
648
- # total_SNP_gene_pair = SNP_gene_pair_gtf.join(SNP_gene_pair_enhancer, how='outer', lsuffix='_gtf', )
649
-
650
- mask_of_nan_gtf = SNP_gene_pair_gtf.gene_name.isna()
651
- mask_of_nan_enhancer = SNP_gene_pair_enhancer.gene_name.isna()
652
-
653
- if self.config.gene_window_enhancer_priority == "gene_window_first":
654
- SNP_gene_pair = SNP_gene_pair_gtf
655
- SNP_gene_pair.loc[mask_of_nan_gtf, "gene_name"] = SNP_gene_pair_enhancer.loc[
656
- mask_of_nan_gtf, "gene_name"
657
- ]
658
- elif self.config.gene_window_enhancer_priority == "enhancer_first":
659
- SNP_gene_pair = SNP_gene_pair_enhancer
660
- SNP_gene_pair.loc[mask_of_nan_enhancer, "gene_name"] = SNP_gene_pair_gtf.loc[
661
- mask_of_nan_enhancer, "gene_name"
662
- ]
663
- else:
664
- raise ValueError(
665
- f"Invalid self.config.gene_window_enhancer_priority: {self.config.gene_window_enhancer_priority}"
666
- )
667
1123
 
668
- elif self.config.gene_window_enhancer_priority is None: # use gtf only
669
- SNP_gene_pair_gtf = self.get_SNP_gene_pair_from_gtf(
670
- bim,
671
- bim_pr,
1124
+ # Save SNP-gene pair mapping
1125
+ self._save_snp_gene_pair_mapping(snp_gene_pair, chrom)
1126
+
1127
+ # Create dummy variables
1128
+ snp_gene_dummy = pd.get_dummies(snp_gene_pair["gene_name"], dummy_na=True)
1129
+
1130
+ return snp_gene_dummy
1131
+
1132
+ def _combine_gtf_and_enhancer_mappings(
1133
+ self, bim: pd.DataFrame, bim_pr: pr.PyRanges
1134
+ ) -> pd.DataFrame:
1135
+ """
1136
+ Combine gene window and enhancer mappings.
1137
+
1138
+ Parameters
1139
+ ----------
1140
+ bim : pd.DataFrame
1141
+ BIM DataFrame
1142
+ bim_pr : pr.PyRanges
1143
+ BIM PyRanges object
1144
+
1145
+ Returns
1146
+ -------
1147
+ pd.DataFrame
1148
+ Combined SNP-gene pair mapping
1149
+ """
1150
+ # Get mappings from both sources
1151
+ gtf_mapping = self._get_snp_gene_pair_from_gtf(bim, bim_pr)
1152
+ enhancer_mapping = self._get_snp_gene_pair_from_enhancer(bim, bim_pr)
1153
+
1154
+ # Find SNPs with missing mappings in each source
1155
+ mask_of_nan_gtf = gtf_mapping.gene_name.isna()
1156
+ mask_of_nan_enhancer = enhancer_mapping.gene_name.isna()
1157
+
1158
+ # Combine based on priority
1159
+ if self.config.gene_window_enhancer_priority == "gene_window_first":
1160
+ # Use gene window mappings first, fill missing with enhancer mappings
1161
+ combined_mapping = gtf_mapping.copy()
1162
+ combined_mapping.loc[mask_of_nan_gtf, "gene_name"] = enhancer_mapping.loc[
1163
+ mask_of_nan_gtf, "gene_name"
1164
+ ]
1165
+ logger.info(
1166
+ f"Filled {mask_of_nan_gtf.sum()} SNPs with no GTF mapping using enhancer mappings"
672
1167
  )
673
- SNP_gene_pair = SNP_gene_pair_gtf
674
1168
 
675
- elif self.config.gene_window_enhancer_priority == "enhancer_only":
676
- SNP_gene_pair_enhancer = self.get_SNP_gene_pair_from_enhancer(
677
- bim,
678
- bim_pr,
1169
+ elif self.config.gene_window_enhancer_priority == "enhancer_first":
1170
+ # Use enhancer mappings first, fill missing with gene window mappings
1171
+ combined_mapping = enhancer_mapping.copy()
1172
+ combined_mapping.loc[mask_of_nan_enhancer, "gene_name"] = gtf_mapping.loc[
1173
+ mask_of_nan_enhancer, "gene_name"
1174
+ ]
1175
+ logger.info(
1176
+ f"Filled {mask_of_nan_enhancer.sum()} SNPs with no enhancer mapping using GTF mappings"
679
1177
  )
680
- SNP_gene_pair = SNP_gene_pair_enhancer
681
- else:
682
- raise ValueError("gtf_pr and enhancer_pr cannot be None at the same time")
683
1178
 
684
- # save the SNP_gene_pair to feather
685
- SNP_gene_pair_save_path = (
686
- Path(self.config.ldscore_save_dir) / f"SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather"
687
- )
688
- SNP_gene_pair_save_path.parent.mkdir(parents=True, exist_ok=True)
689
- SNP_gene_pair.reset_index().to_feather(SNP_gene_pair_save_path)
1179
+ else:
1180
+ raise ValueError(
1181
+ f"Invalid gene_window_enhancer_priority for combining: {self.config.gene_window_enhancer_priority}"
1182
+ )
690
1183
 
691
- # Get the dummy matrix
692
- SNP_gene_pair_dummy = pd.get_dummies(SNP_gene_pair["gene_name"], dummy_na=True)
693
- return SNP_gene_pair_dummy
1184
+ return combined_mapping
694
1185
 
695
- def get_SNP_gene_pair_from_gtf(self, bim, bim_pr):
1186
+ def _get_snp_gene_pair_from_gtf(self, bim: pd.DataFrame, bim_pr: pr.PyRanges) -> pd.DataFrame:
1187
+ """
1188
+ Get SNP-gene pairs based on GTF annotations.
1189
+
1190
+ Parameters
1191
+ ----------
1192
+ bim : pd.DataFrame
1193
+ BIM DataFrame
1194
+ bim_pr : pr.PyRanges
1195
+ BIM PyRanges object
1196
+
1197
+ Returns
1198
+ -------
1199
+ pd.DataFrame
1200
+ SNP-gene pairs based on GTF
1201
+ """
696
1202
  logger.info(
697
- "Get SNP-gene pair from gtf, if a SNP is in multiple genes, it will be assigned to the most nearby gene (TSS)"
1203
+ "Getting SNP-gene pairs from GTF. SNPs in multiple genes will be assigned to the nearest gene (by TSS)"
698
1204
  )
699
- overlaps_small = Overlaps_gtf_bim(self.gtf_pr, bim_pr)
700
- # Get the SNP-gene pair
1205
+
1206
+ # Find overlaps between SNPs and gene windows
1207
+ overlaps = overlaps_gtf_bim(self.gtf_pr, bim_pr)
1208
+
1209
+ # Get SNP information
701
1210
  annot = bim[["CHR", "BP", "SNP", "CM"]]
702
- SNP_gene_pair = (
703
- overlaps_small[["SNP", "gene_name"]]
1211
+
1212
+ # Create SNP-gene pairs DataFrame
1213
+ snp_gene_pair = (
1214
+ overlaps[["SNP", "gene_name"]]
704
1215
  .set_index("SNP")
705
1216
  .join(annot.set_index("SNP"), how="right")
706
1217
  )
707
- return SNP_gene_pair
708
1218
 
709
- def get_SNP_gene_pair_from_enhancer(
710
- self,
711
- bim,
712
- bim_pr,
713
- ):
714
- logger.info(
715
- "Get SNP-gene pair from enhancer, if a SNP is in multiple genes, it will be assigned to the gene with highest marker score"
716
- )
717
- # Get the SNP-gene pair
718
- overlaps_small = self.enhancer_pr.join(bim_pr).df
1219
+ logger.info(f"Found {overlaps.shape[0]} SNP-gene pairs from GTF")
1220
+
1221
+ return snp_gene_pair
1222
+
1223
+ def _get_snp_gene_pair_from_enhancer(
1224
+ self, bim: pd.DataFrame, bim_pr: pr.PyRanges
1225
+ ) -> pd.DataFrame:
1226
+ """
1227
+ Get SNP-gene pairs based on enhancer annotations.
1228
+
1229
+ Parameters
1230
+ ----------
1231
+ bim : pd.DataFrame
1232
+ BIM DataFrame
1233
+ bim_pr : pr.PyRanges
1234
+ BIM PyRanges object
1235
+
1236
+ Returns
1237
+ -------
1238
+ pd.DataFrame
1239
+ SNP-gene pairs based on enhancer
1240
+ """
1241
+ if self.enhancer_pr is None:
1242
+ raise ValueError("Enhancer annotation file is required but not provided")
1243
+
1244
+ # Find overlaps between SNPs and enhancers
1245
+ overlaps = self.enhancer_pr.join(bim_pr).df
1246
+
1247
+ # Get SNP information
719
1248
  annot = bim[["CHR", "BP", "SNP", "CM"]]
1249
+
720
1250
  if self.config.snp_multiple_enhancer_strategy == "max_mkscore":
721
- logger.debug("select the gene with highest marker score")
722
- overlaps_small = overlaps_small.loc[overlaps_small.groupby("SNP").avg_mkscore.idxmax()]
1251
+ logger.info(
1252
+ "SNPs in multiple enhancers will be assigned to the gene with highest marker score"
1253
+ )
1254
+ overlaps = overlaps.loc[overlaps.groupby("SNP").avg_mkscore.idxmax()]
723
1255
 
724
1256
  elif self.config.snp_multiple_enhancer_strategy == "nearest_TSS":
725
- logger.debug("select the gene with nearest TSS")
726
- overlaps_small["Distance"] = np.abs(overlaps_small["Start_b"] - overlaps_small["TSS"])
727
- overlaps_small = overlaps_small.loc[overlaps_small.groupby("SNP").Distance.idxmin()]
1257
+ logger.info("SNPs in multiple enhancers will be assigned to the gene with nearest TSS")
1258
+ overlaps["Distance"] = np.abs(overlaps["Start_b"] - overlaps["TSS"])
1259
+ overlaps = overlaps.loc[overlaps.groupby("SNP").Distance.idxmin()]
728
1260
 
729
- SNP_gene_pair = (
730
- overlaps_small[["SNP", "gene_name"]]
1261
+ # Create SNP-gene pairs DataFrame
1262
+ snp_gene_pair = (
1263
+ overlaps[["SNP", "gene_name"]]
731
1264
  .set_index("SNP")
732
1265
  .join(annot.set_index("SNP"), how="right")
733
1266
  )
734
1267
 
735
- return SNP_gene_pair
1268
+ logger.info(f"Found {overlaps.shape[0]} SNP-gene pairs from enhancers")
1269
+
1270
+ return snp_gene_pair
1271
+
1272
+ def _save_snp_gene_pair_mapping(self, snp_gene_pair: pd.DataFrame, chrom: int):
1273
+ """
1274
+ Save SNP-gene pair mapping to a feather file.
1275
+
1276
+ Parameters
1277
+ ----------
1278
+ snp_gene_pair : pd.DataFrame
1279
+ SNP-gene pair mapping
1280
+ chrom : int
1281
+ Chromosome number
1282
+ """
1283
+ save_path = (
1284
+ Path(self.config.ldscore_save_dir) / f"SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather"
1285
+ )
1286
+ save_path.parent.mkdir(parents=True, exist_ok=True)
1287
+ snp_gene_pair.reset_index().to_feather(save_path)
1288
+
1289
+ def _clear_memory(self):
1290
+ """Clear memory to prevent leaks."""
1291
+ gc.collect()
736
1292
 
737
1293
 
738
1294
  def run_generate_ldscore(config: GenerateLDScoreConfig):
1295
+ """
1296
+ Main function to run the LD score generation.
1297
+
1298
+ Parameters
1299
+ ----------
1300
+ config : GenerateLDScoreConfig
1301
+ Configuration object
1302
+ """
1303
+ # Create output directory
1304
+ Path(config.ldscore_save_dir).mkdir(parents=True, exist_ok=True)
1305
+
739
1306
  if config.ldscore_save_format == "quick_mode":
740
1307
  logger.info(
741
1308
  "Running in quick_mode. Skip the process of generating ldscore. Using the pre-calculated ldscore."
742
1309
  )
743
- ldscore_save_dir = config.ldscore_save_dir
1310
+ ldscore_save_dir = Path(config.ldscore_save_dir)
744
1311
 
745
- # link the baseline annotation
746
- baseline_annotation_dir = Path(config.baseline_annotation_dir)
747
- (ldscore_save_dir / "baseline").symlink_to(
748
- baseline_annotation_dir, target_is_directory=True
749
- )
1312
+ # Set up symbolic links
1313
+ baseline_dir = ldscore_save_dir / "baseline"
1314
+ baseline_dir.parent.mkdir(parents=True, exist_ok=True)
1315
+ if not baseline_dir.exists():
1316
+ baseline_dir.symlink_to(config.baseline_annotation_dir, target_is_directory=True)
1317
+
1318
+ snp_gene_pair_dir = ldscore_save_dir / "SNP_gene_pair"
1319
+ snp_gene_pair_dir.parent.mkdir(parents=True, exist_ok=True)
1320
+ if not snp_gene_pair_dir.exists():
1321
+ snp_gene_pair_dir.symlink_to(config.SNP_gene_pair_dir, target_is_directory=True)
1322
+
1323
+ # Create a done file to mark completion
1324
+ done_file = ldscore_save_dir / f"{config.sample_name}_generate_ldscore.done"
1325
+ done_file.touch()
750
1326
 
751
- # link the SNP_gene_pair
752
- SNP_gene_pair_dir = Path(config.SNP_gene_pair_dir)
753
- (ldscore_save_dir / "SNP_gene_pair").symlink_to(
754
- SNP_gene_pair_dir, target_is_directory=True
755
- )
756
1327
  return
757
- s_ldsc_boost = S_LDSC_Boost(config)
1328
+
1329
+ # Initialize calculator
1330
+ calculator = LDScoreCalculator(config)
1331
+
1332
+ # Process chromosomes
758
1333
  if config.chrom == "all":
1334
+ # Process all chromosomes
759
1335
  for chrom in range(1, 23):
760
- s_ldsc_boost.process_chromosome(chrom)
1336
+ try:
1337
+ calculator.process_chromosome(chrom)
1338
+ except Exception as e:
1339
+ logger.error(f"Error processing chromosome {chrom}: {e}")
1340
+ raise
761
1341
  else:
762
- s_ldsc_boost.process_chromosome(config.chrom)
1342
+ # Process one chromosome
1343
+ try:
1344
+ chrom = int(config.chrom)
1345
+ except ValueError:
1346
+ logger.error(f"Invalid chromosome: {config.chrom}")
1347
+ raise ValueError(
1348
+ f"Invalid chromosome: {config.chrom}. Must be an integer between 1-22 or 'all'"
1349
+ ) from None
1350
+ else:
1351
+ calculator.process_chromosome(chrom)
1352
+
1353
+ # Create a done file to mark completion
1354
+ done_file = Path(config.ldscore_save_dir) / f"{config.sample_name}_generate_ldscore.done"
1355
+ done_file.touch()
1356
+
1357
+ logger.info(f"LD score generation completed for {config.sample_name}")