gsMap 1.65__py3-none-any.whl → 1.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,95 +1,75 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Tue Jul 4 21:31:27 2023
5
-
6
- @author: songliyang
7
- """
8
1
  import numpy as np
9
2
  import pandas as pd
10
3
  import scipy.sparse as sp
11
- import sklearn.neighbors
4
+ from sklearn.neighbors import NearestNeighbors
12
5
  import torch
13
6
 
14
-
15
- def Cal_Spatial_Net(adata, n_neighbors=5, verbose=True):
16
- """\
17
- Construct the spatial neighbor networks.
18
- """
19
- #-
7
+ def cal_spatial_net(adata, n_neighbors=5, verbose=True):
8
+ """Construct the spatial neighbor network."""
20
9
  if verbose:
21
10
  print('------Calculating spatial graph...')
22
- coor = pd.DataFrame(adata.obsm['spatial'])
23
- coor.index = adata.obs.index
24
- #-
25
- nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
26
- #-
27
- distances, indices = nbrs.kneighbors(coor, return_distance=True)
28
- KNN_list = []
29
- for it in range(indices.shape[0]):
30
- KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
31
- #-
32
- KNN_df = pd.concat(KNN_list)
33
- KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
34
- #-
35
- Spatial_Net = KNN_df.copy()
36
- Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
37
- id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
38
- Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
39
- Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
40
- #-
41
- return Spatial_Net
42
-
11
+ coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
12
+ nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
13
+ distances, indices = nbrs.kneighbors(coor)
14
+ n_cells, n_neighbors = indices.shape
15
+ cell_indices = np.arange(n_cells)
16
+ cell1 = np.repeat(cell_indices, n_neighbors)
17
+ cell2 = indices.flatten()
18
+ distance = distances.flatten()
19
+ knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
20
+ knn_df = knn_df[knn_df['Distance'] > 0].copy()
21
+ cell_id_map = dict(zip(cell_indices, coor.index))
22
+ knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
23
+ knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
24
+ return knn_df
43
25
 
44
26
  def sparse_mx_to_torch_sparse_tensor(sparse_mx):
45
27
  """Convert a scipy sparse matrix to a torch sparse tensor."""
46
28
  sparse_mx = sparse_mx.tocoo().astype(np.float32)
47
- indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
29
+ indices = torch.from_numpy(
30
+ np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
31
+ )
48
32
  values = torch.from_numpy(sparse_mx.data)
49
33
  shape = torch.Size(sparse_mx.shape)
50
34
  return torch.sparse.FloatTensor(indices, values, shape)
51
35
 
52
-
53
36
  def preprocess_graph(adj):
37
+ """Symmetrically normalize the adjacency matrix."""
54
38
  adj = sp.coo_matrix(adj)
55
39
  adj_ = adj + sp.eye(adj.shape[0])
56
- rowsum = np.array(adj_.sum(1))
57
- degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
40
+ rowsum = np.array(adj_.sum(1)).flatten()
41
+ degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
58
42
  adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
59
43
  return sparse_mx_to_torch_sparse_tensor(adj_normalized)
60
44
 
61
-
62
-
63
- def Construct_Adjacency_Matrix(adata,Params, verbose=True):
64
- # Construct the neighbor graph
65
- Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
66
- #-
45
+ def construct_adjacency_matrix(adata, params, verbose=True):
46
+ """Construct the adjacency matrix from spatial data."""
47
+ spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
67
48
  if verbose:
68
- print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
69
- print('%.2f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
70
- #-
71
- cells = np.array(adata.obs.index)
72
- cells_id_tran = dict(zip(cells, range(cells.shape[0])))
73
- #-
74
- G_df = Spatial_Net.copy()
75
- G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
76
- G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
77
- #-
78
- if Params.weighted_adj:
79
- distance_normalized = G_df.Distance/(max(G_df.Distance)+1)
80
- adj_org = sp.coo_matrix((np.exp(-distance_normalized**2/(2)), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
49
+ num_edges = spatial_net.shape[0]
50
+ num_cells = adata.n_obs
51
+ print(f'The graph contains {num_edges} edges, {num_cells} cells.')
52
+ print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
53
+ cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
54
+ spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
55
+ spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
56
+ if params.weighted_adj:
57
+ distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
58
+ weights = np.exp(-0.5 * distance_normalized ** 2)
59
+ adj_org = sp.coo_matrix(
60
+ (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
61
+ shape=(adata.n_obs, adata.n_obs)
62
+ )
81
63
  else:
82
- adj_org = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
83
- #-
84
- adj_m1 = adj_org
85
- adj_norm_m1 = preprocess_graph(adj_m1)
86
- adj_label_m1 = adj_m1 + sp.eye(adj_m1.shape[0])
87
- norm_m1 = adj_m1.shape[0] * adj_m1.shape[0] / float((adj_m1.shape[0] * adj_m1.shape[0] - adj_m1.sum()) * 2)
88
- #-
64
+ adj_org = sp.coo_matrix(
65
+ (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
66
+ shape=(adata.n_obs, adata.n_obs)
67
+ )
68
+ adj_norm = preprocess_graph(adj_org)
69
+ norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
89
70
  graph_dict = {
90
71
  "adj_org": adj_org,
91
- "adj_norm": adj_norm_m1,
92
- "norm_value": norm_m1
72
+ "adj_norm": adj_norm,
73
+ "norm_value": norm_value
93
74
  }
94
- #-
95
75
  return graph_dict
gsMap/GNN_VAE/model.py CHANGED
@@ -1,87 +1,89 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Mon Jul 3 11:42:44 2023
5
-
6
- @author: songliyang
7
- """
8
-
9
1
  import torch
10
2
  import torch.nn as nn
11
3
  import torch.nn.functional as F
12
4
  from torch_geometric.nn import GATConv
13
5
 
14
-
15
6
  def full_block(in_features, out_features, p_drop):
16
- return nn.Sequential(nn.Linear(in_features, out_features),
17
- nn.BatchNorm1d(out_features),
18
- nn.ELU(),
19
- nn.Dropout(p=p_drop))
7
+ return nn.Sequential(
8
+ nn.Linear(in_features, out_features),
9
+ nn.BatchNorm1d(out_features),
10
+ nn.ELU(),
11
+ nn.Dropout(p=p_drop)
12
+ )
20
13
 
21
-
22
- class GNN(nn.Module):
23
- def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
14
+ class GATModel(nn.Module):
15
+ def __init__(self, input_dim, params, num_classes=1):
24
16
  super().__init__()
25
- self.conv1 = GATConv(in_features, out_features,heads)
26
- self.act = act
27
- self.dr = dr
28
- #-
29
- def forward(self, x, edge_index):
30
- out = self.conv1(x, edge_index)
31
- out = self.act(out)
32
- out = F.dropout(out, self.dr, self.training)
33
- return out
34
-
35
-
36
-
37
- class GNN_VAE_Model(nn.Module):
38
- def __init__(self, input_dim,params,num_classes=1):
39
- super(GNN_VAE_Model, self).__init__()
40
17
  self.var = params.var
41
18
  self.num_classes = num_classes
42
-
19
+ self.params = params
20
+
43
21
  # Encoder
44
- self.encoder = nn.Sequential()
45
- self.encoder.add_module('encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
46
- self.encoder.add_module('encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))
47
-
48
- # GNN (GAT)
49
- self.gn1 = GNN(params.feat_hidden2, params.gat_hidden1, params.p_drop, act=F.relu,heads = params.nheads)
50
- self.gn2 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
51
- self.gn3 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
52
-
22
+ self.encoder = nn.Sequential(
23
+ full_block(input_dim, params.feat_hidden1, params.p_drop),
24
+ full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
25
+ )
26
+
27
+ # GAT Layers
28
+ self.gat1 = GATConv(
29
+ in_channels=params.feat_hidden2,
30
+ out_channels=params.gat_hidden1,
31
+ heads=params.nheads,
32
+ dropout=params.p_drop
33
+ )
34
+ self.gat2 = GATConv(
35
+ in_channels=params.gat_hidden1 * params.nheads,
36
+ out_channels=params.gat_hidden2,
37
+ heads=1,
38
+ concat=False,
39
+ dropout=params.p_drop
40
+ )
41
+ if self.var:
42
+ self.gat3 = GATConv(
43
+ in_channels=params.gat_hidden1 * params.nheads,
44
+ out_channels=params.gat_hidden2,
45
+ heads=1,
46
+ concat=False,
47
+ dropout=params.p_drop
48
+ )
49
+
53
50
  # Decoder
54
- self.decoder = nn.Sequential()
55
- self.decoder.add_module('decoder_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
56
- self.decoder.add_module('decoder_L2', full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop))
57
- self.decoder.add_module('decoder_output', nn.Sequential(nn.Linear(params.feat_hidden1, input_dim)))
58
-
59
- # Cluster
60
- self.cluster = nn.Sequential()
61
- self.cluster.add_module('cluster_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
62
- self.cluster.add_module('cluster_output', nn.Linear(params.feat_hidden2, self.num_classes))
63
-
64
- def encode(self, x, adj):
65
- feat_x = self.encoder(x)
66
- hidden1 = self.gn1(feat_x, adj)
67
- mu = self.gn2(hidden1, adj)
51
+ self.decoder = nn.Sequential(
52
+ full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
53
+ full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
54
+ nn.Linear(params.feat_hidden1, input_dim)
55
+ )
56
+
57
+ # Clustering Layer
58
+ self.cluster = nn.Sequential(
59
+ full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
60
+ nn.Linear(params.feat_hidden2, self.num_classes)
61
+ )
62
+
63
+ def encode(self, x, edge_index):
64
+ x = self.encoder(x)
65
+ x = self.gat1(x, edge_index)
66
+ x = F.relu(x)
67
+ x = F.dropout(x, p=self.params.p_drop, training=self.training)
68
+
69
+ mu = self.gat2(x, edge_index)
68
70
  if self.var:
69
- logvar = self.gn3(hidden1, adj)
71
+ logvar = self.gat3(x, edge_index)
70
72
  return mu, logvar
71
73
  else:
72
74
  return mu, None
73
-
75
+
74
76
  def reparameterize(self, mu, logvar):
75
77
  if self.training and logvar is not None:
76
- std = torch.exp(logvar)
78
+ std = torch.exp(0.5 * logvar)
77
79
  eps = torch.randn_like(std)
78
- return eps.mul(std).add_(mu)
80
+ return eps * std + mu
79
81
  else:
80
82
  return mu
81
-
82
- def forward(self, x, adj):
83
- mu, logvar = self.encode(x, adj)
84
- gnn_z = self.reparameterize(mu, logvar)
85
- x_reconstructed = self.decoder(gnn_z)
86
- pred_label = F.softmax(self.cluster(gnn_z),dim=1)
87
- return pred_label, x_reconstructed, gnn_z, mu, logvar
83
+
84
+ def forward(self, x, edge_index):
85
+ mu, logvar = self.encode(x, edge_index)
86
+ z = self.reparameterize(mu, logvar)
87
+ x_reconstructed = self.decoder(z)
88
+ pred_label = F.softmax(self.cluster(z), dim=1)
89
+ return pred_label, x_reconstructed, z, mu, logvar
gsMap/GNN_VAE/train.py CHANGED
@@ -1,97 +1,86 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Tue Jul 4 19:58:58 2023
5
-
6
- @author: songliyang
7
- """
1
+ import logging
8
2
  import time
9
3
 
10
4
  import torch
5
+ import torch.nn.functional as F
11
6
  from progress.bar import Bar
12
7
 
13
- from gsMap.GNN_VAE.model import GNN_VAE_Model
8
+ from gsMap.GNN_VAE.model import GATModel
9
+
10
+ logger = logging.getLogger(__name__)
14
11
 
15
12
 
16
13
  def reconstruction_loss(decoded, x):
17
- loss_fn = torch.nn.MSELoss()
18
- loss = loss_fn(decoded, x)
19
- return loss
14
+ """Compute the mean squared error loss."""
15
+ return F.mse_loss(decoded, x)
20
16
 
21
17
 
22
18
  def label_loss(pred_label, true_label):
23
- loss_fn = torch.nn.CrossEntropyLoss()
24
- loss = loss_fn(pred_label, true_label)
25
- return loss
26
-
27
-
28
- class Model_Train:
29
- def __init__(self, node_X, graph_dict, params, label=None):
30
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
- torch.cuda.empty_cache()
19
+ """Compute the cross-entropy loss."""
20
+ return F.cross_entropy(pred_label, true_label)
32
21
 
22
+ class ModelTrainer:
23
+ def __init__(self, node_x, graph_dict, params, label=None):
24
+ """Initialize the ModelTrainer with data and hyperparameters."""
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
26
  self.params = params
34
- self.device = device
35
27
  self.epochs = params.epochs
36
- self.node_X = torch.FloatTensor(node_X.copy()).to(device)
37
- self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
28
+ self.node_x = torch.FloatTensor(node_x).to(self.device)
29
+ self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
38
30
  self.label = label
39
31
  self.num_classes = 1
40
-
41
- if not self.label is None:
32
+
33
+ if self.label is not None:
42
34
  self.label = torch.tensor(self.label).to(self.device)
43
- self.num_classes = len(self.label.unique())
44
-
45
- # Set Model
46
- self.model = GNN_VAE_Model(self.params.feat_cell,self.params,self.num_classes).to(device)
47
- self.optimizer = torch.optim.Adam(params = list(self.model.parameters()),
48
- lr = self.params.gat_lr, weight_decay = self.params.gcn_decay)
49
-
50
- # Train
35
+ self.num_classes = len(torch.unique(self.label))
36
+
37
+ # Set up the model
38
+ self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
39
+ self.optimizer = torch.optim.Adam(
40
+ self.model.parameters(),
41
+ lr=self.params.gat_lr,
42
+ weight_decay=self.params.gcn_decay
43
+ )
44
+
51
45
  def run_train(self):
46
+ """Train the model."""
52
47
  self.model.train()
53
48
  prev_loss = float('inf')
54
-
55
- bar = Bar('GAT-AE model train:', max = self.epochs)
56
- bar.check_tty = False
49
+ bar = Bar('GAT-AE model train:', max=self.epochs)
50
+ bar.check_tty = False
51
+
52
+ logger.info('Start training...')
57
53
  for epoch in range(self.epochs):
58
54
  start_time = time.time()
59
- self.model.train()
60
55
  self.optimizer.zero_grad()
61
- pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
62
- loss_rec = reconstruction_loss(de_feat, self.node_X)
63
-
64
- # Check whether annotation was provided
65
- if not self.label is None:
56
+ pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
57
+ loss_rec = reconstruction_loss(de_feat, self.node_x)
58
+
59
+ if self.label is not None:
66
60
  loss_pre = label_loss(pred_label, self.label)
67
- loss = (self.params.rec_w * loss_rec) + (self.params.label_w * loss_pre)
61
+ loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
68
62
  else:
69
63
  loss = loss_rec
70
-
64
+
71
65
  loss.backward()
72
66
  self.optimizer.step()
73
-
74
- # Update process
75
- end_time = time.time()
76
- batch_time = end_time - start_time
77
-
78
-
79
- bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
80
- bar.suffix = bar_str.format(epoch + 1,self.epochs,
81
- batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
67
+
68
+ batch_time = time.time() - start_time
69
+ left_time = batch_time * (self.epochs - epoch - 1) / 60 # in minutes
70
+
71
+ bar.suffix = f'{epoch + 1} / {self.epochs} | Left time: {left_time:.2f} mins | Loss: {loss.item():.4f}'
82
72
  bar.next()
83
-
84
- # Check convergence
73
+
85
74
  if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
86
- print('\nConvergence reached. Training stopped.')
75
+ logger.info('\nConvergence reached. Training stopped.')
87
76
  break
88
77
 
89
78
  prev_loss = loss.item()
90
-
91
79
  bar.finish()
92
- #-
80
+
93
81
  def get_latent(self):
82
+ """Retrieve the latent representation from the model."""
94
83
  self.model.eval()
95
- pred, de_fea, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
96
- latent_z = latent_z.data.cpu().numpy()
97
- return latent_z
84
+ with torch.no_grad():
85
+ _, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
86
+ return latent_z.cpu().numpy()
gsMap/__init__.py CHANGED
@@ -2,4 +2,4 @@
2
2
  Genetics-informed pathogenic spatial mapping
3
3
  '''
4
4
 
5
- __version__ = '1.65'
5
+ __version__ = '1.67'
gsMap/config.py CHANGED
@@ -55,7 +55,8 @@ def add_find_latent_representations_args(parser):
55
55
  add_shared_args(parser)
56
56
  parser.add_argument('--input_hdf5_path', required=True, type=str, help='Path to the input HDF5 file.')
57
57
  parser.add_argument('--annotation', required=True, type=str, help='Name of the annotation in adata.obs to use.')
58
- parser.add_argument('--data_layer', required=True, type=str, help='Data layer for gene expression (e.g., "counts", "log1p").')
58
+ parser.add_argument('--data_layer', type=str, default='counts', required=True,
59
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
59
60
  parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs.')
60
61
  parser.add_argument('--feat_hidden1', type=int, default=256, help='Neurons in the first hidden layer.')
61
62
  parser.add_argument('--feat_hidden2', type=int, default=128, help='Neurons in the second hidden layer.')
@@ -66,7 +67,6 @@ def add_find_latent_representations_args(parser):
66
67
  parser.add_argument('--n_neighbors', type=int, default=11, help='Number of neighbors for GAT.')
67
68
  parser.add_argument('--n_comps', type=int, default=300, help='Number of principal components for PCA.')
68
69
  parser.add_argument('--weighted_adj', action='store_true', help='Use weighted adjacency in GAT.')
69
- parser.add_argument('--var', action='store_true', help='Enable variance calculations.')
70
70
  parser.add_argument('--convergence_threshold', type=float, default=1e-4, help='Threshold for convergence.')
71
71
  parser.add_argument('--hierarchically', action='store_true', help='Enable hierarchical latent representation finding.')
72
72
 
@@ -236,8 +236,8 @@ def add_run_all_mode_args(parser):
236
236
  help='Path to the input spatial transcriptomics data (H5AD format).')
237
237
  parser.add_argument('--annotation', type=str, required=True,
238
238
  help='Name of the annotation in adata.obs to use.')
239
- parser.add_argument('--data_layer', type=str, default='X',
240
- help='Data layer of h5ad for gene expression (e.g., "counts", "log1p", "X").')
239
+ parser.add_argument('--data_layer', type=str, default='counts', required=True,
240
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
241
241
 
242
242
  # GWAS Data Parameters
243
243
  parser.add_argument('--trait_name', type=str, help='Name of the trait for GWAS analysis (required if sumstats_file is provided).')