gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,446 @@
1
+ import logging
2
+ import math
3
+ import re
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy.stats import chi2
8
+
9
+ from gsMap.config import FormatSumstatsConfig
10
+
11
+ VALID_SNPS = {"AC", "AG", "CA", "CT", "GA", "GT", "TC", "TG"}
12
+ logger = logging.getLogger(__name__)
13
+
14
+ default_cnames = {
15
+ # RS NUMBER
16
+ "SNP": "SNP",
17
+ "RS": "SNP",
18
+ "RSID": "SNP",
19
+ "RS_NUMBER": "SNP",
20
+ "RS_NUMBERS": "SNP",
21
+ # P-VALUE
22
+ "P": "P",
23
+ "PVALUE": "P",
24
+ "P_VALUE": "P",
25
+ "PVAL": "P",
26
+ "P_VAL": "P",
27
+ "GC_PVALUE": "P",
28
+ "p": "P",
29
+ # EFFECT_ALLELE (A1)
30
+ "A1": "A1",
31
+ "ALLELE1": "A1",
32
+ "ALLELE_1": "A1",
33
+ "EFFECT_ALLELE": "A1",
34
+ "REFERENCE_ALLELE": "A1",
35
+ "INC_ALLELE": "A1",
36
+ "EA": "A1",
37
+ # NON_EFFECT_ALLELE (A2)
38
+ "A2": "A2",
39
+ "ALLELE2": "A2",
40
+ "ALLELE_2": "A2",
41
+ "OTHER_ALLELE": "A2",
42
+ "NON_EFFECT_ALLELE": "A2",
43
+ "DEC_ALLELE": "A2",
44
+ "NEA": "A2",
45
+ # N
46
+ "N": "N",
47
+ "NCASE": "N_CAS",
48
+ "CASES_N": "N_CAS",
49
+ "N_CASE": "N_CAS",
50
+ "N_CASES": "N_CAS",
51
+ "N_CONTROLS": "N_CON",
52
+ "N_CAS": "N_CAS",
53
+ "N_CON": "N_CON",
54
+ "NCONTROL": "N_CON",
55
+ "CONTROLS_N": "N_CON",
56
+ "N_CONTROL": "N_CON",
57
+ "WEIGHT": "N",
58
+ # SIGNED STATISTICS
59
+ "ZSCORE": "Z",
60
+ "Z-SCORE": "Z",
61
+ "GC_ZSCORE": "Z",
62
+ "Z": "Z",
63
+ "OR": "OR",
64
+ "B": "BETA",
65
+ "BETA": "BETA",
66
+ "LOG_ODDS": "LOG_ODDS",
67
+ "EFFECTS": "BETA",
68
+ "EFFECT": "BETA",
69
+ "b": "BETA",
70
+ "beta": "BETA",
71
+ # SE
72
+ "se": "SE",
73
+ # INFO
74
+ "INFO": "INFO",
75
+ "Info": "INFO",
76
+ # MAF
77
+ "EAF": "FRQ",
78
+ "FRQ": "FRQ",
79
+ "MAF": "FRQ",
80
+ "FRQ_U": "FRQ",
81
+ "F_U": "FRQ",
82
+ "frq_A1": "FRQ",
83
+ "frq": "FRQ",
84
+ "freq": "FRQ",
85
+ }
86
+
87
+
88
+ def get_compression(fh):
89
+ """
90
+ Read filename suffixes and figure out whether it is gzipped,bzip2'ed or not compressed
91
+ """
92
+ fh = str(fh)
93
+ if fh.endswith("gz"):
94
+ compression = "gzip"
95
+ elif fh.endswith("bz2"):
96
+ compression = "bz2"
97
+ else:
98
+ compression = None
99
+
100
+ return compression
101
+
102
+
103
+ def gwas_checkname(gwas, config):
104
+ """
105
+ Iterpret column names of gwas
106
+ """
107
+ old_name = gwas.columns
108
+ mapped_cnames = {}
109
+ for col in gwas.columns:
110
+ mapped_cnames[col] = default_cnames.get(col, col)
111
+ gwas.columns = list(mapped_cnames.values())
112
+
113
+ # When column names are provided by users
114
+ name_updates = {
115
+ "SNP": config.snp,
116
+ "A1": config.a1,
117
+ "A2": config.a2,
118
+ "INFO": config.info,
119
+ "BETA": config.beta,
120
+ "SE": config.se,
121
+ "P": config.p,
122
+ "FRQ": config.frq,
123
+ "N": config.n,
124
+ "Z": config.z,
125
+ "Chr": config.chr,
126
+ "Pos": config.pos,
127
+ "OR": config.OR,
128
+ "SE_OR": config.se_OR,
129
+ }
130
+
131
+ for key, value in name_updates.items():
132
+ if value is not None and value in gwas.columns:
133
+ gwas.rename(columns={value: key}, inplace=True)
134
+ new_name = gwas.columns
135
+ # check the name duplication
136
+ for head in new_name:
137
+ numc = list(new_name).count(head)
138
+ if numc > 1:
139
+ raise ValueError(
140
+ f"Found {numc} different {head} columns, please check your {head} column."
141
+ )
142
+
143
+ name_dict = {new_name[i]: old_name[i] for i in range(len(new_name))}
144
+
145
+ # When at OR scale
146
+ if "OR" in new_name and "SE_OR" in new_name:
147
+ gwas["BETA"] = gwas.OR.apply(lambda x: math.log(x) if x > 0 else None)
148
+ gwas["SE"] = gwas.SE_OR.apply(lambda x: math.log(x) if x > 0 else None)
149
+
150
+ interpreting = {
151
+ "SNP": "Variant ID (e.g., rs number).",
152
+ "A1": "Allele 1, interpreted as the effect allele for signed sumstat.",
153
+ "A2": "Allele 2, interpreted as the non-effect allele for signed sumstat.",
154
+ "BETA": "[linear/logistic] regression coefficient (0 → no effect; above 0 → A1 is trait/risk increasing).",
155
+ "SE": "Standard error of the regression coefficient.",
156
+ "OR": "Odds ratio, will be transferred to linear scale.",
157
+ "SE_OR": "Standard error of the odds ratio, will be transferred to linear scale.",
158
+ "P": "P-Value.",
159
+ "Z": "Z-Value.",
160
+ "N": "Sample size.",
161
+ "INFO": "INFO score (imputation quality; higher → better imputation).",
162
+ "FRQ": "Allele frequency of A1.",
163
+ "Chr": "Chromsome.",
164
+ "Pos": "SNP positions.",
165
+ }
166
+
167
+ logger.info("\nIterpreting column names as follows:")
168
+ for key, _value in interpreting.items():
169
+ if key in new_name:
170
+ logger.info(f"{name_dict[key]}: {interpreting[key]}")
171
+
172
+ return gwas
173
+
174
+
175
+ def gwas_checkformat(gwas, config):
176
+ """
177
+ Check column names required for different format
178
+ """
179
+ if config.format == "gsMap":
180
+ condition1 = np.any(np.isin(["P", "Z"], gwas.columns))
181
+ condition2 = np.all(np.isin(["BETA", "SE"], gwas.columns))
182
+ if not (condition1 or condition2):
183
+ raise ValueError(
184
+ "To munge GWAS data into gsMap format, either P or Z values, or both BETA and SE values, are required."
185
+ )
186
+ else:
187
+ if "Z" in gwas.columns:
188
+ pass
189
+ elif "P" in gwas.columns:
190
+ gwas["Z"] = np.sqrt(chi2.isf(gwas.P, 1)) * np.where(gwas["BETA"] < 0, -1, 1)
191
+ else:
192
+ gwas["Z"] = gwas.BETA / gwas.SE
193
+
194
+ elif config.format == "COJO":
195
+ condition = np.all(np.isin(["A1", "A2", "FRQ", "BETA", "SE", "P", "N"], gwas.columns))
196
+ if not condition:
197
+ raise ValueError(
198
+ "To munge GWAS data into COJO format, either A1|A2|FRQ|BETA|SE|P|N, are required."
199
+ )
200
+ else:
201
+ gwas["Z"] = np.sqrt(chi2.isf(gwas.P, 1)) * np.where(gwas["BETA"] < 0, -1, 1)
202
+
203
+ return gwas
204
+
205
+
206
+ def filter_info(info, config):
207
+ """Remove INFO < args.info_min (default 0.9) and complain about out-of-bounds INFO."""
208
+ if type(info) is pd.Series: # one INFO column
209
+ jj = ((info > 2.0) | (info < 0)) & info.notnull()
210
+ ii = info >= config.info_min
211
+ elif type(info) is pd.DataFrame: # several INFO columns
212
+ jj = ((info > 2.0) & info.notnull()).any(axis=1) | ((info < 0) & info.notnull()).any(
213
+ axis=1
214
+ )
215
+ ii = info.sum(axis=1) >= config.info_min * (len(info.columns))
216
+ else:
217
+ raise ValueError("Expected pd.DataFrame or pd.Series.")
218
+
219
+ bad_info = jj.sum()
220
+ if bad_info > 0:
221
+ msg = "WARNING: {N} SNPs had INFO outside of [0,1.5]. The INFO column may be mislabeled."
222
+ logger.warning(msg.format(N=bad_info))
223
+
224
+ return ii
225
+
226
+
227
+ def filter_frq(frq, config):
228
+ """
229
+ Filter on MAF. Remove MAF < args.maf_min and out-of-bounds MAF.
230
+ """
231
+ jj = (frq < 0) | (frq > 1)
232
+ bad_frq = jj.sum()
233
+ if bad_frq > 0:
234
+ msg = "WARNING: {N} SNPs had FRQ outside of [0,1]. The FRQ column may be mislabeled."
235
+ logger.warning(msg.format(N=bad_frq))
236
+
237
+ frq = np.minimum(frq, 1 - frq)
238
+ ii = frq > config.maf_min
239
+ return ii & ~jj
240
+
241
+
242
+ def filter_pvals(P, config):
243
+ """Remove out-of-bounds P-values"""
244
+ ii = (P > 0) & (P <= 1)
245
+ bad_p = (~ii).sum()
246
+ if bad_p > 0:
247
+ msg = "WARNING: {N} SNPs had P outside of (0,1]. The P column may be mislabeled."
248
+ logger.warning(msg.format(N=bad_p))
249
+
250
+ return ii
251
+
252
+
253
+ def filter_alleles(a):
254
+ """Remove alleles that do not describe strand-unambiguous SNPs"""
255
+ return a.isin(VALID_SNPS)
256
+
257
+
258
+ def gwas_qc(gwas, config):
259
+ """
260
+ Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
261
+ """
262
+ old = len(gwas)
263
+ logger.info("\nFiltering SNPs as follows:")
264
+ # filter: SNPs with missing values
265
+ drops = {"NA": 0, "P": 0, "INFO": 0, "FRQ": 0, "A": 0, "SNP": 0, "Dup": 0, "N": 0}
266
+
267
+ gwas = gwas.dropna(
268
+ axis=0, how="any", subset=filter(lambda x: x != "INFO", gwas.columns)
269
+ ).reset_index(drop=True)
270
+
271
+ drops["NA"] = old - len(gwas)
272
+ logger.info(f"Removed {drops['NA']} SNPs with missing values.")
273
+
274
+ # filter: SNPs with Info < 0.9
275
+ if "INFO" in gwas.columns:
276
+ old = len(gwas)
277
+ gwas = gwas.loc[filter_info(gwas["INFO"], config)]
278
+ drops["INFO"] = old - len(gwas)
279
+ logger.info(f"Removed {drops['INFO']} SNPs with INFO <= 0.9.")
280
+
281
+ # filter: SNPs with MAF <= 0.01
282
+ if "FRQ" in gwas.columns:
283
+ old = len(gwas)
284
+ gwas = gwas.loc[filter_frq(gwas["FRQ"], config)]
285
+ drops["FRQ"] += old - len(gwas)
286
+ logger.info(f"Removed {drops['FRQ']} SNPs with MAF <= 0.01.")
287
+
288
+ # filter: P-value that out-of-bounds [0,1]
289
+ if "P" in gwas.columns:
290
+ old = len(gwas)
291
+ gwas = gwas.loc[filter_pvals(gwas["P"], config)]
292
+ drops["P"] += old - len(gwas)
293
+ logger.info(f"Removed {drops['P']} SNPs with out-of-bounds p-values.")
294
+
295
+ # filter: Variants that are strand-ambiguous
296
+ if "A1" in gwas.columns and "A2" in gwas.columns:
297
+ gwas.A1 = gwas.A1.str.upper()
298
+ gwas.A2 = gwas.A2.str.upper()
299
+ gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
300
+ drops["A"] += old - len(gwas)
301
+ logger.info(f"Removed {drops['A']} variants that were not SNPs or were strand-ambiguous.")
302
+
303
+ # filter: Duplicated rs numbers
304
+ if "SNP" in gwas.columns:
305
+ old = len(gwas)
306
+ gwas = gwas.drop_duplicates(subset="SNP").reset_index(drop=True)
307
+ drops["Dup"] += old - len(gwas)
308
+ logger.info(f"Removed {drops['Dup']} SNPs with duplicated rs numbers.")
309
+
310
+ # filter:Sample size
311
+ n_min = gwas.N.quantile(0.9) / 1.5
312
+ old = len(gwas)
313
+ gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
314
+ drops["N"] += old - len(gwas)
315
+ logger.info(f"Removed {drops['N']} SNPs with N < {n_min}.")
316
+
317
+ return gwas
318
+
319
+
320
+ def variant_to_rsid(gwas, config):
321
+ """
322
+ Convert variant id (Chr, Pos) to rsid
323
+ """
324
+ logger.info("\nConverting the SNP position to rsid. This process may take some time.")
325
+ unique_ids = set(gwas["id"])
326
+ chr_format = gwas["Chr"].unique().astype(str)
327
+ chr_format = [re.sub(r"\d+", "", value) for value in chr_format][1]
328
+
329
+ dtype = {"chr": str, "pos": str, "ref": str, "alt": str, "dbsnp": str}
330
+ chunk_iter = pd.read_csv(
331
+ config.dbsnp,
332
+ chunksize=config.chunksize,
333
+ sep="\t",
334
+ skiprows=1,
335
+ dtype=dtype,
336
+ names=["chr", "pos", "ref", "alt", "dbsnp"],
337
+ )
338
+
339
+ # Iterate over chunks
340
+ matching_id = pd.DataFrame()
341
+ for chunk in chunk_iter:
342
+ chunk["id"] = chr_format + chunk["chr"] + "_" + chunk["pos"]
343
+ matching_id = pd.concat(
344
+ [matching_id, chunk[chunk["id"].isin(unique_ids)][["dbsnp", "id"]]]
345
+ )
346
+
347
+ matching_id = matching_id.drop_duplicates(subset="dbsnp").reset_index(drop=True)
348
+ matching_id = matching_id.drop_duplicates(subset="id").reset_index(drop=True)
349
+ matching_id.index = matching_id.id
350
+ return matching_id
351
+
352
+
353
+ def clean_SNP_id(gwas, config):
354
+ """
355
+ Clean SNP id
356
+ """
357
+ old = len(gwas)
358
+ condition1 = "SNP" in gwas.columns
359
+ condition2 = np.all(np.isin(["Chr", "Pos"], gwas.columns))
360
+
361
+ if not (condition1 or condition2):
362
+ raise ValueError("Either SNP rsid, or both SNP chromosome and position, are required.")
363
+ elif condition1:
364
+ pass
365
+ elif condition2:
366
+ if config.dbsnp is None:
367
+ raise ValueError("To Convert SNP positions to rsid, dbsnp reference is required.")
368
+ else:
369
+ gwas["id"] = gwas["Chr"].astype(str) + "_" + gwas["Pos"].astype(str)
370
+ gwas = gwas.drop_duplicates(subset="id").reset_index(drop=True)
371
+ gwas.index = gwas.id
372
+
373
+ matching_id = variant_to_rsid(gwas, config)
374
+ gwas = gwas.loc[matching_id.id]
375
+ gwas["SNP"] = matching_id.dbsnp
376
+ num_fail = old - len(gwas)
377
+ logger.info(f"Removed {num_fail} SNPs that did not convert to rsid.")
378
+
379
+ return gwas
380
+
381
+
382
+ def gwas_metadata(gwas, config):
383
+ """
384
+ Report key features of GWAS data
385
+ """
386
+ logger.info("\nSummary of GWAS data:")
387
+ CHISQ = gwas.Z**2
388
+ mean_chisq = CHISQ.mean()
389
+ logger.info("Mean chi^2 = " + str(round(mean_chisq, 3)))
390
+ if mean_chisq < 1.02:
391
+ logger.warning("Mean chi^2 may be too small.")
392
+
393
+ logger.info("Lambda GC = " + str(round(CHISQ.median() / 0.4549, 3)))
394
+ logger.info("Max chi^2 = " + str(round(CHISQ.max(), 3)))
395
+ logger.info(
396
+ f"{(CHISQ > 29).sum()} Genome-wide significant SNPs (some may have been removed by filtering)."
397
+ )
398
+
399
+
400
+ def gwas_format(config: FormatSumstatsConfig):
401
+ """
402
+ Format GWAS data
403
+ """
404
+ logger.info(f"------Formating gwas data for {config.sumstats}...")
405
+ compression_type = get_compression(config.sumstats)
406
+ gwas = pd.read_csv(
407
+ config.sumstats,
408
+ sep=r"\s+",
409
+ header=0,
410
+ compression=compression_type,
411
+ na_values=[".", "NA"],
412
+ )
413
+
414
+ if isinstance(config.n, int | float):
415
+ logger.info(f"Set the sample size of gwas data as {config.n}.")
416
+ gwas["N"] = config.n
417
+ config.n = "N"
418
+
419
+ logger.info(f"Read {len(gwas)} SNPs from {config.sumstats}.")
420
+
421
+ # Check name and format
422
+ gwas = gwas_checkname(gwas, config)
423
+ gwas = gwas_checkformat(gwas, config)
424
+ # Clean the snp id
425
+ gwas = clean_SNP_id(gwas, config)
426
+ # QC
427
+ gwas = gwas_qc(gwas, config)
428
+ # Meta
429
+ gwas_metadata(gwas, config)
430
+
431
+ # Saving the data
432
+ if config.format == "COJO":
433
+ keep = ["SNP", "A1", "A2", "FRQ", "BETA", "SE", "P", "N"]
434
+ appendix = ".cojo"
435
+ elif config.format == "gsMap":
436
+ keep = ["SNP", "A1", "A2", "Z", "N"]
437
+ appendix = ".sumstats"
438
+
439
+ if "Chr" in gwas.columns and "Pos" in gwas.columns and config.keep_chr_pos is True:
440
+ keep = keep + ["Chr", "Pos"]
441
+
442
+ gwas = gwas[keep]
443
+ out_name = config.out + appendix + ".gz"
444
+
445
+ logger.info(f"\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.")
446
+ gwas.to_csv(out_name, sep="\t", index=False, float_format="%.3f", compression="gzip")