gsMap 1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,209 @@
1
+ import argparse
2
+ import logging
3
+ import pprint
4
+ import random
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scanpy as sc
11
+ import torch
12
+ from sklearn import preprocessing
13
+
14
+ from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
15
+ from gsMap.GNN_VAE.train import Model_Train
16
+ from gsMap.config import add_find_latent_representations_args, FindLatentRepresentationsConfig
17
+
18
+ # seed all
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.DEBUG)
22
+ handler = logging.StreamHandler()
23
+ handler.setFormatter(logging.Formatter(
24
+ '[{asctime}] {levelname:8s} {filename} {message}', style='{'))
25
+ logger.addHandler(handler)
26
+
27
+
28
+ def set_seed(seed_value):
29
+ """
30
+ Set seed for reproducibility in PyTorch.
31
+ """
32
+ torch.manual_seed(seed_value) # Set the seed for PyTorch
33
+ np.random.seed(seed_value) # Set the seed for NumPy
34
+ random.seed(seed_value) # Set the seed for Python random module
35
+ if torch.cuda.is_available():
36
+ print('Running use GPU')
37
+ torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
38
+ torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
39
+ else:
40
+ print('Running use CPU')
41
+
42
+ set_seed(2024)
43
+
44
+ # The class for finding latent representations
45
+ class Latent_Representation_Finder:
46
+
47
+ def __init__(self, adata, Params):
48
+ self.adata = adata.copy()
49
+ self.Params = Params
50
+
51
+ # Standard process
52
+ if self.Params.type == 'count' or self.Params.type == 'counts':
53
+ self.adata.X = self.adata.layers[self.Params.type]
54
+ sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
55
+ sc.pp.normalize_total(self.adata, target_sum=1e4)
56
+ sc.pp.log1p(self.adata)
57
+ sc.pp.scale(self.adata)
58
+ else:
59
+ self.adata.X = self.adata.layers[self.Params.type]
60
+ sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
61
+
62
+ def Run_GNN_VAE(self, label, verbose='whole ST data'):
63
+
64
+ # Construct the neighbouring graph
65
+ graph_dict = Construct_Adjacency_Matrix(self.adata, self.Params)
66
+
67
+ # Process the feature matrix
68
+ node_X = self.adata[:, self.adata.var.highly_variable].X
69
+ print(f'The shape of feature matrix is {node_X.shape}.')
70
+ if self.Params.input_pca:
71
+ node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
72
+
73
+ # Update the input shape
74
+ self.Params.n_nodes = node_X.shape[0]
75
+ self.Params.feat_cell = node_X.shape[1]
76
+
77
+ # Run GNN-VAE
78
+ print(f'------Finding latent representations for {verbose}...')
79
+ gvae = Model_Train(node_X, graph_dict, self.Params, label)
80
+ gvae.run_train()
81
+
82
+ return gvae.get_latent()
83
+
84
+ def Run_PCA(self):
85
+ sc.tl.pca(self.adata)
86
+ return self.adata.obsm['X_pca'][:, 0:self.Params.n_comps]
87
+
88
+
89
+ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
90
+ num_features = args.feat_cell
91
+ args.output_dir = Path(args.output_hdf5_path).parent
92
+ args.output_dir.mkdir(parents=True, exist_ok=True,mode=0o755)
93
+ # Load the ST data
94
+ print(f'------Loading ST data of {args.sample_name}...')
95
+ adata = sc.read_h5ad(f'{args.input_hdf5_path}')
96
+ adata.var_names_make_unique()
97
+ print('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
98
+ # Load the cell type annotation
99
+ if not args.annotation is None:
100
+ # remove cells without enough annotations
101
+ adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
102
+ num = adata.obs[args.annotation].value_counts()
103
+ adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list()),]
104
+
105
+ le = preprocessing.LabelEncoder()
106
+ le.fit(adata.obs[args.annotation])
107
+ adata.obs['categorical_label'] = le.transform(adata.obs[args.annotation])
108
+ label = adata.obs['categorical_label'].to_list()
109
+ else:
110
+ label = None
111
+ # Find latent representations
112
+ latent_rep = Latent_Representation_Finder(adata, args)
113
+ latent_GVAE = latent_rep.Run_GNN_VAE(label)
114
+ latent_PCA = latent_rep.Run_PCA()
115
+ # Add latent representations to the spe data
116
+ print(f'------Adding latent representations...')
117
+ adata.obsm["latent_GVAE"] = latent_GVAE
118
+ adata.obsm["latent_PCA"] = latent_PCA
119
+ # Run umap based on latent representations
120
+ for name in ['latent_GVAE', 'latent_PCA']:
121
+ sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
122
+ sc.tl.umap(adata)
123
+ adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
124
+
125
+ # Find the latent representations hierarchically (optionally)
126
+ if not args.annotation is None and args.hierarchically:
127
+ print(f'------Finding latent representations hierarchically...')
128
+ PCA_all = pd.DataFrame()
129
+ GVAE_all = pd.DataFrame()
130
+
131
+ for ct in adata.obs[args.annotation].unique():
132
+ adata_part = adata[adata.obs[args.annotation] == ct, :]
133
+ print(adata_part.shape)
134
+
135
+ # Find latent representations for the selected ct
136
+ latent_rep = Latent_Representation_Finder(adata_part, args)
137
+
138
+ latent_PCA_part = pd.DataFrame(latent_rep.Run_PCA())
139
+ if adata_part.shape[0] <= args.n_comps:
140
+ latent_GVAE_part = latent_PCA_part
141
+ else:
142
+ latent_GVAE_part = pd.DataFrame(latent_rep.Run_GNN_VAE(label=None, verbose=ct))
143
+
144
+ latent_GVAE_part.index = adata_part.obs_names
145
+ latent_PCA_part.index = adata_part.obs_names
146
+
147
+ GVAE_all = pd.concat((GVAE_all, latent_GVAE_part), axis=0)
148
+ PCA_all = pd.concat((PCA_all, latent_PCA_part), axis=0)
149
+
150
+ args.feat_cell = num_features
151
+
152
+ adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
153
+ adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
154
+ print(f'------Saving ST data...')
155
+ adata.write(args.output_hdf5_path)
156
+
157
+
158
+ if __name__ == '__main__':
159
+ parser = argparse.ArgumentParser(description="This script is designed to find latent representations in spatial transcriptomics data using a Graph Neural Network Variational Autoencoder (GNN-VAE). It processes input data, constructs a neighboring graph, and runs GNN-VAE to output latent representations.")
160
+ add_find_latent_representations_args(parser)
161
+ TEST=True
162
+ if TEST:
163
+ test_dir = '/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021'
164
+ name = 'Cortex_151507'
165
+
166
+
167
+ args = parser.parse_args(
168
+ [
169
+ '--input_hdf5_path','/storage/yangjianLab/songliyang/SpatialData/Data/Brain/Human/Nature_Neuroscience_2021/processed/h5ad/Cortex_151507.h5ad',
170
+ '--output_hdf5_path',f'{test_dir}/{name}/hdf5/{name}_add_latent.h5ad',
171
+ '--sample_name', name,
172
+ '--annotation','layer_guess',
173
+ '--type','count',
174
+ ]
175
+
176
+ )
177
+
178
+ else:
179
+ args = parser.parse_args()
180
+ config=FindLatentRepresentationsConfig(**{'annotation': 'SubClass',
181
+ 'convergence_threshold': 0.0001,
182
+ 'epochs': 300,
183
+ 'feat_cell': 3000,
184
+ 'feat_hidden1': 256,
185
+ 'feat_hidden2': 128,
186
+ 'gcn_decay': 0.01,
187
+ 'gcn_hidden1': 64,
188
+ 'gcn_hidden2': 30,
189
+ 'gcn_lr': 0.001,
190
+ 'hierarchically': False,
191
+ 'input_hdf5_path': '/storage/yangjianLab/songliyang/SpatialData/Data/Brain/macaque/Cell/processed/h5ad/T862_macaque3.h5ad',
192
+ 'label_w': 1.0,
193
+ 'n_comps': 300,
194
+ 'n_neighbors': 11,
195
+ 'nheads': 3,
196
+ 'output_hdf5_path': 'T862_macaque3/find_latent_representations/T862_macaque3_add_latent.h5ad',
197
+ 'p_drop': 0.1,
198
+ 'rec_w': 1.0,
199
+ 'sample_name': 'T862_macaque3',
200
+ 'type': 'SCT',
201
+ 'var': False,
202
+ 'weighted_adj': False})
203
+ # config=FindLatentRepresentationsConfig(**vars(args))
204
+ start_time = time.time()
205
+ logger.info(f'Find latent representations for {config.sample_name}...')
206
+ pprint.pprint(config)
207
+ run_find_latent_representation(config)
208
+ end_time = time.time()
209
+ logger.info(f'Find latent representations for {config.sample_name} finished. Time spent: {(end_time - start_time) / 60:.2f} min.')
@@ -0,0 +1,410 @@
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import itertools as it
5
+ import math
6
+ import re
7
+ import argparse
8
+ import logging
9
+ from scipy.stats import chi2
10
+
11
+ from gsMap.config import FormatSumstatsConfig, add_format_sumstats_args
12
+
13
+
14
+ VALID_SNPS = set(['AC', 'AG', 'CA', 'CT', 'GA', 'GT', 'TC', 'TG'])
15
+ logger = logging.getLogger(__name__)
16
+
17
+ default_cnames = {
18
+ # RS NUMBER
19
+ 'SNP': 'SNP',
20
+ 'RS': 'SNP',
21
+ 'RSID': 'SNP',
22
+ 'RS_NUMBER': 'SNP',
23
+ 'RS_NUMBERS': 'SNP',
24
+ # P-VALUE
25
+ 'P': 'P',
26
+ 'PVALUE': 'P',
27
+ 'P_VALUE': 'P',
28
+ 'PVAL': 'P',
29
+ 'P_VAL': 'P',
30
+ 'GC_PVALUE': 'P',
31
+ 'p': 'P',
32
+ # EFFECT_ALLELE (A1)
33
+ 'A1': 'A1',
34
+ 'ALLELE1': 'A1',
35
+ 'ALLELE_1': 'A1',
36
+ 'EFFECT_ALLELE': 'A1',
37
+ 'REFERENCE_ALLELE': 'A1',
38
+ 'INC_ALLELE': 'A1',
39
+ 'EA': 'A1',
40
+ # NON_EFFECT_ALLELE (A2)
41
+ 'A2': 'A2',
42
+ 'ALLELE2': 'A2',
43
+ 'ALLELE_2': 'A2',
44
+ 'OTHER_ALLELE': 'A2',
45
+ 'NON_EFFECT_ALLELE': 'A2',
46
+ 'DEC_ALLELE': 'A2',
47
+ 'NEA': 'A2',
48
+ # N
49
+ 'N': 'N',
50
+ 'NCASE': 'N_CAS',
51
+ 'CASES_N': 'N_CAS',
52
+ 'N_CASE': 'N_CAS',
53
+ 'N_CASES': 'N_CAS',
54
+ 'N_CONTROLS': 'N_CON',
55
+ 'N_CAS': 'N_CAS',
56
+ 'N_CON': 'N_CON',
57
+ 'N_CASE': 'N_CAS',
58
+ 'NCONTROL': 'N_CON',
59
+ 'CONTROLS_N': 'N_CON',
60
+ 'N_CONTROL': 'N_CON',
61
+ 'WEIGHT': 'N',
62
+ # SIGNED STATISTICS
63
+ 'ZSCORE': 'Z',
64
+ 'Z-SCORE': 'Z',
65
+ 'GC_ZSCORE': 'Z',
66
+ 'Z': 'Z',
67
+ 'OR': 'OR',
68
+ 'B': 'BETA',
69
+ 'BETA': 'BETA',
70
+ 'LOG_ODDS': 'LOG_ODDS',
71
+ 'EFFECTS': 'BETA',
72
+ 'EFFECT': 'BETA',
73
+ 'b': 'BETA',
74
+ 'beta': 'BETA',
75
+ #SE
76
+ 'se': 'SE',
77
+ # INFO
78
+ 'INFO': 'INFO',
79
+ 'Info': 'INFO',
80
+ # MAF
81
+ 'EAF': 'FRQ',
82
+ 'FRQ': 'FRQ',
83
+ 'MAF': 'FRQ',
84
+ 'FRQ_U': 'FRQ',
85
+ 'F_U': 'FRQ',
86
+ 'frq_A1': 'FRQ',
87
+ 'frq': 'FRQ',
88
+ 'freq': 'FRQ'
89
+ }
90
+
91
+
92
+ def get_compression(fh):
93
+ '''
94
+ Read filename suffixes and figure out whether it is gzipped,bzip2'ed or not compressed
95
+ '''
96
+ if fh.endswith('gz'):
97
+ compression = 'gzip'
98
+ elif fh.endswith('bz2'):
99
+ compression = 'bz2'
100
+ else:
101
+ compression = None
102
+
103
+ return compression
104
+
105
+
106
+ def gwas_checkname(gwas,config):
107
+ '''
108
+ Iterpret column names of gwas
109
+ '''
110
+ old_name = gwas.columns
111
+ mapped_cnames = {}
112
+ for col in gwas.columns:
113
+ mapped_cnames[col] = default_cnames.get(col, col)
114
+ gwas.columns = list(mapped_cnames.values())
115
+
116
+ # When column names are provided by users
117
+ name_updates = {'SNP': config.snp,'A1': config.a1,'A2': config.a2,'INFO': config.info,
118
+ 'BETA': config.beta,'SE': config.se,'P': config.p,'FRQ': config.frq,'N': config.n,
119
+ 'Z': config.z,'Chr': config.chr, 'Pos': config.pos,'OR':config.OR, 'SE_OR':config.se_OR}
120
+
121
+ for key, value in name_updates.items():
122
+ if value is not None and value in gwas.columns:
123
+ gwas.rename(columns={value: key}, inplace=True)
124
+ new_name = gwas.columns
125
+ name_dict = {new_name[i]: old_name[i] for i in range(len(new_name))}
126
+
127
+ # When at OR scale
128
+ if 'OR' in new_name and 'SE_OR' in new_name:
129
+ gwas['BETA'] = gwas.OR.apply(lambda x: math.log(x) if x > 0 else None)
130
+ gwas['SE'] = gwas.SE_OR.apply(lambda x: math.log(x) if x > 0 else None)
131
+
132
+ interpreting = {
133
+ "SNP": 'Variant ID (e.g., rs number).',
134
+ "A1": 'Allele 1, interpreted as the effect allele for signed sumstat.',
135
+ "A2": 'Allele 2, interpreted as the non-effect allele for signed sumstat.',
136
+ "BETA": '[linear/logistic] regression coefficient (0 → no effect; above 0 → A1 is trait/risk increasing).',
137
+ "SE": 'Standard error of the regression coefficient.',
138
+ "OR": 'Odds ratio, will be transferred to linear scale.',
139
+ "SE_OR": 'Standard error of the odds ratio, will be transferred to linear scale.',
140
+ "P": 'P-Value.',
141
+ "Z": 'Z-Value.',
142
+ "N": 'Sample size.',
143
+ "INFO": 'INFO score (imputation quality; higher → better imputation).',
144
+ "FRQ": 'Allele frequency of A1.',
145
+ "Chr":'Chromsome.',
146
+ 'Pos': 'SNP positions.'
147
+ }
148
+
149
+ print(f'\nIterpreting column names as follows:')
150
+ for key, value in interpreting.items():
151
+ if key in new_name:
152
+ print(f'{name_dict[key]}: {interpreting[key]}')
153
+
154
+ return gwas
155
+
156
+
157
+ def gwas_checkformat(gwas,config):
158
+ '''
159
+ Check column names required for different format
160
+ '''
161
+ if config.format=='gsMap':
162
+ condition1 = np.any(np.isin(['P', 'Z'],gwas.columns))
163
+ condition2 = np.all(np.isin(['BETA', 'SE'],gwas.columns))
164
+ if not (condition1 or condition2):
165
+ raise ValueError('To munge GWAS data into gsMap format, either P or Z values, or both BETA and SE values, are required.')
166
+ else:
167
+ if 'Z' in gwas.columns:
168
+ pass
169
+ elif 'P' in gwas.columns:
170
+ gwas['Z'] = np.sqrt(chi2.isf(gwas.P, 1)) * np.where(gwas['BETA'] < 0, -1, 1)
171
+ else:
172
+ gwas['Z'] = gwas.BETA / gwas.SE
173
+
174
+ elif config.format=='COJO':
175
+ condition = np.all(np.isin(['A1','A2','FRQ','BETA','SE','P','N'],gwas.columns))
176
+ if not condition:
177
+ raise ValueError('To munge GWAS data into COJO format, either A1|A2|FRQ|BETA|SE|P|N, are required.')
178
+ else:
179
+ gwas['Z'] = np.sqrt(chi2.isf(gwas.P, 1)) * np.where(gwas['BETA'] < 0, -1, 1)
180
+
181
+ return gwas
182
+
183
+
184
+ def filter_info(info,config):
185
+ '''Remove INFO < args.info_min (default 0.9) and complain about out-of-bounds INFO.'''
186
+ if type(info) is pd.Series: # one INFO column
187
+ jj = ((info > 2.0) | (info < 0)) & info.notnull()
188
+ ii = info >= config.info_min
189
+ elif type(info) is pd.DataFrame: # several INFO columns
190
+ jj = (((info > 2.0) & info.notnull()).any(axis=1) | (
191
+ (info < 0) & info.notnull()).any(axis=1))
192
+ ii = (info.sum(axis=1) >= config.info_min * (len(info.columns)))
193
+ else:
194
+ raise ValueError('Expected pd.DataFrame or pd.Series.')
195
+
196
+ bad_info = jj.sum()
197
+ if bad_info > 0:
198
+ msg = 'WARNING: {N} SNPs had INFO outside of [0,1.5]. The INFO column may be mislabeled.'
199
+ logger.warning(msg.format(N=bad_info))
200
+
201
+ return ii
202
+
203
+
204
+ def filter_frq(frq,config):
205
+ '''
206
+ Filter on MAF. Remove MAF < args.maf_min and out-of-bounds MAF.
207
+ '''
208
+ jj = (frq < 0) | (frq > 1)
209
+ bad_frq = jj.sum()
210
+ if bad_frq > 0:
211
+ msg = 'WARNING: {N} SNPs had FRQ outside of [0,1]. The FRQ column may be mislabeled.'
212
+ logger.warning(msg.format(N=bad_frq))
213
+
214
+ frq = np.minimum(frq, 1 - frq)
215
+ ii = frq > config.maf_min
216
+ return ii & ~jj
217
+
218
+
219
+ def filter_pvals(P,config):
220
+ '''Remove out-of-bounds P-values'''
221
+ ii = (P > 0) & (P <= 1)
222
+ bad_p = (~ii).sum()
223
+ if bad_p > 0:
224
+ msg = 'WARNING: {N} SNPs had P outside of (0,1]. The P column may be mislabeled.'
225
+ logger.warning(msg.format(N=bad_p))
226
+
227
+ return ii
228
+
229
+
230
+ def filter_alleles(a):
231
+ '''Remove alleles that do not describe strand-unambiguous SNPs'''
232
+ return a.isin(VALID_SNPS)
233
+
234
+
235
+ def gwas_qc(gwas,config):
236
+ '''
237
+ Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
238
+ '''
239
+ old = len(gwas)
240
+ print(f'\nFiltering SNPs as follows:')
241
+ # filter: SNPs with missing values
242
+ drops = {'NA': 0, 'P': 0, 'INFO': 0, 'FRQ': 0, 'A': 0, 'SNP': 0, 'Dup': 0, 'N':0}
243
+
244
+ gwas = gwas.dropna(axis=0, how="any", subset=filter(
245
+ lambda x: x != 'INFO', gwas.columns)).reset_index(drop=True)
246
+
247
+ drops['NA'] = old - len(gwas)
248
+ print(f'Removed {drops["NA"]} SNPs with missing values.')
249
+
250
+ # filter: SNPs with Info < 0.9
251
+ if 'INFO' in gwas.columns:
252
+ old = len(gwas)
253
+ gwas = gwas.loc[filter_info(gwas['INFO'],config)]
254
+ drops['INFO'] = old - len(gwas)
255
+ print(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
256
+
257
+ # filter: SNPs with MAF <= 0.01
258
+ if 'FRQ' in gwas.columns:
259
+ old = len(gwas)
260
+ gwas = gwas.loc[filter_frq(gwas['FRQ'],config)]
261
+ drops['FRQ'] += old - len(gwas)
262
+ print(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
263
+
264
+ # filter: P-value that out-of-bounds [0,1]
265
+ if 'P' in gwas.columns:
266
+ old = len(gwas)
267
+ gwas = gwas.loc[filter_pvals(gwas['P'],config)]
268
+ drops['P'] += old - len(gwas)
269
+ print(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
270
+
271
+ # filter: Variants that are strand-ambiguous
272
+ if 'A1' in gwas.columns and 'A2' in gwas.columns:
273
+ gwas.A1 = gwas.A1.str.upper()
274
+ gwas.A2 = gwas.A2.str.upper()
275
+ gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
276
+ drops['A'] += old - len(gwas)
277
+ print(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
278
+
279
+ # filter: Duplicated rs numbers
280
+ if 'SNP' in gwas.columns:
281
+ old = len(gwas)
282
+ gwas = gwas.drop_duplicates(subset='SNP').reset_index(drop=True)
283
+ drops['Dup'] += old - len(gwas)
284
+ print(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
285
+
286
+ # filter:Sample size
287
+ n_min = gwas.N.quantile(0.9) / 1.5
288
+ old = len(gwas)
289
+ gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
290
+ drops['N'] += old - len(gwas)
291
+ print(f'Removed {drops["N"]} SNPs with N < {n_min}.')
292
+
293
+ return gwas
294
+
295
+
296
+ def variant_to_rsid(gwas,config):
297
+ '''
298
+ Convert variant id (Chr, Pos) to rsid
299
+ '''
300
+ print("\nConverting the SNP position to rsid. This process may take some time.")
301
+ unique_ids = set(gwas['id'])
302
+ chr_format = gwas['Chr'].unique().astype(str)
303
+ chr_format = [re.sub(r'\d+', '', value) for value in chr_format][1]
304
+
305
+ dtype = {'chr': str, 'pos': str, 'ref': str, 'alt': str, 'dbsnp': str}
306
+ chunk_iter = pd.read_csv(config.dbsnp, chunksize=config.chunksize, sep="\t", skiprows=1,
307
+ dtype=dtype, names=['chr', 'pos', 'ref', 'alt', 'dbsnp'])
308
+
309
+ # Iterate over chunks
310
+ matching_id = pd.DataFrame()
311
+ for chunk in chunk_iter:
312
+ chunk['id'] = chr_format+chunk["chr"]+"_"+chunk["pos"]
313
+ matching_id = pd.concat([matching_id, chunk[chunk['id'].isin(unique_ids)][['dbsnp','id']]])
314
+
315
+ matching_id = matching_id.drop_duplicates(subset='dbsnp').reset_index(drop=True)
316
+ matching_id = matching_id.drop_duplicates(subset='id').reset_index(drop=True)
317
+ matching_id.index = matching_id.id
318
+ return matching_id
319
+
320
+
321
+ def clean_SNP_id(gwas,config):
322
+ '''
323
+ Clean SNP id
324
+ '''
325
+ old = len(gwas)
326
+ condition1 = 'SNP' in gwas.columns
327
+ condition2 = np.all(np.isin(['Chr', 'Pos'],gwas.columns))
328
+
329
+ if not (condition1 or condition2):
330
+ raise ValueError('Either SNP rsid, or both SNP chromosome and position, are required.')
331
+ elif condition1:
332
+ pass
333
+ elif condition2:
334
+ if config.dbsnp is None:
335
+ raise ValueError('To Convert SNP positions to rsid, dbsnp reference is required.')
336
+ else:
337
+ gwas['id'] = gwas["Chr"].astype(str)+"_"+gwas["Pos"].astype(str)
338
+ gwas = gwas.drop_duplicates(subset='id').reset_index(drop=True)
339
+ gwas.index = gwas.id
340
+
341
+ matching_id = variant_to_rsid(gwas,config)
342
+ gwas = gwas.loc[matching_id.id]
343
+ gwas['SNP'] = matching_id.dbsnp
344
+ num_fail = old - len(gwas)
345
+ print(f'Removed {num_fail} SNPs that did not convert to rsid.')
346
+
347
+ return gwas
348
+
349
+
350
+ def gwas_metadata(gwas,config):
351
+ '''
352
+ Report key features of GWAS data
353
+ '''
354
+ print('\nMetadata:')
355
+ CHISQ = (gwas.Z ** 2)
356
+ mean_chisq = CHISQ.mean()
357
+ print('Mean chi^2 = ' + str(round(mean_chisq, 3)))
358
+ if mean_chisq < 1.02:
359
+ logger.warning("Mean chi^2 may be too small.")
360
+
361
+ print('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
362
+ print('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
363
+ print('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ> 29).sum()))
364
+
365
+
366
+ def gwas_format(config:FormatSumstatsConfig):
367
+ '''
368
+ Format GWAS data
369
+ '''
370
+ print(f'------Formating gwas data for {config.sumstats}...')
371
+ gwas_file="/storage/yangjianLab/songliyang/GWAS_trait/COJO/Alcohol_Dependence.txt"
372
+ gwas = pd.read_csv(config.sumstats,delim_whitespace=True,
373
+ header=0,compression=get_compression(gwas_file),na_values=['.', 'NA'])
374
+ print(f'Read {len(gwas)} SNPs from {config.sumstats}.')
375
+
376
+ # Check name and format
377
+ gwas = gwas_checkname(gwas,config)
378
+ gwas = gwas_checkformat(gwas,config)
379
+ # Clean the snp id
380
+ gwas = clean_SNP_id(gwas,config)
381
+ # QC
382
+ gwas = gwas_qc(gwas,config)
383
+ # Meta
384
+ gwas_metadata(gwas,config)
385
+
386
+ # Saving the data
387
+ if config.format=='COJO':
388
+ keep = ['SNP','A1','A2','FRQ','BETA','SE','P','N']
389
+ appendix = '.cojo'
390
+ elif config.format=='gsMap':
391
+ keep = ["A1","A2","Z","N","SNP"]
392
+ appendix = '.sumstats'
393
+
394
+ if 'Chr' in gwas.columns and 'Pos' in gwas.columns and config.keep_chr_pos is True:
395
+ keep = keep + ['Chr','Pos']
396
+
397
+ gwas = gwas[keep]
398
+ out_name = config.out + appendix +'.gz'
399
+
400
+ print(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
401
+ gwas.to_csv(out_name, sep="\t", index=False,
402
+ float_format='%.3f', compression = 'gzip')
403
+
404
+
405
+ if __name__ == '__main__':
406
+ parser = argparse.ArgumentParser(description="Visualization the results")
407
+ parser = add_format_sumstats_args(parser)
408
+ args = parser.parse_args()
409
+ config = FormatSumstatsConfig(**vars(args))
410
+ gwas_format(config)