gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,610 @@
1
+ """
2
+ Module for reading and processing PLINK genotype data and calculating LD scores.
3
+
4
+ Note:
5
+ This code is adapted and modified from:
6
+ https://github.com/bulik/ldsc/blob/master/ldsc/ldscore.py
7
+ """
8
+
9
+ import logging
10
+
11
+ import bitarray as ba
12
+ import numba
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pyranges as pr
16
+ import torch
17
+ from tqdm import tqdm
18
+
19
+ from gsMap.utils.torch_utils import torch_device, torch_sync
20
+
21
+ # Configure logger
22
+ logger = logging.getLogger("gsMap.utils.plink_ldscore_tool")
23
+
24
+
25
+ @numba.njit
26
+ def getBlockLefts(coords: np.ndarray, max_dist: float):
27
+ """
28
+ Converts coordinates + max block length to a list of coordinates of the leftmost
29
+ SNPs to be included in blocks.
30
+ """
31
+ M = len(coords)
32
+ j = 0
33
+ block_left = np.zeros(M)
34
+ for i in range(M):
35
+ while j < M and abs(coords[j] - coords[i]) > max_dist:
36
+ j += 1
37
+
38
+ block_left[i] = j
39
+ return block_left
40
+
41
+
42
+ @numba.njit
43
+ def normalized_snps(X: np.ndarray, b: int, minorRef, freq, currentSNP):
44
+ """
45
+ Normalize the SNPs and impute the missing ones with the mean
46
+
47
+ Parameters
48
+ ----------
49
+ fam_file : str
50
+ Path to the FAM file
51
+
52
+ Returns
53
+ -------
54
+ pd.DataFrame
55
+ DataFrame containing FAM data
56
+ """
57
+ Y = np.zeros(X.shape, dtype="float32")
58
+
59
+ for j in range(0, b):
60
+ newsnp = X[:, j]
61
+ ii = newsnp != 9
62
+ avg = np.mean(newsnp[ii])
63
+ newsnp[np.logical_not(ii)] = avg
64
+ denom = np.std(newsnp)
65
+ if denom == 0:
66
+ denom = 1
67
+
68
+ if minorRef is not None and freq[currentSNP + j] > 0.5:
69
+ denom = denom * -1
70
+
71
+ Y[:, j] = (newsnp - avg) / denom
72
+ return Y
73
+
74
+
75
+ def l2_unbiased(x: torch.Tensor, n: int):
76
+ """
77
+ Calculate the unbiased estimate of L2.
78
+ """
79
+ denom = n - 2 if n > 2 else n # allow n<2 for testing purposes
80
+ sq = torch.square(x)
81
+ return sq - (1 - sq) / denom
82
+
83
+
84
+ class PlinkBEDFile:
85
+ """
86
+ Interface for Plink .bed format for reading and processing genotype data.
87
+ """
88
+
89
+ def __init__(self, bfile_prefix):
90
+ """
91
+ Initialize the PlinkBEDFile from a PLINK file prefix.
92
+
93
+ Parameters
94
+ ----------
95
+ bfile_prefix : str
96
+ PLINK file prefix (without .bed/.bim/.fam extension)
97
+ """
98
+ # Initialize bitarray for bed code mapping
99
+ self._bedcode = {
100
+ 2: ba.bitarray("11"),
101
+ 9: ba.bitarray("10"),
102
+ 1: ba.bitarray("01"),
103
+ 0: ba.bitarray("00"),
104
+ }
105
+
106
+ # Load BIM file
107
+ self.bim_df = self.load_bim(f"{bfile_prefix}.bim")
108
+
109
+ # Load FAM file
110
+ self.fam_df = self.load_fam(f"{bfile_prefix}.fam")
111
+
112
+ # Set up initial parameters
113
+ self.m_original = len(self.bim_df)
114
+ self.n_original = len(self.fam_df)
115
+
116
+ # Read the bed file
117
+ logger.info(f"Loading Plink genotype data from {bfile_prefix}.bed")
118
+ (self.nru_original, self.geno_original) = self._read(
119
+ f"{bfile_prefix}.bed", self.m_original, self.n_original
120
+ )
121
+
122
+ # Pre-calculate MAF for all SNPs
123
+ logger.info("Calculating MAF and QC for all SNPs")
124
+ self.all_snp_info = self._calculate_all_snp_info()
125
+
126
+ # Filter out invalid SNPs
127
+ valid_mask = self.all_snp_info["valid_snp"]
128
+ if num_invalid := np.sum(~valid_mask):
129
+ logger.warning(
130
+ f"Filtering out {num_invalid} bad quality SNPs: {self.bim_df.loc[~valid_mask, 'SNP'].tolist()}"
131
+ )
132
+ else:
133
+ logger.info("All SNPs passed the basic quality check")
134
+
135
+ # Create new genotype data with only the valid SNPs
136
+ new_geno = ba.bitarray()
137
+ for j in np.arange(self.m_original)[valid_mask]:
138
+ new_geno += self.geno_original[
139
+ 2 * self.nru_original * j : 2 * self.nru_original * (j + 1)
140
+ ]
141
+
142
+ # Update original data to only include valid SNPs
143
+ self.geno_original = new_geno
144
+
145
+ # Only keep valid SNPs
146
+ self.bim_df = self.bim_df.loc[valid_mask].reset_index(drop=True)
147
+ self.m_original = len(self.bim_df)
148
+ self.kept_snps = np.arange(self.m_original)
149
+
150
+ # Initialize current state variables
151
+ self._currentSNP = 0
152
+ self.m = self.m_original
153
+ self.n = self.n_original
154
+ self.nru = self.nru_original
155
+ self.geno = self.geno_original.copy()
156
+
157
+ # Update frequency info based on valid SNPs
158
+ self.freq = self.all_snp_info["freq"][valid_mask]
159
+ self.maf = np.minimum(self.freq, 1 - self.freq)
160
+ self.sqrtpq = np.sqrt(self.freq * (1 - self.freq))
161
+
162
+ # Add MAF to the BIM dataframe
163
+ self.bim_df["MAF"] = self.maf
164
+
165
+ logger.info(f"Loaded genotype data with {self.m} SNPs and {self.n} individuals")
166
+
167
+ @staticmethod
168
+ def load_bim(bim_file):
169
+ """
170
+ Load a BIM file into a pandas DataFrame.
171
+
172
+ Parameters
173
+ ----------
174
+ bim_file : str
175
+ Path to the BIM file
176
+
177
+ Returns
178
+ -------
179
+ pd.DataFrame
180
+ DataFrame containing BIM data
181
+ """
182
+ df = pd.read_csv(
183
+ bim_file, sep="\t", header=None, names=["CHR", "SNP", "CM", "BP", "A1", "A2"]
184
+ )
185
+ return df
186
+
187
+ @staticmethod
188
+ def convert_bim_to_pyrange(bim_df) -> pr.PyRanges:
189
+ bim_pr = bim_df.copy()
190
+ bim_pr.drop(columns=["MAF"], inplace=True)
191
+ bim_pr.columns = ["Chromosome", "SNP", "CM", "Start", "A1", "A2"]
192
+ bim_pr.Chromosome = "chr" + bim_pr["Chromosome"].astype(str)
193
+
194
+ # Adjust coordinates (BIM is 1-based, PyRanges uses 0-based)
195
+ bim_pr["End"] = bim_pr["Start"].copy()
196
+ bim_pr["Start"] = bim_pr["Start"] - 1
197
+
198
+ bim_pr = pr.PyRanges(bim_pr)
199
+
200
+ return bim_pr
201
+
202
+ @staticmethod
203
+ def load_fam(fam_file):
204
+ """
205
+ Load a FAM file into a pandas DataFrame.
206
+
207
+ Parameters
208
+ ----------
209
+ fam_file : str
210
+ Path to the FAM file
211
+
212
+ Returns
213
+ -------
214
+ pd.DataFrame
215
+ DataFrame containing FAM data
216
+ """
217
+ df = pd.read_csv(fam_file, sep=r"\s+", header=None, usecols=[1], names=["IID"])
218
+ return df
219
+
220
+ def _read(self, fname, m, n):
221
+ """
222
+ Read the bed file and return the genotype data.
223
+ """
224
+ if not fname.endswith(".bed"):
225
+ raise ValueError(".bed filename must end in .bed")
226
+
227
+ fh = open(fname, "rb")
228
+ magicNumber = ba.bitarray(endian="little")
229
+ magicNumber.fromfile(fh, 2)
230
+ bedMode = ba.bitarray(endian="little")
231
+ bedMode.fromfile(fh, 1)
232
+ e = (4 - n % 4) if n % 4 != 0 else 0
233
+ nru = n + e
234
+
235
+ # Check magic number
236
+ if magicNumber != ba.bitarray("0011011011011000"):
237
+ raise OSError("Magic number from Plink .bed file not recognized")
238
+
239
+ if bedMode != ba.bitarray("10000000"):
240
+ raise OSError("Plink .bed file must be in default SNP-major mode")
241
+
242
+ # Check file length
243
+ geno = ba.bitarray(endian="little")
244
+ geno.fromfile(fh)
245
+ self._test_length(geno, m, nru)
246
+ return (nru, geno)
247
+
248
+ def _test_length(self, geno, m, nru):
249
+ """
250
+ Test if the genotype data has the expected length.
251
+ """
252
+ exp_len = 2 * m * nru
253
+ real_len = len(geno)
254
+ if real_len != exp_len:
255
+ s = "Plink .bed file has {n1} bits, expected {n2}"
256
+ raise OSError(s.format(n1=real_len, n2=exp_len))
257
+
258
+ def _calculate_all_snp_info(self):
259
+ """
260
+ Pre-calculate MAF and other information for all SNPs.
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ Dictionary containing information for all SNPs
266
+ """
267
+ nru = self.nru_original
268
+ n = self.n_original
269
+ m = self.m_original
270
+ geno = self.geno_original
271
+
272
+ snp_info = {
273
+ "freq": np.zeros(m), # Allele frequencies
274
+ "het_miss_count": np.zeros(m), # Count of het or missing genotypes
275
+ "valid_snp": np.zeros(m, dtype=bool), # Whether SNP passes basic criteria
276
+ }
277
+
278
+ # For each SNP, calculate statistics
279
+ for j in range(m):
280
+ z = geno[2 * nru * j : 2 * nru * (j + 1)]
281
+ A = z[0::2]
282
+ a = A.count()
283
+ B = z[1::2]
284
+ b = B.count()
285
+ c = (A & B).count()
286
+ major_ct = b + c # number of copies of the major allele
287
+ n_nomiss = n - a + c # number of individuals with nonmissing genotypes
288
+ f = major_ct / (2 * n_nomiss) if n_nomiss > 0 else 0
289
+ het_miss_ct = a + b - 2 * c # count of SNPs that are het or missing
290
+
291
+ snp_info["freq"][j] = f
292
+ snp_info["het_miss_count"][j] = het_miss_ct
293
+ snp_info["valid_snp"][j] = het_miss_ct < n # Basic validity check
294
+
295
+ return snp_info
296
+
297
+ def apply_filters(self, keep_snps=None, keep_indivs=None, mafMin=None):
298
+ """
299
+ Apply filters to the genotype data without reloading the bed file.
300
+
301
+ Parameters
302
+ ----------
303
+ keep_snps : array-like, optional
304
+ Indices of SNPs to keep.
305
+ keep_indivs : array-like, optional
306
+ Indices of individuals to keep.
307
+ mafMin : float, optional
308
+ Minimum minor allele frequency.
309
+
310
+ Returns
311
+ -------
312
+ self
313
+ Returns self for method chaining.
314
+ """
315
+ # Reset to original state first
316
+ self.geno = self.geno_original.copy()
317
+ self.m = self.m_original
318
+ self.n = self.n_original
319
+ self.nru = self.nru_original
320
+ self._currentSNP = 0
321
+
322
+ # Initialize with all SNPs
323
+ kept_snps = np.arange(self.m_original)
324
+
325
+ # Apply MAF filter using pre-calculated values
326
+ if mafMin is not None and mafMin > 0:
327
+ # Remove the redundant valid_snp check since all SNPs are already valid
328
+ maf_mask = self.maf > mafMin
329
+ kept_snps = kept_snps[maf_mask]
330
+ logger.info(f"After MAF filtering (>{mafMin}), {len(kept_snps)} SNPs remain")
331
+
332
+ # Apply SNP filter if specified
333
+ if keep_snps is not None:
334
+ keep_snps = np.array(keep_snps, dtype="int")
335
+ if np.any(keep_snps > self.m_original):
336
+ raise ValueError("keep_snps indices out of bounds")
337
+
338
+ # Intersect with current kept_snps
339
+ kept_snps = np.intersect1d(kept_snps, keep_snps)
340
+ logger.info(f"After keep_snps filtering, {len(kept_snps)} SNPs remain")
341
+
342
+ # Filter SNPs in the genotype data
343
+ if len(kept_snps) < self.m_original:
344
+ # Create new genotype data with only the kept SNPs
345
+ new_geno = ba.bitarray()
346
+ for j in kept_snps:
347
+ new_geno += self.geno_original[2 * self.nru * j : 2 * self.nru * (j + 1)]
348
+ self.geno = new_geno
349
+ self.m = len(kept_snps)
350
+
351
+ # Filter individuals if specified
352
+ if keep_indivs is not None:
353
+ keep_indivs = np.array(keep_indivs, dtype="int")
354
+ if np.any(keep_indivs > self.n):
355
+ raise ValueError("keep_indivs indices out of bounds")
356
+
357
+ (self.geno, self.m, self.n) = self._filter_indivs(
358
+ self.geno, keep_indivs, self.m, self.n
359
+ )
360
+
361
+ if self.n > 0:
362
+ logger.info(f"After filtering, {self.n} individuals remain")
363
+ else:
364
+ raise ValueError("After filtering, no individuals remain")
365
+
366
+ # Update kept_snps and other attributes
367
+ self.kept_snps = kept_snps
368
+ self.freq = self.all_snp_info["freq"][kept_snps]
369
+ self.maf = np.minimum(self.freq, 1 - self.freq)
370
+ self.sqrtpq = np.sqrt(self.freq * (1 - self.freq))
371
+
372
+ return self
373
+
374
+ def _filter_indivs(self, geno, keep_indivs, m, n):
375
+ """
376
+ Filter individuals based on the keep_indivs parameter.
377
+ """
378
+ n_new = len(keep_indivs)
379
+ e = (4 - n_new % 4) if n_new % 4 != 0 else 0
380
+ nru_new = n_new + e
381
+ nru = self.nru
382
+ z = ba.bitarray(m * 2 * nru_new, endian="little")
383
+ z.setall(0)
384
+ for e, i in enumerate(keep_indivs):
385
+ z[2 * e :: 2 * nru_new] = geno[2 * i :: 2 * nru]
386
+ z[2 * e + 1 :: 2 * nru_new] = geno[2 * i + 1 :: 2 * nru]
387
+ self.nru = nru_new
388
+ return (z, m, n_new)
389
+
390
+ def get_snps_by_maf(self, mafMin):
391
+ """
392
+ Get the list of SNPs that pass the MAF threshold.
393
+
394
+ Parameters
395
+ ----------
396
+ mafMin : float
397
+ Minimum MAF threshold
398
+
399
+ Returns
400
+ -------
401
+ list
402
+ List of SNP IDs that pass the MAF threshold
403
+ """
404
+ maf_mask = self.maf > mafMin
405
+
406
+ # Get SNP names from the BIM dataframe
407
+ snp_pass_maf = self.bim_df.loc[maf_mask, "SNP"].tolist()
408
+
409
+ logger.info(f"{len(snp_pass_maf)} SNPs with MAF > f{mafMin}")
410
+
411
+ return snp_pass_maf
412
+
413
+ def get_ldscore(self, annot_matrix=None, ld_wind=1.0, ld_unit="CM", keep_snps_index=None):
414
+ """
415
+ Calculate LD scores using an annotation matrix.
416
+
417
+ Parameters
418
+ ----------
419
+ annot_matrix : np.ndarray, optional
420
+ Annotation matrix. If None, uses a matrix of all ones.
421
+ ld_wind : float, optional
422
+ LD window size, by default 1.0
423
+ ld_unit : str, optional
424
+ Unit for the LD window, by default "CM"
425
+ keep_snps_index : list[int], optional
426
+ Indices of SNPs to keep, by default None
427
+
428
+ Returns
429
+ -------
430
+ np.ndarray
431
+ Array with calculated LD scores
432
+ """
433
+ # Apply filters if needed
434
+ if keep_snps_index is not None:
435
+ original_kept_snps = self.kept_snps.copy()
436
+ self.apply_filters(keep_snps=keep_snps_index)
437
+
438
+ # Configure LD window based on specified unit
439
+ if ld_unit == "SNP":
440
+ max_dist = ld_wind
441
+ coords = np.array(range(self.m))
442
+ elif ld_unit == "KB":
443
+ max_dist = ld_wind * 1000
444
+ coords = np.array(self.bim_df.loc[self.kept_snps, "BP"])
445
+ elif ld_unit == "CM":
446
+ max_dist = ld_wind
447
+ coords = np.array(self.bim_df.loc[self.kept_snps, "CM"])
448
+ # Check if the CM is all 0
449
+ if np.all(coords == 0):
450
+ logger.warning(
451
+ "All CM values are 0. Using 1MB window size for LD score calculation."
452
+ )
453
+ max_dist = 1_000_000
454
+ coords = np.array(self.bim_df.loc[self.kept_snps, "BP"])
455
+ else:
456
+ raise ValueError(f"Invalid ld_wind_unit: {ld_unit}. Must be one of: SNP, KB, CM")
457
+
458
+ # Calculate blocks for LD computation
459
+ block_left = getBlockLefts(coords, max_dist)
460
+ assert block_left.sum() > 0, "Invalid window size, please check the ld_wind parameter."
461
+
462
+ # Calculate LD scores
463
+ ld_scores = self.ldScoreVarBlocks(block_left, 100, annot=annot_matrix)
464
+
465
+ # Restore original state if filters were applied
466
+ if keep_snps_index is not None:
467
+ self.apply_filters(keep_snps=original_kept_snps)
468
+
469
+ return ld_scores
470
+
471
+ def restart(self):
472
+ """
473
+ Reset the current SNP index to 0.
474
+ """
475
+ self._currentSNP = 0
476
+
477
+ def nextSNPs(self, b, minorRef=None):
478
+ """
479
+ Unpacks the binary array of genotypes and returns an n x b matrix of floats of
480
+ normalized genotypes for the next b SNPs.
481
+ """
482
+ try:
483
+ b = int(b)
484
+ if b <= 0:
485
+ raise ValueError("b must be > 0")
486
+ except TypeError as e:
487
+ raise TypeError("b must be an integer") from e
488
+
489
+ if self._currentSNP + b > self.m:
490
+ s = "{b} SNPs requested, {k} SNPs remain"
491
+ raise ValueError(s.format(b=b, k=(self.m - self._currentSNP)))
492
+
493
+ c = self._currentSNP
494
+ n = self.n
495
+ nru = self.nru
496
+ slice = self.geno[2 * c * nru : 2 * (c + b) * nru]
497
+ X = np.array(slice.decode(self._bedcode), dtype="float32").reshape((b, nru)).T
498
+ X = X[0:n, :]
499
+ Y = normalized_snps(X, b, minorRef, self.freq, self._currentSNP)
500
+
501
+ self._currentSNP += b
502
+ return Y
503
+
504
+ def ldScoreVarBlocks(self, block_left: np.ndarray, c, annot=None):
505
+ """
506
+ Computes an unbiased estimate of L2(j) for j=1,..,M.
507
+ """
508
+
509
+ def func(x):
510
+ return l2_unbiased(x, self.n)
511
+
512
+ snp_getter = self.nextSNPs
513
+ return self._corSumVarBlocks(block_left, c, func, snp_getter, annot)
514
+
515
+ def _corSumVarBlocks(self, block_left, c, func, snp_getter, annot=None):
516
+ """
517
+ Calculate the sum of correlation coefficients.
518
+ """
519
+ m, n = self.m, self.n
520
+ block_sizes = np.array(np.arange(m) - block_left)
521
+ block_sizes = np.ceil(block_sizes / c) * c
522
+ if annot is None:
523
+ annot = np.ones((m, 1), dtype="float32")
524
+ else:
525
+ # annot = annot.astype("float32") # Ensure annot is float32
526
+ annot_m = annot.shape[0]
527
+ if annot_m != self.m:
528
+ raise ValueError("Incorrect number of SNPs in annot")
529
+
530
+ n_a = annot.shape[1] # number of annotations
531
+ cor_sum = np.zeros((m, n_a), dtype="float32")
532
+ # b = index of first SNP for which SNP 0 is not included in LD Score
533
+ b = np.nonzero(block_left > 0)
534
+ if np.any(b):
535
+ b = b[0][0]
536
+ else:
537
+ b = m
538
+ b = int(np.ceil(b / c) * c) # round up to a multiple of c
539
+ if b > m:
540
+ c = 1
541
+ b = m
542
+
543
+ l_A = 0 # l_A := index of leftmost SNP in matrix A
544
+
545
+ device = torch_device()
546
+ A = torch.from_numpy(snp_getter(b)).to(device) # This now returns float32 data
547
+ cor_sum = torch.from_numpy(cor_sum).to(device)
548
+ annot = torch.from_numpy(annot).to(device)
549
+ rfuncAB = torch.zeros((b, c), dtype=torch.float32, device=device)
550
+ rfuncBB = torch.zeros((c, c), dtype=torch.float32, device=device)
551
+
552
+ # chunk inside of block
553
+ for l_B in np.arange(0, b, c): # l_B := index of leftmost SNP in matrix B
554
+ B = A[:, l_B : l_B + c]
555
+ # ld matrix
556
+ torch.mm(A.T, B / n, out=rfuncAB)
557
+ # ld matrix square
558
+ rfuncAB = func(rfuncAB)
559
+ cor_sum[l_A : l_A + b, :] += torch.mm(rfuncAB, annot[l_B : l_B + c, :].float())
560
+
561
+ # chunk to right of block
562
+ b0 = b
563
+ md = int(c * np.floor(m / c))
564
+ end = md + 1 if md != m else md
565
+ for l_B in tqdm(np.arange(b0, end, c), desc="Compute SNP Gene Weight"):
566
+ # check if the annot matrix is all zeros for this block + chunk
567
+ # this happens w/ sparse categories (i.e., pathways)
568
+ # update the block
569
+ old_b = b
570
+ b = int(block_sizes[l_B])
571
+ if l_B > b0 and b > 0:
572
+ # block_size can't increase more than c
573
+ # block_size can't be less than c unless it is zero
574
+ # both of these things make sense
575
+ A = torch.hstack((A[:, old_b - b + c : old_b], B))
576
+ l_A += old_b - b + c
577
+ elif l_B == b0 and b > 0:
578
+ A = A[:, b0 - b : b0]
579
+ l_A = b0 - b
580
+ elif b == 0: # no SNPs to left in window, e.g., after a sequence gap
581
+ A = torch.zeros((n, 0), dtype=torch.float32, device=device)
582
+ l_A = l_B
583
+ if l_B == md:
584
+ c = m - md
585
+ rfuncAB = torch.zeros((b, c), dtype=torch.float32, device=device)
586
+ rfuncBB = torch.zeros((c, c), dtype=torch.float32, device=device)
587
+ if b != old_b:
588
+ rfuncAB = torch.zeros((b, c), dtype=torch.float32, device=device)
589
+
590
+ B = torch.from_numpy(snp_getter(c)).to(device) # This now returns float32 data
591
+
592
+ annot_l_A = annot[l_A : l_A + b, :].float()
593
+ annot_l_B = annot[l_B : l_B + c, :].float()
594
+ p1 = torch.all(annot_l_A == 0)
595
+ p2 = torch.all(annot_l_B == 0)
596
+ if p1 and p2:
597
+ continue
598
+
599
+ B_n = B / n
600
+
601
+ rfuncAB = func(torch.mm(A.T, B_n))
602
+ cor_sum[l_A : l_A + b, :] += torch.mm(rfuncAB, annot_l_B)
603
+ cor_sum[l_B : l_B + c, :] += torch.mm(annot_l_A.T, rfuncAB).T
604
+
605
+ rfuncBB = func(torch.mm(B.T, B_n))
606
+ cor_sum[l_B : l_B + c, :] += torch.mm(rfuncBB, annot_l_B)
607
+
608
+ torch_sync()
609
+
610
+ return cor_sum.cpu().numpy()