gsMap 1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,95 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Jul 4 21:31:27 2023
5
+
6
+ @author: songliyang
7
+ """
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy.sparse as sp
11
+ import sklearn.neighbors
12
+ import torch
13
+
14
+
15
+ def Cal_Spatial_Net(adata, n_neighbors=5, verbose=True):
16
+ """\
17
+ Construct the spatial neighbor networks.
18
+ """
19
+ #-
20
+ if verbose:
21
+ print('------Calculating spatial graph...')
22
+ coor = pd.DataFrame(adata.obsm['spatial'])
23
+ coor.index = adata.obs.index
24
+ #-
25
+ nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
26
+ #-
27
+ distances, indices = nbrs.kneighbors(coor, return_distance=True)
28
+ KNN_list = []
29
+ for it in range(indices.shape[0]):
30
+ KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
31
+ #-
32
+ KNN_df = pd.concat(KNN_list)
33
+ KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
34
+ #-
35
+ Spatial_Net = KNN_df.copy()
36
+ Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
37
+ id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
38
+ Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
39
+ Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
40
+ #-
41
+ return Spatial_Net
42
+
43
+
44
+ def sparse_mx_to_torch_sparse_tensor(sparse_mx):
45
+ """Convert a scipy sparse matrix to a torch sparse tensor."""
46
+ sparse_mx = sparse_mx.tocoo().astype(np.float32)
47
+ indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
48
+ values = torch.from_numpy(sparse_mx.data)
49
+ shape = torch.Size(sparse_mx.shape)
50
+ return torch.sparse.FloatTensor(indices, values, shape)
51
+
52
+
53
+ def preprocess_graph(adj):
54
+ adj = sp.coo_matrix(adj)
55
+ adj_ = adj + sp.eye(adj.shape[0])
56
+ rowsum = np.array(adj_.sum(1))
57
+ degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
58
+ adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
59
+ return sparse_mx_to_torch_sparse_tensor(adj_normalized)
60
+
61
+
62
+
63
+ def Construct_Adjacency_Matrix(adata,Params, verbose=True):
64
+ # Construct the neighbor graph
65
+ Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
66
+ #-
67
+ if verbose:
68
+ print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
69
+ print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
70
+ #-
71
+ cells = np.array(adata.obs.index)
72
+ cells_id_tran = dict(zip(cells, range(cells.shape[0])))
73
+ #-
74
+ G_df = Spatial_Net.copy()
75
+ G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
76
+ G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
77
+ #-
78
+ if Params.weighted_adj:
79
+ distance_normalized = G_df.Distance/(max(G_df.Distance)+1)
80
+ adj_org = sp.coo_matrix((np.exp(-distance_normalized**2/(2)), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
81
+ else:
82
+ adj_org = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
83
+ #-
84
+ adj_m1 = adj_org
85
+ adj_norm_m1 = preprocess_graph(adj_m1)
86
+ adj_label_m1 = adj_m1 + sp.eye(adj_m1.shape[0])
87
+ norm_m1 = adj_m1.shape[0] * adj_m1.shape[0] / float((adj_m1.shape[0] * adj_m1.shape[0] - adj_m1.sum()) * 2)
88
+ #-
89
+ graph_dict = {
90
+ "adj_org": adj_org,
91
+ "adj_norm": adj_norm_m1,
92
+ "norm_value": norm_m1
93
+ }
94
+ #-
95
+ return graph_dict
gsMap/GNN_VAE/model.py ADDED
@@ -0,0 +1,87 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Mon Jul 3 11:42:44 2023
5
+
6
+ @author: songliyang
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch_geometric.nn import GATConv
13
+
14
+
15
+ def full_block(in_features, out_features, p_drop):
16
+ return nn.Sequential(nn.Linear(in_features, out_features),
17
+ nn.BatchNorm1d(out_features),
18
+ nn.ELU(),
19
+ nn.Dropout(p=p_drop))
20
+
21
+
22
+ class GNN(nn.Module):
23
+ def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
24
+ super().__init__()
25
+ self.conv1 = GATConv(in_features, out_features,heads)
26
+ self.act = act
27
+ self.dr = dr
28
+ #-
29
+ def forward(self, x, edge_index):
30
+ out = self.conv1(x, edge_index)
31
+ out = self.act(out)
32
+ out = F.dropout(out, self.dr, self.training)
33
+ return out
34
+
35
+
36
+
37
+ class GNN_VAE_Model(nn.Module):
38
+ def __init__(self, input_dim,params,num_classes=1):
39
+ super(GNN_VAE_Model, self).__init__()
40
+ self.var = params.var
41
+ self.num_classes = num_classes
42
+
43
+ # Encoder
44
+ self.encoder = nn.Sequential()
45
+ self.encoder.add_module('encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
46
+ self.encoder.add_module('encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))
47
+
48
+ # GNN (GAT)
49
+ self.gn1 = GNN(params.feat_hidden2, params.gcn_hidden1, params.p_drop, act=F.relu,heads = params.nheads)
50
+ self.gn2 = GNN(params.gcn_hidden1*params.nheads, params.gcn_hidden2, params.p_drop, act=lambda x: x)
51
+ self.gn3 = GNN(params.gcn_hidden1*params.nheads, params.gcn_hidden2, params.p_drop, act=lambda x: x)
52
+
53
+ # Decoder
54
+ self.decoder = nn.Sequential()
55
+ self.decoder.add_module('decoder_L1', full_block(params.gcn_hidden2, params.feat_hidden2, params.p_drop))
56
+ self.decoder.add_module('decoder_L2', full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop))
57
+ self.decoder.add_module('decoder_output', nn.Sequential(nn.Linear(params.feat_hidden1, input_dim)))
58
+
59
+ # Cluster
60
+ self.cluster = nn.Sequential()
61
+ self.cluster.add_module('cluster_L1', full_block(params.gcn_hidden2, params.feat_hidden2, params.p_drop))
62
+ self.cluster.add_module('cluster_output', nn.Linear(params.feat_hidden2, self.num_classes))
63
+
64
+ def encode(self, x, adj):
65
+ feat_x = self.encoder(x)
66
+ hidden1 = self.gn1(feat_x, adj)
67
+ mu = self.gn2(hidden1, adj)
68
+ if self.var:
69
+ logvar = self.gn3(hidden1, adj)
70
+ return mu, logvar
71
+ else:
72
+ return mu, None
73
+
74
+ def reparameterize(self, mu, logvar):
75
+ if self.training and logvar is not None:
76
+ std = torch.exp(logvar)
77
+ eps = torch.randn_like(std)
78
+ return eps.mul(std).add_(mu)
79
+ else:
80
+ return mu
81
+
82
+ def forward(self, x, adj):
83
+ mu, logvar = self.encode(x, adj)
84
+ gnn_z = self.reparameterize(mu, logvar)
85
+ x_reconstructed = self.decoder(gnn_z)
86
+ pred_label = F.softmax(self.cluster(gnn_z),dim=1)
87
+ return pred_label, x_reconstructed, gnn_z, mu, logvar
gsMap/GNN_VAE/train.py ADDED
@@ -0,0 +1,97 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Jul 4 19:58:58 2023
5
+
6
+ @author: songliyang
7
+ """
8
+ import time
9
+
10
+ import torch
11
+ from progress.bar import Bar
12
+
13
+ from gsMap.GNN_VAE.model import GNN_VAE_Model
14
+
15
+
16
+ def reconstruction_loss(decoded, x):
17
+ loss_fn = torch.nn.MSELoss()
18
+ loss = loss_fn(decoded, x)
19
+ return loss
20
+
21
+
22
+ def label_loss(pred_label, true_label):
23
+ loss_fn = torch.nn.CrossEntropyLoss()
24
+ loss = loss_fn(pred_label, true_label)
25
+ return loss
26
+
27
+
28
+ class Model_Train:
29
+ def __init__(self, node_X, graph_dict, params, label=None):
30
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+ torch.cuda.empty_cache()
32
+
33
+ self.params = params
34
+ self.device = device
35
+ self.epochs = params.epochs
36
+ self.node_X = torch.FloatTensor(node_X.copy()).to(device)
37
+ self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
38
+ self.label = label
39
+ self.num_classes = 1
40
+
41
+ if not self.label is None:
42
+ self.label = torch.tensor(self.label).to(self.device)
43
+ self.num_classes = len(self.label.unique())
44
+
45
+ # Set Model
46
+ self.model = GNN_VAE_Model(self.params.feat_cell,self.params,self.num_classes).to(device)
47
+ self.optimizer = torch.optim.Adam(params = list(self.model.parameters()),
48
+ lr = self.params.gcn_lr, weight_decay = self.params.gcn_decay)
49
+
50
+ # Train
51
+ def run_train(self):
52
+ self.model.train()
53
+ prev_loss = float('inf')
54
+
55
+ bar = Bar('GAT-AE model train:', max = self.epochs)
56
+ bar.check_tty = False
57
+ for epoch in range(self.epochs):
58
+ start_time = time.time()
59
+ self.model.train()
60
+ self.optimizer.zero_grad()
61
+ pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
62
+ loss_rec = reconstruction_loss(de_feat, self.node_X)
63
+
64
+ # Check whether annotation was provided
65
+ if not self.label is None:
66
+ loss_pre = label_loss(pred_label, self.label)
67
+ loss = (self.params.rec_w * loss_rec) + (self.params.label_w * loss_pre)
68
+ else:
69
+ loss = loss_rec
70
+
71
+ loss.backward()
72
+ self.optimizer.step()
73
+
74
+ # Update process
75
+ end_time = time.time()
76
+ batch_time = end_time - start_time
77
+
78
+
79
+ bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
80
+ bar.suffix = bar_str.format(epoch + 1,self.epochs,
81
+ batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
82
+ bar.next()
83
+
84
+ # Check convergence
85
+ if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
86
+ print('\nConvergence reached. Training stopped.')
87
+ break
88
+
89
+ prev_loss = loss.item()
90
+
91
+ bar.finish()
92
+ #-
93
+ def get_latent(self):
94
+ self.model.eval()
95
+ pred, de_fea, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
96
+ latent_z = latent_z.data.cpu().numpy()
97
+ return latent_z
gsMap/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ '''
2
+ Genetics-informed pathogenic spatial mapping
3
+ '''
4
+
5
+ __version__ = '1.60'
gsMap/__main__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .main import main
2
+ if __name__ == '__main__':
3
+ main()
@@ -0,0 +1,163 @@
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scanpy as sc
7
+ import scipy as sp
8
+
9
+ from gsMap.config import CauchyCombinationConfig, add_Cauchy_combination_args
10
+
11
+ # The fun of cauchy combination
12
+ def acat_test(pvalues, weights=None):
13
+ '''acat_test()
14
+ Aggregated Cauchy Assocaition Test
15
+ A p-value combination method using the Cauchy distribution.
16
+
17
+ Inspired by: https://github.com/yaowuliu/ACAT/blob/master/R/ACAT.R
18
+ Inputs:
19
+ pvalues: <list or numpy array>
20
+ The p-values you want to combine.
21
+ weights: <list or numpy array>, default=None
22
+ The weights for each of the p-values. If None, equal weights are used.
23
+
24
+ Returns:
25
+ pval: <float>
26
+ The ACAT combined p-value.
27
+ '''
28
+ if any(np.isnan(pvalues)):
29
+ raise Exception("Cannot have NAs in the p-values.")
30
+ if any([(i > 1) | (i < 0) for i in pvalues]):
31
+ raise Exception("P-values must be between 0 and 1.")
32
+ if any([i == 1 for i in pvalues]) & any([i == 0 for i in pvalues]):
33
+ raise Exception("Cannot have both 0 and 1 p-values.")
34
+ if any([i == 0 for i in pvalues]):
35
+ print("Warn: p-values are exactly 0.")
36
+ return 0
37
+ if any([i == 1 for i in pvalues]):
38
+ print("Warn: p-values are exactly 1.")
39
+ return 1
40
+ if weights == None:
41
+ weights = [1 / len(pvalues) for i in pvalues]
42
+ elif len(weights) != len(pvalues):
43
+ raise Exception("Length of weights and p-values differs.")
44
+ elif any([i < 0 for i in weights]):
45
+ raise Exception("All weights must be positive.")
46
+ else:
47
+ weights = [i / len(weights) for i in weights]
48
+
49
+ pvalues = np.array(pvalues)
50
+ weights = np.array(weights)
51
+
52
+ if any([i < 1e-16 for i in pvalues]) == False:
53
+ cct_stat = sum(weights * np.tan((0.5 - pvalues) * np.pi))
54
+ else:
55
+ is_small = [i < (1e-16) for i in pvalues]
56
+ is_large = [i >= (1e-16) for i in pvalues]
57
+ cct_stat = sum((weights[is_small] / pvalues[is_small]) / np.pi)
58
+ cct_stat += sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
59
+
60
+ if cct_stat > 1e15:
61
+ pval = (1 / cct_stat) / np.pi
62
+ else:
63
+ pval = 1 - sp.stats.cauchy.cdf(cct_stat)
64
+
65
+ return pval
66
+
67
+
68
+ def run_Cauchy_combination(config:CauchyCombinationConfig):
69
+ # Load the ldsc results
70
+ print(f'------Loading LDSC results of {config.input_ldsc_dir}...')
71
+ ldsc_input_file= Path(config.input_ldsc_dir)/f'{config.sample_name}_{config.trait_name}.csv.gz'
72
+ ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
73
+ ldsc.spot = ldsc.spot.astype(str).replace('\.0', '', regex=True)
74
+ ldsc.index = ldsc.spot
75
+ if config.meta is None:
76
+ # Load the spatial data
77
+ print(f'------Loading ST data of {config.input_hdf5_path}...')
78
+ spe = sc.read_h5ad(f'{config.input_hdf5_path}')
79
+
80
+ common_cell = np.intersect1d(ldsc.index, spe.obs_names)
81
+ spe = spe[common_cell,]
82
+ ldsc = ldsc.loc[common_cell]
83
+
84
+ # Add the annotation
85
+ ldsc['annotation'] = spe.obs.loc[ldsc.spot][config.annotation].to_list()
86
+
87
+ elif config.meta is not None:
88
+ # Or Load the additional annotation (just for the macaque data at this stage: 2023Nov25)
89
+ print(f'------Loading additional annotation...')
90
+ meta = pd.read_csv(config.meta, index_col=0)
91
+ meta = meta.loc[meta.slide == config.slide]
92
+ meta.index = meta.cell_id.astype(str).replace('\.0', '', regex=True)
93
+
94
+ common_cell = np.intersect1d(ldsc.index, meta.index)
95
+ meta = meta.loc[common_cell]
96
+ ldsc = ldsc.loc[common_cell]
97
+
98
+ # Add the annotation
99
+ ldsc['annotation'] = meta.loc[ldsc.spot][config.annotation].to_list()
100
+ # Perform the Cauchy combination based on the given annotations
101
+ p_cauchy = []
102
+ p_median = []
103
+ for ct in np.unique(ldsc.annotation):
104
+ p_temp = ldsc.loc[ldsc['annotation'] == ct, 'p']
105
+
106
+ # The Cauchy test is sensitive to very small p-values, so extreme outliers should be considered for removal...
107
+ # to enhance robustness, particularly in cases where spot annotations may be incorrect.
108
+ # p_cauchy_temp = acat_test(p_temp[p_temp != np.min(p_temp)])
109
+ p_temp_log = -np.log10(p_temp)
110
+ median_log = np.median(p_temp_log)
111
+ IQR_log = np.percentile(p_temp_log, 75) - np.percentile(p_temp_log, 25)
112
+
113
+ p_use = p_temp[p_temp_log < median_log + 3*IQR_log]
114
+ n_remove = len(p_temp) - len(p_use)
115
+
116
+ # Outlier: -log10(p) < median + 3IQR && len(outlier set) < 20
117
+ if (0 < n_remove < 20):
118
+ print(f'Remove {n_remove}/{len(p_temp)} outliers (median + 3IQR) for {ct}.')
119
+ p_cauchy_temp = acat_test(p_use)
120
+ else:
121
+ p_cauchy_temp = acat_test(p_temp)
122
+
123
+ p_median_temp = np.median(p_temp)
124
+
125
+ p_cauchy.append(p_cauchy_temp)
126
+ p_median.append(p_median_temp)
127
+ # p_tissue = pd.DataFrame(p_cauchy,p_median,np.unique(ldsc.annotation))
128
+ data = {'p_cauchy': p_cauchy, 'p_median': p_median, 'annotation': np.unique(ldsc.annotation)}
129
+ p_tissue = pd.DataFrame(data)
130
+ p_tissue.columns = ['p_cauchy', 'p_median', 'annotation']
131
+ # Save the results
132
+ output_dir = Path(config.output_cauchy_dir)
133
+ output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
134
+ output_file = output_dir / f'{config.sample_name}_{config.trait_name}.Cauchy.csv.gz'
135
+ p_tissue.to_csv(
136
+ output_file,
137
+ compression='gzip',
138
+ index=False,
139
+ )
140
+
141
+
142
+ if __name__ == '__main__':
143
+ TEST = True
144
+ if TEST:
145
+ test_dir = '/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021'
146
+ name = 'Cortex_151507'
147
+
148
+ config = CauchyCombinationConfig(
149
+ input_hdf5_path= f'{test_dir}/{name}/hdf5/{name}_add_latent.h5ad',
150
+ input_ldsc_dir=
151
+ f'/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021/snake_workdir/Cortex_151507/ldsc/',
152
+ sample_name=name,
153
+ annotation='layer_guess',
154
+ output_cauchy_dir='/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021/snake_workdir/Cortex_151507/cauchy/',
155
+ trait_name='adult1_adult2_onset_asthma',
156
+ )
157
+ else:
158
+
159
+ parser = argparse.ArgumentParser(description="Run Cauchy Combination Analysis")
160
+ add_Cauchy_combination_args(parser)
161
+ args = parser.parse_args()
162
+ config = CauchyCombinationConfig(**vars(args))
163
+ run_Cauchy_combination(config)