gsMap 1.66__tar.gz → 1.67__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {gsmap-1.66 → gsmap-1.67}/PKG-INFO +2 -2
  2. {gsmap-1.66 → gsmap-1.67}/docs/source/mouse_example.md +1 -1
  3. {gsmap-1.66 → gsmap-1.67}/pyproject.toml +1 -1
  4. gsmap-1.67/src/gsMap/GNN_VAE/adjacency_matrix.py +75 -0
  5. gsmap-1.67/src/gsMap/GNN_VAE/model.py +89 -0
  6. gsmap-1.67/src/gsMap/GNN_VAE/train.py +86 -0
  7. {gsmap-1.66 → gsmap-1.67}/src/gsMap/__init__.py +1 -1
  8. {gsmap-1.66 → gsmap-1.67}/src/gsMap/config.py +4 -4
  9. gsmap-1.67/src/gsMap/find_latent_representation.py +145 -0
  10. {gsmap-1.66 → gsmap-1.67}/src/gsMap/format_sumstats.py +20 -20
  11. {gsmap-1.66 → gsmap-1.67}/src/gsMap/latent_to_gene.py +0 -7
  12. {gsmap-1.66 → gsmap-1.67}/src/gsMap/spatial_ldsc_multiple_sumstats.py +0 -2
  13. gsmap-1.66/src/gsMap/GNN_VAE/adjacency_matrix.py +0 -95
  14. gsmap-1.66/src/gsMap/GNN_VAE/model.py +0 -87
  15. gsmap-1.66/src/gsMap/GNN_VAE/train.py +0 -97
  16. gsmap-1.66/src/gsMap/find_latent_representation.py +0 -145
  17. {gsmap-1.66 → gsmap-1.67}/.github/workflows/publish-to-pypi.yml +0 -0
  18. {gsmap-1.66 → gsmap-1.67}/.gitignore +0 -0
  19. {gsmap-1.66 → gsmap-1.67}/LICENSE +0 -0
  20. {gsmap-1.66 → gsmap-1.67}/README.md +0 -0
  21. {gsmap-1.66 → gsmap-1.67}/docs/Makefile +0 -0
  22. {gsmap-1.66 → gsmap-1.67}/docs/make.bat +0 -0
  23. {gsmap-1.66 → gsmap-1.67}/docs/requirements.txt +0 -0
  24. {gsmap-1.66 → gsmap-1.67}/docs/source/_static/schematic.svg +0 -0
  25. {gsmap-1.66 → gsmap-1.67}/docs/source/api/cauchy_combination.rst +0 -0
  26. {gsmap-1.66 → gsmap-1.67}/docs/source/api/find_latent_representations.rst +0 -0
  27. {gsmap-1.66 → gsmap-1.67}/docs/source/api/format_sumstats.rst +0 -0
  28. {gsmap-1.66 → gsmap-1.67}/docs/source/api/generate_ldscore.rst +0 -0
  29. {gsmap-1.66 → gsmap-1.67}/docs/source/api/latent_to_gene.rst +0 -0
  30. {gsmap-1.66 → gsmap-1.67}/docs/source/api/quick_mode.rst +0 -0
  31. {gsmap-1.66 → gsmap-1.67}/docs/source/api/report.rst +0 -0
  32. {gsmap-1.66 → gsmap-1.67}/docs/source/api/spatial_ldsc.rst +0 -0
  33. {gsmap-1.66 → gsmap-1.67}/docs/source/api.rst +0 -0
  34. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_Height.json +0 -0
  35. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_IQ.json +0 -0
  36. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_MCHC.json +0 -0
  37. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_SCZ.json +0 -0
  38. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_Height.json +0 -0
  39. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_IQ.json +0 -0
  40. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_MCHC.json +0 -0
  41. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_SCZ.json +0 -0
  42. {gsmap-1.66 → gsmap-1.67}/docs/source/charts/test.json +0 -0
  43. {gsmap-1.66 → gsmap-1.67}/docs/source/conf.py +0 -0
  44. {gsmap-1.66 → gsmap-1.67}/docs/source/data.rst +0 -0
  45. {gsmap-1.66 → gsmap-1.67}/docs/source/data_format.md +0 -0
  46. {gsmap-1.66 → gsmap-1.67}/docs/source/index.rst +0 -0
  47. {gsmap-1.66 → gsmap-1.67}/docs/source/install.rst +0 -0
  48. {gsmap-1.66 → gsmap-1.67}/docs/source/mouse.rst +0 -0
  49. {gsmap-1.66 → gsmap-1.67}/docs/source/quick_mode.md +0 -0
  50. {gsmap-1.66 → gsmap-1.67}/docs/source/release.rst +0 -0
  51. {gsmap-1.66 → gsmap-1.67}/docs/source/tutorials.rst +0 -0
  52. {gsmap-1.66 → gsmap-1.67}/schematic.png +0 -0
  53. {gsmap-1.66 → gsmap-1.67}/src/gsMap/GNN_VAE/__init__.py +0 -0
  54. {gsmap-1.66 → gsmap-1.67}/src/gsMap/__main__.py +0 -0
  55. {gsmap-1.66 → gsmap-1.67}/src/gsMap/cauchy_combination_test.py +0 -0
  56. {gsmap-1.66 → gsmap-1.67}/src/gsMap/diagnosis.py +0 -0
  57. {gsmap-1.66 → gsmap-1.67}/src/gsMap/generate_ldscore.py +0 -0
  58. {gsmap-1.66 → gsmap-1.67}/src/gsMap/main.py +0 -0
  59. {gsmap-1.66 → gsmap-1.67}/src/gsMap/report.py +0 -0
  60. {gsmap-1.66 → gsmap-1.67}/src/gsMap/run_all_mode.py +0 -0
  61. {gsmap-1.66 → gsmap-1.67}/src/gsMap/setup.py +0 -0
  62. {gsmap-1.66 → gsmap-1.67}/src/gsMap/templates/report_template.html +0 -0
  63. {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/__init__.py +0 -0
  64. {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/generate_r2_matrix.py +0 -0
  65. {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/jackknife.py +0 -0
  66. {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/make_annotations.py +0 -0
  67. {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/manhattan_plot.py +0 -0
  68. {gsmap-1.66 → gsmap-1.67}/src/gsMap/utils/regression_read.py +0 -0
  69. {gsmap-1.66 → gsmap-1.67}/src/gsMap/visualize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gsMap
3
- Version: 1.66
3
+ Version: 1.67
4
4
  Summary: Genetics-informed pathogenic spatial mapping
5
5
  Author-email: liyang <songliyang@westlake.edu.cn>, wenhao <chenwenhao@westlake.edu.cn>
6
6
  Requires-Python: >=3.8
@@ -27,7 +27,7 @@ Requires-Dist: pyfiglet
27
27
  Requires-Dist: plotly
28
28
  Requires-Dist: kaleido
29
29
  Requires-Dist: jinja2
30
- Requires-Dist: scanpy
30
+ Requires-Dist: scanpy >=1.8.0
31
31
  Requires-Dist: zarr
32
32
  Requires-Dist: bitarray
33
33
  Requires-Dist: pyarrow
@@ -209,7 +209,7 @@ gsmap run_cauchy_combination \
209
209
  --annotation 'annotation'
210
210
  ```
211
211
 
212
- ### 6. report generation
212
+ ### 6. report generation (optional)
213
213
 
214
214
  **Objective**: Generate gsMap reports, including visualizations of mapping results and diagnostic plots.
215
215
 
@@ -35,7 +35,7 @@ dependencies = [
35
35
  'plotly',
36
36
  'kaleido',
37
37
  'jinja2',
38
- 'scanpy',
38
+ 'scanpy >=1.8.0',
39
39
  'zarr',
40
40
  'bitarray',
41
41
  'pyarrow',
@@ -0,0 +1,75 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import scipy.sparse as sp
4
+ from sklearn.neighbors import NearestNeighbors
5
+ import torch
6
+
7
+ def cal_spatial_net(adata, n_neighbors=5, verbose=True):
8
+ """Construct the spatial neighbor network."""
9
+ if verbose:
10
+ print('------Calculating spatial graph...')
11
+ coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
12
+ nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
13
+ distances, indices = nbrs.kneighbors(coor)
14
+ n_cells, n_neighbors = indices.shape
15
+ cell_indices = np.arange(n_cells)
16
+ cell1 = np.repeat(cell_indices, n_neighbors)
17
+ cell2 = indices.flatten()
18
+ distance = distances.flatten()
19
+ knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
20
+ knn_df = knn_df[knn_df['Distance'] > 0].copy()
21
+ cell_id_map = dict(zip(cell_indices, coor.index))
22
+ knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
23
+ knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
24
+ return knn_df
25
+
26
+ def sparse_mx_to_torch_sparse_tensor(sparse_mx):
27
+ """Convert a scipy sparse matrix to a torch sparse tensor."""
28
+ sparse_mx = sparse_mx.tocoo().astype(np.float32)
29
+ indices = torch.from_numpy(
30
+ np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
31
+ )
32
+ values = torch.from_numpy(sparse_mx.data)
33
+ shape = torch.Size(sparse_mx.shape)
34
+ return torch.sparse.FloatTensor(indices, values, shape)
35
+
36
+ def preprocess_graph(adj):
37
+ """Symmetrically normalize the adjacency matrix."""
38
+ adj = sp.coo_matrix(adj)
39
+ adj_ = adj + sp.eye(adj.shape[0])
40
+ rowsum = np.array(adj_.sum(1)).flatten()
41
+ degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
42
+ adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
43
+ return sparse_mx_to_torch_sparse_tensor(adj_normalized)
44
+
45
+ def construct_adjacency_matrix(adata, params, verbose=True):
46
+ """Construct the adjacency matrix from spatial data."""
47
+ spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
48
+ if verbose:
49
+ num_edges = spatial_net.shape[0]
50
+ num_cells = adata.n_obs
51
+ print(f'The graph contains {num_edges} edges, {num_cells} cells.')
52
+ print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
53
+ cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
54
+ spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
55
+ spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
56
+ if params.weighted_adj:
57
+ distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
58
+ weights = np.exp(-0.5 * distance_normalized ** 2)
59
+ adj_org = sp.coo_matrix(
60
+ (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
61
+ shape=(adata.n_obs, adata.n_obs)
62
+ )
63
+ else:
64
+ adj_org = sp.coo_matrix(
65
+ (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
66
+ shape=(adata.n_obs, adata.n_obs)
67
+ )
68
+ adj_norm = preprocess_graph(adj_org)
69
+ norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
70
+ graph_dict = {
71
+ "adj_org": adj_org,
72
+ "adj_norm": adj_norm,
73
+ "norm_value": norm_value
74
+ }
75
+ return graph_dict
@@ -0,0 +1,89 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.nn import GATConv
5
+
6
+ def full_block(in_features, out_features, p_drop):
7
+ return nn.Sequential(
8
+ nn.Linear(in_features, out_features),
9
+ nn.BatchNorm1d(out_features),
10
+ nn.ELU(),
11
+ nn.Dropout(p=p_drop)
12
+ )
13
+
14
+ class GATModel(nn.Module):
15
+ def __init__(self, input_dim, params, num_classes=1):
16
+ super().__init__()
17
+ self.var = params.var
18
+ self.num_classes = num_classes
19
+ self.params = params
20
+
21
+ # Encoder
22
+ self.encoder = nn.Sequential(
23
+ full_block(input_dim, params.feat_hidden1, params.p_drop),
24
+ full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
25
+ )
26
+
27
+ # GAT Layers
28
+ self.gat1 = GATConv(
29
+ in_channels=params.feat_hidden2,
30
+ out_channels=params.gat_hidden1,
31
+ heads=params.nheads,
32
+ dropout=params.p_drop
33
+ )
34
+ self.gat2 = GATConv(
35
+ in_channels=params.gat_hidden1 * params.nheads,
36
+ out_channels=params.gat_hidden2,
37
+ heads=1,
38
+ concat=False,
39
+ dropout=params.p_drop
40
+ )
41
+ if self.var:
42
+ self.gat3 = GATConv(
43
+ in_channels=params.gat_hidden1 * params.nheads,
44
+ out_channels=params.gat_hidden2,
45
+ heads=1,
46
+ concat=False,
47
+ dropout=params.p_drop
48
+ )
49
+
50
+ # Decoder
51
+ self.decoder = nn.Sequential(
52
+ full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
53
+ full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
54
+ nn.Linear(params.feat_hidden1, input_dim)
55
+ )
56
+
57
+ # Clustering Layer
58
+ self.cluster = nn.Sequential(
59
+ full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
60
+ nn.Linear(params.feat_hidden2, self.num_classes)
61
+ )
62
+
63
+ def encode(self, x, edge_index):
64
+ x = self.encoder(x)
65
+ x = self.gat1(x, edge_index)
66
+ x = F.relu(x)
67
+ x = F.dropout(x, p=self.params.p_drop, training=self.training)
68
+
69
+ mu = self.gat2(x, edge_index)
70
+ if self.var:
71
+ logvar = self.gat3(x, edge_index)
72
+ return mu, logvar
73
+ else:
74
+ return mu, None
75
+
76
+ def reparameterize(self, mu, logvar):
77
+ if self.training and logvar is not None:
78
+ std = torch.exp(0.5 * logvar)
79
+ eps = torch.randn_like(std)
80
+ return eps * std + mu
81
+ else:
82
+ return mu
83
+
84
+ def forward(self, x, edge_index):
85
+ mu, logvar = self.encode(x, edge_index)
86
+ z = self.reparameterize(mu, logvar)
87
+ x_reconstructed = self.decoder(z)
88
+ pred_label = F.softmax(self.cluster(z), dim=1)
89
+ return pred_label, x_reconstructed, z, mu, logvar
@@ -0,0 +1,86 @@
1
+ import logging
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from progress.bar import Bar
7
+
8
+ from gsMap.GNN_VAE.model import GATModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def reconstruction_loss(decoded, x):
14
+ """Compute the mean squared error loss."""
15
+ return F.mse_loss(decoded, x)
16
+
17
+
18
+ def label_loss(pred_label, true_label):
19
+ """Compute the cross-entropy loss."""
20
+ return F.cross_entropy(pred_label, true_label)
21
+
22
+ class ModelTrainer:
23
+ def __init__(self, node_x, graph_dict, params, label=None):
24
+ """Initialize the ModelTrainer with data and hyperparameters."""
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ self.params = params
27
+ self.epochs = params.epochs
28
+ self.node_x = torch.FloatTensor(node_x).to(self.device)
29
+ self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
30
+ self.label = label
31
+ self.num_classes = 1
32
+
33
+ if self.label is not None:
34
+ self.label = torch.tensor(self.label).to(self.device)
35
+ self.num_classes = len(torch.unique(self.label))
36
+
37
+ # Set up the model
38
+ self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
39
+ self.optimizer = torch.optim.Adam(
40
+ self.model.parameters(),
41
+ lr=self.params.gat_lr,
42
+ weight_decay=self.params.gcn_decay
43
+ )
44
+
45
+ def run_train(self):
46
+ """Train the model."""
47
+ self.model.train()
48
+ prev_loss = float('inf')
49
+ bar = Bar('GAT-AE model train:', max=self.epochs)
50
+ bar.check_tty = False
51
+
52
+ logger.info('Start training...')
53
+ for epoch in range(self.epochs):
54
+ start_time = time.time()
55
+ self.optimizer.zero_grad()
56
+ pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
57
+ loss_rec = reconstruction_loss(de_feat, self.node_x)
58
+
59
+ if self.label is not None:
60
+ loss_pre = label_loss(pred_label, self.label)
61
+ loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
62
+ else:
63
+ loss = loss_rec
64
+
65
+ loss.backward()
66
+ self.optimizer.step()
67
+
68
+ batch_time = time.time() - start_time
69
+ left_time = batch_time * (self.epochs - epoch - 1) / 60 # in minutes
70
+
71
+ bar.suffix = f'{epoch + 1} / {self.epochs} | Left time: {left_time:.2f} mins | Loss: {loss.item():.4f}'
72
+ bar.next()
73
+
74
+ if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
75
+ logger.info('\nConvergence reached. Training stopped.')
76
+ break
77
+
78
+ prev_loss = loss.item()
79
+ bar.finish()
80
+
81
+ def get_latent(self):
82
+ """Retrieve the latent representation from the model."""
83
+ self.model.eval()
84
+ with torch.no_grad():
85
+ _, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
86
+ return latent_z.cpu().numpy()
@@ -2,4 +2,4 @@
2
2
  Genetics-informed pathogenic spatial mapping
3
3
  '''
4
4
 
5
- __version__ = '1.66'
5
+ __version__ = '1.67'
@@ -55,7 +55,8 @@ def add_find_latent_representations_args(parser):
55
55
  add_shared_args(parser)
56
56
  parser.add_argument('--input_hdf5_path', required=True, type=str, help='Path to the input HDF5 file.')
57
57
  parser.add_argument('--annotation', required=True, type=str, help='Name of the annotation in adata.obs to use.')
58
- parser.add_argument('--data_layer', required=True, type=str, help='Data layer for gene expression (e.g., "counts", "log1p").')
58
+ parser.add_argument('--data_layer', type=str, default='counts', required=True,
59
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
59
60
  parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs.')
60
61
  parser.add_argument('--feat_hidden1', type=int, default=256, help='Neurons in the first hidden layer.')
61
62
  parser.add_argument('--feat_hidden2', type=int, default=128, help='Neurons in the second hidden layer.')
@@ -66,7 +67,6 @@ def add_find_latent_representations_args(parser):
66
67
  parser.add_argument('--n_neighbors', type=int, default=11, help='Number of neighbors for GAT.')
67
68
  parser.add_argument('--n_comps', type=int, default=300, help='Number of principal components for PCA.')
68
69
  parser.add_argument('--weighted_adj', action='store_true', help='Use weighted adjacency in GAT.')
69
- parser.add_argument('--var', action='store_true', help='Enable variance calculations.')
70
70
  parser.add_argument('--convergence_threshold', type=float, default=1e-4, help='Threshold for convergence.')
71
71
  parser.add_argument('--hierarchically', action='store_true', help='Enable hierarchical latent representation finding.')
72
72
 
@@ -236,8 +236,8 @@ def add_run_all_mode_args(parser):
236
236
  help='Path to the input spatial transcriptomics data (H5AD format).')
237
237
  parser.add_argument('--annotation', type=str, required=True,
238
238
  help='Name of the annotation in adata.obs to use.')
239
- parser.add_argument('--data_layer', type=str, default='X',
240
- help='Data layer of h5ad for gene expression (e.g., "counts", "log1p", "X").')
239
+ parser.add_argument('--data_layer', type=str, default='counts', required=True,
240
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
241
241
 
242
242
  # GWAS Data Parameters
243
243
  parser.add_argument('--trait_name', type=str, help='Name of the trait for GWAS analysis (required if sumstats_file is provided).')
@@ -0,0 +1,145 @@
1
+ import logging
2
+ import random
3
+ import numpy as np
4
+ import scanpy as sc
5
+ import torch
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from gsMap.GNN_VAE.adjacency_matrix import construct_adjacency_matrix
9
+ from gsMap.GNN_VAE.train import ModelTrainer
10
+ from gsMap.config import FindLatentRepresentationsConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def set_seed(seed_value):
16
+ """
17
+ Set seed for reproducibility in PyTorch and other libraries.
18
+ """
19
+ torch.manual_seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ random.seed(seed_value)
22
+ if torch.cuda.is_available():
23
+ logger.info('Using GPU for computations.')
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
26
+ else:
27
+ logger.info('Using CPU for computations.')
28
+
29
+ def preprocess_data(adata, params):
30
+ """
31
+ Preprocess the AnnData
32
+ """
33
+ logger.info('Preprocessing data...')
34
+ adata.var_names_make_unique()
35
+
36
+ sc.pp.filter_genes(adata, min_cells=30)
37
+ if params.data_layer in adata.layers.keys():
38
+ adata.X = adata.layers[params.data_layer]
39
+ else:
40
+ raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
41
+
42
+ if params.data_layer in ['count', 'counts']:
43
+
44
+ sc.pp.normalize_total(adata, target_sum=1e4)
45
+ sc.pp.log1p(adata)
46
+
47
+ # Identify highly variable genes
48
+ sc.pp.highly_variable_genes(
49
+ adata,
50
+ flavor="seurat_v3",
51
+ n_top_genes=params.feat_cell,
52
+ )
53
+
54
+ elif params.data_layer in adata.layers.keys():
55
+ logger.info(f'Using {params.data_layer} data...')
56
+ sc.pp.highly_variable_genes(
57
+ adata,
58
+ flavor="seurat",
59
+ n_top_genes=params.feat_cell,
60
+ )
61
+
62
+ return adata
63
+
64
+
65
+ class LatentRepresentationFinder:
66
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
67
+ self.params = args
68
+
69
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
70
+
71
+ if self.params.data_layer in ['count', 'counts']:
72
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
73
+
74
+ # Construct the neighboring graph
75
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
76
+
77
+ def compute_pca(self):
78
+ self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
79
+ return self.latent_pca
80
+
81
+ def run_gnn_vae(self, label, verbose='whole ST data'):
82
+
83
+ # Use PCA if specified
84
+ if self.params.input_pca:
85
+ node_X = self.compute_pca()
86
+ else:
87
+ node_X = self.expression_array
88
+
89
+ # Update the input shape
90
+ self.params.n_nodes = node_X.shape[0]
91
+ self.params.feat_cell = node_X.shape[1]
92
+
93
+ # Run GNN
94
+ logger.info(f'Finding latent representations for {verbose}...')
95
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
96
+ gvae.run_train()
97
+
98
+ del self.graph_dict
99
+
100
+ return gvae.get_latent()
101
+
102
+
103
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
104
+ set_seed(2024)
105
+
106
+ # Load the ST data
107
+ logger.info(f'Loading ST data of {args.sample_name}...')
108
+ adata = sc.read_h5ad(args.input_hdf5_path)
109
+ logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
110
+
111
+ # Load the cell type annotation
112
+ if args.annotation is not None:
113
+ # Remove cells without enough annotations
114
+ adata = adata[~adata.obs[args.annotation].isnull()]
115
+ num = adata.obs[args.annotation].value_counts()
116
+ valid_annotations = num[num >= 30].index.to_list()
117
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
118
+
119
+ le = LabelEncoder()
120
+ adata.obs['categorical_label'] = le.fit_transform(adata.obs[args.annotation])
121
+ label = adata.obs['categorical_label'].to_numpy()
122
+ else:
123
+ label = None
124
+
125
+ # Preprocess data
126
+ adata = preprocess_data(adata, args)
127
+
128
+ latent_rep = LatentRepresentationFinder(adata, args)
129
+ latent_gvae = latent_rep.run_gnn_vae(label)
130
+ latent_pca = latent_rep.compute_pca()
131
+
132
+ # Add latent representations to the AnnData object
133
+ logger.info('Adding latent representations...')
134
+ adata.obsm["latent_GVAE"] = latent_gvae
135
+ adata.obsm["latent_PCA"] = latent_pca
136
+
137
+ # Run UMAP based on latent representations
138
+ for name in ['latent_GVAE', 'latent_PCA']:
139
+ sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
140
+ sc.tl.umap(adata)
141
+ adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
142
+
143
+ # Save the AnnData object
144
+ logger.info('Saving ST data...')
145
+ adata.write(args.hdf5_with_latent_path)
@@ -150,10 +150,10 @@ def gwas_checkname(gwas, config):
150
150
  'Pos': 'SNP positions.'
151
151
  }
152
152
 
153
- print(f'\nIterpreting column names as follows:')
153
+ logger.info(f'\nIterpreting column names as follows:')
154
154
  for key, value in interpreting.items():
155
155
  if key in new_name:
156
- print(f'{name_dict[key]}: {interpreting[key]}')
156
+ logger.info(f'{name_dict[key]}: {interpreting[key]}')
157
157
 
158
158
  return gwas
159
159
 
@@ -242,7 +242,7 @@ def gwas_qc(gwas, config):
242
242
  Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
243
243
  '''
244
244
  old = len(gwas)
245
- print(f'\nFiltering SNPs as follows:')
245
+ logger.info(f'\nFiltering SNPs as follows:')
246
246
  # filter: SNPs with missing values
247
247
  drops = {'NA': 0, 'P': 0, 'INFO': 0, 'FRQ': 0, 'A': 0, 'SNP': 0, 'Dup': 0, 'N': 0}
248
248
 
@@ -250,28 +250,28 @@ def gwas_qc(gwas, config):
250
250
  lambda x: x != 'INFO', gwas.columns)).reset_index(drop=True)
251
251
 
252
252
  drops['NA'] = old - len(gwas)
253
- print(f'Removed {drops["NA"]} SNPs with missing values.')
253
+ logger.info(f'Removed {drops["NA"]} SNPs with missing values.')
254
254
 
255
255
  # filter: SNPs with Info < 0.9
256
256
  if 'INFO' in gwas.columns:
257
257
  old = len(gwas)
258
258
  gwas = gwas.loc[filter_info(gwas['INFO'], config)]
259
259
  drops['INFO'] = old - len(gwas)
260
- print(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
260
+ logger.info(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
261
261
 
262
262
  # filter: SNPs with MAF <= 0.01
263
263
  if 'FRQ' in gwas.columns:
264
264
  old = len(gwas)
265
265
  gwas = gwas.loc[filter_frq(gwas['FRQ'], config)]
266
266
  drops['FRQ'] += old - len(gwas)
267
- print(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
267
+ logger.info(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
268
268
 
269
269
  # filter: P-value that out-of-bounds [0,1]
270
270
  if 'P' in gwas.columns:
271
271
  old = len(gwas)
272
272
  gwas = gwas.loc[filter_pvals(gwas['P'], config)]
273
273
  drops['P'] += old - len(gwas)
274
- print(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
274
+ logger.info(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
275
275
 
276
276
  # filter: Variants that are strand-ambiguous
277
277
  if 'A1' in gwas.columns and 'A2' in gwas.columns:
@@ -279,21 +279,21 @@ def gwas_qc(gwas, config):
279
279
  gwas.A2 = gwas.A2.str.upper()
280
280
  gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
281
281
  drops['A'] += old - len(gwas)
282
- print(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
282
+ logger.info(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
283
283
 
284
284
  # filter: Duplicated rs numbers
285
285
  if 'SNP' in gwas.columns:
286
286
  old = len(gwas)
287
287
  gwas = gwas.drop_duplicates(subset='SNP').reset_index(drop=True)
288
288
  drops['Dup'] += old - len(gwas)
289
- print(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
289
+ logger.info(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
290
290
 
291
291
  # filter:Sample size
292
292
  n_min = gwas.N.quantile(0.9) / 1.5
293
293
  old = len(gwas)
294
294
  gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
295
295
  drops['N'] += old - len(gwas)
296
- print(f'Removed {drops["N"]} SNPs with N < {n_min}.')
296
+ logger.info(f'Removed {drops["N"]} SNPs with N < {n_min}.')
297
297
 
298
298
  return gwas
299
299
 
@@ -302,7 +302,7 @@ def variant_to_rsid(gwas, config):
302
302
  '''
303
303
  Convert variant id (Chr, Pos) to rsid
304
304
  '''
305
- print("\nConverting the SNP position to rsid. This process may take some time.")
305
+ logger.info("\nConverting the SNP position to rsid. This process may take some time.")
306
306
  unique_ids = set(gwas['id'])
307
307
  chr_format = gwas['Chr'].unique().astype(str)
308
308
  chr_format = [re.sub(r'\d+', '', value) for value in chr_format][1]
@@ -347,7 +347,7 @@ def clean_SNP_id(gwas, config):
347
347
  gwas = gwas.loc[matching_id.id]
348
348
  gwas['SNP'] = matching_id.dbsnp
349
349
  num_fail = old - len(gwas)
350
- print(f'Removed {num_fail} SNPs that did not convert to rsid.')
350
+ logger.info(f'Removed {num_fail} SNPs that did not convert to rsid.')
351
351
 
352
352
  return gwas
353
353
 
@@ -356,27 +356,27 @@ def gwas_metadata(gwas, config):
356
356
  '''
357
357
  Report key features of GWAS data
358
358
  '''
359
- print('\nMetadata:')
359
+ logger.info('\nSummary of GWAS data:')
360
360
  CHISQ = (gwas.Z ** 2)
361
361
  mean_chisq = CHISQ.mean()
362
- print('Mean chi^2 = ' + str(round(mean_chisq, 3)))
362
+ logger.info('Mean chi^2 = ' + str(round(mean_chisq, 3)))
363
363
  if mean_chisq < 1.02:
364
364
  logger.warning("Mean chi^2 may be too small.")
365
365
 
366
- print('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
367
- print('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
368
- print('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
366
+ logger.info('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
367
+ logger.info('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
368
+ logger.info('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
369
369
 
370
370
 
371
371
  def gwas_format(config: FormatSumstatsConfig):
372
372
  '''
373
373
  Format GWAS data
374
374
  '''
375
- print(f'------Formating gwas data for {config.sumstats}...')
375
+ logger.info(f'------Formating gwas data for {config.sumstats}...')
376
376
  compression_type = get_compression(config.sumstats)
377
377
  gwas = pd.read_csv(config.sumstats, delim_whitespace=True, header=0, compression=compression_type,
378
378
  na_values=['.', 'NA'])
379
- print(f'Read {len(gwas)} SNPs from {config.sumstats}.')
379
+ logger.info(f'Read {len(gwas)} SNPs from {config.sumstats}.')
380
380
 
381
381
  # Check name and format
382
382
  gwas = gwas_checkname(gwas, config)
@@ -402,6 +402,6 @@ def gwas_format(config: FormatSumstatsConfig):
402
402
  gwas = gwas[keep]
403
403
  out_name = config.out + appendix + '.gz'
404
404
 
405
- print(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
405
+ logger.info(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
406
406
  gwas.to_csv(out_name, sep="\t", index=False,
407
407
  float_format='%.3f', compression='gzip')
@@ -4,12 +4,10 @@ from pathlib import Path
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
  import scanpy as sc
7
- from scipy.sparse import csr_matrix
8
7
  from scipy.stats import gmean
9
8
  from scipy.stats import rankdata
10
9
  from sklearn.metrics.pairwise import cosine_similarity
11
10
  from sklearn.neighbors import NearestNeighbors
12
- from joblib import Parallel, delayed
13
11
  from tqdm import tqdm
14
12
 
15
13
  from gsMap.config import LatentToGeneConfig
@@ -152,11 +150,6 @@ def run_latent_to_gene(config: LatentToGeneConfig):
152
150
  adata.var_names = homologs.loc[adata.var_names, 'HUMAN_GENE_SYM'].values
153
151
  adata = adata[:, ~adata.var_names.duplicated()]
154
152
 
155
- # Remove cells and genes that are not expressed
156
- logger.info(f'Number of cells, genes of the input data: {adata.shape[0]},{adata.shape[1]}')
157
- adata = adata[adata.X.sum(axis=1) > 0, adata.X.sum(axis=0) > 0]
158
- logger.info(f'Number of cells, genes after transformation: {adata.shape[0]},{adata.shape[1]}')
159
-
160
153
  # Create mappings
161
154
  n_cells = adata.n_obs
162
155
  n_genes = adata.n_vars
@@ -20,8 +20,6 @@ logger = logging.getLogger('gsMap.spatial_ldsc')
20
20
 
21
21
  # %%
22
22
  def _coef_new(jknife):
23
- # return coef[0], coef_se[0], z[0]]
24
- # est_ = jknife.est[0, 0] / Nbar
25
23
  est_ = jknife.jknife_est[0, 0] / Nbar
26
24
  se_ = jknife.jknife_se[0, 0] / Nbar
27
25
  return est_, se_
@@ -1,95 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Tue Jul 4 21:31:27 2023
5
-
6
- @author: songliyang
7
- """
8
- import numpy as np
9
- import pandas as pd
10
- import scipy.sparse as sp
11
- import sklearn.neighbors
12
- import torch
13
-
14
-
15
- def Cal_Spatial_Net(adata, n_neighbors=5, verbose=True):
16
- """\
17
- Construct the spatial neighbor networks.
18
- """
19
- #-
20
- if verbose:
21
- print('------Calculating spatial graph...')
22
- coor = pd.DataFrame(adata.obsm['spatial'])
23
- coor.index = adata.obs.index
24
- #-
25
- nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
26
- #-
27
- distances, indices = nbrs.kneighbors(coor, return_distance=True)
28
- KNN_list = []
29
- for it in range(indices.shape[0]):
30
- KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
31
- #-
32
- KNN_df = pd.concat(KNN_list)
33
- KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
34
- #-
35
- Spatial_Net = KNN_df.copy()
36
- Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
37
- id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
38
- Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
39
- Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
40
- #-
41
- return Spatial_Net
42
-
43
-
44
- def sparse_mx_to_torch_sparse_tensor(sparse_mx):
45
- """Convert a scipy sparse matrix to a torch sparse tensor."""
46
- sparse_mx = sparse_mx.tocoo().astype(np.float32)
47
- indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
48
- values = torch.from_numpy(sparse_mx.data)
49
- shape = torch.Size(sparse_mx.shape)
50
- return torch.sparse.FloatTensor(indices, values, shape)
51
-
52
-
53
- def preprocess_graph(adj):
54
- adj = sp.coo_matrix(adj)
55
- adj_ = adj + sp.eye(adj.shape[0])
56
- rowsum = np.array(adj_.sum(1))
57
- degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
58
- adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
59
- return sparse_mx_to_torch_sparse_tensor(adj_normalized)
60
-
61
-
62
-
63
- def Construct_Adjacency_Matrix(adata,Params, verbose=True):
64
- # Construct the neighbor graph
65
- Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
66
- #-
67
- if verbose:
68
- print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
69
- print('%.2f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
70
- #-
71
- cells = np.array(adata.obs.index)
72
- cells_id_tran = dict(zip(cells, range(cells.shape[0])))
73
- #-
74
- G_df = Spatial_Net.copy()
75
- G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
76
- G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
77
- #-
78
- if Params.weighted_adj:
79
- distance_normalized = G_df.Distance/(max(G_df.Distance)+1)
80
- adj_org = sp.coo_matrix((np.exp(-distance_normalized**2/(2)), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
81
- else:
82
- adj_org = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
83
- #-
84
- adj_m1 = adj_org
85
- adj_norm_m1 = preprocess_graph(adj_m1)
86
- adj_label_m1 = adj_m1 + sp.eye(adj_m1.shape[0])
87
- norm_m1 = adj_m1.shape[0] * adj_m1.shape[0] / float((adj_m1.shape[0] * adj_m1.shape[0] - adj_m1.sum()) * 2)
88
- #-
89
- graph_dict = {
90
- "adj_org": adj_org,
91
- "adj_norm": adj_norm_m1,
92
- "norm_value": norm_m1
93
- }
94
- #-
95
- return graph_dict
@@ -1,87 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Mon Jul 3 11:42:44 2023
5
-
6
- @author: songliyang
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from torch_geometric.nn import GATConv
13
-
14
-
15
- def full_block(in_features, out_features, p_drop):
16
- return nn.Sequential(nn.Linear(in_features, out_features),
17
- nn.BatchNorm1d(out_features),
18
- nn.ELU(),
19
- nn.Dropout(p=p_drop))
20
-
21
-
22
- class GNN(nn.Module):
23
- def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
24
- super().__init__()
25
- self.conv1 = GATConv(in_features, out_features,heads)
26
- self.act = act
27
- self.dr = dr
28
- #-
29
- def forward(self, x, edge_index):
30
- out = self.conv1(x, edge_index)
31
- out = self.act(out)
32
- out = F.dropout(out, self.dr, self.training)
33
- return out
34
-
35
-
36
-
37
- class GNN_VAE_Model(nn.Module):
38
- def __init__(self, input_dim,params,num_classes=1):
39
- super(GNN_VAE_Model, self).__init__()
40
- self.var = params.var
41
- self.num_classes = num_classes
42
-
43
- # Encoder
44
- self.encoder = nn.Sequential()
45
- self.encoder.add_module('encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
46
- self.encoder.add_module('encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))
47
-
48
- # GNN (GAT)
49
- self.gn1 = GNN(params.feat_hidden2, params.gat_hidden1, params.p_drop, act=F.relu,heads = params.nheads)
50
- self.gn2 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
51
- self.gn3 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
52
-
53
- # Decoder
54
- self.decoder = nn.Sequential()
55
- self.decoder.add_module('decoder_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
56
- self.decoder.add_module('decoder_L2', full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop))
57
- self.decoder.add_module('decoder_output', nn.Sequential(nn.Linear(params.feat_hidden1, input_dim)))
58
-
59
- # Cluster
60
- self.cluster = nn.Sequential()
61
- self.cluster.add_module('cluster_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
62
- self.cluster.add_module('cluster_output', nn.Linear(params.feat_hidden2, self.num_classes))
63
-
64
- def encode(self, x, adj):
65
- feat_x = self.encoder(x)
66
- hidden1 = self.gn1(feat_x, adj)
67
- mu = self.gn2(hidden1, adj)
68
- if self.var:
69
- logvar = self.gn3(hidden1, adj)
70
- return mu, logvar
71
- else:
72
- return mu, None
73
-
74
- def reparameterize(self, mu, logvar):
75
- if self.training and logvar is not None:
76
- std = torch.exp(logvar)
77
- eps = torch.randn_like(std)
78
- return eps.mul(std).add_(mu)
79
- else:
80
- return mu
81
-
82
- def forward(self, x, adj):
83
- mu, logvar = self.encode(x, adj)
84
- gnn_z = self.reparameterize(mu, logvar)
85
- x_reconstructed = self.decoder(gnn_z)
86
- pred_label = F.softmax(self.cluster(gnn_z),dim=1)
87
- return pred_label, x_reconstructed, gnn_z, mu, logvar
@@ -1,97 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Tue Jul 4 19:58:58 2023
5
-
6
- @author: songliyang
7
- """
8
- import time
9
-
10
- import torch
11
- from progress.bar import Bar
12
-
13
- from gsMap.GNN_VAE.model import GNN_VAE_Model
14
-
15
-
16
- def reconstruction_loss(decoded, x):
17
- loss_fn = torch.nn.MSELoss()
18
- loss = loss_fn(decoded, x)
19
- return loss
20
-
21
-
22
- def label_loss(pred_label, true_label):
23
- loss_fn = torch.nn.CrossEntropyLoss()
24
- loss = loss_fn(pred_label, true_label)
25
- return loss
26
-
27
-
28
- class Model_Train:
29
- def __init__(self, node_X, graph_dict, params, label=None):
30
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
- torch.cuda.empty_cache()
32
-
33
- self.params = params
34
- self.device = device
35
- self.epochs = params.epochs
36
- self.node_X = torch.FloatTensor(node_X.copy()).to(device)
37
- self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
38
- self.label = label
39
- self.num_classes = 1
40
-
41
- if not self.label is None:
42
- self.label = torch.tensor(self.label).to(self.device)
43
- self.num_classes = len(self.label.unique())
44
-
45
- # Set Model
46
- self.model = GNN_VAE_Model(self.params.feat_cell,self.params,self.num_classes).to(device)
47
- self.optimizer = torch.optim.Adam(params = list(self.model.parameters()),
48
- lr = self.params.gat_lr, weight_decay = self.params.gcn_decay)
49
-
50
- # Train
51
- def run_train(self):
52
- self.model.train()
53
- prev_loss = float('inf')
54
-
55
- bar = Bar('GAT-AE model train:', max = self.epochs)
56
- bar.check_tty = False
57
- for epoch in range(self.epochs):
58
- start_time = time.time()
59
- self.model.train()
60
- self.optimizer.zero_grad()
61
- pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
62
- loss_rec = reconstruction_loss(de_feat, self.node_X)
63
-
64
- # Check whether annotation was provided
65
- if not self.label is None:
66
- loss_pre = label_loss(pred_label, self.label)
67
- loss = (self.params.rec_w * loss_rec) + (self.params.label_w * loss_pre)
68
- else:
69
- loss = loss_rec
70
-
71
- loss.backward()
72
- self.optimizer.step()
73
-
74
- # Update process
75
- end_time = time.time()
76
- batch_time = end_time - start_time
77
-
78
-
79
- bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
80
- bar.suffix = bar_str.format(epoch + 1,self.epochs,
81
- batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
82
- bar.next()
83
-
84
- # Check convergence
85
- if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
86
- print('\nConvergence reached. Training stopped.')
87
- break
88
-
89
- prev_loss = loss.item()
90
-
91
- bar.finish()
92
- #-
93
- def get_latent(self):
94
- self.model.eval()
95
- pred, de_fea, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
96
- latent_z = latent_z.data.cpu().numpy()
97
- return latent_z
@@ -1,145 +0,0 @@
1
- import logging
2
- import random
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import scanpy as sc
7
- import torch
8
- from sklearn import preprocessing
9
-
10
- from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
11
- from gsMap.GNN_VAE.train import Model_Train
12
- from gsMap.config import FindLatentRepresentationsConfig
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
- def set_seed(seed_value):
17
- """
18
- Set seed for reproducibility in PyTorch.
19
- """
20
- torch.manual_seed(seed_value) # Set the seed for PyTorch
21
- np.random.seed(seed_value) # Set the seed for NumPy
22
- random.seed(seed_value) # Set the seed for Python random module
23
- if torch.cuda.is_available():
24
- logger.info('Running use GPU')
25
- torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
26
- torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
27
- else:
28
- logger.info('Running use CPU')
29
-
30
-
31
- # The class for finding latent representations
32
- class Latent_Representation_Finder:
33
-
34
- def __init__(self, adata, args:FindLatentRepresentationsConfig):
35
- self.adata = adata.copy()
36
- self.Params = args
37
-
38
- # Standard process
39
- if self.Params.data_layer == 'count' or self.Params.data_layer == 'counts':
40
- self.adata.X = self.adata.layers[self.Params.data_layer]
41
- sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
42
- sc.pp.normalize_total(self.adata, target_sum=1e4)
43
- sc.pp.log1p(self.adata)
44
- sc.pp.scale(self.adata)
45
- else:
46
- if self.Params.data_layer != 'X':
47
- self.adata.X = self.adata.layers[self.Params.data_layer]
48
- sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
49
-
50
- def Run_GNN_VAE(self, label, verbose='whole ST data'):
51
-
52
- # Construct the neighbouring graph
53
- graph_dict = Construct_Adjacency_Matrix(self.adata, self.Params)
54
-
55
- # Process the feature matrix
56
- node_X = self.adata[:, self.adata.var.highly_variable].X
57
- logger.info(f'The shape of feature matrix is {node_X.shape}.')
58
- if self.Params.input_pca:
59
- node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
60
-
61
- # Update the input shape
62
- self.Params.n_nodes = node_X.shape[0]
63
- self.Params.feat_cell = node_X.shape[1]
64
-
65
- # Run GNN-VAE
66
- logger.info(f'------Finding latent representations for {verbose}...')
67
- gvae = Model_Train(node_X, graph_dict, self.Params, label)
68
- gvae.run_train()
69
-
70
- return gvae.get_latent()
71
-
72
- def Run_PCA(self):
73
- sc.tl.pca(self.adata)
74
- return self.adata.obsm['X_pca'][:, 0:self.Params.n_comps]
75
-
76
-
77
- def run_find_latent_representation(args:FindLatentRepresentationsConfig):
78
- set_seed(2024)
79
- num_features = args.feat_cell
80
- args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
81
- # Load the ST data
82
- logger.info(f'------Loading ST data of {args.sample_name}...')
83
- adata = sc.read_h5ad(f'{args.input_hdf5_path}')
84
- adata.var_names_make_unique()
85
- adata.X = adata.layers[args.data_layer] if args.data_layer in adata.layers.keys() else adata.X
86
- logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
87
- # Load the cell type annotation
88
- if not args.annotation is None:
89
- # remove cells without enough annotations
90
- adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
91
- num = adata.obs[args.annotation].value_counts()
92
- adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())]
93
-
94
- le = preprocessing.LabelEncoder()
95
- le.fit(adata.obs[args.annotation])
96
- adata.obs['categorical_label'] = le.transform(adata.obs[args.annotation])
97
- label = adata.obs['categorical_label'].to_list()
98
- else:
99
- label = None
100
- # Find latent representations
101
- latent_rep = Latent_Representation_Finder(adata, args)
102
- latent_GVAE = latent_rep.Run_GNN_VAE(label)
103
- latent_PCA = latent_rep.Run_PCA()
104
- # Add latent representations to the spe data
105
- logger.info(f'------Adding latent representations...')
106
- adata.obsm["latent_GVAE"] = latent_GVAE
107
- adata.obsm["latent_PCA"] = latent_PCA
108
- # Run umap based on latent representations
109
- for name in ['latent_GVAE', 'latent_PCA']:
110
- sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
111
- sc.tl.umap(adata)
112
- adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
113
-
114
- # Find the latent representations hierarchically (optionally)
115
- if not args.annotation is None and args.hierarchically:
116
- logger.info(f'------Finding latent representations hierarchically...')
117
- PCA_all = pd.DataFrame()
118
- GVAE_all = pd.DataFrame()
119
-
120
- for ct in adata.obs[args.annotation].unique():
121
- adata_part = adata[adata.obs[args.annotation] == ct, :]
122
- logger.info(adata_part.shape)
123
-
124
- # Find latent representations for the selected ct
125
- latent_rep = Latent_Representation_Finder(adata_part, args)
126
-
127
- latent_PCA_part = pd.DataFrame(latent_rep.Run_PCA())
128
- if adata_part.shape[0] <= args.n_comps:
129
- latent_GVAE_part = latent_PCA_part
130
- else:
131
- latent_GVAE_part = pd.DataFrame(latent_rep.Run_GNN_VAE(label=None, verbose=ct))
132
-
133
- latent_GVAE_part.index = adata_part.obs_names
134
- latent_PCA_part.index = adata_part.obs_names
135
-
136
- GVAE_all = pd.concat((GVAE_all, latent_GVAE_part), axis=0)
137
- PCA_all = pd.concat((PCA_all, latent_PCA_part), axis=0)
138
-
139
- args.feat_cell = num_features
140
-
141
- adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
142
- adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
143
- logger.info(f'------Saving ST data...')
144
- adata.write(args.hdf5_with_latent_path)
145
-
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes