gsMap 1.71.1__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsMap/config.py CHANGED
@@ -1,805 +1,1312 @@
1
- import sys
2
- import argparse
3
- import logging
4
- from collections import OrderedDict, namedtuple
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from pprint import pprint
8
- from typing import Callable
9
- from typing import Union, Literal, Tuple, Optional, List
10
- from functools import wraps
11
- import pyfiglet
12
-
13
- from gsMap.__init__ import __version__
14
-
15
- # Global registry to hold functions
16
- cli_function_registry = OrderedDict()
17
- subcommand = namedtuple('subcommand', ['name', 'func', 'add_args_function', 'description'])
18
-
19
-
20
- def get_gsMap_logger(logger_name):
21
- logger = logging.getLogger(logger_name)
22
- logger.setLevel(logging.DEBUG)
23
- handler = logging.StreamHandler()
24
- handler.setFormatter(logging.Formatter(
25
- '[{asctime}] {levelname:.5s} | {name} - {message}', style='{'))
26
- logger.addHandler(handler)
27
- return logger
28
-
29
- logger = get_gsMap_logger('gsMap')
30
-
31
- # Decorator to register functions for cli parsing
32
- def register_cli(name: str, description: str, add_args_function: Callable) -> Callable:
33
- def decorator(func: Callable) -> Callable:
34
- def wrapper(*args, **kwargs):
35
- name.replace('_', ' ')
36
- gsMap_main_logo = pyfiglet.figlet_format("gsMap", font='doom', width=80, justify='center', ).rstrip()
37
- print(gsMap_main_logo, flush=True)
38
- version_number = 'Version: ' + __version__
39
- print(version_number.center(80), flush=True)
40
- print('=' * 80, flush=True)
41
- logger.info(f"Running {name}...")
42
- func(*args, **kwargs)
43
- logger.info(f"Finished running {name}.")
44
-
45
- cli_function_registry[name] = subcommand(name=name, func=wrapper, add_args_function=add_args_function,
46
- description=description)
47
- return wrapper
48
-
49
- return decorator
50
-
51
- def add_shared_args(parser):
52
- parser.add_argument('--workdir', type=str, required=True, help='Path to the working directory.')
53
- parser.add_argument('--sample_name', type=str, required=True, help='Name of the sample.')
54
-
55
- def add_find_latent_representations_args(parser):
56
- add_shared_args(parser)
57
- parser.add_argument('--input_hdf5_path', required=True, type=str, help='Path to the input HDF5 file.')
58
- parser.add_argument('--annotation', required=True, type=str, help='Name of the annotation in adata.obs to use.')
59
- parser.add_argument('--data_layer', type=str, default='counts', required=True,
60
- help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
61
- parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs.')
62
- parser.add_argument('--feat_hidden1', type=int, default=256, help='Neurons in the first hidden layer.')
63
- parser.add_argument('--feat_hidden2', type=int, default=128, help='Neurons in the second hidden layer.')
64
- parser.add_argument('--gat_hidden1', type=int, default=64, help='Units in the first GAT hidden layer.')
65
- parser.add_argument('--gat_hidden2', type=int, default=30, help='Units in the second GAT hidden layer.')
66
- parser.add_argument('--p_drop', type=float, default=0.1, help='Dropout rate.')
67
- parser.add_argument('--gat_lr', type=float, default=0.001, help='Learning rate for the GAT.')
68
- parser.add_argument('--n_neighbors', type=int, default=11, help='Number of neighbors for GAT.')
69
- parser.add_argument('--n_comps', type=int, default=300, help='Number of principal components for PCA.')
70
- parser.add_argument('--weighted_adj', action='store_true', help='Use weighted adjacency in GAT.')
71
- parser.add_argument('--convergence_threshold', type=float, default=1e-4, help='Threshold for convergence.')
72
- parser.add_argument('--hierarchically', action='store_true', help='Enable hierarchical latent representation finding.')
73
-
74
-
75
- def chrom_choice(value):
76
- if value.isdigit():
77
- ivalue = int(value)
78
- if 1 <= ivalue <= 22:
79
- return ivalue
80
- elif value.lower() == 'all':
81
- return value
82
- else:
83
- raise argparse.ArgumentTypeError(f"'{value}' is an invalid chromosome choice. Choose from 1-22 or 'all'.")
84
-
85
-
86
- def filter_args_for_dataclass(args_dict, data_class: dataclass):
87
- return {k: v for k, v in args_dict.items() if k in data_class.__dataclass_fields__}
88
-
89
-
90
- def get_dataclass_from_parser(args: argparse.Namespace, data_class: dataclass):
91
- remain_kwargs = filter_args_for_dataclass(vars(args), data_class)
92
- print(f'Using the following arguments for {data_class.__name__}:', flush=True)
93
- pprint(remain_kwargs, indent=4)
94
- sys.stdout.flush()
95
- return data_class(**remain_kwargs)
96
-
97
-
98
- def add_latent_to_gene_args(parser):
99
- add_shared_args(parser)
100
- parser.add_argument('--annotation', type=str, help='Name of the annotation in adata.obs to use. (optional).')
101
- parser.add_argument('--no_expression_fraction', action='store_true', help='Skip expression fraction filtering.')
102
- parser.add_argument('--latent_representation', type=str, choices=['latent_GVAE', 'latent_PCA'], default='latent_GVAE',
103
- help='Type of latent representation.')
104
- parser.add_argument('--num_neighbour', type=int, default=21, help='Number of neighbors.')
105
- parser.add_argument('--num_neighbour_spatial', type=int, default=101, help='Number of spatial neighbors.')
106
- # parser.add_argument('--species', type=str, help='Species name for homolog gene mapping (optional).')
107
- parser.add_argument('--homolog_file', type=str, help='Path to homologous gene conversion file (optional).')
108
-
109
-
110
- def add_generate_ldscore_args(parser):
111
- add_shared_args(parser)
112
- parser.add_argument('--chrom', type=str, required=True, help='Chromosome id (1-22) or "all".')
113
- parser.add_argument('--bfile_root', type=str, required=True, help='Root path for genotype plink bfiles (.bim, .bed, .fam).')
114
- parser.add_argument('--keep_snp_root', type=str, required=True, help='Root path for SNP files.')
115
- parser.add_argument('--gtf_annotation_file', type=str, required=True, help='Path to GTF annotation file.')
116
- parser.add_argument('--gene_window_size', type=int, default=50000, help='Gene window size in base pairs.')
117
- parser.add_argument('--enhancer_annotation_file', type=str, help='Path to enhancer annotation file (optional).')
118
- parser.add_argument('--snp_multiple_enhancer_strategy', type=str, choices=['max_mkscore', 'nearest_TSS'], default='max_mkscore',
119
- help='Strategy for handling multiple enhancers per SNP.')
120
- parser.add_argument('--gene_window_enhancer_priority', type=str, choices=['gene_window_first', 'enhancer_first', 'enhancer_only'],
121
- help='Priority between gene window and enhancer annotations.')
122
- parser.add_argument('--spots_per_chunk', type=int, default=1000, help='Number of spots per chunk.')
123
- parser.add_argument('--ld_wind', type=int, default=1, help='LD window size.')
124
- parser.add_argument('--ld_unit', type=str, choices=['SNP', 'KB', 'CM'], default='CM', help='Unit for LD window.')
125
- parser.add_argument('--additional_baseline_annotation', type=str, default=None, help='Path of additional baseline annotations')
126
-
127
-
128
- def add_latent_to_gene_args(parser):
129
- add_shared_args(parser)
130
- parser.add_argument('--annotation', type=str, required=True, help='Name of the annotation layer.')
131
- parser.add_argument('--no_expression_fraction', action='store_true', help='Skip expression fraction filtering.')
132
- parser.add_argument('--latent_representation', type=str, choices=['latent_GVAE', 'latent_PCA'], default='latent_GVAE',
133
- help='Type of latent representation.')
134
- parser.add_argument('--num_neighbour', type=int, default=21, help='Number of neighbors.')
135
- parser.add_argument('--num_neighbour_spatial', type=int, default=101, help='Number of spatial neighbors.')
136
- # parser.add_argument('--species', type=str, help='Species name for homolog gene mapping (optional).')
137
- parser.add_argument('--homolog_file', type=str, help='Path to homologous gene conversion file (optional).')
138
-
139
-
140
- def add_spatial_ldsc_args(parser):
141
- add_shared_args(parser)
142
- parser.add_argument('--sumstats_file', type=str, required=True, help='Path to GWAS summary statistics file.')
143
- parser.add_argument('--w_file', type=str, required=True, help='Path to regression weight file.')
144
- parser.add_argument('--trait_name', type=str, required=True, help='Name of the trait being analyzed.')
145
- parser.add_argument('--n_blocks', type=int, default=200, help='Number of blocks for jackknife resampling.')
146
- parser.add_argument('--chisq_max', type=int, help='Maximum chi-square value for filtering SNPs.')
147
- parser.add_argument('--num_processes', type=int, default=4, help='Number of processes for parallel computing.')
148
- parser.add_argument('--use_additional_baseline_annotation', type=bool, nargs='?', const=True, default=True, help='Use additional baseline annotations when provided')
149
-
150
-
151
- def add_Cauchy_combination_args(parser):
152
- add_shared_args(parser)
153
- parser.add_argument('--trait_name', type=str, required=True, help='Name of the trait being analyzed.')
154
- parser.add_argument('--annotation', type=str, required=True, help='Name of the annotation in adata.obs to use.')
155
- parser.add_argument('--meta', type=str, help='Optional meta information.')
156
- parser.add_argument('--slide', type=str, help='Optional slide information.')
157
-
158
-
159
- def add_report_args(parser):
160
- add_shared_args(parser)
161
- parser.add_argument('--trait_name', type=str, required=True, help='Name of the trait to generate the report for.')
162
- parser.add_argument('--annotation', type=str, required=True, help='Annotation layer name.')
163
- # parser.add_argument('--plot_type', type=str, choices=['manhattan', 'GSS', 'gsMap', 'all'], default='all',
164
- # help="Type of diagnostic plot to generate. Choose from 'manhattan', 'GSS', 'gsMap', or 'all'.")
165
- parser.add_argument('--top_corr_genes', type=int, default=50,
166
- help='Number of top correlated genes to display.')
167
- parser.add_argument('--selected_genes', type=str, nargs='*',
168
- help='List of specific genes to include in the report (optional).')
169
- parser.add_argument('--sumstats_file', type=str, required=True, help='Path to GWAS summary statistics file.')
170
-
171
- # Optional arguments for customization
172
- parser.add_argument('--fig_width', type=int, default=None, help='Width of the generated figures in pixels.')
173
- parser.add_argument('--fig_height', type=int, default=None, help='Height of the generated figures in pixels.')
174
- parser.add_argument('--point_size', type=int, default=None, help='Point size for the figures.')
175
- parser.add_argument('--fig_style', type=str, default='light', choices=['dark', 'light'],
176
- help='Style of the generated figures.')
177
-
178
- def add_format_sumstats_args(parser):
179
- # Required arguments
180
- parser.add_argument('--sumstats', required=True, type=str,
181
- help='Path to gwas summary data')
182
- parser.add_argument('--out', required=True, type=str,
183
- help='Path to save the formatted gwas data')
184
-
185
- # Arguments for specify column name
186
- parser.add_argument('--snp', default=None, type=str,
187
- help="Name of snp column (if not a name that gsMap understands)")
188
- parser.add_argument('--a1', default=None, type=str,
189
- help="Name of effect allele column (if not a name that gsMap understands)")
190
- parser.add_argument('--a2', default=None, type=str,
191
- help="Name of none-effect allele column (if not a name that gsMap understands)")
192
- parser.add_argument('--info', default=None, type=str,
193
- help="Name of info column (if not a name that gsMap understands)")
194
- parser.add_argument('--beta', default=None, type=str,
195
- help="Name of gwas beta column (if not a name that gsMap understands).")
196
- parser.add_argument('--se', default=None, type=str,
197
- help="Name of gwas standar error of beta column (if not a name that gsMap understands)")
198
- parser.add_argument('--p', default=None, type=str,
199
- help="Name of p-value column (if not a name that gsMap understands)")
200
- parser.add_argument('--frq', default=None, type=str,
201
- help="Name of A1 ferquency column (if not a name that gsMap understands)")
202
- parser.add_argument('--n', default=None, type=str,
203
- help="Name of sample size column (if not a name that gsMap understands)")
204
- parser.add_argument('--z', default=None, type=str,
205
- help="Name of gwas Z-statistics column (if not a name that gsMap understands)")
206
- parser.add_argument('--OR', default=None, type=str,
207
- help="Name of gwas OR column (if not a name that gsMap understands)")
208
- parser.add_argument('--se_OR', default=None, type=str,
209
- help="Name of standar error of OR column (if not a name that gsMap understands)")
210
-
211
- # Arguments for convert SNP (chr, pos) to rsid
212
- parser.add_argument('--chr', default="Chr", type=str,
213
- help="Name of SNP chromosome column (if not a name that gsMap understands)")
214
- parser.add_argument('--pos', default="Pos", type=str,
215
- help="Name of SNP positions column (if not a name that gsMap understands)")
216
- parser.add_argument('--dbsnp', default=None, type=str,
217
- help='Path to reference dnsnp file')
218
- parser.add_argument('--chunksize', default=1e+6, type=int,
219
- help='Chunk size for loading dbsnp file')
220
-
221
- # Arguments for output format and quality
222
- parser.add_argument('--format', default='gsMap', type=str,
223
- help='Format of output data', choices=['gsMap', 'COJO'])
224
- parser.add_argument('--info_min', default=0.9, type=float,
225
- help='Minimum INFO score.')
226
- parser.add_argument('--maf_min', default=0.01, type=float,
227
- help='Minimum MAF.')
228
- parser.add_argument('--keep_chr_pos', action='store_true', default=False,
229
- help='Keep SNP chromosome and position columns in the output data')
230
-
231
- def add_run_all_mode_args(parser):
232
- add_shared_args(parser)
233
-
234
- # Required paths and configurations
235
- parser.add_argument('--gsMap_resource_dir', type=str, required=True,
236
- help='Directory containing gsMap resources (e.g., genome annotations, LD reference panel, etc.).')
237
- parser.add_argument('--hdf5_path', type=str, required=True,
238
- help='Path to the input spatial transcriptomics data (H5AD format).')
239
- parser.add_argument('--annotation', type=str, required=True,
240
- help='Name of the annotation in adata.obs to use.')
241
- parser.add_argument('--data_layer', type=str, default='counts', required=True,
242
- help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
243
-
244
- # GWAS Data Parameters
245
- parser.add_argument('--trait_name', type=str, help='Name of the trait for GWAS analysis (required if sumstats_file is provided).')
246
- parser.add_argument('--sumstats_file', type=str,
247
- help='Path to GWAS summary statistics file. Either sumstats_file or sumstats_config_file is required.')
248
- parser.add_argument('--sumstats_config_file', type=str,
249
- help='Path to GWAS summary statistics config file. Either sumstats_file or sumstats_config_file is required.')
250
-
251
- # Homolog Data Parameters
252
- parser.add_argument('--homolog_file', type=str,
253
- help='Path to homologous gene for converting gene names from different species to human (optional, used for cross-species analysis).')
254
-
255
- # Maximum number of processes
256
- parser.add_argument('--max_processes', type=int, default=10,
257
- help='Maximum number of processes for parallel execution.')
258
-
259
- # # Optional paths for customization
260
- # parser.add_argument('--bfile_root', type=str,
261
- # help='Root path to PLINK bfiles (LD reference panel). If not provided, it will use the default in gsMap_resource_dir.')
262
- # parser.add_argument('--keep_snp_root', type=str,
263
- # help='Root path for SNP filtering. If not provided, it will use the default in gsMap_resource_dir.')
264
- # parser.add_argument('--w_file', type=str,
265
- # help='Path to the regression weight file. If not provided, it will use the default in gsMap_resource_dir.')
266
- # parser.add_argument('--snp_gene_weight_adata_path', type=str,
267
- # help='Path to the SNP-gene weight matrix file. If not provided, it will use the default in gsMap_resource_dir.')
268
- # parser.add_argument('--baseline_annotation_dir', type=str,
269
- # help='Directory containing the baseline annotations for quick mode. If not provided, it will use the default in gsMap_resource_dir.')
270
- # parser.add_argument('--SNP_gene_pair_dir', type=str,
271
- # help='Directory for SNP-gene pair data. If not provided, it will use the default in gsMap_resource_dir.')
272
-
273
-
274
- def ensure_path_exists(func):
275
- @wraps(func)
276
- def wrapper(*args, **kwargs):
277
- result = func(*args, **kwargs)
278
- if isinstance(result, Path):
279
- if result.suffix:
280
- result.parent.mkdir(parents=True, exist_ok=True, mode=0o755)
281
- else: # It's a directory path
282
- result.mkdir(parents=True, exist_ok=True, mode=0o755)
283
- return result
284
-
285
- return wrapper
286
-
287
-
288
- @dataclass
289
- class ConfigWithAutoPaths:
290
- workdir: str
291
- sample_name: str
292
-
293
- def __post_init__(self):
294
- if self.workdir is None:
295
- raise ValueError('workdir must be provided.')
296
-
297
- @property
298
- @ensure_path_exists
299
- def hdf5_with_latent_path(self) -> Path:
300
- return Path(f'{self.workdir}/{self.sample_name}/find_latent_representations/{self.sample_name}_add_latent.h5ad')
301
-
302
- @property
303
- @ensure_path_exists
304
- def mkscore_feather_path(self) -> Path:
305
- return Path(f'{self.workdir}/{self.sample_name}/latent_to_gene/{self.sample_name}_gene_marker_score.feather')
306
-
307
- @property
308
- @ensure_path_exists
309
- def ldscore_save_dir(self) -> Path:
310
- return Path(f'{self.workdir}/{self.sample_name}/generate_ldscore')
311
-
312
- @property
313
- @ensure_path_exists
314
- def ldsc_save_dir(self) -> Path:
315
- return Path(f'{self.workdir}/{self.sample_name}/spatial_ldsc')
316
-
317
- @property
318
- @ensure_path_exists
319
- def cauchy_save_dir(self) -> Path:
320
- return Path(f'{self.workdir}/{self.sample_name}/cauchy_combination')
321
-
322
- @ensure_path_exists
323
- def get_report_dir(self, trait_name: str) -> Path:
324
- return Path(f'{self.workdir}/{self.sample_name}/report/{trait_name}')
325
-
326
- def get_gsMap_report_file(self, trait_name: str) -> Path:
327
- return self.get_report_dir(trait_name) / f'{self.sample_name}_{trait_name}_gsMap_Report.html'
328
-
329
- @ensure_path_exists
330
- def get_manhattan_html_plot_path(self, trait_name: str) -> Path:
331
- return Path(
332
- f'{self.workdir}/{self.sample_name}/report/{trait_name}/manhattan_plot/{self.sample_name}_{trait_name}_Diagnostic_Manhattan_Plot.html')
333
-
334
- @ensure_path_exists
335
- def get_GSS_plot_dir(self, trait_name: str) -> Path:
336
- return Path(f'{self.workdir}/{self.sample_name}/report/{trait_name}/GSS_plot')
337
-
338
- def get_GSS_plot_select_gene_file(self, trait_name: str) -> Path:
339
- return self.get_GSS_plot_dir(trait_name) / 'plot_genes.csv'
340
-
341
- @ensure_path_exists
342
- def get_ldsc_result_file(self, trait_name: str) -> Path:
343
- return Path(f'{self.ldsc_save_dir}/{self.sample_name}_{trait_name}.csv.gz')
344
-
345
- @ensure_path_exists
346
- def get_cauchy_result_file(self, trait_name: str) -> Path:
347
- return Path(f'{self.cauchy_save_dir}/{self.sample_name}_{trait_name}.Cauchy.csv.gz')
348
-
349
- @ensure_path_exists
350
- def get_gene_diagnostic_info_save_path(self, trait_name: str) -> Path:
351
- return Path(
352
- f'{self.workdir}/{self.sample_name}/report/{trait_name}/{self.sample_name}_{trait_name}_Gene_Diagnostic_Info.csv')
353
-
354
- @ensure_path_exists
355
- def get_gsMap_plot_save_dir(self, trait_name: str) -> Path:
356
- return Path(f'{self.workdir}/{self.sample_name}/report/{trait_name}/gsMap_plot')
357
-
358
- def get_gsMap_html_plot_save_path(self, trait_name: str) -> Path:
359
- return self.get_gsMap_plot_save_dir(trait_name) / f'{self.sample_name}_{trait_name}_gsMap_plot.html'
360
-
361
- @dataclass
362
- class FindLatentRepresentationsConfig(ConfigWithAutoPaths):
363
- input_hdf5_path: str
364
- # output_hdf5_path: str
365
- annotation: str = None
366
- data_layer: str = None
367
-
368
- epochs: int = 300
369
- feat_hidden1: int = 256
370
- feat_hidden2: int = 128
371
- feat_cell: int = 3000
372
- gat_hidden1: int = 64
373
- gat_hidden2: int = 30
374
- p_drop: float = 0.1
375
- gat_lr: float = 0.001
376
- gcn_decay: float = 0.01
377
- n_neighbors: int = 11
378
- label_w: float = 1
379
- rec_w: float = 1
380
- input_pca: bool = True
381
- n_comps: int = 300
382
- weighted_adj: bool = False
383
- nheads: int = 3
384
- var: bool = False
385
- convergence_threshold: float = 1e-4
386
- hierarchically: bool = False
387
-
388
- def __post_init__(self):
389
- # self.output_hdf5_path = self.hdf5_with_latent_path
390
- if self.hierarchically:
391
- if self.annotation is None:
392
- raise ValueError('annotation must be provided if hierarchically is True.')
393
- logger.info(
394
- f'------Hierarchical mode is enabled. This will find the latent representations within each annotation.')
395
-
396
- # remind for not providing annotation
397
- if self.annotation is None:
398
- logger.warning(
399
- 'annotation is not provided. This will find the latent representations for the whole dataset.')
400
- else:
401
- logger.info(f'------Find latent representations for {self.annotation}...')
402
-
403
-
404
- @dataclass
405
- class LatentToGeneConfig(ConfigWithAutoPaths):
406
- # input_hdf5_with_latent_path: str
407
- # output_feather_path: str
408
- no_expression_fraction: bool = False
409
- latent_representation: str = 'latent_GVAE'
410
- num_neighbour: int = 21
411
- num_neighbour_spatial: int = 101
412
- homolog_file: str = None
413
- gM_slices: str = None
414
- annotation: str = None
415
-
416
- def __post_init__(self):
417
- if self.homolog_file is not None:
418
- logger.info(f"User provided homolog file to map gene names to human: {self.homolog_file}")
419
- # check the format of the homolog file
420
- with open(self.homolog_file, 'r') as f:
421
- first_line = f.readline().strip()
422
- _n_col = len(first_line.split())
423
- if _n_col != 2:
424
- raise ValueError(
425
- f"Invalid homolog file format. Expected 2 columns, first column should be other species gene name, second column should be human gene name. "
426
- f"Got {_n_col} columns in the first line.")
427
- else:
428
- first_col_name, second_col_name = first_line.split()
429
- self.species = first_col_name
430
- logger.info(
431
- f"Homolog file provided and will map gene name from column1:{first_col_name} to column2:{second_col_name}")
432
- else:
433
- logger.info("No homolog file provided. Run in human mode.")
434
-
435
-
436
- @dataclass
437
- class GenerateLDScoreConfig(ConfigWithAutoPaths):
438
- chrom: Union[int, str]
439
-
440
- bfile_root: str
441
- keep_snp_root: Optional[str]
442
-
443
- # annotation by gene distance
444
- gtf_annotation_file: str
445
- gene_window_size: int = 50000
446
-
447
- # annotation by enhancer
448
- enhancer_annotation_file: str = None
449
- snp_multiple_enhancer_strategy: Literal['max_mkscore', 'nearest_TSS'] = 'max_mkscore'
450
- gene_window_enhancer_priority: Optional[Literal['gene_window_first', 'enhancer_first', 'enhancer_only',]] = None
451
-
452
- # for calculating ld score
453
- additional_baseline_annotation: str = None
454
- spots_per_chunk: int = 1_000
455
- ld_wind: int = 1
456
- ld_unit: str = 'CM'
457
-
458
- # zarr config
459
- ldscore_save_format: Literal['feather', 'zarr', 'quick_mode'] = 'feather'
460
-
461
- zarr_chunk_size: Tuple[int, int] = None
462
-
463
- # for pre calculating the SNP Gene ldscore Weight
464
- save_pre_calculate_snp_gene_weight_matrix: bool = False
465
-
466
- baseline_annotation_dir: Optional[str] = None
467
- SNP_gene_pair_dir: Optional[str] = None
468
- def __post_init__(self):
469
- # if self.mkscore_feather_file is None:
470
- # self.mkscore_feather_file = self._get_mkscore_feather_path()
471
-
472
- if self.enhancer_annotation_file is not None and self.gene_window_enhancer_priority is None:
473
- logger.warning("enhancer_annotation_file is provided but gene_window_enhancer_priority is not provided. "
474
- "by default, gene_window_enhancer_priority is set to 'enhancer_only', when enhancer_annotation_file is provided.")
475
- self.gene_window_enhancer_priority = 'enhancer_only'
476
- if self.enhancer_annotation_file is None and self.gene_window_enhancer_priority is not None:
477
- logger.warning("gene_window_enhancer_priority is provided but enhancer_annotation_file is not provided. "
478
- "by default, gene_window_enhancer_priority is set to None, when enhancer_annotation_file is not provided.")
479
- self.gene_window_enhancer_priority = None
480
- assert self.gene_window_enhancer_priority in [None, 'gene_window_first', 'enhancer_first', 'enhancer_only', ], \
481
- f"gene_window_enhancer_priority must be one of None, 'gene_window_first', 'enhancer_first', 'enhancer_only', but got {self.gene_window_enhancer_priority}."
482
- if self.gene_window_enhancer_priority in ['gene_window_first', 'enhancer_first']:
483
- logger.info(f'Both gene_window and enhancer annotation will be used to calculate LD score. ')
484
- logger.info(
485
- f'SNP within +-{self.gene_window_size} bp of gene body will be used and enhancer annotation will be used to calculate LD score. If a snp maps to multiple enhancers, the strategy to choose by your select strategy: {self.snp_multiple_enhancer_strategy}.')
486
- elif self.gene_window_enhancer_priority == 'enhancer_only':
487
- logger.info(f'Only enhancer annotation will be used to calculate LD score. ')
488
- else:
489
- logger.info(
490
- f'Only gene window annotation will be used to calculate LD score. SNP within +-{self.gene_window_size} bp of gene body will be used. ')
491
-
492
- # remind for baseline annotation
493
- if self.additional_baseline_annotation is None:
494
- logger.info(f'------Baseline annotation is not provided. Default baseline annotation will be used.')
495
- else:
496
- logger.info(
497
- f'------Baseline annotation is provided. Additional baseline annotation will be used with the default baseline annotation.')
498
- logger.info(f'------Baseline annotation directory: {self.additional_baseline_annotation}')
499
- # check the existence of baseline annotation
500
- if self.chrom == 'all':
501
- for chrom in range(1, 23):
502
- chrom = str(chrom)
503
- baseline_annotation_path = Path(
504
- self.additional_baseline_annotation) / f'baseline.{chrom}.annot.gz'
505
- if not baseline_annotation_path.exists():
506
- raise FileNotFoundError(
507
- f'baseline.{chrom}.annot.gz is not found in {self.additional_baseline_annotation}.')
508
- else:
509
- baseline_annotation_path = Path(
510
- self.additional_baseline_annotation) / f'baseline.{self.chrom}.annot.gz'
511
- if not baseline_annotation_path.exists():
512
- raise FileNotFoundError(
513
- f'baseline.{self.chrom}.annot.gz is not found in {self.additional_baseline_annotation}.')
514
-
515
- # set the default zarr chunk size
516
- if self.ldscore_save_format == 'zarr' and self.zarr_chunk_size is None:
517
- self.zarr_chunk_size = (10_000, self.spots_per_chunk)
518
-
519
-
520
- @dataclass
521
- class SpatialLDSCConfig(ConfigWithAutoPaths):
522
- w_file: str
523
- # ldscore_save_dir: str
524
- use_additional_baseline_annotation: bool = True
525
- trait_name: Optional[str] = None
526
- sumstats_file: Optional[str] = None
527
- sumstats_config_file: Optional[str] = None
528
- num_processes: int = 4
529
- not_M_5_50: bool = False
530
- n_blocks: int = 200
531
- chisq_max: Optional[int] = None
532
- all_chunk: Optional[int] = None
533
- chunk_range: Optional[Tuple[int, int]] = None
534
-
535
- ldscore_save_format: Literal['feather', 'zarr', 'quick_mode'] = 'feather'
536
-
537
- spots_per_chunk_quick_mode: int = 1_000
538
- snp_gene_weight_adata_path: Optional[str] = None
539
-
540
- def __post_init__(self):
541
- super().__post_init__()
542
- if self.sumstats_file is None and self.sumstats_config_file is None:
543
- raise ValueError('One of sumstats_file and sumstats_config_file must be provided.')
544
- if self.sumstats_file is not None and self.sumstats_config_file is not None:
545
- raise ValueError('Only one of sumstats_file and sumstats_config_file must be provided.')
546
- if self.sumstats_file is not None and self.trait_name is None:
547
- raise ValueError('trait_name must be provided if sumstats_file is provided.')
548
- if self.sumstats_config_file is not None and self.trait_name is not None:
549
- raise ValueError('trait_name must not be provided if sumstats_config_file is provided.')
550
- self.sumstats_config_dict = {}
551
- # load the sumstats config file
552
- if self.sumstats_config_file is not None:
553
- import yaml
554
- with open(self.sumstats_config_file) as f:
555
- config = yaml.load(f, Loader=yaml.FullLoader)
556
- for trait_name, sumstats_file in config.items():
557
- assert Path(sumstats_file).exists(), f'{sumstats_file} does not exist.'
558
- # load the sumstats file
559
- elif self.sumstats_file is not None:
560
- self.sumstats_config_dict[self.trait_name] = self.sumstats_file
561
- else:
562
- raise ValueError('One of sumstats_file and sumstats_config_file must be provided.')
563
-
564
- for sumstats_file in self.sumstats_config_dict.values():
565
- assert Path(sumstats_file).exists(), f'{sumstats_file} does not exist.'
566
-
567
- # check if additional baseline annotation is exist
568
- # self.use_additional_baseline_annotation = False
569
-
570
- if self.use_additional_baseline_annotation:
571
- self.process_additional_baseline_annotation()
572
-
573
- def process_additional_baseline_annotation(self):
574
- additional_baseline_annotation = Path(self.ldscore_save_dir) / 'additional_baseline'
575
- dir_exists = additional_baseline_annotation.exists()
576
-
577
- if not dir_exists:
578
- self.use_additional_baseline_annotation = False
579
- # if self.use_additional_baseline_annotation:
580
- # logger.warning(f"additional_baseline directory is not found in {self.ldscore_save_dir}.")
581
- # print('''\
582
- # if you want to use additional baseline annotation,
583
- # please provide additional baseline annotation when calculating ld score.
584
- # ''')
585
- # raise FileNotFoundError(
586
- # f'additional_baseline directory is not found.')
587
- # return
588
- # self.use_additional_baseline_annotation = self.use_additional_baseline_annotation or True
589
- else:
590
- logger.info(
591
- f'------Additional baseline annotation is provided. It will be used with the default baseline annotation.')
592
- logger.info(f'------Additional baseline annotation directory: {additional_baseline_annotation}')
593
-
594
- chrom_list = range(1, 23)
595
- for chrom in chrom_list:
596
- baseline_annotation_path = additional_baseline_annotation / f'baseline.{chrom}.l2.ldscore.feather'
597
- if not baseline_annotation_path.exists():
598
- raise FileNotFoundError(
599
- f'baseline.{chrom}.annot.gz is not found in {additional_baseline_annotation}.')
600
- return None
601
-
602
-
603
- @dataclass
604
- class CauchyCombinationConfig(ConfigWithAutoPaths):
605
- trait_name: str
606
- annotation: str
607
- meta: str = None
608
- slide: str = None
609
-
610
-
611
- @dataclass
612
- class VisualizeConfig(ConfigWithAutoPaths):
613
- trait_name: str
614
-
615
- annotation: str = None
616
- fig_title: str = None
617
- fig_height: int = 600
618
- fig_width: int = 800
619
- point_size: int = None
620
- fig_style: Literal['dark', 'light'] = 'light'
621
-
622
-
623
- @dataclass
624
- class DiagnosisConfig(ConfigWithAutoPaths):
625
- annotation: str
626
- # mkscore_feather_file: str
627
-
628
- trait_name: str
629
- sumstats_file: str
630
- plot_type: Literal['manhattan', 'GSS', 'gsMap', 'all'] = 'all'
631
- top_corr_genes: int = 50
632
- selected_genes: Optional[List[str]] = None
633
-
634
- fig_width: Optional[int] = None
635
- fig_height: Optional[int] = None
636
- point_size: Optional[int] = None
637
- fig_style: Literal['dark', 'light'] = 'light'
638
-
639
- def __post_init__(self):
640
- if any([self.fig_width, self.fig_height, self.point_size]):
641
- logger.info('Customizing the figure size and point size.')
642
- assert all([self.fig_width, self.fig_height, self.point_size]), 'All of fig_width, fig_height, and point_size must be provided.'
643
- self.customize_fig = True
644
- else:
645
- self.customize_fig = False
646
- @dataclass
647
- class ReportConfig(DiagnosisConfig):
648
- pass
649
-
650
-
651
- @dataclass
652
- class RunAllModeConfig(ConfigWithAutoPaths):
653
- gsMap_resource_dir: str
654
-
655
- # == ST DATA PARAMETERS ==
656
- hdf5_path: str
657
- annotation: str
658
- data_layer: str = 'X'
659
-
660
- # ==GWAS DATA PARAMETERS==
661
- trait_name: Optional[str] = None
662
- sumstats_file: Optional[str] = None
663
- sumstats_config_file: Optional[str] = None
664
-
665
- # === homolog PARAMETERS ===
666
- homolog_file: Optional[str] = None
667
-
668
- max_processes: int = 10
669
-
670
- def __post_init__(self):
671
- super().__post_init__()
672
- self.gtffile = f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v39lift37.annotation.gtf"
673
- self.bfile_root = f"{self.gsMap_resource_dir}/LD_Reference_Panel/1000G_EUR_Phase3_plink/1000G.EUR.QC"
674
- self.keep_snp_root = f"{self.gsMap_resource_dir}/LDSC_resource/hapmap3_snps/hm"
675
- self.w_file = f"{self.gsMap_resource_dir}/LDSC_resource/weights_hm3_no_hla/weights."
676
- self.snp_gene_weight_adata_path = f"{self.gsMap_resource_dir}/quick_mode/snp_gene_weight_matrix.h5ad"
677
- self.baseline_annotation_dir = Path(f"{self.gsMap_resource_dir}/quick_mode/baseline").resolve()
678
- self.SNP_gene_pair_dir = Path(f"{self.gsMap_resource_dir}/quick_mode/SNP_gene_pair").resolve()
679
- # check the existence of the input files and resources files
680
- for file in [self.hdf5_path, self.gtffile]:
681
- if not Path(file).exists():
682
- raise FileNotFoundError(f"File {file} does not exist.")
683
-
684
- if self.sumstats_file is None and self.sumstats_config_file is None:
685
- raise ValueError('One of sumstats_file and sumstats_config_file must be provided.')
686
- if self.sumstats_file is not None and self.sumstats_config_file is not None:
687
- raise ValueError('Only one of sumstats_file and sumstats_config_file must be provided.')
688
- if self.sumstats_file is not None and self.trait_name is None:
689
- raise ValueError('trait_name must be provided if sumstats_file is provided.')
690
- if self.sumstats_config_file is not None and self.trait_name is not None:
691
- raise ValueError('trait_name must not be provided if sumstats_config_file is provided.')
692
- self.sumstats_config_dict = {}
693
- # load the sumstats config file
694
- if self.sumstats_config_file is not None:
695
- import yaml
696
- with open(self.sumstats_config_file) as f:
697
- config = yaml.load(f, Loader=yaml.FullLoader)
698
- for trait_name, sumstats_file in config.items():
699
- assert Path(sumstats_file).exists(), f'{sumstats_file} does not exist.'
700
- self.sumstats_config_dict[trait_name] = sumstats_file
701
- # load the sumstats file
702
- elif self.sumstats_file is not None and self.trait_name is not None:
703
- self.sumstats_config_dict[self.trait_name] = self.sumstats_file
704
- else:
705
- raise ValueError('One of sumstats_file and sumstats_config_file must be provided.')
706
-
707
- for sumstats_file in self.sumstats_config_dict.values():
708
- assert Path(sumstats_file).exists(), f'{sumstats_file} does not exist.'
709
-
710
-
711
- @dataclass
712
- class FormatSumstatsConfig:
713
- sumstats: str
714
- out: str
715
- dbsnp: str
716
- snp: str = None
717
- a1: str = None
718
- a2: str = None
719
- info: str = None
720
- beta: str = None
721
- se: str = None
722
- p: str = None
723
- frq: str = None
724
- n: str = None
725
- z: str = None
726
- OR: str = None
727
- se_OR: str = None
728
- format: str = None
729
- chr: str = None
730
- pos: str = None
731
- chunksize: int = 1e+7
732
- info_min: float = 0.9
733
- maf_min: float = 0.01
734
- keep_chr_pos: bool = False
735
-
736
-
737
- @register_cli(name='run_find_latent_representations',
738
- description='Run Find_latent_representations \nFind the latent representations of each spot by running GNN-VAE',
739
- add_args_function=add_find_latent_representations_args)
740
- def run_find_latent_representation_from_cli(args: argparse.Namespace):
741
- from gsMap.find_latent_representation import run_find_latent_representation
742
- config = get_dataclass_from_parser(args, FindLatentRepresentationsConfig)
743
- run_find_latent_representation(config)
744
-
745
-
746
- @register_cli(name='run_latent_to_gene',
747
- description='Run Latent_to_gene \nEstimate gene marker gene scores for each spot by using latent representations from nearby spots',
748
- add_args_function=add_latent_to_gene_args)
749
- def run_latent_to_gene_from_cli(args: argparse.Namespace):
750
- from gsMap.latent_to_gene import run_latent_to_gene
751
- config = get_dataclass_from_parser(args, LatentToGeneConfig)
752
- run_latent_to_gene(config)
753
-
754
-
755
- @register_cli(name='run_generate_ldscore',
756
- description='Run Generate_ldscore \nGenerate LD scores for each spot',
757
- add_args_function=add_generate_ldscore_args)
758
- def run_generate_ldscore_from_cli(args: argparse.Namespace):
759
- from gsMap.generate_ldscore import run_generate_ldscore
760
- config = get_dataclass_from_parser(args, GenerateLDScoreConfig)
761
- run_generate_ldscore(config)
762
-
763
-
764
- @register_cli(name='run_spatial_ldsc',
765
- description='Run Spatial_ldsc \nRun spatial LDSC for each spot',
766
- add_args_function=add_spatial_ldsc_args)
767
- def run_spatial_ldsc_from_cli(args: argparse.Namespace):
768
- from gsMap.spatial_ldsc_multiple_sumstats import run_spatial_ldsc
769
- config = get_dataclass_from_parser(args, SpatialLDSCConfig)
770
- run_spatial_ldsc(config)
771
-
772
-
773
- @register_cli(name='run_cauchy_combination',
774
- description='Run Cauchy_combination for each annotation',
775
- add_args_function=add_Cauchy_combination_args)
776
- def run_Cauchy_combination_from_cli(args: argparse.Namespace):
777
- from gsMap.cauchy_combination_test import run_Cauchy_combination
778
- config = get_dataclass_from_parser(args, CauchyCombinationConfig)
779
- run_Cauchy_combination(config)
780
-
781
-
782
- @register_cli(name='run_report',
783
- description='Run Report to generate diagnostic plots and tables',
784
- add_args_function=add_report_args)
785
- def run_Report_from_cli(args: argparse.Namespace):
786
- from gsMap.report import run_report
787
- config = get_dataclass_from_parser(args, ReportConfig)
788
- run_report(config)
789
-
790
-
791
- @register_cli(name='format_sumstats',
792
- description='Format gwas summary statistics',
793
- add_args_function=add_format_sumstats_args)
794
- def gwas_format_from_cli(args: argparse.Namespace):
795
- from gsMap.format_sumstats import gwas_format
796
- config = get_dataclass_from_parser(args, FormatSumstatsConfig)
797
- gwas_format(config)
798
-
799
- @register_cli(name='quick_mode',
800
- description='Run all the gsMap pipeline in quick mode',
801
- add_args_function=add_run_all_mode_args)
802
- def run_all_mode_from_cli(args: argparse.Namespace):
803
- from gsMap.run_all_mode import run_pipeline
804
- config = get_dataclass_from_parser(args, RunAllModeConfig)
805
- run_pipeline(config)
1
+ import argparse
2
+ import dataclasses
3
+ import logging
4
+ import sys
5
+ from collections import OrderedDict, namedtuple
6
+ from collections.abc import Callable
7
+ from dataclasses import dataclass
8
+ from functools import wraps
9
+ from pathlib import Path
10
+ from pprint import pprint
11
+ from typing import Literal
12
+
13
+ import pyfiglet
14
+ import yaml
15
+
16
+ from gsMap.__init__ import __version__
17
+
18
+ # Global registry to hold functions
19
+ cli_function_registry = OrderedDict()
20
+ subcommand = namedtuple("subcommand", ["name", "func", "add_args_function", "description"])
21
+
22
+
23
+ def get_gsMap_logger(logger_name):
24
+ logger = logging.getLogger(logger_name)
25
+ logger.setLevel(logging.DEBUG)
26
+ handler = logging.StreamHandler()
27
+ handler.setFormatter(
28
+ logging.Formatter("[{asctime}] {levelname:.5s} | {name} - {message}", style="{")
29
+ )
30
+ logger.addHandler(handler)
31
+ return logger
32
+
33
+
34
+ logger = get_gsMap_logger("gsMap")
35
+
36
+
37
+ # Decorator to register functions for cli parsing
38
+ def register_cli(name: str, description: str, add_args_function: Callable) -> Callable:
39
+ def decorator(func: Callable) -> Callable:
40
+ def wrapper(*args, **kwargs):
41
+ name.replace("_", " ")
42
+ gsMap_main_logo = pyfiglet.figlet_format(
43
+ "gsMap",
44
+ font="doom",
45
+ width=80,
46
+ justify="center",
47
+ ).rstrip()
48
+ print(gsMap_main_logo, flush=True)
49
+ version_number = "Version: " + __version__
50
+ print(version_number.center(80), flush=True)
51
+ print("=" * 80, flush=True)
52
+ logger.info(f"Running {name}...")
53
+ func(*args, **kwargs)
54
+ logger.info(f"Finished running {name}.")
55
+
56
+ cli_function_registry[name] = subcommand(
57
+ name=name, func=wrapper, add_args_function=add_args_function, description=description
58
+ )
59
+ return wrapper
60
+
61
+ return decorator
62
+
63
+
64
+ def add_shared_args(parser):
65
+ parser.add_argument(
66
+ "--workdir", type=str, required=True, help="Path to the working directory."
67
+ )
68
+ parser.add_argument("--sample_name", type=str, required=True, help="Name of the sample.")
69
+
70
+
71
+ def add_find_latent_representations_args(parser):
72
+ add_shared_args(parser)
73
+ parser.add_argument(
74
+ "--input_hdf5_path", required=True, type=str, help="Path to the input HDF5 file."
75
+ )
76
+ parser.add_argument(
77
+ "--annotation", required=True, type=str, help="Name of the annotation in adata.obs to use."
78
+ )
79
+ parser.add_argument(
80
+ "--data_layer",
81
+ type=str,
82
+ default="counts",
83
+ required=True,
84
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").',
85
+ )
86
+ parser.add_argument("--epochs", type=int, default=300, help="Number of training epochs.")
87
+ parser.add_argument(
88
+ "--feat_hidden1", type=int, default=256, help="Neurons in the first hidden layer."
89
+ )
90
+ parser.add_argument(
91
+ "--feat_hidden2", type=int, default=128, help="Neurons in the second hidden layer."
92
+ )
93
+ parser.add_argument(
94
+ "--gat_hidden1", type=int, default=64, help="Units in the first GAT hidden layer."
95
+ )
96
+ parser.add_argument(
97
+ "--gat_hidden2", type=int, default=30, help="Units in the second GAT hidden layer."
98
+ )
99
+ parser.add_argument("--p_drop", type=float, default=0.1, help="Dropout rate.")
100
+ parser.add_argument("--gat_lr", type=float, default=0.001, help="Learning rate for the GAT.")
101
+ parser.add_argument("--n_neighbors", type=int, default=11, help="Number of neighbors for GAT.")
102
+ parser.add_argument(
103
+ "--n_comps", type=int, default=300, help="Number of principal components for PCA."
104
+ )
105
+ parser.add_argument(
106
+ "--weighted_adj", action="store_true", help="Use weighted adjacency in GAT."
107
+ )
108
+ parser.add_argument(
109
+ "--convergence_threshold", type=float, default=1e-4, help="Threshold for convergence."
110
+ )
111
+ parser.add_argument(
112
+ "--hierarchically",
113
+ action="store_true",
114
+ help="Enable hierarchical latent representation finding.",
115
+ )
116
+
117
+
118
+ def chrom_choice(value):
119
+ if value.isdigit():
120
+ ivalue = int(value)
121
+ if 1 <= ivalue <= 22:
122
+ return ivalue
123
+ elif value.lower() == "all":
124
+ return value
125
+ else:
126
+ raise argparse.ArgumentTypeError(
127
+ f"'{value}' is an invalid chromosome choice. Choose from 1-22 or 'all'."
128
+ )
129
+
130
+
131
+ def filter_args_for_dataclass(args_dict, data_class: dataclass):
132
+ return {k: v for k, v in args_dict.items() if k in data_class.__dataclass_fields__}
133
+
134
+
135
+ def get_dataclass_from_parser(args: argparse.Namespace, data_class: dataclass):
136
+ remain_kwargs = filter_args_for_dataclass(vars(args), data_class)
137
+ print(f"Using the following arguments for {data_class.__name__}:", flush=True)
138
+ pprint(remain_kwargs, indent=4)
139
+ sys.stdout.flush()
140
+ return data_class(**remain_kwargs)
141
+
142
+
143
+ def add_latent_to_gene_args(parser):
144
+ add_shared_args(parser)
145
+
146
+ parser.add_argument(
147
+ "--input_hdf5_path",
148
+ type=str,
149
+ default=None,
150
+ help="Path to the input HDF5 file with latent representations, if --latent_representation is specified.",
151
+ )
152
+ parser.add_argument(
153
+ "--no_expression_fraction", action="store_true", help="Skip expression fraction filtering."
154
+ )
155
+ parser.add_argument(
156
+ "--latent_representation",
157
+ type=str,
158
+ default=None,
159
+ help="Type of latent representation. This should exist in the h5ad obsm.",
160
+ )
161
+ parser.add_argument("--num_neighbour", type=int, default=21, help="Number of neighbors.")
162
+ parser.add_argument(
163
+ "--num_neighbour_spatial", type=int, default=101, help="Number of spatial neighbors."
164
+ )
165
+ parser.add_argument(
166
+ "--homolog_file",
167
+ type=str,
168
+ default=None,
169
+ help="Path to homologous gene conversion file (optional).",
170
+ )
171
+ parser.add_argument(
172
+ "--gM_slices", type=str, default=None, help="Path to the slice mean file (optional)."
173
+ )
174
+ parser.add_argument(
175
+ "--annotation",
176
+ type=str,
177
+ default=None,
178
+ help="Name of the annotation in adata.obs to use (optional).",
179
+ )
180
+
181
+
182
+ def add_generate_ldscore_args(parser):
183
+ add_shared_args(parser)
184
+ parser.add_argument("--chrom", type=str, required=True, help='Chromosome id (1-22) or "all".')
185
+ parser.add_argument(
186
+ "--bfile_root",
187
+ type=str,
188
+ required=True,
189
+ help="Root path for genotype plink bfiles (.bim, .bed, .fam).",
190
+ )
191
+ parser.add_argument(
192
+ "--keep_snp_root", type=str, required=True, help="Root path for SNP files."
193
+ )
194
+ parser.add_argument(
195
+ "--gtf_annotation_file", type=str, required=True, help="Path to GTF annotation file."
196
+ )
197
+ parser.add_argument(
198
+ "--gene_window_size", type=int, default=50000, help="Gene window size in base pairs."
199
+ )
200
+ parser.add_argument(
201
+ "--enhancer_annotation_file", type=str, help="Path to enhancer annotation file (optional)."
202
+ )
203
+ parser.add_argument(
204
+ "--snp_multiple_enhancer_strategy",
205
+ type=str,
206
+ choices=["max_mkscore", "nearest_TSS"],
207
+ default="max_mkscore",
208
+ help="Strategy for handling multiple enhancers per SNP.",
209
+ )
210
+ parser.add_argument(
211
+ "--gene_window_enhancer_priority",
212
+ type=str,
213
+ choices=["gene_window_first", "enhancer_first", "enhancer_only"],
214
+ help="Priority between gene window and enhancer annotations.",
215
+ )
216
+ parser.add_argument(
217
+ "--spots_per_chunk", type=int, default=1000, help="Number of spots per chunk."
218
+ )
219
+ parser.add_argument("--ld_wind", type=int, default=1, help="LD window size.")
220
+ parser.add_argument(
221
+ "--ld_unit",
222
+ type=str,
223
+ choices=["SNP", "KB", "CM"],
224
+ default="CM",
225
+ help="Unit for LD window.",
226
+ )
227
+ parser.add_argument(
228
+ "--additional_baseline_annotation",
229
+ type=str,
230
+ default=None,
231
+ help="Path of additional baseline annotations",
232
+ )
233
+
234
+
235
+ def add_spatial_ldsc_args(parser):
236
+ add_shared_args(parser)
237
+ parser.add_argument(
238
+ "--sumstats_file", type=str, required=True, help="Path to GWAS summary statistics file."
239
+ )
240
+ parser.add_argument(
241
+ "--w_file", type=str, required=True, help="Path to regression weight file."
242
+ )
243
+ parser.add_argument(
244
+ "--trait_name", type=str, required=True, help="Name of the trait being analyzed."
245
+ )
246
+ parser.add_argument(
247
+ "--n_blocks", type=int, default=200, help="Number of blocks for jackknife resampling."
248
+ )
249
+ parser.add_argument(
250
+ "--chisq_max", type=int, help="Maximum chi-square value for filtering SNPs."
251
+ )
252
+ parser.add_argument(
253
+ "--num_processes", type=int, default=4, help="Number of processes for parallel computing."
254
+ )
255
+ parser.add_argument(
256
+ "--use_additional_baseline_annotation",
257
+ type=bool,
258
+ nargs="?",
259
+ const=True,
260
+ default=True,
261
+ help="Use additional baseline annotations when provided",
262
+ )
263
+
264
+
265
+ def add_Cauchy_combination_args(parser):
266
+ parser.add_argument(
267
+ "--workdir", type=str, required=True, help="Path to the working directory."
268
+ )
269
+ parser.add_argument("--sample_name", type=str, required=False, help="Name of the sample.")
270
+
271
+ parser.add_argument(
272
+ "--trait_name", type=str, required=True, help="Name of the trait being analyzed."
273
+ )
274
+ parser.add_argument(
275
+ "--annotation", type=str, required=True, help="Name of the annotation in adata.obs to use."
276
+ )
277
+
278
+ parser.add_argument(
279
+ "--sample_name_list",
280
+ type=str,
281
+ nargs="+",
282
+ required=False,
283
+ help="List of sample names to process. Provide as a space-separated list.",
284
+ )
285
+ parser.add_argument(
286
+ "--output_file",
287
+ type=str,
288
+ required=False,
289
+ help="Path to save the combined Cauchy results. Required when using multiple samples.",
290
+ )
291
+
292
+
293
+ def add_report_args(parser):
294
+ add_shared_args(parser)
295
+ parser.add_argument(
296
+ "--trait_name",
297
+ type=str,
298
+ required=True,
299
+ help="Name of the trait to generate the report for.",
300
+ )
301
+ parser.add_argument("--annotation", type=str, required=True, help="Annotation layer name.")
302
+ # parser.add_argument('--plot_type', type=str, choices=['manhattan', 'GSS', 'gsMap', 'all'], default='all',
303
+ # help="Type of diagnostic plot to generate. Choose from 'manhattan', 'GSS', 'gsMap', or 'all'.")
304
+ parser.add_argument(
305
+ "--top_corr_genes", type=int, default=50, help="Number of top correlated genes to display."
306
+ )
307
+ parser.add_argument(
308
+ "--selected_genes",
309
+ type=str,
310
+ nargs="*",
311
+ help="List of specific genes to include in the report (optional).",
312
+ )
313
+ parser.add_argument(
314
+ "--sumstats_file", type=str, required=True, help="Path to GWAS summary statistics file."
315
+ )
316
+
317
+ # Optional arguments for customization
318
+ parser.add_argument(
319
+ "--fig_width", type=int, default=None, help="Width of the generated figures in pixels."
320
+ )
321
+ parser.add_argument(
322
+ "--fig_height", type=int, default=None, help="Height of the generated figures in pixels."
323
+ )
324
+ parser.add_argument("--point_size", type=int, default=None, help="Point size for the figures.")
325
+ parser.add_argument(
326
+ "--fig_style",
327
+ type=str,
328
+ default="light",
329
+ choices=["dark", "light"],
330
+ help="Style of the generated figures.",
331
+ )
332
+
333
+
334
+ def add_create_slice_mean_args(parser):
335
+ parser.add_argument(
336
+ "--sample_name_list",
337
+ type=str,
338
+ nargs="+",
339
+ required=True,
340
+ help="List of sample names to process. Provide as a space-separated list.",
341
+ )
342
+
343
+ parser.add_argument(
344
+ "--h5ad_list",
345
+ type=str,
346
+ nargs="+",
347
+ help="List of h5ad file paths corresponding to the sample names. Provide as a space-separated list.",
348
+ )
349
+ parser.add_argument(
350
+ "--h5ad_yaml",
351
+ type=str,
352
+ default=None,
353
+ help="Path to the YAML file containing sample names and associated h5ad file paths",
354
+ )
355
+ parser.add_argument(
356
+ "--slice_mean_output_file",
357
+ type=str,
358
+ required=True,
359
+ help="Path to the output file for the slice mean",
360
+ )
361
+ parser.add_argument(
362
+ "--homolog_file", type=str, help="Path to homologous gene conversion file (optional)."
363
+ )
364
+ parser.add_argument(
365
+ "--data_layer",
366
+ type=str,
367
+ default="counts",
368
+ required=True,
369
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").',
370
+ )
371
+
372
+
373
+ def add_format_sumstats_args(parser):
374
+ # Required arguments
375
+ parser.add_argument("--sumstats", required=True, type=str, help="Path to gwas summary data")
376
+ parser.add_argument(
377
+ "--out", required=True, type=str, help="Path to save the formatted gwas data"
378
+ )
379
+
380
+ # Arguments for specify column name
381
+ parser.add_argument(
382
+ "--snp",
383
+ default=None,
384
+ type=str,
385
+ help="Name of snp column (if not a name that gsMap understands)",
386
+ )
387
+ parser.add_argument(
388
+ "--a1",
389
+ default=None,
390
+ type=str,
391
+ help="Name of effect allele column (if not a name that gsMap understands)",
392
+ )
393
+ parser.add_argument(
394
+ "--a2",
395
+ default=None,
396
+ type=str,
397
+ help="Name of none-effect allele column (if not a name that gsMap understands)",
398
+ )
399
+ parser.add_argument(
400
+ "--info",
401
+ default=None,
402
+ type=str,
403
+ help="Name of info column (if not a name that gsMap understands)",
404
+ )
405
+ parser.add_argument(
406
+ "--beta",
407
+ default=None,
408
+ type=str,
409
+ help="Name of gwas beta column (if not a name that gsMap understands).",
410
+ )
411
+ parser.add_argument(
412
+ "--se",
413
+ default=None,
414
+ type=str,
415
+ help="Name of gwas standar error of beta column (if not a name that gsMap understands)",
416
+ )
417
+ parser.add_argument(
418
+ "--p",
419
+ default=None,
420
+ type=str,
421
+ help="Name of p-value column (if not a name that gsMap understands)",
422
+ )
423
+ parser.add_argument(
424
+ "--frq",
425
+ default=None,
426
+ type=str,
427
+ help="Name of A1 ferquency column (if not a name that gsMap understands)",
428
+ )
429
+ parser.add_argument(
430
+ "--n",
431
+ default=None,
432
+ type=str,
433
+ help="Name of sample size column (if not a name that gsMap understands)",
434
+ )
435
+ parser.add_argument(
436
+ "--z",
437
+ default=None,
438
+ type=str,
439
+ help="Name of gwas Z-statistics column (if not a name that gsMap understands)",
440
+ )
441
+ parser.add_argument(
442
+ "--OR",
443
+ default=None,
444
+ type=str,
445
+ help="Name of gwas OR column (if not a name that gsMap understands)",
446
+ )
447
+ parser.add_argument(
448
+ "--se_OR",
449
+ default=None,
450
+ type=str,
451
+ help="Name of standar error of OR column (if not a name that gsMap understands)",
452
+ )
453
+
454
+ # Arguments for convert SNP (chr, pos) to rsid
455
+ parser.add_argument(
456
+ "--chr",
457
+ default="Chr",
458
+ type=str,
459
+ help="Name of SNP chromosome column (if not a name that gsMap understands)",
460
+ )
461
+ parser.add_argument(
462
+ "--pos",
463
+ default="Pos",
464
+ type=str,
465
+ help="Name of SNP positions column (if not a name that gsMap understands)",
466
+ )
467
+ parser.add_argument("--dbsnp", default=None, type=str, help="Path to reference dnsnp file")
468
+ parser.add_argument(
469
+ "--chunksize", default=1e6, type=int, help="Chunk size for loading dbsnp file"
470
+ )
471
+
472
+ # Arguments for output format and quality
473
+ parser.add_argument(
474
+ "--format",
475
+ default="gsMap",
476
+ type=str,
477
+ help="Format of output data",
478
+ choices=["gsMap", "COJO"],
479
+ )
480
+ parser.add_argument("--info_min", default=0.9, type=float, help="Minimum INFO score.")
481
+ parser.add_argument("--maf_min", default=0.01, type=float, help="Minimum MAF.")
482
+ parser.add_argument(
483
+ "--keep_chr_pos",
484
+ action="store_true",
485
+ default=False,
486
+ help="Keep SNP chromosome and position columns in the output data",
487
+ )
488
+
489
+
490
+ def add_run_all_mode_args(parser):
491
+ add_shared_args(parser)
492
+
493
+ # Required paths and configurations
494
+ parser.add_argument(
495
+ "--gsMap_resource_dir",
496
+ type=str,
497
+ required=True,
498
+ help="Directory containing gsMap resources (e.g., genome annotations, LD reference panel, etc.).",
499
+ )
500
+ parser.add_argument(
501
+ "--hdf5_path",
502
+ type=str,
503
+ required=True,
504
+ help="Path to the input spatial transcriptomics data (H5AD format).",
505
+ )
506
+ parser.add_argument(
507
+ "--annotation", type=str, required=True, help="Name of the annotation in adata.obs to use."
508
+ )
509
+ parser.add_argument(
510
+ "--data_layer",
511
+ type=str,
512
+ default="counts",
513
+ required=True,
514
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").',
515
+ )
516
+
517
+ # GWAS Data Parameters
518
+ parser.add_argument(
519
+ "--trait_name",
520
+ type=str,
521
+ help="Name of the trait for GWAS analysis (required if sumstats_file is provided).",
522
+ )
523
+ parser.add_argument(
524
+ "--sumstats_file",
525
+ type=str,
526
+ help="Path to GWAS summary statistics file. Either sumstats_file or sumstats_config_file is required.",
527
+ )
528
+ parser.add_argument(
529
+ "--sumstats_config_file",
530
+ type=str,
531
+ help="Path to GWAS summary statistics config file. Either sumstats_file or sumstats_config_file is required.",
532
+ )
533
+
534
+ # Homolog Data Parameters
535
+ parser.add_argument(
536
+ "--homolog_file",
537
+ type=str,
538
+ help="Path to homologous gene for converting gene names from different species to human (optional, used for cross-species analysis).",
539
+ )
540
+
541
+ # Maximum number of processes
542
+ parser.add_argument(
543
+ "--max_processes",
544
+ type=int,
545
+ default=10,
546
+ help="Maximum number of processes for parallel execution.",
547
+ )
548
+
549
+ parser.add_argument(
550
+ "--latent_representation",
551
+ type=str,
552
+ default=None,
553
+ help="Type of latent representation. This should exist in the h5ad obsm.",
554
+ )
555
+ parser.add_argument("--num_neighbour", type=int, default=21, help="Number of neighbors.")
556
+ parser.add_argument(
557
+ "--num_neighbour_spatial", type=int, default=101, help="Number of spatial neighbors."
558
+ )
559
+ parser.add_argument(
560
+ "--gM_slices", type=str, default=None, help="Path to the slice mean file (optional)."
561
+ )
562
+
563
+
564
+ def ensure_path_exists(func):
565
+ @wraps(func)
566
+ def wrapper(*args, **kwargs):
567
+ result = func(*args, **kwargs)
568
+ if isinstance(result, Path):
569
+ if result.suffix:
570
+ result.parent.mkdir(parents=True, exist_ok=True, mode=0o755)
571
+ else: # It's a directory path
572
+ result.mkdir(parents=True, exist_ok=True, mode=0o755)
573
+ return result
574
+
575
+ return wrapper
576
+
577
+
578
+ @dataclass
579
+ class ConfigWithAutoPaths:
580
+ workdir: str
581
+ sample_name: str | None
582
+
583
+ def __post_init__(self):
584
+ if self.workdir is None:
585
+ raise ValueError("workdir must be provided.")
586
+
587
+ @property
588
+ @ensure_path_exists
589
+ def hdf5_with_latent_path(self) -> Path:
590
+ return Path(
591
+ f"{self.workdir}/{self.sample_name}/find_latent_representations/{self.sample_name}_add_latent.h5ad"
592
+ )
593
+
594
+ @property
595
+ @ensure_path_exists
596
+ def mkscore_feather_path(self) -> Path:
597
+ return Path(
598
+ f"{self.workdir}/{self.sample_name}/latent_to_gene/{self.sample_name}_gene_marker_score.feather"
599
+ )
600
+
601
+ @property
602
+ @ensure_path_exists
603
+ def ldscore_save_dir(self) -> Path:
604
+ return Path(f"{self.workdir}/{self.sample_name}/generate_ldscore")
605
+
606
+ @property
607
+ @ensure_path_exists
608
+ def ldsc_save_dir(self) -> Path:
609
+ return Path(f"{self.workdir}/{self.sample_name}/spatial_ldsc")
610
+
611
+ @property
612
+ @ensure_path_exists
613
+ def cauchy_save_dir(self) -> Path:
614
+ return Path(f"{self.workdir}/{self.sample_name}/cauchy_combination")
615
+
616
+ @ensure_path_exists
617
+ def get_report_dir(self, trait_name: str) -> Path:
618
+ return Path(f"{self.workdir}/{self.sample_name}/report/{trait_name}")
619
+
620
+ def get_gsMap_report_file(self, trait_name: str) -> Path:
621
+ return (
622
+ self.get_report_dir(trait_name) / f"{self.sample_name}_{trait_name}_gsMap_Report.html"
623
+ )
624
+
625
+ @ensure_path_exists
626
+ def get_manhattan_html_plot_path(self, trait_name: str) -> Path:
627
+ return Path(
628
+ f"{self.workdir}/{self.sample_name}/report/{trait_name}/manhattan_plot/{self.sample_name}_{trait_name}_Diagnostic_Manhattan_Plot.html"
629
+ )
630
+
631
+ @ensure_path_exists
632
+ def get_GSS_plot_dir(self, trait_name: str) -> Path:
633
+ return Path(f"{self.workdir}/{self.sample_name}/report/{trait_name}/GSS_plot")
634
+
635
+ def get_GSS_plot_select_gene_file(self, trait_name: str) -> Path:
636
+ return self.get_GSS_plot_dir(trait_name) / "plot_genes.csv"
637
+
638
+ @ensure_path_exists
639
+ def get_ldsc_result_file(self, trait_name: str) -> Path:
640
+ return Path(f"{self.ldsc_save_dir}/{self.sample_name}_{trait_name}.csv.gz")
641
+
642
+ @ensure_path_exists
643
+ def get_cauchy_result_file(self, trait_name: str) -> Path:
644
+ return Path(f"{self.cauchy_save_dir}/{self.sample_name}_{trait_name}.Cauchy.csv.gz")
645
+
646
+ @ensure_path_exists
647
+ def get_gene_diagnostic_info_save_path(self, trait_name: str) -> Path:
648
+ return Path(
649
+ f"{self.workdir}/{self.sample_name}/report/{trait_name}/{self.sample_name}_{trait_name}_Gene_Diagnostic_Info.csv"
650
+ )
651
+
652
+ @ensure_path_exists
653
+ def get_gsMap_plot_save_dir(self, trait_name: str) -> Path:
654
+ return Path(f"{self.workdir}/{self.sample_name}/report/{trait_name}/gsMap_plot")
655
+
656
+ def get_gsMap_html_plot_save_path(self, trait_name: str) -> Path:
657
+ return (
658
+ self.get_gsMap_plot_save_dir(trait_name)
659
+ / f"{self.sample_name}_{trait_name}_gsMap_plot.html"
660
+ )
661
+
662
+
663
+ @dataclass
664
+ class CreateSliceMeanConfig:
665
+ slice_mean_output_file: str | Path
666
+ h5ad_yaml: str | dict | None = None
667
+ sample_name_list: list | None = None
668
+ h5ad_list: list | None = None
669
+ homolog_file: str | None = None
670
+ species: str | None = None
671
+ data_layer: str = None
672
+
673
+ def __post_init__(self):
674
+ if self.h5ad_list is None and self.h5ad_yaml is None:
675
+ raise ValueError("At least one of --h5ad_list or --h5ad_yaml must be provided.")
676
+ if self.h5ad_yaml is not None:
677
+ if isinstance(self.h5ad_yaml, str):
678
+ logger.info(f"Reading h5ad yaml file: {self.h5ad_yaml}")
679
+ h5ad_dict = (
680
+ yaml.safe_load(open(self.h5ad_yaml))
681
+ if isinstance(self.h5ad_yaml, str)
682
+ else self.h5ad_yaml
683
+ )
684
+ elif self.sample_name_list and self.h5ad_list:
685
+ logger.info("Reading sample name list and h5ad list")
686
+ h5ad_dict = dict(zip(self.sample_name_list, self.h5ad_list, strict=False))
687
+ else:
688
+ raise ValueError(
689
+ "Please provide either h5ad_yaml or both sample_name_list and h5ad_list."
690
+ )
691
+
692
+ # check if sample names is unique
693
+ assert len(h5ad_dict) == len(set(h5ad_dict)), "Sample names must be unique."
694
+ assert len(h5ad_dict) > 1, "At least two samples are required."
695
+
696
+ logger.info(f"Input h5ad files: {h5ad_dict}")
697
+
698
+ # Check if all files exist
699
+ self.h5ad_dict = {}
700
+ for sample_name, h5ad_file in h5ad_dict.items():
701
+ h5ad_file = Path(h5ad_file)
702
+ if not h5ad_file.exists():
703
+ raise FileNotFoundError(f"{h5ad_file} does not exist.")
704
+ self.h5ad_dict[sample_name] = h5ad_file
705
+
706
+ self.slice_mean_output_file = Path(self.slice_mean_output_file)
707
+ self.slice_mean_output_file.parent.mkdir(parents=True, exist_ok=True)
708
+
709
+ verify_homolog_file_format(self)
710
+
711
+
712
+ @dataclass
713
+ class FindLatentRepresentationsConfig(ConfigWithAutoPaths):
714
+ input_hdf5_path: str
715
+ # output_hdf5_path: str
716
+ annotation: str = None
717
+ data_layer: str = None
718
+
719
+ epochs: int = 300
720
+ feat_hidden1: int = 256
721
+ feat_hidden2: int = 128
722
+ feat_cell: int = 3000
723
+ gat_hidden1: int = 64
724
+ gat_hidden2: int = 30
725
+ p_drop: float = 0.1
726
+ gat_lr: float = 0.001
727
+ gcn_decay: float = 0.01
728
+ n_neighbors: int = 11
729
+ label_w: float = 1
730
+ rec_w: float = 1
731
+ input_pca: bool = True
732
+ n_comps: int = 300
733
+ weighted_adj: bool = False
734
+ nheads: int = 3
735
+ var: bool = False
736
+ convergence_threshold: float = 1e-4
737
+ hierarchically: bool = False
738
+
739
+ def __post_init__(self):
740
+ # self.output_hdf5_path = self.hdf5_with_latent_path
741
+ if self.hierarchically:
742
+ if self.annotation is None:
743
+ raise ValueError("annotation must be provided if hierarchically is True.")
744
+ logger.info(
745
+ "------Hierarchical mode is enabled. This will find the latent representations within each annotation."
746
+ )
747
+
748
+ # remind for not providing annotation
749
+ if self.annotation is None:
750
+ logger.warning(
751
+ "annotation is not provided. This will find the latent representations for the whole dataset."
752
+ )
753
+ else:
754
+ logger.info(f"------Find latent representations for {self.annotation}...")
755
+
756
+
757
+ @dataclass
758
+ class LatentToGeneConfig(ConfigWithAutoPaths):
759
+ # input_hdf5_with_latent_path: str
760
+ # output_feather_path: str
761
+ input_hdf5_path: str | Path = None
762
+ no_expression_fraction: bool = False
763
+ latent_representation: str = None
764
+ num_neighbour: int = 21
765
+ num_neighbour_spatial: int = 101
766
+ homolog_file: str = None
767
+ gM_slices: str = None
768
+ annotation: str = None
769
+ species: str = None
770
+
771
+ def __post_init__(self):
772
+ if self.input_hdf5_path is None:
773
+ self.input_hdf5_path = self.hdf5_with_latent_path
774
+ assert self.input_hdf5_path.exists(), (
775
+ f"{self.input_hdf5_path} does not exist. Please run FindLatentRepresentations first."
776
+ )
777
+ else:
778
+ assert Path(self.input_hdf5_path).exists(), f"{self.input_hdf5_path} does not exist."
779
+ # copy to self.hdf5_with_latent_path
780
+ import shutil
781
+
782
+ shutil.copy2(self.input_hdf5_path, self.hdf5_with_latent_path)
783
+
784
+ if self.latent_representation is not None:
785
+ logger.info(f"Using the provided latent representation: {self.latent_representation}")
786
+ else:
787
+ self.latent_representation = "latent_GVAE"
788
+ logger.info(f"Using default latent representation: {self.latent_representation}")
789
+
790
+ if self.gM_slices is not None:
791
+ assert Path(self.gM_slices).exists(), f"{self.gM_slices} does not exist."
792
+ logger.info(f"Using the provided slice mean file: {self.gM_slices}.")
793
+
794
+ verify_homolog_file_format(self)
795
+
796
+
797
+ def verify_homolog_file_format(config):
798
+ if config.homolog_file is not None:
799
+ logger.info(
800
+ f"User provided homolog file to map gene names to human: {config.homolog_file}"
801
+ )
802
+ # check the format of the homolog file
803
+ with open(config.homolog_file) as f:
804
+ first_line = f.readline().strip()
805
+ _n_col = len(first_line.split())
806
+ if _n_col != 2:
807
+ raise ValueError(
808
+ f"Invalid homolog file format. Expected 2 columns, first column should be other species gene name, second column should be human gene name. "
809
+ f"Got {_n_col} columns in the first line."
810
+ )
811
+ else:
812
+ first_col_name, second_col_name = first_line.split()
813
+ config.species = first_col_name
814
+ logger.info(
815
+ f"Homolog file provided and will map gene name from column1:{first_col_name} to column2:{second_col_name}"
816
+ )
817
+ else:
818
+ logger.info("No homolog file provided. Run in human mode.")
819
+
820
+
821
+ @dataclass
822
+ class GenerateLDScoreConfig(ConfigWithAutoPaths):
823
+ chrom: int | str
824
+
825
+ bfile_root: str
826
+ keep_snp_root: str | None
827
+
828
+ # annotation by gene distance
829
+ gtf_annotation_file: str
830
+ gene_window_size: int = 50000
831
+
832
+ # annotation by enhancer
833
+ enhancer_annotation_file: str = None
834
+ snp_multiple_enhancer_strategy: Literal["max_mkscore", "nearest_TSS"] = "max_mkscore"
835
+ gene_window_enhancer_priority: (
836
+ Literal["gene_window_first", "enhancer_first", "enhancer_only"] | None
837
+ ) = None
838
+
839
+ # for calculating ld score
840
+ additional_baseline_annotation: str = None
841
+ spots_per_chunk: int = 1_000
842
+ ld_wind: int = 1
843
+ ld_unit: str = "CM"
844
+
845
+ # zarr config
846
+ ldscore_save_format: Literal["feather", "zarr", "quick_mode"] = "feather"
847
+
848
+ zarr_chunk_size: tuple[int, int] = None
849
+
850
+ # for pre calculating the SNP Gene ldscore Weight
851
+ save_pre_calculate_snp_gene_weight_matrix: bool = False
852
+
853
+ baseline_annotation_dir: str | None = None
854
+ SNP_gene_pair_dir: str | None = None
855
+
856
+ def __post_init__(self):
857
+ # if self.mkscore_feather_file is None:
858
+ # self.mkscore_feather_file = self._get_mkscore_feather_path()
859
+
860
+ if (
861
+ self.enhancer_annotation_file is not None
862
+ and self.gene_window_enhancer_priority is None
863
+ ):
864
+ logger.warning(
865
+ "enhancer_annotation_file is provided but gene_window_enhancer_priority is not provided. "
866
+ "by default, gene_window_enhancer_priority is set to 'enhancer_only', when enhancer_annotation_file is provided."
867
+ )
868
+ self.gene_window_enhancer_priority = "enhancer_only"
869
+ if (
870
+ self.enhancer_annotation_file is None
871
+ and self.gene_window_enhancer_priority is not None
872
+ ):
873
+ logger.warning(
874
+ "gene_window_enhancer_priority is provided but enhancer_annotation_file is not provided. "
875
+ "by default, gene_window_enhancer_priority is set to None, when enhancer_annotation_file is not provided."
876
+ )
877
+ self.gene_window_enhancer_priority = None
878
+ assert self.gene_window_enhancer_priority in [
879
+ None,
880
+ "gene_window_first",
881
+ "enhancer_first",
882
+ "enhancer_only",
883
+ ], (
884
+ f"gene_window_enhancer_priority must be one of None, 'gene_window_first', 'enhancer_first', 'enhancer_only', but got {self.gene_window_enhancer_priority}."
885
+ )
886
+ if self.gene_window_enhancer_priority in ["gene_window_first", "enhancer_first"]:
887
+ logger.info(
888
+ "Both gene_window and enhancer annotation will be used to calculate LD score. "
889
+ )
890
+ logger.info(
891
+ f"SNP within +-{self.gene_window_size} bp of gene body will be used and enhancer annotation will be used to calculate LD score. If a snp maps to multiple enhancers, the strategy to choose by your select strategy: {self.snp_multiple_enhancer_strategy}."
892
+ )
893
+ elif self.gene_window_enhancer_priority == "enhancer_only":
894
+ logger.info("Only enhancer annotation will be used to calculate LD score. ")
895
+ else:
896
+ logger.info(
897
+ f"Only gene window annotation will be used to calculate LD score. SNP within +-{self.gene_window_size} bp of gene body will be used. "
898
+ )
899
+
900
+ # remind for baseline annotation
901
+ if self.additional_baseline_annotation is None:
902
+ logger.info(
903
+ "------Baseline annotation is not provided. Default baseline annotation will be used."
904
+ )
905
+ else:
906
+ logger.info(
907
+ "------Baseline annotation is provided. Additional baseline annotation will be used with the default baseline annotation."
908
+ )
909
+ logger.info(
910
+ f"------Baseline annotation directory: {self.additional_baseline_annotation}"
911
+ )
912
+ # check the existence of baseline annotation
913
+ if self.chrom == "all":
914
+ for chrom in range(1, 23):
915
+ chrom = str(chrom)
916
+ baseline_annotation_path = (
917
+ Path(self.additional_baseline_annotation) / f"baseline.{chrom}.annot.gz"
918
+ )
919
+ if not baseline_annotation_path.exists():
920
+ raise FileNotFoundError(
921
+ f"baseline.{chrom}.annot.gz is not found in {self.additional_baseline_annotation}."
922
+ )
923
+ else:
924
+ baseline_annotation_path = (
925
+ Path(self.additional_baseline_annotation) / f"baseline.{self.chrom}.annot.gz"
926
+ )
927
+ if not baseline_annotation_path.exists():
928
+ raise FileNotFoundError(
929
+ f"baseline.{self.chrom}.annot.gz is not found in {self.additional_baseline_annotation}."
930
+ )
931
+
932
+ # set the default zarr chunk size
933
+ if self.ldscore_save_format == "zarr" and self.zarr_chunk_size is None:
934
+ self.zarr_chunk_size = (10_000, self.spots_per_chunk)
935
+
936
+
937
+ @dataclass
938
+ class SpatialLDSCConfig(ConfigWithAutoPaths):
939
+ w_file: str
940
+ # ldscore_save_dir: str
941
+ use_additional_baseline_annotation: bool = True
942
+ trait_name: str | None = None
943
+ sumstats_file: str | None = None
944
+ sumstats_config_file: str | None = None
945
+ num_processes: int = 4
946
+ not_M_5_50: bool = False
947
+ n_blocks: int = 200
948
+ chisq_max: int | None = None
949
+ all_chunk: int | None = None
950
+ chunk_range: tuple[int, int] | None = None
951
+
952
+ ldscore_save_format: Literal["feather", "zarr", "quick_mode"] = "feather"
953
+
954
+ spots_per_chunk_quick_mode: int = 1_000
955
+ snp_gene_weight_adata_path: str | None = None
956
+
957
+ def __post_init__(self):
958
+ super().__post_init__()
959
+ if self.sumstats_file is None and self.sumstats_config_file is None:
960
+ raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
961
+ if self.sumstats_file is not None and self.sumstats_config_file is not None:
962
+ raise ValueError(
963
+ "Only one of sumstats_file and sumstats_config_file must be provided."
964
+ )
965
+ if self.sumstats_file is not None and self.trait_name is None:
966
+ raise ValueError("trait_name must be provided if sumstats_file is provided.")
967
+ if self.sumstats_config_file is not None and self.trait_name is not None:
968
+ raise ValueError(
969
+ "trait_name must not be provided if sumstats_config_file is provided."
970
+ )
971
+ self.sumstats_config_dict = {}
972
+ # load the sumstats config file
973
+ if self.sumstats_config_file is not None:
974
+ import yaml
975
+
976
+ with open(self.sumstats_config_file) as f:
977
+ config = yaml.load(f, Loader=yaml.FullLoader)
978
+ for _trait_name, sumstats_file in config.items():
979
+ assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
980
+ # load the sumstats file
981
+ elif self.sumstats_file is not None:
982
+ self.sumstats_config_dict[self.trait_name] = self.sumstats_file
983
+ else:
984
+ raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
985
+
986
+ for sumstats_file in self.sumstats_config_dict.values():
987
+ assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
988
+
989
+ # check if additional baseline annotation is exist
990
+ # self.use_additional_baseline_annotation = False
991
+
992
+ if self.use_additional_baseline_annotation:
993
+ self.process_additional_baseline_annotation()
994
+
995
+ def process_additional_baseline_annotation(self):
996
+ additional_baseline_annotation = Path(self.ldscore_save_dir) / "additional_baseline"
997
+ dir_exists = additional_baseline_annotation.exists()
998
+
999
+ if not dir_exists:
1000
+ self.use_additional_baseline_annotation = False
1001
+ # if self.use_additional_baseline_annotation:
1002
+ # logger.warning(f"additional_baseline directory is not found in {self.ldscore_save_dir}.")
1003
+ # print('''\
1004
+ # if you want to use additional baseline annotation,
1005
+ # please provide additional baseline annotation when calculating ld score.
1006
+ # ''')
1007
+ # raise FileNotFoundError(
1008
+ # f'additional_baseline directory is not found.')
1009
+ # return
1010
+ # self.use_additional_baseline_annotation = self.use_additional_baseline_annotation or True
1011
+ else:
1012
+ logger.info(
1013
+ "------Additional baseline annotation is provided. It will be used with the default baseline annotation."
1014
+ )
1015
+ logger.info(
1016
+ f"------Additional baseline annotation directory: {additional_baseline_annotation}"
1017
+ )
1018
+
1019
+ chrom_list = range(1, 23)
1020
+ for chrom in chrom_list:
1021
+ baseline_annotation_path = (
1022
+ additional_baseline_annotation / f"baseline.{chrom}.l2.ldscore.feather"
1023
+ )
1024
+ if not baseline_annotation_path.exists():
1025
+ raise FileNotFoundError(
1026
+ f"baseline.{chrom}.annot.gz is not found in {additional_baseline_annotation}."
1027
+ )
1028
+ return None
1029
+
1030
+
1031
+ @dataclass
1032
+ class CauchyCombinationConfig(ConfigWithAutoPaths):
1033
+ trait_name: str
1034
+ annotation: str
1035
+ sample_name_list: list[str] = dataclasses.field(default_factory=list)
1036
+ output_file: str | Path | None = None
1037
+
1038
+ def __post_init__(self):
1039
+ if self.sample_name is not None:
1040
+ if len(self.sample_name_list) > 0:
1041
+ raise ValueError("Only one of sample_name and sample_name_list must be provided.")
1042
+ else:
1043
+ self.sample_name_list = [self.sample_name]
1044
+ self.output_file = (
1045
+ self.get_cauchy_result_file(self.trait_name)
1046
+ if self.output_file is None
1047
+ else self.output_file
1048
+ )
1049
+ else:
1050
+ assert len(self.sample_name_list) > 0, "At least one sample name must be provided."
1051
+ assert self.output_file is not None, (
1052
+ "Output_file must be provided if sample_name_list is provided."
1053
+ )
1054
+
1055
+
1056
+ @dataclass
1057
+ class VisualizeConfig(ConfigWithAutoPaths):
1058
+ trait_name: str
1059
+
1060
+ annotation: str = None
1061
+ fig_title: str = None
1062
+ fig_height: int = 600
1063
+ fig_width: int = 800
1064
+ point_size: int = None
1065
+ fig_style: Literal["dark", "light"] = "light"
1066
+
1067
+
1068
+ @dataclass
1069
+ class DiagnosisConfig(ConfigWithAutoPaths):
1070
+ annotation: str
1071
+ # mkscore_feather_file: str
1072
+
1073
+ trait_name: str
1074
+ sumstats_file: str
1075
+ plot_type: Literal["manhattan", "GSS", "gsMap", "all"] = "all"
1076
+ top_corr_genes: int = 50
1077
+ selected_genes: list[str] | None = None
1078
+
1079
+ fig_width: int | None = None
1080
+ fig_height: int | None = None
1081
+ point_size: int | None = None
1082
+ fig_style: Literal["dark", "light"] = "light"
1083
+
1084
+ def __post_init__(self):
1085
+ if any([self.fig_width, self.fig_height, self.point_size]):
1086
+ logger.info("Customizing the figure size and point size.")
1087
+ assert all([self.fig_width, self.fig_height, self.point_size]), (
1088
+ "All of fig_width, fig_height, and point_size must be provided."
1089
+ )
1090
+ self.customize_fig = True
1091
+ else:
1092
+ self.customize_fig = False
1093
+
1094
+
1095
+ @dataclass
1096
+ class ReportConfig(DiagnosisConfig):
1097
+ pass
1098
+
1099
+
1100
+ @dataclass
1101
+ class RunAllModeConfig(ConfigWithAutoPaths):
1102
+ gsMap_resource_dir: str
1103
+
1104
+ # == ST DATA PARAMETERS ==
1105
+ hdf5_path: str
1106
+ annotation: str
1107
+ data_layer: str = "X"
1108
+
1109
+ # == latent 2 Gene PARAMETERS ==
1110
+ gM_slices: str | None = None
1111
+ latent_representation: str = None
1112
+ num_neighbour: int = 21
1113
+ num_neighbour_spatial: int = 101
1114
+
1115
+ # ==GWAS DATA PARAMETERS==
1116
+ trait_name: str | None = None
1117
+ sumstats_file: str | None = None
1118
+ sumstats_config_file: str | None = None
1119
+
1120
+ # === homolog PARAMETERS ===
1121
+ homolog_file: str | None = None
1122
+
1123
+ max_processes: int = 10
1124
+
1125
+ def __post_init__(self):
1126
+ super().__post_init__()
1127
+ self.gtffile = (
1128
+ f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v39lift37.annotation.gtf"
1129
+ )
1130
+ self.bfile_root = (
1131
+ f"{self.gsMap_resource_dir}/LD_Reference_Panel/1000G_EUR_Phase3_plink/1000G.EUR.QC"
1132
+ )
1133
+ self.keep_snp_root = f"{self.gsMap_resource_dir}/LDSC_resource/hapmap3_snps/hm"
1134
+ self.w_file = f"{self.gsMap_resource_dir}/LDSC_resource/weights_hm3_no_hla/weights."
1135
+ self.snp_gene_weight_adata_path = (
1136
+ f"{self.gsMap_resource_dir}/quick_mode/snp_gene_weight_matrix.h5ad"
1137
+ )
1138
+ self.baseline_annotation_dir = Path(
1139
+ f"{self.gsMap_resource_dir}/quick_mode/baseline"
1140
+ ).resolve()
1141
+ self.SNP_gene_pair_dir = Path(
1142
+ f"{self.gsMap_resource_dir}/quick_mode/SNP_gene_pair"
1143
+ ).resolve()
1144
+ # check the existence of the input files and resources files
1145
+ for file in [self.hdf5_path, self.gtffile]:
1146
+ if not Path(file).exists():
1147
+ raise FileNotFoundError(f"File {file} does not exist.")
1148
+
1149
+ if self.sumstats_file is None and self.sumstats_config_file is None:
1150
+ raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
1151
+ if self.sumstats_file is not None and self.sumstats_config_file is not None:
1152
+ raise ValueError(
1153
+ "Only one of sumstats_file and sumstats_config_file must be provided."
1154
+ )
1155
+ if self.sumstats_file is not None and self.trait_name is None:
1156
+ raise ValueError("trait_name must be provided if sumstats_file is provided.")
1157
+ if self.sumstats_config_file is not None and self.trait_name is not None:
1158
+ raise ValueError(
1159
+ "trait_name must not be provided if sumstats_config_file is provided."
1160
+ )
1161
+ self.sumstats_config_dict = {}
1162
+ # load the sumstats config file
1163
+ if self.sumstats_config_file is not None:
1164
+ import yaml
1165
+
1166
+ with open(self.sumstats_config_file) as f:
1167
+ config = yaml.load(f, Loader=yaml.FullLoader)
1168
+ for trait_name, sumstats_file in config.items():
1169
+ assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
1170
+ self.sumstats_config_dict[trait_name] = sumstats_file
1171
+ # load the sumstats file
1172
+ elif self.sumstats_file is not None and self.trait_name is not None:
1173
+ self.sumstats_config_dict[self.trait_name] = self.sumstats_file
1174
+ else:
1175
+ raise ValueError("One of sumstats_file and sumstats_config_file must be provided.")
1176
+
1177
+ for sumstats_file in self.sumstats_config_dict.values():
1178
+ assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
1179
+
1180
+
1181
+ @dataclass
1182
+ class FormatSumstatsConfig:
1183
+ sumstats: str
1184
+ out: str
1185
+ dbsnp: str
1186
+ snp: str = None
1187
+ a1: str = None
1188
+ a2: str = None
1189
+ info: str = None
1190
+ beta: str = None
1191
+ se: str = None
1192
+ p: str = None
1193
+ frq: str = None
1194
+ n: str = None
1195
+ z: str = None
1196
+ OR: str = None
1197
+ se_OR: str = None
1198
+ format: str = None
1199
+ chr: str = None
1200
+ pos: str = None
1201
+ chunksize: int = 1e7
1202
+ info_min: float = 0.9
1203
+ maf_min: float = 0.01
1204
+ keep_chr_pos: bool = False
1205
+
1206
+
1207
+ @register_cli(
1208
+ name="run_find_latent_representations",
1209
+ description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN-VAE",
1210
+ add_args_function=add_find_latent_representations_args,
1211
+ )
1212
+ def run_find_latent_representation_from_cli(args: argparse.Namespace):
1213
+ from gsMap.find_latent_representation import run_find_latent_representation
1214
+
1215
+ config = get_dataclass_from_parser(args, FindLatentRepresentationsConfig)
1216
+ run_find_latent_representation(config)
1217
+
1218
+
1219
+ @register_cli(
1220
+ name="run_latent_to_gene",
1221
+ description="Run Latent_to_gene \nEstimate gene marker gene scores for each spot by using latent representations from nearby spots",
1222
+ add_args_function=add_latent_to_gene_args,
1223
+ )
1224
+ def run_latent_to_gene_from_cli(args: argparse.Namespace):
1225
+ from gsMap.latent_to_gene import run_latent_to_gene
1226
+
1227
+ config = get_dataclass_from_parser(args, LatentToGeneConfig)
1228
+ run_latent_to_gene(config)
1229
+
1230
+
1231
+ @register_cli(
1232
+ name="run_generate_ldscore",
1233
+ description="Run Generate_ldscore \nGenerate LD scores for each spot",
1234
+ add_args_function=add_generate_ldscore_args,
1235
+ )
1236
+ def run_generate_ldscore_from_cli(args: argparse.Namespace):
1237
+ from gsMap.generate_ldscore import run_generate_ldscore
1238
+
1239
+ config = get_dataclass_from_parser(args, GenerateLDScoreConfig)
1240
+ run_generate_ldscore(config)
1241
+
1242
+
1243
+ @register_cli(
1244
+ name="run_spatial_ldsc",
1245
+ description="Run Spatial_ldsc \nRun spatial LDSC for each spot",
1246
+ add_args_function=add_spatial_ldsc_args,
1247
+ )
1248
+ def run_spatial_ldsc_from_cli(args: argparse.Namespace):
1249
+ from gsMap.spatial_ldsc_multiple_sumstats import run_spatial_ldsc
1250
+
1251
+ config = get_dataclass_from_parser(args, SpatialLDSCConfig)
1252
+ run_spatial_ldsc(config)
1253
+
1254
+
1255
+ @register_cli(
1256
+ name="run_cauchy_combination",
1257
+ description="Run Cauchy_combination for each annotation",
1258
+ add_args_function=add_Cauchy_combination_args,
1259
+ )
1260
+ def run_Cauchy_combination_from_cli(args: argparse.Namespace):
1261
+ from gsMap.cauchy_combination_test import run_Cauchy_combination
1262
+
1263
+ config = get_dataclass_from_parser(args, CauchyCombinationConfig)
1264
+ run_Cauchy_combination(config)
1265
+
1266
+
1267
+ @register_cli(
1268
+ name="run_report",
1269
+ description="Run Report to generate diagnostic plots and tables",
1270
+ add_args_function=add_report_args,
1271
+ )
1272
+ def run_Report_from_cli(args: argparse.Namespace):
1273
+ from gsMap.report import run_report
1274
+
1275
+ config = get_dataclass_from_parser(args, ReportConfig)
1276
+ run_report(config)
1277
+
1278
+
1279
+ @register_cli(
1280
+ name="format_sumstats",
1281
+ description="Format gwas summary statistics",
1282
+ add_args_function=add_format_sumstats_args,
1283
+ )
1284
+ def gwas_format_from_cli(args: argparse.Namespace):
1285
+ from gsMap.format_sumstats import gwas_format
1286
+
1287
+ config = get_dataclass_from_parser(args, FormatSumstatsConfig)
1288
+ gwas_format(config)
1289
+
1290
+
1291
+ @register_cli(
1292
+ name="quick_mode",
1293
+ description="Run all the gsMap pipeline in quick mode",
1294
+ add_args_function=add_run_all_mode_args,
1295
+ )
1296
+ def run_all_mode_from_cli(args: argparse.Namespace):
1297
+ from gsMap.run_all_mode import run_pipeline
1298
+
1299
+ config = get_dataclass_from_parser(args, RunAllModeConfig)
1300
+ run_pipeline(config)
1301
+
1302
+
1303
+ @register_cli(
1304
+ name="create_slice_mean",
1305
+ description="Create slice mean from multiple h5ad files",
1306
+ add_args_function=add_create_slice_mean_args,
1307
+ )
1308
+ def create_slice_mean_from_cli(args: argparse.Namespace):
1309
+ from gsMap.create_slice_mean import run_create_slice_mean
1310
+
1311
+ config = get_dataclass_from_parser(args, CreateSliceMeanConfig)
1312
+ run_create_slice_mean(config)