gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module for reading and processing PLINK genotype data and calculating LD scores.
|
|
3
|
+
|
|
4
|
+
Note:
|
|
5
|
+
This code is adapted and modified from:
|
|
6
|
+
https://github.com/bulik/ldsc/blob/master/ldsc/ldscore.py
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
import bitarray as ba
|
|
12
|
+
import numba
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import pyranges as pr
|
|
16
|
+
import torch
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from gsMap.utils.torch_utils import torch_device, torch_sync
|
|
20
|
+
|
|
21
|
+
# Configure logger
|
|
22
|
+
logger = logging.getLogger("gsMap.utils.plink_ldscore_tool")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@numba.njit
|
|
26
|
+
def getBlockLefts(coords: np.ndarray, max_dist: float):
|
|
27
|
+
"""
|
|
28
|
+
Converts coordinates + max block length to a list of coordinates of the leftmost
|
|
29
|
+
SNPs to be included in blocks.
|
|
30
|
+
"""
|
|
31
|
+
M = len(coords)
|
|
32
|
+
j = 0
|
|
33
|
+
block_left = np.zeros(M)
|
|
34
|
+
for i in range(M):
|
|
35
|
+
while j < M and abs(coords[j] - coords[i]) > max_dist:
|
|
36
|
+
j += 1
|
|
37
|
+
|
|
38
|
+
block_left[i] = j
|
|
39
|
+
return block_left
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@numba.njit
|
|
43
|
+
def normalized_snps(X: np.ndarray, b: int, minorRef, freq, currentSNP):
|
|
44
|
+
"""
|
|
45
|
+
Normalize the SNPs and impute the missing ones with the mean
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
fam_file : str
|
|
50
|
+
Path to the FAM file
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
pd.DataFrame
|
|
55
|
+
DataFrame containing FAM data
|
|
56
|
+
"""
|
|
57
|
+
Y = np.zeros(X.shape, dtype="float32")
|
|
58
|
+
|
|
59
|
+
for j in range(0, b):
|
|
60
|
+
newsnp = X[:, j]
|
|
61
|
+
ii = newsnp != 9
|
|
62
|
+
avg = np.mean(newsnp[ii])
|
|
63
|
+
newsnp[np.logical_not(ii)] = avg
|
|
64
|
+
denom = np.std(newsnp)
|
|
65
|
+
if denom == 0:
|
|
66
|
+
denom = 1
|
|
67
|
+
|
|
68
|
+
if minorRef is not None and freq[currentSNP + j] > 0.5:
|
|
69
|
+
denom = denom * -1
|
|
70
|
+
|
|
71
|
+
Y[:, j] = (newsnp - avg) / denom
|
|
72
|
+
return Y
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def l2_unbiased(x: torch.Tensor, n: int):
|
|
76
|
+
"""
|
|
77
|
+
Calculate the unbiased estimate of L2.
|
|
78
|
+
"""
|
|
79
|
+
denom = n - 2 if n > 2 else n # allow n<2 for testing purposes
|
|
80
|
+
sq = torch.square(x)
|
|
81
|
+
return sq - (1 - sq) / denom
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class PlinkBEDFile:
|
|
85
|
+
"""
|
|
86
|
+
Interface for Plink .bed format for reading and processing genotype data.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, bfile_prefix):
|
|
90
|
+
"""
|
|
91
|
+
Initialize the PlinkBEDFile from a PLINK file prefix.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
bfile_prefix : str
|
|
96
|
+
PLINK file prefix (without .bed/.bim/.fam extension)
|
|
97
|
+
"""
|
|
98
|
+
# Initialize bitarray for bed code mapping
|
|
99
|
+
self._bedcode = {
|
|
100
|
+
2: ba.bitarray("11"),
|
|
101
|
+
9: ba.bitarray("10"),
|
|
102
|
+
1: ba.bitarray("01"),
|
|
103
|
+
0: ba.bitarray("00"),
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Load BIM file
|
|
107
|
+
self.bim_df = self.load_bim(f"{bfile_prefix}.bim")
|
|
108
|
+
|
|
109
|
+
# Load FAM file
|
|
110
|
+
self.fam_df = self.load_fam(f"{bfile_prefix}.fam")
|
|
111
|
+
|
|
112
|
+
# Set up initial parameters
|
|
113
|
+
self.m_original = len(self.bim_df)
|
|
114
|
+
self.n_original = len(self.fam_df)
|
|
115
|
+
|
|
116
|
+
# Read the bed file
|
|
117
|
+
logger.info(f"Loading Plink genotype data from {bfile_prefix}.bed")
|
|
118
|
+
(self.nru_original, self.geno_original) = self._read(
|
|
119
|
+
f"{bfile_prefix}.bed", self.m_original, self.n_original
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Pre-calculate MAF for all SNPs
|
|
123
|
+
logger.info("Calculating MAF and QC for all SNPs")
|
|
124
|
+
self.all_snp_info = self._calculate_all_snp_info()
|
|
125
|
+
|
|
126
|
+
# Filter out invalid SNPs
|
|
127
|
+
valid_mask = self.all_snp_info["valid_snp"]
|
|
128
|
+
if num_invalid := np.sum(~valid_mask):
|
|
129
|
+
logger.warning(
|
|
130
|
+
f"Filtering out {num_invalid} bad quality SNPs: {self.bim_df.loc[~valid_mask, 'SNP'].tolist()}"
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
logger.info("All SNPs passed the basic quality check")
|
|
134
|
+
|
|
135
|
+
# Create new genotype data with only the valid SNPs
|
|
136
|
+
new_geno = ba.bitarray()
|
|
137
|
+
for j in np.arange(self.m_original)[valid_mask]:
|
|
138
|
+
new_geno += self.geno_original[
|
|
139
|
+
2 * self.nru_original * j : 2 * self.nru_original * (j + 1)
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
# Update original data to only include valid SNPs
|
|
143
|
+
self.geno_original = new_geno
|
|
144
|
+
|
|
145
|
+
# Only keep valid SNPs
|
|
146
|
+
self.bim_df = self.bim_df.loc[valid_mask].reset_index(drop=True)
|
|
147
|
+
self.m_original = len(self.bim_df)
|
|
148
|
+
self.kept_snps = np.arange(self.m_original)
|
|
149
|
+
|
|
150
|
+
# Initialize current state variables
|
|
151
|
+
self._currentSNP = 0
|
|
152
|
+
self.m = self.m_original
|
|
153
|
+
self.n = self.n_original
|
|
154
|
+
self.nru = self.nru_original
|
|
155
|
+
self.geno = self.geno_original.copy()
|
|
156
|
+
|
|
157
|
+
# Update frequency info based on valid SNPs
|
|
158
|
+
self.freq = self.all_snp_info["freq"][valid_mask]
|
|
159
|
+
self.maf = np.minimum(self.freq, 1 - self.freq)
|
|
160
|
+
self.sqrtpq = np.sqrt(self.freq * (1 - self.freq))
|
|
161
|
+
|
|
162
|
+
# Add MAF to the BIM dataframe
|
|
163
|
+
self.bim_df["MAF"] = self.maf
|
|
164
|
+
|
|
165
|
+
logger.info(f"Loaded genotype data with {self.m} SNPs and {self.n} individuals")
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def load_bim(bim_file):
|
|
169
|
+
"""
|
|
170
|
+
Load a BIM file into a pandas DataFrame.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
bim_file : str
|
|
175
|
+
Path to the BIM file
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
pd.DataFrame
|
|
180
|
+
DataFrame containing BIM data
|
|
181
|
+
"""
|
|
182
|
+
df = pd.read_csv(
|
|
183
|
+
bim_file, sep="\t", header=None, names=["CHR", "SNP", "CM", "BP", "A1", "A2"]
|
|
184
|
+
)
|
|
185
|
+
return df
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def convert_bim_to_pyrange(bim_df) -> pr.PyRanges:
|
|
189
|
+
bim_pr = bim_df.copy()
|
|
190
|
+
bim_pr.drop(columns=["MAF"], inplace=True)
|
|
191
|
+
bim_pr.columns = ["Chromosome", "SNP", "CM", "Start", "A1", "A2"]
|
|
192
|
+
bim_pr.Chromosome = "chr" + bim_pr["Chromosome"].astype(str)
|
|
193
|
+
|
|
194
|
+
# Adjust coordinates (BIM is 1-based, PyRanges uses 0-based)
|
|
195
|
+
bim_pr["End"] = bim_pr["Start"].copy()
|
|
196
|
+
bim_pr["Start"] = bim_pr["Start"] - 1
|
|
197
|
+
|
|
198
|
+
bim_pr = pr.PyRanges(bim_pr)
|
|
199
|
+
|
|
200
|
+
return bim_pr
|
|
201
|
+
|
|
202
|
+
@staticmethod
|
|
203
|
+
def load_fam(fam_file):
|
|
204
|
+
"""
|
|
205
|
+
Load a FAM file into a pandas DataFrame.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
fam_file : str
|
|
210
|
+
Path to the FAM file
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
pd.DataFrame
|
|
215
|
+
DataFrame containing FAM data
|
|
216
|
+
"""
|
|
217
|
+
df = pd.read_csv(fam_file, sep=r"\s+", header=None, usecols=[1], names=["IID"])
|
|
218
|
+
return df
|
|
219
|
+
|
|
220
|
+
def _read(self, fname, m, n):
|
|
221
|
+
"""
|
|
222
|
+
Read the bed file and return the genotype data.
|
|
223
|
+
"""
|
|
224
|
+
if not fname.endswith(".bed"):
|
|
225
|
+
raise ValueError(".bed filename must end in .bed")
|
|
226
|
+
|
|
227
|
+
fh = open(fname, "rb")
|
|
228
|
+
magicNumber = ba.bitarray(endian="little")
|
|
229
|
+
magicNumber.fromfile(fh, 2)
|
|
230
|
+
bedMode = ba.bitarray(endian="little")
|
|
231
|
+
bedMode.fromfile(fh, 1)
|
|
232
|
+
e = (4 - n % 4) if n % 4 != 0 else 0
|
|
233
|
+
nru = n + e
|
|
234
|
+
|
|
235
|
+
# Check magic number
|
|
236
|
+
if magicNumber != ba.bitarray("0011011011011000"):
|
|
237
|
+
raise OSError("Magic number from Plink .bed file not recognized")
|
|
238
|
+
|
|
239
|
+
if bedMode != ba.bitarray("10000000"):
|
|
240
|
+
raise OSError("Plink .bed file must be in default SNP-major mode")
|
|
241
|
+
|
|
242
|
+
# Check file length
|
|
243
|
+
geno = ba.bitarray(endian="little")
|
|
244
|
+
geno.fromfile(fh)
|
|
245
|
+
self._test_length(geno, m, nru)
|
|
246
|
+
return (nru, geno)
|
|
247
|
+
|
|
248
|
+
def _test_length(self, geno, m, nru):
|
|
249
|
+
"""
|
|
250
|
+
Test if the genotype data has the expected length.
|
|
251
|
+
"""
|
|
252
|
+
exp_len = 2 * m * nru
|
|
253
|
+
real_len = len(geno)
|
|
254
|
+
if real_len != exp_len:
|
|
255
|
+
s = "Plink .bed file has {n1} bits, expected {n2}"
|
|
256
|
+
raise OSError(s.format(n1=real_len, n2=exp_len))
|
|
257
|
+
|
|
258
|
+
def _calculate_all_snp_info(self):
|
|
259
|
+
"""
|
|
260
|
+
Pre-calculate MAF and other information for all SNPs.
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
dict
|
|
265
|
+
Dictionary containing information for all SNPs
|
|
266
|
+
"""
|
|
267
|
+
nru = self.nru_original
|
|
268
|
+
n = self.n_original
|
|
269
|
+
m = self.m_original
|
|
270
|
+
geno = self.geno_original
|
|
271
|
+
|
|
272
|
+
snp_info = {
|
|
273
|
+
"freq": np.zeros(m), # Allele frequencies
|
|
274
|
+
"het_miss_count": np.zeros(m), # Count of het or missing genotypes
|
|
275
|
+
"valid_snp": np.zeros(m, dtype=bool), # Whether SNP passes basic criteria
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
# For each SNP, calculate statistics
|
|
279
|
+
for j in range(m):
|
|
280
|
+
z = geno[2 * nru * j : 2 * nru * (j + 1)]
|
|
281
|
+
A = z[0::2]
|
|
282
|
+
a = A.count()
|
|
283
|
+
B = z[1::2]
|
|
284
|
+
b = B.count()
|
|
285
|
+
c = (A & B).count()
|
|
286
|
+
major_ct = b + c # number of copies of the major allele
|
|
287
|
+
n_nomiss = n - a + c # number of individuals with nonmissing genotypes
|
|
288
|
+
f = major_ct / (2 * n_nomiss) if n_nomiss > 0 else 0
|
|
289
|
+
het_miss_ct = a + b - 2 * c # count of SNPs that are het or missing
|
|
290
|
+
|
|
291
|
+
snp_info["freq"][j] = f
|
|
292
|
+
snp_info["het_miss_count"][j] = het_miss_ct
|
|
293
|
+
snp_info["valid_snp"][j] = het_miss_ct < n # Basic validity check
|
|
294
|
+
|
|
295
|
+
return snp_info
|
|
296
|
+
|
|
297
|
+
def apply_filters(self, keep_snps=None, keep_indivs=None, mafMin=None):
|
|
298
|
+
"""
|
|
299
|
+
Apply filters to the genotype data without reloading the bed file.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
keep_snps : array-like, optional
|
|
304
|
+
Indices of SNPs to keep.
|
|
305
|
+
keep_indivs : array-like, optional
|
|
306
|
+
Indices of individuals to keep.
|
|
307
|
+
mafMin : float, optional
|
|
308
|
+
Minimum minor allele frequency.
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
self
|
|
313
|
+
Returns self for method chaining.
|
|
314
|
+
"""
|
|
315
|
+
# Reset to original state first
|
|
316
|
+
self.geno = self.geno_original.copy()
|
|
317
|
+
self.m = self.m_original
|
|
318
|
+
self.n = self.n_original
|
|
319
|
+
self.nru = self.nru_original
|
|
320
|
+
self._currentSNP = 0
|
|
321
|
+
|
|
322
|
+
# Initialize with all SNPs
|
|
323
|
+
kept_snps = np.arange(self.m_original)
|
|
324
|
+
|
|
325
|
+
# Apply MAF filter using pre-calculated values
|
|
326
|
+
if mafMin is not None and mafMin > 0:
|
|
327
|
+
# Remove the redundant valid_snp check since all SNPs are already valid
|
|
328
|
+
maf_mask = self.maf > mafMin
|
|
329
|
+
kept_snps = kept_snps[maf_mask]
|
|
330
|
+
logger.info(f"After MAF filtering (>{mafMin}), {len(kept_snps)} SNPs remain")
|
|
331
|
+
|
|
332
|
+
# Apply SNP filter if specified
|
|
333
|
+
if keep_snps is not None:
|
|
334
|
+
keep_snps = np.array(keep_snps, dtype="int")
|
|
335
|
+
if np.any(keep_snps > self.m_original):
|
|
336
|
+
raise ValueError("keep_snps indices out of bounds")
|
|
337
|
+
|
|
338
|
+
# Intersect with current kept_snps
|
|
339
|
+
kept_snps = np.intersect1d(kept_snps, keep_snps)
|
|
340
|
+
logger.info(f"After keep_snps filtering, {len(kept_snps)} SNPs remain")
|
|
341
|
+
|
|
342
|
+
# Filter SNPs in the genotype data
|
|
343
|
+
if len(kept_snps) < self.m_original:
|
|
344
|
+
# Create new genotype data with only the kept SNPs
|
|
345
|
+
new_geno = ba.bitarray()
|
|
346
|
+
for j in kept_snps:
|
|
347
|
+
new_geno += self.geno_original[2 * self.nru * j : 2 * self.nru * (j + 1)]
|
|
348
|
+
self.geno = new_geno
|
|
349
|
+
self.m = len(kept_snps)
|
|
350
|
+
|
|
351
|
+
# Filter individuals if specified
|
|
352
|
+
if keep_indivs is not None:
|
|
353
|
+
keep_indivs = np.array(keep_indivs, dtype="int")
|
|
354
|
+
if np.any(keep_indivs > self.n):
|
|
355
|
+
raise ValueError("keep_indivs indices out of bounds")
|
|
356
|
+
|
|
357
|
+
(self.geno, self.m, self.n) = self._filter_indivs(
|
|
358
|
+
self.geno, keep_indivs, self.m, self.n
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if self.n > 0:
|
|
362
|
+
logger.info(f"After filtering, {self.n} individuals remain")
|
|
363
|
+
else:
|
|
364
|
+
raise ValueError("After filtering, no individuals remain")
|
|
365
|
+
|
|
366
|
+
# Update kept_snps and other attributes
|
|
367
|
+
self.kept_snps = kept_snps
|
|
368
|
+
self.freq = self.all_snp_info["freq"][kept_snps]
|
|
369
|
+
self.maf = np.minimum(self.freq, 1 - self.freq)
|
|
370
|
+
self.sqrtpq = np.sqrt(self.freq * (1 - self.freq))
|
|
371
|
+
|
|
372
|
+
return self
|
|
373
|
+
|
|
374
|
+
def _filter_indivs(self, geno, keep_indivs, m, n):
|
|
375
|
+
"""
|
|
376
|
+
Filter individuals based on the keep_indivs parameter.
|
|
377
|
+
"""
|
|
378
|
+
n_new = len(keep_indivs)
|
|
379
|
+
e = (4 - n_new % 4) if n_new % 4 != 0 else 0
|
|
380
|
+
nru_new = n_new + e
|
|
381
|
+
nru = self.nru
|
|
382
|
+
z = ba.bitarray(m * 2 * nru_new, endian="little")
|
|
383
|
+
z.setall(0)
|
|
384
|
+
for e, i in enumerate(keep_indivs):
|
|
385
|
+
z[2 * e :: 2 * nru_new] = geno[2 * i :: 2 * nru]
|
|
386
|
+
z[2 * e + 1 :: 2 * nru_new] = geno[2 * i + 1 :: 2 * nru]
|
|
387
|
+
self.nru = nru_new
|
|
388
|
+
return (z, m, n_new)
|
|
389
|
+
|
|
390
|
+
def get_snps_by_maf(self, mafMin):
|
|
391
|
+
"""
|
|
392
|
+
Get the list of SNPs that pass the MAF threshold.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
mafMin : float
|
|
397
|
+
Minimum MAF threshold
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
list
|
|
402
|
+
List of SNP IDs that pass the MAF threshold
|
|
403
|
+
"""
|
|
404
|
+
maf_mask = self.maf > mafMin
|
|
405
|
+
|
|
406
|
+
# Get SNP names from the BIM dataframe
|
|
407
|
+
snp_pass_maf = self.bim_df.loc[maf_mask, "SNP"].tolist()
|
|
408
|
+
|
|
409
|
+
logger.info(f"{len(snp_pass_maf)} SNPs with MAF > f{mafMin}")
|
|
410
|
+
|
|
411
|
+
return snp_pass_maf
|
|
412
|
+
|
|
413
|
+
def get_ldscore(self, annot_matrix=None, ld_wind=1.0, ld_unit="CM", keep_snps_index=None):
|
|
414
|
+
"""
|
|
415
|
+
Calculate LD scores using an annotation matrix.
|
|
416
|
+
|
|
417
|
+
Parameters
|
|
418
|
+
----------
|
|
419
|
+
annot_matrix : np.ndarray, optional
|
|
420
|
+
Annotation matrix. If None, uses a matrix of all ones.
|
|
421
|
+
ld_wind : float, optional
|
|
422
|
+
LD window size, by default 1.0
|
|
423
|
+
ld_unit : str, optional
|
|
424
|
+
Unit for the LD window, by default "CM"
|
|
425
|
+
keep_snps_index : list[int], optional
|
|
426
|
+
Indices of SNPs to keep, by default None
|
|
427
|
+
|
|
428
|
+
Returns
|
|
429
|
+
-------
|
|
430
|
+
np.ndarray
|
|
431
|
+
Array with calculated LD scores
|
|
432
|
+
"""
|
|
433
|
+
# Apply filters if needed
|
|
434
|
+
if keep_snps_index is not None:
|
|
435
|
+
original_kept_snps = self.kept_snps.copy()
|
|
436
|
+
self.apply_filters(keep_snps=keep_snps_index)
|
|
437
|
+
|
|
438
|
+
# Configure LD window based on specified unit
|
|
439
|
+
if ld_unit == "SNP":
|
|
440
|
+
max_dist = ld_wind
|
|
441
|
+
coords = np.array(range(self.m))
|
|
442
|
+
elif ld_unit == "KB":
|
|
443
|
+
max_dist = ld_wind * 1000
|
|
444
|
+
coords = np.array(self.bim_df.loc[self.kept_snps, "BP"])
|
|
445
|
+
elif ld_unit == "CM":
|
|
446
|
+
max_dist = ld_wind
|
|
447
|
+
coords = np.array(self.bim_df.loc[self.kept_snps, "CM"])
|
|
448
|
+
# Check if the CM is all 0
|
|
449
|
+
if np.all(coords == 0):
|
|
450
|
+
logger.warning(
|
|
451
|
+
"All CM values are 0. Using 1MB window size for LD score calculation."
|
|
452
|
+
)
|
|
453
|
+
max_dist = 1_000_000
|
|
454
|
+
coords = np.array(self.bim_df.loc[self.kept_snps, "BP"])
|
|
455
|
+
else:
|
|
456
|
+
raise ValueError(f"Invalid ld_wind_unit: {ld_unit}. Must be one of: SNP, KB, CM")
|
|
457
|
+
|
|
458
|
+
# Calculate blocks for LD computation
|
|
459
|
+
block_left = getBlockLefts(coords, max_dist)
|
|
460
|
+
assert block_left.sum() > 0, "Invalid window size, please check the ld_wind parameter."
|
|
461
|
+
|
|
462
|
+
# Calculate LD scores
|
|
463
|
+
ld_scores = self.ldScoreVarBlocks(block_left, 100, annot=annot_matrix)
|
|
464
|
+
|
|
465
|
+
# Restore original state if filters were applied
|
|
466
|
+
if keep_snps_index is not None:
|
|
467
|
+
self.apply_filters(keep_snps=original_kept_snps)
|
|
468
|
+
|
|
469
|
+
return ld_scores
|
|
470
|
+
|
|
471
|
+
def restart(self):
|
|
472
|
+
"""
|
|
473
|
+
Reset the current SNP index to 0.
|
|
474
|
+
"""
|
|
475
|
+
self._currentSNP = 0
|
|
476
|
+
|
|
477
|
+
def nextSNPs(self, b, minorRef=None):
|
|
478
|
+
"""
|
|
479
|
+
Unpacks the binary array of genotypes and returns an n x b matrix of floats of
|
|
480
|
+
normalized genotypes for the next b SNPs.
|
|
481
|
+
"""
|
|
482
|
+
try:
|
|
483
|
+
b = int(b)
|
|
484
|
+
if b <= 0:
|
|
485
|
+
raise ValueError("b must be > 0")
|
|
486
|
+
except TypeError as e:
|
|
487
|
+
raise TypeError("b must be an integer") from e
|
|
488
|
+
|
|
489
|
+
if self._currentSNP + b > self.m:
|
|
490
|
+
s = "{b} SNPs requested, {k} SNPs remain"
|
|
491
|
+
raise ValueError(s.format(b=b, k=(self.m - self._currentSNP)))
|
|
492
|
+
|
|
493
|
+
c = self._currentSNP
|
|
494
|
+
n = self.n
|
|
495
|
+
nru = self.nru
|
|
496
|
+
slice = self.geno[2 * c * nru : 2 * (c + b) * nru]
|
|
497
|
+
X = np.array(slice.decode(self._bedcode), dtype="float32").reshape((b, nru)).T
|
|
498
|
+
X = X[0:n, :]
|
|
499
|
+
Y = normalized_snps(X, b, minorRef, self.freq, self._currentSNP)
|
|
500
|
+
|
|
501
|
+
self._currentSNP += b
|
|
502
|
+
return Y
|
|
503
|
+
|
|
504
|
+
def ldScoreVarBlocks(self, block_left: np.ndarray, c, annot=None):
|
|
505
|
+
"""
|
|
506
|
+
Computes an unbiased estimate of L2(j) for j=1,..,M.
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
def func(x):
|
|
510
|
+
return l2_unbiased(x, self.n)
|
|
511
|
+
|
|
512
|
+
snp_getter = self.nextSNPs
|
|
513
|
+
return self._corSumVarBlocks(block_left, c, func, snp_getter, annot)
|
|
514
|
+
|
|
515
|
+
def _corSumVarBlocks(self, block_left, c, func, snp_getter, annot=None):
|
|
516
|
+
"""
|
|
517
|
+
Calculate the sum of correlation coefficients.
|
|
518
|
+
"""
|
|
519
|
+
m, n = self.m, self.n
|
|
520
|
+
block_sizes = np.array(np.arange(m) - block_left)
|
|
521
|
+
block_sizes = np.ceil(block_sizes / c) * c
|
|
522
|
+
if annot is None:
|
|
523
|
+
annot = np.ones((m, 1), dtype="float32")
|
|
524
|
+
else:
|
|
525
|
+
# annot = annot.astype("float32") # Ensure annot is float32
|
|
526
|
+
annot_m = annot.shape[0]
|
|
527
|
+
if annot_m != self.m:
|
|
528
|
+
raise ValueError("Incorrect number of SNPs in annot")
|
|
529
|
+
|
|
530
|
+
n_a = annot.shape[1] # number of annotations
|
|
531
|
+
cor_sum = np.zeros((m, n_a), dtype="float32")
|
|
532
|
+
# b = index of first SNP for which SNP 0 is not included in LD Score
|
|
533
|
+
b = np.nonzero(block_left > 0)
|
|
534
|
+
if np.any(b):
|
|
535
|
+
b = b[0][0]
|
|
536
|
+
else:
|
|
537
|
+
b = m
|
|
538
|
+
b = int(np.ceil(b / c) * c) # round up to a multiple of c
|
|
539
|
+
if b > m:
|
|
540
|
+
c = 1
|
|
541
|
+
b = m
|
|
542
|
+
|
|
543
|
+
l_A = 0 # l_A := index of leftmost SNP in matrix A
|
|
544
|
+
|
|
545
|
+
device = torch_device()
|
|
546
|
+
A = torch.from_numpy(snp_getter(b)).to(device) # This now returns float32 data
|
|
547
|
+
cor_sum = torch.from_numpy(cor_sum).to(device)
|
|
548
|
+
annot = torch.from_numpy(annot).to(device)
|
|
549
|
+
rfuncAB = torch.zeros((b, c), dtype=torch.float32, device=device)
|
|
550
|
+
rfuncBB = torch.zeros((c, c), dtype=torch.float32, device=device)
|
|
551
|
+
|
|
552
|
+
# chunk inside of block
|
|
553
|
+
for l_B in np.arange(0, b, c): # l_B := index of leftmost SNP in matrix B
|
|
554
|
+
B = A[:, l_B : l_B + c]
|
|
555
|
+
# ld matrix
|
|
556
|
+
torch.mm(A.T, B / n, out=rfuncAB)
|
|
557
|
+
# ld matrix square
|
|
558
|
+
rfuncAB = func(rfuncAB)
|
|
559
|
+
cor_sum[l_A : l_A + b, :] += torch.mm(rfuncAB, annot[l_B : l_B + c, :].float())
|
|
560
|
+
|
|
561
|
+
# chunk to right of block
|
|
562
|
+
b0 = b
|
|
563
|
+
md = int(c * np.floor(m / c))
|
|
564
|
+
end = md + 1 if md != m else md
|
|
565
|
+
for l_B in tqdm(np.arange(b0, end, c), desc="Compute SNP Gene Weight"):
|
|
566
|
+
# check if the annot matrix is all zeros for this block + chunk
|
|
567
|
+
# this happens w/ sparse categories (i.e., pathways)
|
|
568
|
+
# update the block
|
|
569
|
+
old_b = b
|
|
570
|
+
b = int(block_sizes[l_B])
|
|
571
|
+
if l_B > b0 and b > 0:
|
|
572
|
+
# block_size can't increase more than c
|
|
573
|
+
# block_size can't be less than c unless it is zero
|
|
574
|
+
# both of these things make sense
|
|
575
|
+
A = torch.hstack((A[:, old_b - b + c : old_b], B))
|
|
576
|
+
l_A += old_b - b + c
|
|
577
|
+
elif l_B == b0 and b > 0:
|
|
578
|
+
A = A[:, b0 - b : b0]
|
|
579
|
+
l_A = b0 - b
|
|
580
|
+
elif b == 0: # no SNPs to left in window, e.g., after a sequence gap
|
|
581
|
+
A = torch.zeros((n, 0), dtype=torch.float32, device=device)
|
|
582
|
+
l_A = l_B
|
|
583
|
+
if l_B == md:
|
|
584
|
+
c = m - md
|
|
585
|
+
rfuncAB = torch.zeros((b, c), dtype=torch.float32, device=device)
|
|
586
|
+
rfuncBB = torch.zeros((c, c), dtype=torch.float32, device=device)
|
|
587
|
+
if b != old_b:
|
|
588
|
+
rfuncAB = torch.zeros((b, c), dtype=torch.float32, device=device)
|
|
589
|
+
|
|
590
|
+
B = torch.from_numpy(snp_getter(c)).to(device) # This now returns float32 data
|
|
591
|
+
|
|
592
|
+
annot_l_A = annot[l_A : l_A + b, :].float()
|
|
593
|
+
annot_l_B = annot[l_B : l_B + c, :].float()
|
|
594
|
+
p1 = torch.all(annot_l_A == 0)
|
|
595
|
+
p2 = torch.all(annot_l_B == 0)
|
|
596
|
+
if p1 and p2:
|
|
597
|
+
continue
|
|
598
|
+
|
|
599
|
+
B_n = B / n
|
|
600
|
+
|
|
601
|
+
rfuncAB = func(torch.mm(A.T, B_n))
|
|
602
|
+
cor_sum[l_A : l_A + b, :] += torch.mm(rfuncAB, annot_l_B)
|
|
603
|
+
cor_sum[l_B : l_B + c, :] += torch.mm(annot_l_A.T, rfuncAB).T
|
|
604
|
+
|
|
605
|
+
rfuncBB = func(torch.mm(B.T, B_n))
|
|
606
|
+
cor_sum[l_B : l_B + c, :] += torch.mm(rfuncBB, annot_l_B)
|
|
607
|
+
|
|
608
|
+
torch_sync()
|
|
609
|
+
|
|
610
|
+
return cor_sum.cpu().numpy()
|