gsMap 1.65__py3-none-any.whl → 1.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN_VAE/adjacency_matrix.py +48 -68
- gsMap/GNN_VAE/model.py +68 -66
- gsMap/GNN_VAE/train.py +50 -61
- gsMap/__init__.py +1 -1
- gsMap/config.py +4 -4
- gsMap/find_latent_representation.py +103 -103
- gsMap/format_sumstats.py +20 -20
- gsMap/latent_to_gene.py +125 -109
- gsMap/spatial_ldsc_multiple_sumstats.py +0 -2
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/METADATA +2 -2
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/RECORD +14 -14
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/LICENSE +0 -0
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/WHEEL +0 -0
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/entry_points.txt +0 -0
@@ -1,145 +1,145 @@
|
|
1
1
|
import logging
|
2
2
|
import random
|
3
|
-
|
4
3
|
import numpy as np
|
5
|
-
import pandas as pd
|
6
4
|
import scanpy as sc
|
7
5
|
import torch
|
8
|
-
from sklearn import
|
9
|
-
|
10
|
-
from gsMap.GNN_VAE.adjacency_matrix import
|
11
|
-
from gsMap.GNN_VAE.train import
|
6
|
+
from sklearn.decomposition import PCA
|
7
|
+
from sklearn.preprocessing import LabelEncoder
|
8
|
+
from gsMap.GNN_VAE.adjacency_matrix import construct_adjacency_matrix
|
9
|
+
from gsMap.GNN_VAE.train import ModelTrainer
|
12
10
|
from gsMap.config import FindLatentRepresentationsConfig
|
13
11
|
|
14
12
|
logger = logging.getLogger(__name__)
|
15
13
|
|
14
|
+
|
16
15
|
def set_seed(seed_value):
|
17
16
|
"""
|
18
|
-
Set seed for reproducibility in PyTorch.
|
17
|
+
Set seed for reproducibility in PyTorch and other libraries.
|
19
18
|
"""
|
20
|
-
torch.manual_seed(seed_value)
|
21
|
-
np.random.seed(seed_value)
|
22
|
-
random.seed(seed_value)
|
19
|
+
torch.manual_seed(seed_value)
|
20
|
+
np.random.seed(seed_value)
|
21
|
+
random.seed(seed_value)
|
23
22
|
if torch.cuda.is_available():
|
24
|
-
logger.info('
|
25
|
-
torch.cuda.manual_seed(seed_value)
|
26
|
-
torch.cuda.manual_seed_all(seed_value)
|
23
|
+
logger.info('Using GPU for computations.')
|
24
|
+
torch.cuda.manual_seed(seed_value)
|
25
|
+
torch.cuda.manual_seed_all(seed_value)
|
27
26
|
else:
|
28
|
-
logger.info('
|
27
|
+
logger.info('Using CPU for computations.')
|
29
28
|
|
29
|
+
def preprocess_data(adata, params):
|
30
|
+
"""
|
31
|
+
Preprocess the AnnData
|
32
|
+
"""
|
33
|
+
logger.info('Preprocessing data...')
|
34
|
+
adata.var_names_make_unique()
|
35
|
+
|
36
|
+
sc.pp.filter_genes(adata, min_cells=30)
|
37
|
+
if params.data_layer in adata.layers.keys():
|
38
|
+
adata.X = adata.layers[params.data_layer]
|
39
|
+
else:
|
40
|
+
raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
|
30
41
|
|
31
|
-
|
32
|
-
class Latent_Representation_Finder:
|
42
|
+
if params.data_layer in ['count', 'counts']:
|
33
43
|
|
34
|
-
|
35
|
-
|
36
|
-
self.Params = args
|
44
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
45
|
+
sc.pp.log1p(adata)
|
37
46
|
|
38
|
-
#
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
sc.pp.scale(self.adata)
|
45
|
-
else:
|
46
|
-
if self.Params.data_layer != 'X':
|
47
|
-
self.adata.X = self.adata.layers[self.Params.data_layer]
|
48
|
-
sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
|
47
|
+
# Identify highly variable genes
|
48
|
+
sc.pp.highly_variable_genes(
|
49
|
+
adata,
|
50
|
+
flavor="seurat_v3",
|
51
|
+
n_top_genes=params.feat_cell,
|
52
|
+
)
|
49
53
|
|
50
|
-
|
54
|
+
elif params.data_layer in adata.layers.keys():
|
55
|
+
logger.info(f'Using {params.data_layer} data...')
|
56
|
+
sc.pp.highly_variable_genes(
|
57
|
+
adata,
|
58
|
+
flavor="seurat",
|
59
|
+
n_top_genes=params.feat_cell,
|
60
|
+
)
|
51
61
|
|
52
|
-
|
53
|
-
graph_dict = Construct_Adjacency_Matrix(self.adata, self.Params)
|
62
|
+
return adata
|
54
63
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
64
|
+
|
65
|
+
class LatentRepresentationFinder:
|
66
|
+
def __init__(self, adata, args: FindLatentRepresentationsConfig):
|
67
|
+
self.params = args
|
68
|
+
|
69
|
+
self.expression_array = adata[:, adata.var.highly_variable].X.copy()
|
70
|
+
|
71
|
+
if self.params.data_layer in ['count', 'counts']:
|
72
|
+
self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
|
73
|
+
|
74
|
+
# Construct the neighboring graph
|
75
|
+
self.graph_dict = construct_adjacency_matrix(adata, self.params)
|
76
|
+
|
77
|
+
def compute_pca(self):
|
78
|
+
self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
|
79
|
+
return self.latent_pca
|
80
|
+
|
81
|
+
def run_gnn_vae(self, label, verbose='whole ST data'):
|
82
|
+
|
83
|
+
# Use PCA if specified
|
84
|
+
if self.params.input_pca:
|
85
|
+
node_X = self.compute_pca()
|
86
|
+
else:
|
87
|
+
node_X = self.expression_array
|
60
88
|
|
61
89
|
# Update the input shape
|
62
|
-
self.
|
63
|
-
self.
|
90
|
+
self.params.n_nodes = node_X.shape[0]
|
91
|
+
self.params.feat_cell = node_X.shape[1]
|
64
92
|
|
65
|
-
# Run GNN
|
66
|
-
logger.info(f'
|
67
|
-
gvae =
|
93
|
+
# Run GNN
|
94
|
+
logger.info(f'Finding latent representations for {verbose}...')
|
95
|
+
gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
|
68
96
|
gvae.run_train()
|
69
97
|
|
70
|
-
|
98
|
+
del self.graph_dict
|
71
99
|
|
72
|
-
|
73
|
-
sc.tl.pca(self.adata)
|
74
|
-
return self.adata.obsm['X_pca'][:, 0:self.Params.n_comps]
|
100
|
+
return gvae.get_latent()
|
75
101
|
|
76
102
|
|
77
|
-
def run_find_latent_representation(args:FindLatentRepresentationsConfig):
|
103
|
+
def run_find_latent_representation(args: FindLatentRepresentationsConfig):
|
78
104
|
set_seed(2024)
|
79
|
-
|
80
|
-
args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
|
105
|
+
|
81
106
|
# Load the ST data
|
82
|
-
logger.info(f'
|
83
|
-
adata = sc.read_h5ad(
|
84
|
-
adata.
|
85
|
-
|
86
|
-
logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
|
107
|
+
logger.info(f'Loading ST data of {args.sample_name}...')
|
108
|
+
adata = sc.read_h5ad(args.input_hdf5_path)
|
109
|
+
logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
|
110
|
+
|
87
111
|
# Load the cell type annotation
|
88
|
-
if
|
89
|
-
#
|
90
|
-
adata = adata[~
|
112
|
+
if args.annotation is not None:
|
113
|
+
# Remove cells without enough annotations
|
114
|
+
adata = adata[~adata.obs[args.annotation].isnull()]
|
91
115
|
num = adata.obs[args.annotation].value_counts()
|
92
|
-
|
116
|
+
valid_annotations = num[num >= 30].index.to_list()
|
117
|
+
adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
|
93
118
|
|
94
|
-
le =
|
95
|
-
le.
|
96
|
-
adata.obs['categorical_label']
|
97
|
-
label = adata.obs['categorical_label'].to_list()
|
119
|
+
le = LabelEncoder()
|
120
|
+
adata.obs['categorical_label'] = le.fit_transform(adata.obs[args.annotation])
|
121
|
+
label = adata.obs['categorical_label'].to_numpy()
|
98
122
|
else:
|
99
123
|
label = None
|
100
|
-
# Find latent representations
|
101
|
-
latent_rep = Latent_Representation_Finder(adata, args)
|
102
|
-
latent_GVAE = latent_rep.Run_GNN_VAE(label)
|
103
|
-
latent_PCA = latent_rep.Run_PCA()
|
104
|
-
# Add latent representations to the spe data
|
105
|
-
logger.info(f'------Adding latent representations...')
|
106
|
-
adata.obsm["latent_GVAE"] = latent_GVAE
|
107
|
-
adata.obsm["latent_PCA"] = latent_PCA
|
108
|
-
# Run umap based on latent representations
|
109
|
-
for name in ['latent_GVAE', 'latent_PCA']:
|
110
|
-
sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
|
111
|
-
sc.tl.umap(adata)
|
112
|
-
adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
|
113
124
|
|
114
|
-
|
115
|
-
|
116
|
-
logger.info(f'------Finding latent representations hierarchically...')
|
117
|
-
PCA_all = pd.DataFrame()
|
118
|
-
GVAE_all = pd.DataFrame()
|
125
|
+
# Preprocess data
|
126
|
+
adata = preprocess_data(adata, args)
|
119
127
|
|
120
|
-
|
121
|
-
|
122
|
-
|
128
|
+
latent_rep = LatentRepresentationFinder(adata, args)
|
129
|
+
latent_gvae = latent_rep.run_gnn_vae(label)
|
130
|
+
latent_pca = latent_rep.compute_pca()
|
123
131
|
|
124
|
-
|
125
|
-
|
132
|
+
# Add latent representations to the AnnData object
|
133
|
+
logger.info('Adding latent representations...')
|
134
|
+
adata.obsm["latent_GVAE"] = latent_gvae
|
135
|
+
adata.obsm["latent_PCA"] = latent_pca
|
126
136
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
latent_GVAE_part.index = adata_part.obs_names
|
134
|
-
latent_PCA_part.index = adata_part.obs_names
|
135
|
-
|
136
|
-
GVAE_all = pd.concat((GVAE_all, latent_GVAE_part), axis=0)
|
137
|
-
PCA_all = pd.concat((PCA_all, latent_PCA_part), axis=0)
|
138
|
-
|
139
|
-
args.feat_cell = num_features
|
137
|
+
# Run UMAP based on latent representations
|
138
|
+
for name in ['latent_GVAE', 'latent_PCA']:
|
139
|
+
sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
|
140
|
+
sc.tl.umap(adata)
|
141
|
+
adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
|
140
142
|
|
141
|
-
|
142
|
-
|
143
|
-
logger.info(f'------Saving ST data...')
|
143
|
+
# Save the AnnData object
|
144
|
+
logger.info('Saving ST data...')
|
144
145
|
adata.write(args.hdf5_with_latent_path)
|
145
|
-
|
gsMap/format_sumstats.py
CHANGED
@@ -150,10 +150,10 @@ def gwas_checkname(gwas, config):
|
|
150
150
|
'Pos': 'SNP positions.'
|
151
151
|
}
|
152
152
|
|
153
|
-
|
153
|
+
logger.info(f'\nIterpreting column names as follows:')
|
154
154
|
for key, value in interpreting.items():
|
155
155
|
if key in new_name:
|
156
|
-
|
156
|
+
logger.info(f'{name_dict[key]}: {interpreting[key]}')
|
157
157
|
|
158
158
|
return gwas
|
159
159
|
|
@@ -242,7 +242,7 @@ def gwas_qc(gwas, config):
|
|
242
242
|
Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
|
243
243
|
'''
|
244
244
|
old = len(gwas)
|
245
|
-
|
245
|
+
logger.info(f'\nFiltering SNPs as follows:')
|
246
246
|
# filter: SNPs with missing values
|
247
247
|
drops = {'NA': 0, 'P': 0, 'INFO': 0, 'FRQ': 0, 'A': 0, 'SNP': 0, 'Dup': 0, 'N': 0}
|
248
248
|
|
@@ -250,28 +250,28 @@ def gwas_qc(gwas, config):
|
|
250
250
|
lambda x: x != 'INFO', gwas.columns)).reset_index(drop=True)
|
251
251
|
|
252
252
|
drops['NA'] = old - len(gwas)
|
253
|
-
|
253
|
+
logger.info(f'Removed {drops["NA"]} SNPs with missing values.')
|
254
254
|
|
255
255
|
# filter: SNPs with Info < 0.9
|
256
256
|
if 'INFO' in gwas.columns:
|
257
257
|
old = len(gwas)
|
258
258
|
gwas = gwas.loc[filter_info(gwas['INFO'], config)]
|
259
259
|
drops['INFO'] = old - len(gwas)
|
260
|
-
|
260
|
+
logger.info(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
|
261
261
|
|
262
262
|
# filter: SNPs with MAF <= 0.01
|
263
263
|
if 'FRQ' in gwas.columns:
|
264
264
|
old = len(gwas)
|
265
265
|
gwas = gwas.loc[filter_frq(gwas['FRQ'], config)]
|
266
266
|
drops['FRQ'] += old - len(gwas)
|
267
|
-
|
267
|
+
logger.info(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
|
268
268
|
|
269
269
|
# filter: P-value that out-of-bounds [0,1]
|
270
270
|
if 'P' in gwas.columns:
|
271
271
|
old = len(gwas)
|
272
272
|
gwas = gwas.loc[filter_pvals(gwas['P'], config)]
|
273
273
|
drops['P'] += old - len(gwas)
|
274
|
-
|
274
|
+
logger.info(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
|
275
275
|
|
276
276
|
# filter: Variants that are strand-ambiguous
|
277
277
|
if 'A1' in gwas.columns and 'A2' in gwas.columns:
|
@@ -279,21 +279,21 @@ def gwas_qc(gwas, config):
|
|
279
279
|
gwas.A2 = gwas.A2.str.upper()
|
280
280
|
gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
|
281
281
|
drops['A'] += old - len(gwas)
|
282
|
-
|
282
|
+
logger.info(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
|
283
283
|
|
284
284
|
# filter: Duplicated rs numbers
|
285
285
|
if 'SNP' in gwas.columns:
|
286
286
|
old = len(gwas)
|
287
287
|
gwas = gwas.drop_duplicates(subset='SNP').reset_index(drop=True)
|
288
288
|
drops['Dup'] += old - len(gwas)
|
289
|
-
|
289
|
+
logger.info(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
|
290
290
|
|
291
291
|
# filter:Sample size
|
292
292
|
n_min = gwas.N.quantile(0.9) / 1.5
|
293
293
|
old = len(gwas)
|
294
294
|
gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
|
295
295
|
drops['N'] += old - len(gwas)
|
296
|
-
|
296
|
+
logger.info(f'Removed {drops["N"]} SNPs with N < {n_min}.')
|
297
297
|
|
298
298
|
return gwas
|
299
299
|
|
@@ -302,7 +302,7 @@ def variant_to_rsid(gwas, config):
|
|
302
302
|
'''
|
303
303
|
Convert variant id (Chr, Pos) to rsid
|
304
304
|
'''
|
305
|
-
|
305
|
+
logger.info("\nConverting the SNP position to rsid. This process may take some time.")
|
306
306
|
unique_ids = set(gwas['id'])
|
307
307
|
chr_format = gwas['Chr'].unique().astype(str)
|
308
308
|
chr_format = [re.sub(r'\d+', '', value) for value in chr_format][1]
|
@@ -347,7 +347,7 @@ def clean_SNP_id(gwas, config):
|
|
347
347
|
gwas = gwas.loc[matching_id.id]
|
348
348
|
gwas['SNP'] = matching_id.dbsnp
|
349
349
|
num_fail = old - len(gwas)
|
350
|
-
|
350
|
+
logger.info(f'Removed {num_fail} SNPs that did not convert to rsid.')
|
351
351
|
|
352
352
|
return gwas
|
353
353
|
|
@@ -356,27 +356,27 @@ def gwas_metadata(gwas, config):
|
|
356
356
|
'''
|
357
357
|
Report key features of GWAS data
|
358
358
|
'''
|
359
|
-
|
359
|
+
logger.info('\nSummary of GWAS data:')
|
360
360
|
CHISQ = (gwas.Z ** 2)
|
361
361
|
mean_chisq = CHISQ.mean()
|
362
|
-
|
362
|
+
logger.info('Mean chi^2 = ' + str(round(mean_chisq, 3)))
|
363
363
|
if mean_chisq < 1.02:
|
364
364
|
logger.warning("Mean chi^2 may be too small.")
|
365
365
|
|
366
|
-
|
367
|
-
|
368
|
-
|
366
|
+
logger.info('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
|
367
|
+
logger.info('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
|
368
|
+
logger.info('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
|
369
369
|
|
370
370
|
|
371
371
|
def gwas_format(config: FormatSumstatsConfig):
|
372
372
|
'''
|
373
373
|
Format GWAS data
|
374
374
|
'''
|
375
|
-
|
375
|
+
logger.info(f'------Formating gwas data for {config.sumstats}...')
|
376
376
|
compression_type = get_compression(config.sumstats)
|
377
377
|
gwas = pd.read_csv(config.sumstats, delim_whitespace=True, header=0, compression=compression_type,
|
378
378
|
na_values=['.', 'NA'])
|
379
|
-
|
379
|
+
logger.info(f'Read {len(gwas)} SNPs from {config.sumstats}.')
|
380
380
|
|
381
381
|
# Check name and format
|
382
382
|
gwas = gwas_checkname(gwas, config)
|
@@ -402,6 +402,6 @@ def gwas_format(config: FormatSumstatsConfig):
|
|
402
402
|
gwas = gwas[keep]
|
403
403
|
out_name = config.out + appendix + '.gz'
|
404
404
|
|
405
|
-
|
405
|
+
logger.info(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
|
406
406
|
gwas.to_csv(out_name, sep="\t", index=False,
|
407
407
|
float_format='%.3f', compression='gzip')
|