gsMap 1.65__py3-none-any.whl → 1.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN_VAE/adjacency_matrix.py +48 -68
- gsMap/GNN_VAE/model.py +68 -66
- gsMap/GNN_VAE/train.py +50 -61
- gsMap/__init__.py +1 -1
- gsMap/config.py +4 -4
- gsMap/find_latent_representation.py +103 -103
- gsMap/format_sumstats.py +20 -20
- gsMap/latent_to_gene.py +125 -109
- gsMap/spatial_ldsc_multiple_sumstats.py +0 -2
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/METADATA +2 -2
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/RECORD +14 -14
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/LICENSE +0 -0
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/WHEEL +0 -0
- {gsmap-1.65.dist-info → gsmap-1.67.dist-info}/entry_points.txt +0 -0
@@ -1,95 +1,75 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
"""
|
4
|
-
Created on Tue Jul 4 21:31:27 2023
|
5
|
-
|
6
|
-
@author: songliyang
|
7
|
-
"""
|
8
1
|
import numpy as np
|
9
2
|
import pandas as pd
|
10
3
|
import scipy.sparse as sp
|
11
|
-
|
4
|
+
from sklearn.neighbors import NearestNeighbors
|
12
5
|
import torch
|
13
6
|
|
14
|
-
|
15
|
-
|
16
|
-
"""\
|
17
|
-
Construct the spatial neighbor networks.
|
18
|
-
"""
|
19
|
-
#-
|
7
|
+
def cal_spatial_net(adata, n_neighbors=5, verbose=True):
|
8
|
+
"""Construct the spatial neighbor network."""
|
20
9
|
if verbose:
|
21
10
|
print('------Calculating spatial graph...')
|
22
|
-
coor = pd.DataFrame(adata.obsm['spatial'])
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
|
37
|
-
id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
|
38
|
-
Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
|
39
|
-
Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
|
40
|
-
#-
|
41
|
-
return Spatial_Net
|
42
|
-
|
11
|
+
coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
|
12
|
+
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
|
13
|
+
distances, indices = nbrs.kneighbors(coor)
|
14
|
+
n_cells, n_neighbors = indices.shape
|
15
|
+
cell_indices = np.arange(n_cells)
|
16
|
+
cell1 = np.repeat(cell_indices, n_neighbors)
|
17
|
+
cell2 = indices.flatten()
|
18
|
+
distance = distances.flatten()
|
19
|
+
knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
|
20
|
+
knn_df = knn_df[knn_df['Distance'] > 0].copy()
|
21
|
+
cell_id_map = dict(zip(cell_indices, coor.index))
|
22
|
+
knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
|
23
|
+
knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
|
24
|
+
return knn_df
|
43
25
|
|
44
26
|
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
|
45
27
|
"""Convert a scipy sparse matrix to a torch sparse tensor."""
|
46
28
|
sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
47
|
-
indices = torch.from_numpy(
|
29
|
+
indices = torch.from_numpy(
|
30
|
+
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
|
31
|
+
)
|
48
32
|
values = torch.from_numpy(sparse_mx.data)
|
49
33
|
shape = torch.Size(sparse_mx.shape)
|
50
34
|
return torch.sparse.FloatTensor(indices, values, shape)
|
51
35
|
|
52
|
-
|
53
36
|
def preprocess_graph(adj):
|
37
|
+
"""Symmetrically normalize the adjacency matrix."""
|
54
38
|
adj = sp.coo_matrix(adj)
|
55
39
|
adj_ = adj + sp.eye(adj.shape[0])
|
56
|
-
rowsum = np.array(adj_.sum(1))
|
57
|
-
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5)
|
40
|
+
rowsum = np.array(adj_.sum(1)).flatten()
|
41
|
+
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
|
58
42
|
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
|
59
43
|
return sparse_mx_to_torch_sparse_tensor(adj_normalized)
|
60
44
|
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
# Construct the neighbor graph
|
65
|
-
Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
|
66
|
-
#-
|
45
|
+
def construct_adjacency_matrix(adata, params, verbose=True):
|
46
|
+
"""Construct the adjacency matrix from spatial data."""
|
47
|
+
spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
|
67
48
|
if verbose:
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
49
|
+
num_edges = spatial_net.shape[0]
|
50
|
+
num_cells = adata.n_obs
|
51
|
+
print(f'The graph contains {num_edges} edges, {num_cells} cells.')
|
52
|
+
print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
|
53
|
+
cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
|
54
|
+
spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
|
55
|
+
spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
|
56
|
+
if params.weighted_adj:
|
57
|
+
distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
|
58
|
+
weights = np.exp(-0.5 * distance_normalized ** 2)
|
59
|
+
adj_org = sp.coo_matrix(
|
60
|
+
(weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
|
61
|
+
shape=(adata.n_obs, adata.n_obs)
|
62
|
+
)
|
81
63
|
else:
|
82
|
-
adj_org = sp.coo_matrix(
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
#-
|
64
|
+
adj_org = sp.coo_matrix(
|
65
|
+
(np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
|
66
|
+
shape=(adata.n_obs, adata.n_obs)
|
67
|
+
)
|
68
|
+
adj_norm = preprocess_graph(adj_org)
|
69
|
+
norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
|
89
70
|
graph_dict = {
|
90
71
|
"adj_org": adj_org,
|
91
|
-
"adj_norm":
|
92
|
-
"norm_value":
|
72
|
+
"adj_norm": adj_norm,
|
73
|
+
"norm_value": norm_value
|
93
74
|
}
|
94
|
-
#-
|
95
75
|
return graph_dict
|
gsMap/GNN_VAE/model.py
CHANGED
@@ -1,87 +1,89 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
"""
|
4
|
-
Created on Mon Jul 3 11:42:44 2023
|
5
|
-
|
6
|
-
@author: songliyang
|
7
|
-
"""
|
8
|
-
|
9
1
|
import torch
|
10
2
|
import torch.nn as nn
|
11
3
|
import torch.nn.functional as F
|
12
4
|
from torch_geometric.nn import GATConv
|
13
5
|
|
14
|
-
|
15
6
|
def full_block(in_features, out_features, p_drop):
|
16
|
-
return nn.Sequential(
|
17
|
-
|
18
|
-
|
19
|
-
|
7
|
+
return nn.Sequential(
|
8
|
+
nn.Linear(in_features, out_features),
|
9
|
+
nn.BatchNorm1d(out_features),
|
10
|
+
nn.ELU(),
|
11
|
+
nn.Dropout(p=p_drop)
|
12
|
+
)
|
20
13
|
|
21
|
-
|
22
|
-
|
23
|
-
def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
|
14
|
+
class GATModel(nn.Module):
|
15
|
+
def __init__(self, input_dim, params, num_classes=1):
|
24
16
|
super().__init__()
|
25
|
-
self.conv1 = GATConv(in_features, out_features,heads)
|
26
|
-
self.act = act
|
27
|
-
self.dr = dr
|
28
|
-
#-
|
29
|
-
def forward(self, x, edge_index):
|
30
|
-
out = self.conv1(x, edge_index)
|
31
|
-
out = self.act(out)
|
32
|
-
out = F.dropout(out, self.dr, self.training)
|
33
|
-
return out
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
class GNN_VAE_Model(nn.Module):
|
38
|
-
def __init__(self, input_dim,params,num_classes=1):
|
39
|
-
super(GNN_VAE_Model, self).__init__()
|
40
17
|
self.var = params.var
|
41
18
|
self.num_classes = num_classes
|
42
|
-
|
19
|
+
self.params = params
|
20
|
+
|
43
21
|
# Encoder
|
44
|
-
self.encoder = nn.Sequential(
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
self.
|
51
|
-
|
52
|
-
|
22
|
+
self.encoder = nn.Sequential(
|
23
|
+
full_block(input_dim, params.feat_hidden1, params.p_drop),
|
24
|
+
full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
|
25
|
+
)
|
26
|
+
|
27
|
+
# GAT Layers
|
28
|
+
self.gat1 = GATConv(
|
29
|
+
in_channels=params.feat_hidden2,
|
30
|
+
out_channels=params.gat_hidden1,
|
31
|
+
heads=params.nheads,
|
32
|
+
dropout=params.p_drop
|
33
|
+
)
|
34
|
+
self.gat2 = GATConv(
|
35
|
+
in_channels=params.gat_hidden1 * params.nheads,
|
36
|
+
out_channels=params.gat_hidden2,
|
37
|
+
heads=1,
|
38
|
+
concat=False,
|
39
|
+
dropout=params.p_drop
|
40
|
+
)
|
41
|
+
if self.var:
|
42
|
+
self.gat3 = GATConv(
|
43
|
+
in_channels=params.gat_hidden1 * params.nheads,
|
44
|
+
out_channels=params.gat_hidden2,
|
45
|
+
heads=1,
|
46
|
+
concat=False,
|
47
|
+
dropout=params.p_drop
|
48
|
+
)
|
49
|
+
|
53
50
|
# Decoder
|
54
|
-
self.decoder = nn.Sequential(
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
self.cluster
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
51
|
+
self.decoder = nn.Sequential(
|
52
|
+
full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
|
53
|
+
full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
|
54
|
+
nn.Linear(params.feat_hidden1, input_dim)
|
55
|
+
)
|
56
|
+
|
57
|
+
# Clustering Layer
|
58
|
+
self.cluster = nn.Sequential(
|
59
|
+
full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
|
60
|
+
nn.Linear(params.feat_hidden2, self.num_classes)
|
61
|
+
)
|
62
|
+
|
63
|
+
def encode(self, x, edge_index):
|
64
|
+
x = self.encoder(x)
|
65
|
+
x = self.gat1(x, edge_index)
|
66
|
+
x = F.relu(x)
|
67
|
+
x = F.dropout(x, p=self.params.p_drop, training=self.training)
|
68
|
+
|
69
|
+
mu = self.gat2(x, edge_index)
|
68
70
|
if self.var:
|
69
|
-
logvar = self.
|
71
|
+
logvar = self.gat3(x, edge_index)
|
70
72
|
return mu, logvar
|
71
73
|
else:
|
72
74
|
return mu, None
|
73
|
-
|
75
|
+
|
74
76
|
def reparameterize(self, mu, logvar):
|
75
77
|
if self.training and logvar is not None:
|
76
|
-
std = torch.exp(logvar)
|
78
|
+
std = torch.exp(0.5 * logvar)
|
77
79
|
eps = torch.randn_like(std)
|
78
|
-
return eps
|
80
|
+
return eps * std + mu
|
79
81
|
else:
|
80
82
|
return mu
|
81
|
-
|
82
|
-
def forward(self, x,
|
83
|
-
mu, logvar = self.encode(x,
|
84
|
-
|
85
|
-
x_reconstructed = self.decoder(
|
86
|
-
pred_label = F.softmax(self.cluster(
|
87
|
-
return pred_label, x_reconstructed,
|
83
|
+
|
84
|
+
def forward(self, x, edge_index):
|
85
|
+
mu, logvar = self.encode(x, edge_index)
|
86
|
+
z = self.reparameterize(mu, logvar)
|
87
|
+
x_reconstructed = self.decoder(z)
|
88
|
+
pred_label = F.softmax(self.cluster(z), dim=1)
|
89
|
+
return pred_label, x_reconstructed, z, mu, logvar
|
gsMap/GNN_VAE/train.py
CHANGED
@@ -1,97 +1,86 @@
|
|
1
|
-
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
"""
|
4
|
-
Created on Tue Jul 4 19:58:58 2023
|
5
|
-
|
6
|
-
@author: songliyang
|
7
|
-
"""
|
1
|
+
import logging
|
8
2
|
import time
|
9
3
|
|
10
4
|
import torch
|
5
|
+
import torch.nn.functional as F
|
11
6
|
from progress.bar import Bar
|
12
7
|
|
13
|
-
from gsMap.GNN_VAE.model import
|
8
|
+
from gsMap.GNN_VAE.model import GATModel
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
14
11
|
|
15
12
|
|
16
13
|
def reconstruction_loss(decoded, x):
|
17
|
-
|
18
|
-
|
19
|
-
return loss
|
14
|
+
"""Compute the mean squared error loss."""
|
15
|
+
return F.mse_loss(decoded, x)
|
20
16
|
|
21
17
|
|
22
18
|
def label_loss(pred_label, true_label):
|
23
|
-
|
24
|
-
|
25
|
-
return loss
|
26
|
-
|
27
|
-
|
28
|
-
class Model_Train:
|
29
|
-
def __init__(self, node_X, graph_dict, params, label=None):
|
30
|
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
31
|
-
torch.cuda.empty_cache()
|
19
|
+
"""Compute the cross-entropy loss."""
|
20
|
+
return F.cross_entropy(pred_label, true_label)
|
32
21
|
|
22
|
+
class ModelTrainer:
|
23
|
+
def __init__(self, node_x, graph_dict, params, label=None):
|
24
|
+
"""Initialize the ModelTrainer with data and hyperparameters."""
|
25
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
33
26
|
self.params = params
|
34
|
-
self.device = device
|
35
27
|
self.epochs = params.epochs
|
36
|
-
self.
|
37
|
-
self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
|
28
|
+
self.node_x = torch.FloatTensor(node_x).to(self.device)
|
29
|
+
self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
|
38
30
|
self.label = label
|
39
31
|
self.num_classes = 1
|
40
|
-
|
41
|
-
if
|
32
|
+
|
33
|
+
if self.label is not None:
|
42
34
|
self.label = torch.tensor(self.label).to(self.device)
|
43
|
-
self.num_classes = len(self.label
|
44
|
-
|
45
|
-
# Set
|
46
|
-
self.model =
|
47
|
-
self.optimizer = torch.optim.Adam(
|
48
|
-
|
49
|
-
|
50
|
-
|
35
|
+
self.num_classes = len(torch.unique(self.label))
|
36
|
+
|
37
|
+
# Set up the model
|
38
|
+
self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
|
39
|
+
self.optimizer = torch.optim.Adam(
|
40
|
+
self.model.parameters(),
|
41
|
+
lr=self.params.gat_lr,
|
42
|
+
weight_decay=self.params.gcn_decay
|
43
|
+
)
|
44
|
+
|
51
45
|
def run_train(self):
|
46
|
+
"""Train the model."""
|
52
47
|
self.model.train()
|
53
48
|
prev_loss = float('inf')
|
54
|
-
|
55
|
-
bar =
|
56
|
-
|
49
|
+
bar = Bar('GAT-AE model train:', max=self.epochs)
|
50
|
+
bar.check_tty = False
|
51
|
+
|
52
|
+
logger.info('Start training...')
|
57
53
|
for epoch in range(self.epochs):
|
58
54
|
start_time = time.time()
|
59
|
-
self.model.train()
|
60
55
|
self.optimizer.zero_grad()
|
61
|
-
pred_label, de_feat, latent_z, mu, logvar = self.model(self.
|
62
|
-
loss_rec = reconstruction_loss(de_feat, self.
|
63
|
-
|
64
|
-
|
65
|
-
if not self.label is None:
|
56
|
+
pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
|
57
|
+
loss_rec = reconstruction_loss(de_feat, self.node_x)
|
58
|
+
|
59
|
+
if self.label is not None:
|
66
60
|
loss_pre = label_loss(pred_label, self.label)
|
67
|
-
loss =
|
61
|
+
loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
|
68
62
|
else:
|
69
63
|
loss = loss_rec
|
70
|
-
|
64
|
+
|
71
65
|
loss.backward()
|
72
66
|
self.optimizer.step()
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
|
80
|
-
bar.suffix = bar_str.format(epoch + 1,self.epochs,
|
81
|
-
batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
|
67
|
+
|
68
|
+
batch_time = time.time() - start_time
|
69
|
+
left_time = batch_time * (self.epochs - epoch - 1) / 60 # in minutes
|
70
|
+
|
71
|
+
bar.suffix = f'{epoch + 1} / {self.epochs} | Left time: {left_time:.2f} mins | Loss: {loss.item():.4f}'
|
82
72
|
bar.next()
|
83
|
-
|
84
|
-
# Check convergence
|
73
|
+
|
85
74
|
if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
|
86
|
-
|
75
|
+
logger.info('\nConvergence reached. Training stopped.')
|
87
76
|
break
|
88
77
|
|
89
78
|
prev_loss = loss.item()
|
90
|
-
|
91
79
|
bar.finish()
|
92
|
-
|
80
|
+
|
93
81
|
def get_latent(self):
|
82
|
+
"""Retrieve the latent representation from the model."""
|
94
83
|
self.model.eval()
|
95
|
-
|
96
|
-
|
97
|
-
return latent_z
|
84
|
+
with torch.no_grad():
|
85
|
+
_, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
|
86
|
+
return latent_z.cpu().numpy()
|
gsMap/__init__.py
CHANGED
gsMap/config.py
CHANGED
@@ -55,7 +55,8 @@ def add_find_latent_representations_args(parser):
|
|
55
55
|
add_shared_args(parser)
|
56
56
|
parser.add_argument('--input_hdf5_path', required=True, type=str, help='Path to the input HDF5 file.')
|
57
57
|
parser.add_argument('--annotation', required=True, type=str, help='Name of the annotation in adata.obs to use.')
|
58
|
-
parser.add_argument('--data_layer',
|
58
|
+
parser.add_argument('--data_layer', type=str, default='counts', required=True,
|
59
|
+
help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
|
59
60
|
parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs.')
|
60
61
|
parser.add_argument('--feat_hidden1', type=int, default=256, help='Neurons in the first hidden layer.')
|
61
62
|
parser.add_argument('--feat_hidden2', type=int, default=128, help='Neurons in the second hidden layer.')
|
@@ -66,7 +67,6 @@ def add_find_latent_representations_args(parser):
|
|
66
67
|
parser.add_argument('--n_neighbors', type=int, default=11, help='Number of neighbors for GAT.')
|
67
68
|
parser.add_argument('--n_comps', type=int, default=300, help='Number of principal components for PCA.')
|
68
69
|
parser.add_argument('--weighted_adj', action='store_true', help='Use weighted adjacency in GAT.')
|
69
|
-
parser.add_argument('--var', action='store_true', help='Enable variance calculations.')
|
70
70
|
parser.add_argument('--convergence_threshold', type=float, default=1e-4, help='Threshold for convergence.')
|
71
71
|
parser.add_argument('--hierarchically', action='store_true', help='Enable hierarchical latent representation finding.')
|
72
72
|
|
@@ -236,8 +236,8 @@ def add_run_all_mode_args(parser):
|
|
236
236
|
help='Path to the input spatial transcriptomics data (H5AD format).')
|
237
237
|
parser.add_argument('--annotation', type=str, required=True,
|
238
238
|
help='Name of the annotation in adata.obs to use.')
|
239
|
-
parser.add_argument('--data_layer', type=str, default='
|
240
|
-
help='Data layer
|
239
|
+
parser.add_argument('--data_layer', type=str, default='counts', required=True,
|
240
|
+
help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
|
241
241
|
|
242
242
|
# GWAS Data Parameters
|
243
243
|
parser.add_argument('--trait_name', type=str, help='Name of the trait for GWAS analysis (required if sumstats_file is provided).')
|