gsMap 1.66__tar.gz → 1.67__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gsmap-1.66 → gsmap-1.67}/PKG-INFO +2 -2
- {gsmap-1.66 → gsmap-1.67}/docs/source/mouse_example.md +1 -1
- {gsmap-1.66 → gsmap-1.67}/pyproject.toml +1 -1
- gsmap-1.67/src/gsMap/GNN_VAE/adjacency_matrix.py +75 -0
- gsmap-1.67/src/gsMap/GNN_VAE/model.py +89 -0
- gsmap-1.67/src/gsMap/GNN_VAE/train.py +86 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/__init__.py +1 -1
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/config.py +4 -4
- gsmap-1.67/src/gsMap/find_latent_representation.py +145 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/format_sumstats.py +20 -20
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/latent_to_gene.py +0 -7
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/spatial_ldsc_multiple_sumstats.py +0 -2
- gsmap-1.66/src/gsMap/GNN_VAE/adjacency_matrix.py +0 -95
- gsmap-1.66/src/gsMap/GNN_VAE/model.py +0 -87
- gsmap-1.66/src/gsMap/GNN_VAE/train.py +0 -97
- gsmap-1.66/src/gsMap/find_latent_representation.py +0 -145
- {gsmap-1.66 → gsmap-1.67}/.github/workflows/publish-to-pypi.yml +0 -0
- {gsmap-1.66 → gsmap-1.67}/.gitignore +0 -0
- {gsmap-1.66 → gsmap-1.67}/LICENSE +0 -0
- {gsmap-1.66 → gsmap-1.67}/README.md +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/Makefile +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/make.bat +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/requirements.txt +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/_static/schematic.svg +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/cauchy_combination.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/find_latent_representations.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/format_sumstats.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/generate_ldscore.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/latent_to_gene.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/quick_mode.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/report.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api/spatial_ldsc.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/api.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_Height.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_IQ.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_MCHC.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_SCZ.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_Height.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_IQ.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_MCHC.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_SCZ.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/charts/test.json +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/conf.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/data.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/data_format.md +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/index.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/install.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/mouse.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/quick_mode.md +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/release.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/docs/source/tutorials.rst +0 -0
- {gsmap-1.66 → gsmap-1.67}/schematic.png +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/GNN_VAE/__init__.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/__main__.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/cauchy_combination_test.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/diagnosis.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/generate_ldscore.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/main.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/report.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/run_all_mode.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/setup.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/templates/report_template.html +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/__init__.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/generate_r2_matrix.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/jackknife.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/make_annotations.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/manhattan_plot.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/regression_read.py +0 -0
- {gsmap-1.66 → gsmap-1.67}/src/gsMap/visualize.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: gsMap
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.67
|
4
4
|
Summary: Genetics-informed pathogenic spatial mapping
|
5
5
|
Author-email: liyang <songliyang@westlake.edu.cn>, wenhao <chenwenhao@westlake.edu.cn>
|
6
6
|
Requires-Python: >=3.8
|
@@ -27,7 +27,7 @@ Requires-Dist: pyfiglet
|
|
27
27
|
Requires-Dist: plotly
|
28
28
|
Requires-Dist: kaleido
|
29
29
|
Requires-Dist: jinja2
|
30
|
-
Requires-Dist: scanpy
|
30
|
+
Requires-Dist: scanpy >=1.8.0
|
31
31
|
Requires-Dist: zarr
|
32
32
|
Requires-Dist: bitarray
|
33
33
|
Requires-Dist: pyarrow
|
@@ -209,7 +209,7 @@ gsmap run_cauchy_combination \
|
|
209
209
|
--annotation 'annotation'
|
210
210
|
```
|
211
211
|
|
212
|
-
### 6. report generation
|
212
|
+
### 6. report generation (optional)
|
213
213
|
|
214
214
|
**Objective**: Generate gsMap reports, including visualizations of mapping results and diagnostic plots.
|
215
215
|
|
@@ -0,0 +1,75 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pandas as pd
|
3
|
+
import scipy.sparse as sp
|
4
|
+
from sklearn.neighbors import NearestNeighbors
|
5
|
+
import torch
|
6
|
+
|
7
|
+
def cal_spatial_net(adata, n_neighbors=5, verbose=True):
|
8
|
+
"""Construct the spatial neighbor network."""
|
9
|
+
if verbose:
|
10
|
+
print('------Calculating spatial graph...')
|
11
|
+
coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
|
12
|
+
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
|
13
|
+
distances, indices = nbrs.kneighbors(coor)
|
14
|
+
n_cells, n_neighbors = indices.shape
|
15
|
+
cell_indices = np.arange(n_cells)
|
16
|
+
cell1 = np.repeat(cell_indices, n_neighbors)
|
17
|
+
cell2 = indices.flatten()
|
18
|
+
distance = distances.flatten()
|
19
|
+
knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
|
20
|
+
knn_df = knn_df[knn_df['Distance'] > 0].copy()
|
21
|
+
cell_id_map = dict(zip(cell_indices, coor.index))
|
22
|
+
knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
|
23
|
+
knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
|
24
|
+
return knn_df
|
25
|
+
|
26
|
+
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
|
27
|
+
"""Convert a scipy sparse matrix to a torch sparse tensor."""
|
28
|
+
sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
29
|
+
indices = torch.from_numpy(
|
30
|
+
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
|
31
|
+
)
|
32
|
+
values = torch.from_numpy(sparse_mx.data)
|
33
|
+
shape = torch.Size(sparse_mx.shape)
|
34
|
+
return torch.sparse.FloatTensor(indices, values, shape)
|
35
|
+
|
36
|
+
def preprocess_graph(adj):
|
37
|
+
"""Symmetrically normalize the adjacency matrix."""
|
38
|
+
adj = sp.coo_matrix(adj)
|
39
|
+
adj_ = adj + sp.eye(adj.shape[0])
|
40
|
+
rowsum = np.array(adj_.sum(1)).flatten()
|
41
|
+
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
|
42
|
+
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
|
43
|
+
return sparse_mx_to_torch_sparse_tensor(adj_normalized)
|
44
|
+
|
45
|
+
def construct_adjacency_matrix(adata, params, verbose=True):
|
46
|
+
"""Construct the adjacency matrix from spatial data."""
|
47
|
+
spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
|
48
|
+
if verbose:
|
49
|
+
num_edges = spatial_net.shape[0]
|
50
|
+
num_cells = adata.n_obs
|
51
|
+
print(f'The graph contains {num_edges} edges, {num_cells} cells.')
|
52
|
+
print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
|
53
|
+
cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
|
54
|
+
spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
|
55
|
+
spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
|
56
|
+
if params.weighted_adj:
|
57
|
+
distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
|
58
|
+
weights = np.exp(-0.5 * distance_normalized ** 2)
|
59
|
+
adj_org = sp.coo_matrix(
|
60
|
+
(weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
|
61
|
+
shape=(adata.n_obs, adata.n_obs)
|
62
|
+
)
|
63
|
+
else:
|
64
|
+
adj_org = sp.coo_matrix(
|
65
|
+
(np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
|
66
|
+
shape=(adata.n_obs, adata.n_obs)
|
67
|
+
)
|
68
|
+
adj_norm = preprocess_graph(adj_org)
|
69
|
+
norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
|
70
|
+
graph_dict = {
|
71
|
+
"adj_org": adj_org,
|
72
|
+
"adj_norm": adj_norm,
|
73
|
+
"norm_value": norm_value
|
74
|
+
}
|
75
|
+
return graph_dict
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch_geometric.nn import GATConv
|
5
|
+
|
6
|
+
def full_block(in_features, out_features, p_drop):
|
7
|
+
return nn.Sequential(
|
8
|
+
nn.Linear(in_features, out_features),
|
9
|
+
nn.BatchNorm1d(out_features),
|
10
|
+
nn.ELU(),
|
11
|
+
nn.Dropout(p=p_drop)
|
12
|
+
)
|
13
|
+
|
14
|
+
class GATModel(nn.Module):
|
15
|
+
def __init__(self, input_dim, params, num_classes=1):
|
16
|
+
super().__init__()
|
17
|
+
self.var = params.var
|
18
|
+
self.num_classes = num_classes
|
19
|
+
self.params = params
|
20
|
+
|
21
|
+
# Encoder
|
22
|
+
self.encoder = nn.Sequential(
|
23
|
+
full_block(input_dim, params.feat_hidden1, params.p_drop),
|
24
|
+
full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
|
25
|
+
)
|
26
|
+
|
27
|
+
# GAT Layers
|
28
|
+
self.gat1 = GATConv(
|
29
|
+
in_channels=params.feat_hidden2,
|
30
|
+
out_channels=params.gat_hidden1,
|
31
|
+
heads=params.nheads,
|
32
|
+
dropout=params.p_drop
|
33
|
+
)
|
34
|
+
self.gat2 = GATConv(
|
35
|
+
in_channels=params.gat_hidden1 * params.nheads,
|
36
|
+
out_channels=params.gat_hidden2,
|
37
|
+
heads=1,
|
38
|
+
concat=False,
|
39
|
+
dropout=params.p_drop
|
40
|
+
)
|
41
|
+
if self.var:
|
42
|
+
self.gat3 = GATConv(
|
43
|
+
in_channels=params.gat_hidden1 * params.nheads,
|
44
|
+
out_channels=params.gat_hidden2,
|
45
|
+
heads=1,
|
46
|
+
concat=False,
|
47
|
+
dropout=params.p_drop
|
48
|
+
)
|
49
|
+
|
50
|
+
# Decoder
|
51
|
+
self.decoder = nn.Sequential(
|
52
|
+
full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
|
53
|
+
full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
|
54
|
+
nn.Linear(params.feat_hidden1, input_dim)
|
55
|
+
)
|
56
|
+
|
57
|
+
# Clustering Layer
|
58
|
+
self.cluster = nn.Sequential(
|
59
|
+
full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
|
60
|
+
nn.Linear(params.feat_hidden2, self.num_classes)
|
61
|
+
)
|
62
|
+
|
63
|
+
def encode(self, x, edge_index):
|
64
|
+
x = self.encoder(x)
|
65
|
+
x = self.gat1(x, edge_index)
|
66
|
+
x = F.relu(x)
|
67
|
+
x = F.dropout(x, p=self.params.p_drop, training=self.training)
|
68
|
+
|
69
|
+
mu = self.gat2(x, edge_index)
|
70
|
+
if self.var:
|
71
|
+
logvar = self.gat3(x, edge_index)
|
72
|
+
return mu, logvar
|
73
|
+
else:
|
74
|
+
return mu, None
|
75
|
+
|
76
|
+
def reparameterize(self, mu, logvar):
|
77
|
+
if self.training and logvar is not None:
|
78
|
+
std = torch.exp(0.5 * logvar)
|
79
|
+
eps = torch.randn_like(std)
|
80
|
+
return eps * std + mu
|
81
|
+
else:
|
82
|
+
return mu
|
83
|
+
|
84
|
+
def forward(self, x, edge_index):
|
85
|
+
mu, logvar = self.encode(x, edge_index)
|
86
|
+
z = self.reparameterize(mu, logvar)
|
87
|
+
x_reconstructed = self.decoder(z)
|
88
|
+
pred_label = F.softmax(self.cluster(z), dim=1)
|
89
|
+
return pred_label, x_reconstructed, z, mu, logvar
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from progress.bar import Bar
|
7
|
+
|
8
|
+
from gsMap.GNN_VAE.model import GATModel
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def reconstruction_loss(decoded, x):
|
14
|
+
"""Compute the mean squared error loss."""
|
15
|
+
return F.mse_loss(decoded, x)
|
16
|
+
|
17
|
+
|
18
|
+
def label_loss(pred_label, true_label):
|
19
|
+
"""Compute the cross-entropy loss."""
|
20
|
+
return F.cross_entropy(pred_label, true_label)
|
21
|
+
|
22
|
+
class ModelTrainer:
|
23
|
+
def __init__(self, node_x, graph_dict, params, label=None):
|
24
|
+
"""Initialize the ModelTrainer with data and hyperparameters."""
|
25
|
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
26
|
+
self.params = params
|
27
|
+
self.epochs = params.epochs
|
28
|
+
self.node_x = torch.FloatTensor(node_x).to(self.device)
|
29
|
+
self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
|
30
|
+
self.label = label
|
31
|
+
self.num_classes = 1
|
32
|
+
|
33
|
+
if self.label is not None:
|
34
|
+
self.label = torch.tensor(self.label).to(self.device)
|
35
|
+
self.num_classes = len(torch.unique(self.label))
|
36
|
+
|
37
|
+
# Set up the model
|
38
|
+
self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
|
39
|
+
self.optimizer = torch.optim.Adam(
|
40
|
+
self.model.parameters(),
|
41
|
+
lr=self.params.gat_lr,
|
42
|
+
weight_decay=self.params.gcn_decay
|
43
|
+
)
|
44
|
+
|
45
|
+
def run_train(self):
|
46
|
+
"""Train the model."""
|
47
|
+
self.model.train()
|
48
|
+
prev_loss = float('inf')
|
49
|
+
bar = Bar('GAT-AE model train:', max=self.epochs)
|
50
|
+
bar.check_tty = False
|
51
|
+
|
52
|
+
logger.info('Start training...')
|
53
|
+
for epoch in range(self.epochs):
|
54
|
+
start_time = time.time()
|
55
|
+
self.optimizer.zero_grad()
|
56
|
+
pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
|
57
|
+
loss_rec = reconstruction_loss(de_feat, self.node_x)
|
58
|
+
|
59
|
+
if self.label is not None:
|
60
|
+
loss_pre = label_loss(pred_label, self.label)
|
61
|
+
loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
|
62
|
+
else:
|
63
|
+
loss = loss_rec
|
64
|
+
|
65
|
+
loss.backward()
|
66
|
+
self.optimizer.step()
|
67
|
+
|
68
|
+
batch_time = time.time() - start_time
|
69
|
+
left_time = batch_time * (self.epochs - epoch - 1) / 60 # in minutes
|
70
|
+
|
71
|
+
bar.suffix = f'{epoch + 1} / {self.epochs} | Left time: {left_time:.2f} mins | Loss: {loss.item():.4f}'
|
72
|
+
bar.next()
|
73
|
+
|
74
|
+
if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
|
75
|
+
logger.info('\nConvergence reached. Training stopped.')
|
76
|
+
break
|
77
|
+
|
78
|
+
prev_loss = loss.item()
|
79
|
+
bar.finish()
|
80
|
+
|
81
|
+
def get_latent(self):
|
82
|
+
"""Retrieve the latent representation from the model."""
|
83
|
+
self.model.eval()
|
84
|
+
with torch.no_grad():
|
85
|
+
_, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
|
86
|
+
return latent_z.cpu().numpy()
|
@@ -55,7 +55,8 @@ def add_find_latent_representations_args(parser):
|
|
55
55
|
add_shared_args(parser)
|
56
56
|
parser.add_argument('--input_hdf5_path', required=True, type=str, help='Path to the input HDF5 file.')
|
57
57
|
parser.add_argument('--annotation', required=True, type=str, help='Name of the annotation in adata.obs to use.')
|
58
|
-
parser.add_argument('--data_layer',
|
58
|
+
parser.add_argument('--data_layer', type=str, default='counts', required=True,
|
59
|
+
help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
|
59
60
|
parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs.')
|
60
61
|
parser.add_argument('--feat_hidden1', type=int, default=256, help='Neurons in the first hidden layer.')
|
61
62
|
parser.add_argument('--feat_hidden2', type=int, default=128, help='Neurons in the second hidden layer.')
|
@@ -66,7 +67,6 @@ def add_find_latent_representations_args(parser):
|
|
66
67
|
parser.add_argument('--n_neighbors', type=int, default=11, help='Number of neighbors for GAT.')
|
67
68
|
parser.add_argument('--n_comps', type=int, default=300, help='Number of principal components for PCA.')
|
68
69
|
parser.add_argument('--weighted_adj', action='store_true', help='Use weighted adjacency in GAT.')
|
69
|
-
parser.add_argument('--var', action='store_true', help='Enable variance calculations.')
|
70
70
|
parser.add_argument('--convergence_threshold', type=float, default=1e-4, help='Threshold for convergence.')
|
71
71
|
parser.add_argument('--hierarchically', action='store_true', help='Enable hierarchical latent representation finding.')
|
72
72
|
|
@@ -236,8 +236,8 @@ def add_run_all_mode_args(parser):
|
|
236
236
|
help='Path to the input spatial transcriptomics data (H5AD format).')
|
237
237
|
parser.add_argument('--annotation', type=str, required=True,
|
238
238
|
help='Name of the annotation in adata.obs to use.')
|
239
|
-
parser.add_argument('--data_layer', type=str, default='
|
240
|
-
help='Data layer
|
239
|
+
parser.add_argument('--data_layer', type=str, default='counts', required=True,
|
240
|
+
help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
|
241
241
|
|
242
242
|
# GWAS Data Parameters
|
243
243
|
parser.add_argument('--trait_name', type=str, help='Name of the trait for GWAS analysis (required if sumstats_file is provided).')
|
@@ -0,0 +1,145 @@
|
|
1
|
+
import logging
|
2
|
+
import random
|
3
|
+
import numpy as np
|
4
|
+
import scanpy as sc
|
5
|
+
import torch
|
6
|
+
from sklearn.decomposition import PCA
|
7
|
+
from sklearn.preprocessing import LabelEncoder
|
8
|
+
from gsMap.GNN_VAE.adjacency_matrix import construct_adjacency_matrix
|
9
|
+
from gsMap.GNN_VAE.train import ModelTrainer
|
10
|
+
from gsMap.config import FindLatentRepresentationsConfig
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def set_seed(seed_value):
|
16
|
+
"""
|
17
|
+
Set seed for reproducibility in PyTorch and other libraries.
|
18
|
+
"""
|
19
|
+
torch.manual_seed(seed_value)
|
20
|
+
np.random.seed(seed_value)
|
21
|
+
random.seed(seed_value)
|
22
|
+
if torch.cuda.is_available():
|
23
|
+
logger.info('Using GPU for computations.')
|
24
|
+
torch.cuda.manual_seed(seed_value)
|
25
|
+
torch.cuda.manual_seed_all(seed_value)
|
26
|
+
else:
|
27
|
+
logger.info('Using CPU for computations.')
|
28
|
+
|
29
|
+
def preprocess_data(adata, params):
|
30
|
+
"""
|
31
|
+
Preprocess the AnnData
|
32
|
+
"""
|
33
|
+
logger.info('Preprocessing data...')
|
34
|
+
adata.var_names_make_unique()
|
35
|
+
|
36
|
+
sc.pp.filter_genes(adata, min_cells=30)
|
37
|
+
if params.data_layer in adata.layers.keys():
|
38
|
+
adata.X = adata.layers[params.data_layer]
|
39
|
+
else:
|
40
|
+
raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
|
41
|
+
|
42
|
+
if params.data_layer in ['count', 'counts']:
|
43
|
+
|
44
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
45
|
+
sc.pp.log1p(adata)
|
46
|
+
|
47
|
+
# Identify highly variable genes
|
48
|
+
sc.pp.highly_variable_genes(
|
49
|
+
adata,
|
50
|
+
flavor="seurat_v3",
|
51
|
+
n_top_genes=params.feat_cell,
|
52
|
+
)
|
53
|
+
|
54
|
+
elif params.data_layer in adata.layers.keys():
|
55
|
+
logger.info(f'Using {params.data_layer} data...')
|
56
|
+
sc.pp.highly_variable_genes(
|
57
|
+
adata,
|
58
|
+
flavor="seurat",
|
59
|
+
n_top_genes=params.feat_cell,
|
60
|
+
)
|
61
|
+
|
62
|
+
return adata
|
63
|
+
|
64
|
+
|
65
|
+
class LatentRepresentationFinder:
|
66
|
+
def __init__(self, adata, args: FindLatentRepresentationsConfig):
|
67
|
+
self.params = args
|
68
|
+
|
69
|
+
self.expression_array = adata[:, adata.var.highly_variable].X.copy()
|
70
|
+
|
71
|
+
if self.params.data_layer in ['count', 'counts']:
|
72
|
+
self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
|
73
|
+
|
74
|
+
# Construct the neighboring graph
|
75
|
+
self.graph_dict = construct_adjacency_matrix(adata, self.params)
|
76
|
+
|
77
|
+
def compute_pca(self):
|
78
|
+
self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
|
79
|
+
return self.latent_pca
|
80
|
+
|
81
|
+
def run_gnn_vae(self, label, verbose='whole ST data'):
|
82
|
+
|
83
|
+
# Use PCA if specified
|
84
|
+
if self.params.input_pca:
|
85
|
+
node_X = self.compute_pca()
|
86
|
+
else:
|
87
|
+
node_X = self.expression_array
|
88
|
+
|
89
|
+
# Update the input shape
|
90
|
+
self.params.n_nodes = node_X.shape[0]
|
91
|
+
self.params.feat_cell = node_X.shape[1]
|
92
|
+
|
93
|
+
# Run GNN
|
94
|
+
logger.info(f'Finding latent representations for {verbose}...')
|
95
|
+
gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
|
96
|
+
gvae.run_train()
|
97
|
+
|
98
|
+
del self.graph_dict
|
99
|
+
|
100
|
+
return gvae.get_latent()
|
101
|
+
|
102
|
+
|
103
|
+
def run_find_latent_representation(args: FindLatentRepresentationsConfig):
|
104
|
+
set_seed(2024)
|
105
|
+
|
106
|
+
# Load the ST data
|
107
|
+
logger.info(f'Loading ST data of {args.sample_name}...')
|
108
|
+
adata = sc.read_h5ad(args.input_hdf5_path)
|
109
|
+
logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
|
110
|
+
|
111
|
+
# Load the cell type annotation
|
112
|
+
if args.annotation is not None:
|
113
|
+
# Remove cells without enough annotations
|
114
|
+
adata = adata[~adata.obs[args.annotation].isnull()]
|
115
|
+
num = adata.obs[args.annotation].value_counts()
|
116
|
+
valid_annotations = num[num >= 30].index.to_list()
|
117
|
+
adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
|
118
|
+
|
119
|
+
le = LabelEncoder()
|
120
|
+
adata.obs['categorical_label'] = le.fit_transform(adata.obs[args.annotation])
|
121
|
+
label = adata.obs['categorical_label'].to_numpy()
|
122
|
+
else:
|
123
|
+
label = None
|
124
|
+
|
125
|
+
# Preprocess data
|
126
|
+
adata = preprocess_data(adata, args)
|
127
|
+
|
128
|
+
latent_rep = LatentRepresentationFinder(adata, args)
|
129
|
+
latent_gvae = latent_rep.run_gnn_vae(label)
|
130
|
+
latent_pca = latent_rep.compute_pca()
|
131
|
+
|
132
|
+
# Add latent representations to the AnnData object
|
133
|
+
logger.info('Adding latent representations...')
|
134
|
+
adata.obsm["latent_GVAE"] = latent_gvae
|
135
|
+
adata.obsm["latent_PCA"] = latent_pca
|
136
|
+
|
137
|
+
# Run UMAP based on latent representations
|
138
|
+
for name in ['latent_GVAE', 'latent_PCA']:
|
139
|
+
sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
|
140
|
+
sc.tl.umap(adata)
|
141
|
+
adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
|
142
|
+
|
143
|
+
# Save the AnnData object
|
144
|
+
logger.info('Saving ST data...')
|
145
|
+
adata.write(args.hdf5_with_latent_path)
|
@@ -150,10 +150,10 @@ def gwas_checkname(gwas, config):
|
|
150
150
|
'Pos': 'SNP positions.'
|
151
151
|
}
|
152
152
|
|
153
|
-
|
153
|
+
logger.info(f'\nIterpreting column names as follows:')
|
154
154
|
for key, value in interpreting.items():
|
155
155
|
if key in new_name:
|
156
|
-
|
156
|
+
logger.info(f'{name_dict[key]}: {interpreting[key]}')
|
157
157
|
|
158
158
|
return gwas
|
159
159
|
|
@@ -242,7 +242,7 @@ def gwas_qc(gwas, config):
|
|
242
242
|
Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
|
243
243
|
'''
|
244
244
|
old = len(gwas)
|
245
|
-
|
245
|
+
logger.info(f'\nFiltering SNPs as follows:')
|
246
246
|
# filter: SNPs with missing values
|
247
247
|
drops = {'NA': 0, 'P': 0, 'INFO': 0, 'FRQ': 0, 'A': 0, 'SNP': 0, 'Dup': 0, 'N': 0}
|
248
248
|
|
@@ -250,28 +250,28 @@ def gwas_qc(gwas, config):
|
|
250
250
|
lambda x: x != 'INFO', gwas.columns)).reset_index(drop=True)
|
251
251
|
|
252
252
|
drops['NA'] = old - len(gwas)
|
253
|
-
|
253
|
+
logger.info(f'Removed {drops["NA"]} SNPs with missing values.')
|
254
254
|
|
255
255
|
# filter: SNPs with Info < 0.9
|
256
256
|
if 'INFO' in gwas.columns:
|
257
257
|
old = len(gwas)
|
258
258
|
gwas = gwas.loc[filter_info(gwas['INFO'], config)]
|
259
259
|
drops['INFO'] = old - len(gwas)
|
260
|
-
|
260
|
+
logger.info(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
|
261
261
|
|
262
262
|
# filter: SNPs with MAF <= 0.01
|
263
263
|
if 'FRQ' in gwas.columns:
|
264
264
|
old = len(gwas)
|
265
265
|
gwas = gwas.loc[filter_frq(gwas['FRQ'], config)]
|
266
266
|
drops['FRQ'] += old - len(gwas)
|
267
|
-
|
267
|
+
logger.info(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
|
268
268
|
|
269
269
|
# filter: P-value that out-of-bounds [0,1]
|
270
270
|
if 'P' in gwas.columns:
|
271
271
|
old = len(gwas)
|
272
272
|
gwas = gwas.loc[filter_pvals(gwas['P'], config)]
|
273
273
|
drops['P'] += old - len(gwas)
|
274
|
-
|
274
|
+
logger.info(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
|
275
275
|
|
276
276
|
# filter: Variants that are strand-ambiguous
|
277
277
|
if 'A1' in gwas.columns and 'A2' in gwas.columns:
|
@@ -279,21 +279,21 @@ def gwas_qc(gwas, config):
|
|
279
279
|
gwas.A2 = gwas.A2.str.upper()
|
280
280
|
gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
|
281
281
|
drops['A'] += old - len(gwas)
|
282
|
-
|
282
|
+
logger.info(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
|
283
283
|
|
284
284
|
# filter: Duplicated rs numbers
|
285
285
|
if 'SNP' in gwas.columns:
|
286
286
|
old = len(gwas)
|
287
287
|
gwas = gwas.drop_duplicates(subset='SNP').reset_index(drop=True)
|
288
288
|
drops['Dup'] += old - len(gwas)
|
289
|
-
|
289
|
+
logger.info(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
|
290
290
|
|
291
291
|
# filter:Sample size
|
292
292
|
n_min = gwas.N.quantile(0.9) / 1.5
|
293
293
|
old = len(gwas)
|
294
294
|
gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
|
295
295
|
drops['N'] += old - len(gwas)
|
296
|
-
|
296
|
+
logger.info(f'Removed {drops["N"]} SNPs with N < {n_min}.')
|
297
297
|
|
298
298
|
return gwas
|
299
299
|
|
@@ -302,7 +302,7 @@ def variant_to_rsid(gwas, config):
|
|
302
302
|
'''
|
303
303
|
Convert variant id (Chr, Pos) to rsid
|
304
304
|
'''
|
305
|
-
|
305
|
+
logger.info("\nConverting the SNP position to rsid. This process may take some time.")
|
306
306
|
unique_ids = set(gwas['id'])
|
307
307
|
chr_format = gwas['Chr'].unique().astype(str)
|
308
308
|
chr_format = [re.sub(r'\d+', '', value) for value in chr_format][1]
|
@@ -347,7 +347,7 @@ def clean_SNP_id(gwas, config):
|
|
347
347
|
gwas = gwas.loc[matching_id.id]
|
348
348
|
gwas['SNP'] = matching_id.dbsnp
|
349
349
|
num_fail = old - len(gwas)
|
350
|
-
|
350
|
+
logger.info(f'Removed {num_fail} SNPs that did not convert to rsid.')
|
351
351
|
|
352
352
|
return gwas
|
353
353
|
|
@@ -356,27 +356,27 @@ def gwas_metadata(gwas, config):
|
|
356
356
|
'''
|
357
357
|
Report key features of GWAS data
|
358
358
|
'''
|
359
|
-
|
359
|
+
logger.info('\nSummary of GWAS data:')
|
360
360
|
CHISQ = (gwas.Z ** 2)
|
361
361
|
mean_chisq = CHISQ.mean()
|
362
|
-
|
362
|
+
logger.info('Mean chi^2 = ' + str(round(mean_chisq, 3)))
|
363
363
|
if mean_chisq < 1.02:
|
364
364
|
logger.warning("Mean chi^2 may be too small.")
|
365
365
|
|
366
|
-
|
367
|
-
|
368
|
-
|
366
|
+
logger.info('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
|
367
|
+
logger.info('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
|
368
|
+
logger.info('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
|
369
369
|
|
370
370
|
|
371
371
|
def gwas_format(config: FormatSumstatsConfig):
|
372
372
|
'''
|
373
373
|
Format GWAS data
|
374
374
|
'''
|
375
|
-
|
375
|
+
logger.info(f'------Formating gwas data for {config.sumstats}...')
|
376
376
|
compression_type = get_compression(config.sumstats)
|
377
377
|
gwas = pd.read_csv(config.sumstats, delim_whitespace=True, header=0, compression=compression_type,
|
378
378
|
na_values=['.', 'NA'])
|
379
|
-
|
379
|
+
logger.info(f'Read {len(gwas)} SNPs from {config.sumstats}.')
|
380
380
|
|
381
381
|
# Check name and format
|
382
382
|
gwas = gwas_checkname(gwas, config)
|
@@ -402,6 +402,6 @@ def gwas_format(config: FormatSumstatsConfig):
|
|
402
402
|
gwas = gwas[keep]
|
403
403
|
out_name = config.out + appendix + '.gz'
|
404
404
|
|
405
|
-
|
405
|
+
logger.info(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
|
406
406
|
gwas.to_csv(out_name, sep="\t", index=False,
|
407
407
|
float_format='%.3f', compression='gzip')
|
@@ -4,12 +4,10 @@ from pathlib import Path
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
6
6
|
import scanpy as sc
|
7
|
-
from scipy.sparse import csr_matrix
|
8
7
|
from scipy.stats import gmean
|
9
8
|
from scipy.stats import rankdata
|
10
9
|
from sklearn.metrics.pairwise import cosine_similarity
|
11
10
|
from sklearn.neighbors import NearestNeighbors
|
12
|
-
from joblib import Parallel, delayed
|
13
11
|
from tqdm import tqdm
|
14
12
|
|
15
13
|
from gsMap.config import LatentToGeneConfig
|
@@ -152,11 +150,6 @@ def run_latent_to_gene(config: LatentToGeneConfig):
|
|
152
150
|
adata.var_names = homologs.loc[adata.var_names, 'HUMAN_GENE_SYM'].values
|
153
151
|
adata = adata[:, ~adata.var_names.duplicated()]
|
154
152
|
|
155
|
-
# Remove cells and genes that are not expressed
|
156
|
-
logger.info(f'Number of cells, genes of the input data: {adata.shape[0]},{adata.shape[1]}')
|
157
|
-
adata = adata[adata.X.sum(axis=1) > 0, adata.X.sum(axis=0) > 0]
|
158
|
-
logger.info(f'Number of cells, genes after transformation: {adata.shape[0]},{adata.shape[1]}')
|
159
|
-
|
160
153
|
# Create mappings
|
161
154
|
n_cells = adata.n_obs
|
162
155
|
n_genes = adata.n_vars
|
@@ -20,8 +20,6 @@ logger = logging.getLogger('gsMap.spatial_ldsc')
|
|
20
20
|
|
21
21
|
# %%
|
22
22
|
def _coef_new(jknife):
|
23
|
-
# return coef[0], coef_se[0], z[0]]
|
24
|
-
# est_ = jknife.est[0, 0] / Nbar
|
25
23
|
est_ = jknife.jknife_est[0, 0] / Nbar
|
26
24
|
se_ = jknife.jknife_se[0, 0] / Nbar
|
27
25
|
return est_, se_
|
@@ -1,95 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
"""
|
4
|
-
Created on Tue Jul 4 21:31:27 2023
|
5
|
-
|
6
|
-
@author: songliyang
|
7
|
-
"""
|
8
|
-
import numpy as np
|
9
|
-
import pandas as pd
|
10
|
-
import scipy.sparse as sp
|
11
|
-
import sklearn.neighbors
|
12
|
-
import torch
|
13
|
-
|
14
|
-
|
15
|
-
def Cal_Spatial_Net(adata, n_neighbors=5, verbose=True):
|
16
|
-
"""\
|
17
|
-
Construct the spatial neighbor networks.
|
18
|
-
"""
|
19
|
-
#-
|
20
|
-
if verbose:
|
21
|
-
print('------Calculating spatial graph...')
|
22
|
-
coor = pd.DataFrame(adata.obsm['spatial'])
|
23
|
-
coor.index = adata.obs.index
|
24
|
-
#-
|
25
|
-
nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
|
26
|
-
#-
|
27
|
-
distances, indices = nbrs.kneighbors(coor, return_distance=True)
|
28
|
-
KNN_list = []
|
29
|
-
for it in range(indices.shape[0]):
|
30
|
-
KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
|
31
|
-
#-
|
32
|
-
KNN_df = pd.concat(KNN_list)
|
33
|
-
KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
|
34
|
-
#-
|
35
|
-
Spatial_Net = KNN_df.copy()
|
36
|
-
Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
|
37
|
-
id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
|
38
|
-
Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
|
39
|
-
Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
|
40
|
-
#-
|
41
|
-
return Spatial_Net
|
42
|
-
|
43
|
-
|
44
|
-
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
|
45
|
-
"""Convert a scipy sparse matrix to a torch sparse tensor."""
|
46
|
-
sparse_mx = sparse_mx.tocoo().astype(np.float32)
|
47
|
-
indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
|
48
|
-
values = torch.from_numpy(sparse_mx.data)
|
49
|
-
shape = torch.Size(sparse_mx.shape)
|
50
|
-
return torch.sparse.FloatTensor(indices, values, shape)
|
51
|
-
|
52
|
-
|
53
|
-
def preprocess_graph(adj):
|
54
|
-
adj = sp.coo_matrix(adj)
|
55
|
-
adj_ = adj + sp.eye(adj.shape[0])
|
56
|
-
rowsum = np.array(adj_.sum(1))
|
57
|
-
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
|
58
|
-
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
|
59
|
-
return sparse_mx_to_torch_sparse_tensor(adj_normalized)
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
def Construct_Adjacency_Matrix(adata,Params, verbose=True):
|
64
|
-
# Construct the neighbor graph
|
65
|
-
Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
|
66
|
-
#-
|
67
|
-
if verbose:
|
68
|
-
print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
|
69
|
-
print('%.2f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
|
70
|
-
#-
|
71
|
-
cells = np.array(adata.obs.index)
|
72
|
-
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
|
73
|
-
#-
|
74
|
-
G_df = Spatial_Net.copy()
|
75
|
-
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
|
76
|
-
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
|
77
|
-
#-
|
78
|
-
if Params.weighted_adj:
|
79
|
-
distance_normalized = G_df.Distance/(max(G_df.Distance)+1)
|
80
|
-
adj_org = sp.coo_matrix((np.exp(-distance_normalized**2/(2)), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
|
81
|
-
else:
|
82
|
-
adj_org = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
|
83
|
-
#-
|
84
|
-
adj_m1 = adj_org
|
85
|
-
adj_norm_m1 = preprocess_graph(adj_m1)
|
86
|
-
adj_label_m1 = adj_m1 + sp.eye(adj_m1.shape[0])
|
87
|
-
norm_m1 = adj_m1.shape[0] * adj_m1.shape[0] / float((adj_m1.shape[0] * adj_m1.shape[0] - adj_m1.sum()) * 2)
|
88
|
-
#-
|
89
|
-
graph_dict = {
|
90
|
-
"adj_org": adj_org,
|
91
|
-
"adj_norm": adj_norm_m1,
|
92
|
-
"norm_value": norm_m1
|
93
|
-
}
|
94
|
-
#-
|
95
|
-
return graph_dict
|
@@ -1,87 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
"""
|
4
|
-
Created on Mon Jul 3 11:42:44 2023
|
5
|
-
|
6
|
-
@author: songliyang
|
7
|
-
"""
|
8
|
-
|
9
|
-
import torch
|
10
|
-
import torch.nn as nn
|
11
|
-
import torch.nn.functional as F
|
12
|
-
from torch_geometric.nn import GATConv
|
13
|
-
|
14
|
-
|
15
|
-
def full_block(in_features, out_features, p_drop):
|
16
|
-
return nn.Sequential(nn.Linear(in_features, out_features),
|
17
|
-
nn.BatchNorm1d(out_features),
|
18
|
-
nn.ELU(),
|
19
|
-
nn.Dropout(p=p_drop))
|
20
|
-
|
21
|
-
|
22
|
-
class GNN(nn.Module):
|
23
|
-
def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
|
24
|
-
super().__init__()
|
25
|
-
self.conv1 = GATConv(in_features, out_features,heads)
|
26
|
-
self.act = act
|
27
|
-
self.dr = dr
|
28
|
-
#-
|
29
|
-
def forward(self, x, edge_index):
|
30
|
-
out = self.conv1(x, edge_index)
|
31
|
-
out = self.act(out)
|
32
|
-
out = F.dropout(out, self.dr, self.training)
|
33
|
-
return out
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
class GNN_VAE_Model(nn.Module):
|
38
|
-
def __init__(self, input_dim,params,num_classes=1):
|
39
|
-
super(GNN_VAE_Model, self).__init__()
|
40
|
-
self.var = params.var
|
41
|
-
self.num_classes = num_classes
|
42
|
-
|
43
|
-
# Encoder
|
44
|
-
self.encoder = nn.Sequential()
|
45
|
-
self.encoder.add_module('encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
|
46
|
-
self.encoder.add_module('encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))
|
47
|
-
|
48
|
-
# GNN (GAT)
|
49
|
-
self.gn1 = GNN(params.feat_hidden2, params.gat_hidden1, params.p_drop, act=F.relu,heads = params.nheads)
|
50
|
-
self.gn2 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
|
51
|
-
self.gn3 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
|
52
|
-
|
53
|
-
# Decoder
|
54
|
-
self.decoder = nn.Sequential()
|
55
|
-
self.decoder.add_module('decoder_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
|
56
|
-
self.decoder.add_module('decoder_L2', full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop))
|
57
|
-
self.decoder.add_module('decoder_output', nn.Sequential(nn.Linear(params.feat_hidden1, input_dim)))
|
58
|
-
|
59
|
-
# Cluster
|
60
|
-
self.cluster = nn.Sequential()
|
61
|
-
self.cluster.add_module('cluster_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
|
62
|
-
self.cluster.add_module('cluster_output', nn.Linear(params.feat_hidden2, self.num_classes))
|
63
|
-
|
64
|
-
def encode(self, x, adj):
|
65
|
-
feat_x = self.encoder(x)
|
66
|
-
hidden1 = self.gn1(feat_x, adj)
|
67
|
-
mu = self.gn2(hidden1, adj)
|
68
|
-
if self.var:
|
69
|
-
logvar = self.gn3(hidden1, adj)
|
70
|
-
return mu, logvar
|
71
|
-
else:
|
72
|
-
return mu, None
|
73
|
-
|
74
|
-
def reparameterize(self, mu, logvar):
|
75
|
-
if self.training and logvar is not None:
|
76
|
-
std = torch.exp(logvar)
|
77
|
-
eps = torch.randn_like(std)
|
78
|
-
return eps.mul(std).add_(mu)
|
79
|
-
else:
|
80
|
-
return mu
|
81
|
-
|
82
|
-
def forward(self, x, adj):
|
83
|
-
mu, logvar = self.encode(x, adj)
|
84
|
-
gnn_z = self.reparameterize(mu, logvar)
|
85
|
-
x_reconstructed = self.decoder(gnn_z)
|
86
|
-
pred_label = F.softmax(self.cluster(gnn_z),dim=1)
|
87
|
-
return pred_label, x_reconstructed, gnn_z, mu, logvar
|
@@ -1,97 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
# -*- coding: utf-8 -*-
|
3
|
-
"""
|
4
|
-
Created on Tue Jul 4 19:58:58 2023
|
5
|
-
|
6
|
-
@author: songliyang
|
7
|
-
"""
|
8
|
-
import time
|
9
|
-
|
10
|
-
import torch
|
11
|
-
from progress.bar import Bar
|
12
|
-
|
13
|
-
from gsMap.GNN_VAE.model import GNN_VAE_Model
|
14
|
-
|
15
|
-
|
16
|
-
def reconstruction_loss(decoded, x):
|
17
|
-
loss_fn = torch.nn.MSELoss()
|
18
|
-
loss = loss_fn(decoded, x)
|
19
|
-
return loss
|
20
|
-
|
21
|
-
|
22
|
-
def label_loss(pred_label, true_label):
|
23
|
-
loss_fn = torch.nn.CrossEntropyLoss()
|
24
|
-
loss = loss_fn(pred_label, true_label)
|
25
|
-
return loss
|
26
|
-
|
27
|
-
|
28
|
-
class Model_Train:
|
29
|
-
def __init__(self, node_X, graph_dict, params, label=None):
|
30
|
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
31
|
-
torch.cuda.empty_cache()
|
32
|
-
|
33
|
-
self.params = params
|
34
|
-
self.device = device
|
35
|
-
self.epochs = params.epochs
|
36
|
-
self.node_X = torch.FloatTensor(node_X.copy()).to(device)
|
37
|
-
self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
|
38
|
-
self.label = label
|
39
|
-
self.num_classes = 1
|
40
|
-
|
41
|
-
if not self.label is None:
|
42
|
-
self.label = torch.tensor(self.label).to(self.device)
|
43
|
-
self.num_classes = len(self.label.unique())
|
44
|
-
|
45
|
-
# Set Model
|
46
|
-
self.model = GNN_VAE_Model(self.params.feat_cell,self.params,self.num_classes).to(device)
|
47
|
-
self.optimizer = torch.optim.Adam(params = list(self.model.parameters()),
|
48
|
-
lr = self.params.gat_lr, weight_decay = self.params.gcn_decay)
|
49
|
-
|
50
|
-
# Train
|
51
|
-
def run_train(self):
|
52
|
-
self.model.train()
|
53
|
-
prev_loss = float('inf')
|
54
|
-
|
55
|
-
bar = Bar('GAT-AE model train:', max = self.epochs)
|
56
|
-
bar.check_tty = False
|
57
|
-
for epoch in range(self.epochs):
|
58
|
-
start_time = time.time()
|
59
|
-
self.model.train()
|
60
|
-
self.optimizer.zero_grad()
|
61
|
-
pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
|
62
|
-
loss_rec = reconstruction_loss(de_feat, self.node_X)
|
63
|
-
|
64
|
-
# Check whether annotation was provided
|
65
|
-
if not self.label is None:
|
66
|
-
loss_pre = label_loss(pred_label, self.label)
|
67
|
-
loss = (self.params.rec_w * loss_rec) + (self.params.label_w * loss_pre)
|
68
|
-
else:
|
69
|
-
loss = loss_rec
|
70
|
-
|
71
|
-
loss.backward()
|
72
|
-
self.optimizer.step()
|
73
|
-
|
74
|
-
# Update process
|
75
|
-
end_time = time.time()
|
76
|
-
batch_time = end_time - start_time
|
77
|
-
|
78
|
-
|
79
|
-
bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
|
80
|
-
bar.suffix = bar_str.format(epoch + 1,self.epochs,
|
81
|
-
batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
|
82
|
-
bar.next()
|
83
|
-
|
84
|
-
# Check convergence
|
85
|
-
if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
|
86
|
-
print('\nConvergence reached. Training stopped.')
|
87
|
-
break
|
88
|
-
|
89
|
-
prev_loss = loss.item()
|
90
|
-
|
91
|
-
bar.finish()
|
92
|
-
#-
|
93
|
-
def get_latent(self):
|
94
|
-
self.model.eval()
|
95
|
-
pred, de_fea, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
|
96
|
-
latent_z = latent_z.data.cpu().numpy()
|
97
|
-
return latent_z
|
@@ -1,145 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import random
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
import pandas as pd
|
6
|
-
import scanpy as sc
|
7
|
-
import torch
|
8
|
-
from sklearn import preprocessing
|
9
|
-
|
10
|
-
from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
|
11
|
-
from gsMap.GNN_VAE.train import Model_Train
|
12
|
-
from gsMap.config import FindLatentRepresentationsConfig
|
13
|
-
|
14
|
-
logger = logging.getLogger(__name__)
|
15
|
-
|
16
|
-
def set_seed(seed_value):
|
17
|
-
"""
|
18
|
-
Set seed for reproducibility in PyTorch.
|
19
|
-
"""
|
20
|
-
torch.manual_seed(seed_value) # Set the seed for PyTorch
|
21
|
-
np.random.seed(seed_value) # Set the seed for NumPy
|
22
|
-
random.seed(seed_value) # Set the seed for Python random module
|
23
|
-
if torch.cuda.is_available():
|
24
|
-
logger.info('Running use GPU')
|
25
|
-
torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
|
26
|
-
torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
|
27
|
-
else:
|
28
|
-
logger.info('Running use CPU')
|
29
|
-
|
30
|
-
|
31
|
-
# The class for finding latent representations
|
32
|
-
class Latent_Representation_Finder:
|
33
|
-
|
34
|
-
def __init__(self, adata, args:FindLatentRepresentationsConfig):
|
35
|
-
self.adata = adata.copy()
|
36
|
-
self.Params = args
|
37
|
-
|
38
|
-
# Standard process
|
39
|
-
if self.Params.data_layer == 'count' or self.Params.data_layer == 'counts':
|
40
|
-
self.adata.X = self.adata.layers[self.Params.data_layer]
|
41
|
-
sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
|
42
|
-
sc.pp.normalize_total(self.adata, target_sum=1e4)
|
43
|
-
sc.pp.log1p(self.adata)
|
44
|
-
sc.pp.scale(self.adata)
|
45
|
-
else:
|
46
|
-
if self.Params.data_layer != 'X':
|
47
|
-
self.adata.X = self.adata.layers[self.Params.data_layer]
|
48
|
-
sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
|
49
|
-
|
50
|
-
def Run_GNN_VAE(self, label, verbose='whole ST data'):
|
51
|
-
|
52
|
-
# Construct the neighbouring graph
|
53
|
-
graph_dict = Construct_Adjacency_Matrix(self.adata, self.Params)
|
54
|
-
|
55
|
-
# Process the feature matrix
|
56
|
-
node_X = self.adata[:, self.adata.var.highly_variable].X
|
57
|
-
logger.info(f'The shape of feature matrix is {node_X.shape}.')
|
58
|
-
if self.Params.input_pca:
|
59
|
-
node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
|
60
|
-
|
61
|
-
# Update the input shape
|
62
|
-
self.Params.n_nodes = node_X.shape[0]
|
63
|
-
self.Params.feat_cell = node_X.shape[1]
|
64
|
-
|
65
|
-
# Run GNN-VAE
|
66
|
-
logger.info(f'------Finding latent representations for {verbose}...')
|
67
|
-
gvae = Model_Train(node_X, graph_dict, self.Params, label)
|
68
|
-
gvae.run_train()
|
69
|
-
|
70
|
-
return gvae.get_latent()
|
71
|
-
|
72
|
-
def Run_PCA(self):
|
73
|
-
sc.tl.pca(self.adata)
|
74
|
-
return self.adata.obsm['X_pca'][:, 0:self.Params.n_comps]
|
75
|
-
|
76
|
-
|
77
|
-
def run_find_latent_representation(args:FindLatentRepresentationsConfig):
|
78
|
-
set_seed(2024)
|
79
|
-
num_features = args.feat_cell
|
80
|
-
args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
|
81
|
-
# Load the ST data
|
82
|
-
logger.info(f'------Loading ST data of {args.sample_name}...')
|
83
|
-
adata = sc.read_h5ad(f'{args.input_hdf5_path}')
|
84
|
-
adata.var_names_make_unique()
|
85
|
-
adata.X = adata.layers[args.data_layer] if args.data_layer in adata.layers.keys() else adata.X
|
86
|
-
logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
|
87
|
-
# Load the cell type annotation
|
88
|
-
if not args.annotation is None:
|
89
|
-
# remove cells without enough annotations
|
90
|
-
adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
|
91
|
-
num = adata.obs[args.annotation].value_counts()
|
92
|
-
adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())]
|
93
|
-
|
94
|
-
le = preprocessing.LabelEncoder()
|
95
|
-
le.fit(adata.obs[args.annotation])
|
96
|
-
adata.obs['categorical_label'] = le.transform(adata.obs[args.annotation])
|
97
|
-
label = adata.obs['categorical_label'].to_list()
|
98
|
-
else:
|
99
|
-
label = None
|
100
|
-
# Find latent representations
|
101
|
-
latent_rep = Latent_Representation_Finder(adata, args)
|
102
|
-
latent_GVAE = latent_rep.Run_GNN_VAE(label)
|
103
|
-
latent_PCA = latent_rep.Run_PCA()
|
104
|
-
# Add latent representations to the spe data
|
105
|
-
logger.info(f'------Adding latent representations...')
|
106
|
-
adata.obsm["latent_GVAE"] = latent_GVAE
|
107
|
-
adata.obsm["latent_PCA"] = latent_PCA
|
108
|
-
# Run umap based on latent representations
|
109
|
-
for name in ['latent_GVAE', 'latent_PCA']:
|
110
|
-
sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
|
111
|
-
sc.tl.umap(adata)
|
112
|
-
adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
|
113
|
-
|
114
|
-
# Find the latent representations hierarchically (optionally)
|
115
|
-
if not args.annotation is None and args.hierarchically:
|
116
|
-
logger.info(f'------Finding latent representations hierarchically...')
|
117
|
-
PCA_all = pd.DataFrame()
|
118
|
-
GVAE_all = pd.DataFrame()
|
119
|
-
|
120
|
-
for ct in adata.obs[args.annotation].unique():
|
121
|
-
adata_part = adata[adata.obs[args.annotation] == ct, :]
|
122
|
-
logger.info(adata_part.shape)
|
123
|
-
|
124
|
-
# Find latent representations for the selected ct
|
125
|
-
latent_rep = Latent_Representation_Finder(adata_part, args)
|
126
|
-
|
127
|
-
latent_PCA_part = pd.DataFrame(latent_rep.Run_PCA())
|
128
|
-
if adata_part.shape[0] <= args.n_comps:
|
129
|
-
latent_GVAE_part = latent_PCA_part
|
130
|
-
else:
|
131
|
-
latent_GVAE_part = pd.DataFrame(latent_rep.Run_GNN_VAE(label=None, verbose=ct))
|
132
|
-
|
133
|
-
latent_GVAE_part.index = adata_part.obs_names
|
134
|
-
latent_PCA_part.index = adata_part.obs_names
|
135
|
-
|
136
|
-
GVAE_all = pd.concat((GVAE_all, latent_GVAE_part), axis=0)
|
137
|
-
PCA_all = pd.concat((PCA_all, latent_PCA_part), axis=0)
|
138
|
-
|
139
|
-
args.feat_cell = num_features
|
140
|
-
|
141
|
-
adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
|
142
|
-
adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
|
143
|
-
logger.info(f'------Saving ST data...')
|
144
|
-
adata.write(args.hdf5_with_latent_path)
|
145
|
-
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|