gsMap 1.65__tar.gz → 1.67__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {gsmap-1.65 → gsmap-1.67}/PKG-INFO +2 -2
  2. {gsmap-1.65 → gsmap-1.67}/docs/source/mouse_example.md +1 -1
  3. {gsmap-1.65 → gsmap-1.67}/pyproject.toml +1 -1
  4. gsmap-1.67/src/gsMap/GNN_VAE/adjacency_matrix.py +75 -0
  5. gsmap-1.67/src/gsMap/GNN_VAE/model.py +89 -0
  6. gsmap-1.67/src/gsMap/GNN_VAE/train.py +86 -0
  7. {gsmap-1.65 → gsmap-1.67}/src/gsMap/__init__.py +1 -1
  8. {gsmap-1.65 → gsmap-1.67}/src/gsMap/config.py +4 -4
  9. gsmap-1.67/src/gsMap/find_latent_representation.py +145 -0
  10. {gsmap-1.65 → gsmap-1.67}/src/gsMap/format_sumstats.py +20 -20
  11. gsmap-1.67/src/gsMap/latent_to_gene.py +234 -0
  12. {gsmap-1.65 → gsmap-1.67}/src/gsMap/spatial_ldsc_multiple_sumstats.py +0 -2
  13. gsmap-1.65/src/gsMap/GNN_VAE/adjacency_matrix.py +0 -95
  14. gsmap-1.65/src/gsMap/GNN_VAE/model.py +0 -87
  15. gsmap-1.65/src/gsMap/GNN_VAE/train.py +0 -97
  16. gsmap-1.65/src/gsMap/find_latent_representation.py +0 -145
  17. gsmap-1.65/src/gsMap/latent_to_gene.py +0 -218
  18. {gsmap-1.65 → gsmap-1.67}/.github/workflows/publish-to-pypi.yml +0 -0
  19. {gsmap-1.65 → gsmap-1.67}/.gitignore +0 -0
  20. {gsmap-1.65 → gsmap-1.67}/LICENSE +0 -0
  21. {gsmap-1.65 → gsmap-1.67}/README.md +0 -0
  22. {gsmap-1.65 → gsmap-1.67}/docs/Makefile +0 -0
  23. {gsmap-1.65 → gsmap-1.67}/docs/make.bat +0 -0
  24. {gsmap-1.65 → gsmap-1.67}/docs/requirements.txt +0 -0
  25. {gsmap-1.65 → gsmap-1.67}/docs/source/_static/schematic.svg +0 -0
  26. {gsmap-1.65 → gsmap-1.67}/docs/source/api/cauchy_combination.rst +0 -0
  27. {gsmap-1.65 → gsmap-1.67}/docs/source/api/find_latent_representations.rst +0 -0
  28. {gsmap-1.65 → gsmap-1.67}/docs/source/api/format_sumstats.rst +0 -0
  29. {gsmap-1.65 → gsmap-1.67}/docs/source/api/generate_ldscore.rst +0 -0
  30. {gsmap-1.65 → gsmap-1.67}/docs/source/api/latent_to_gene.rst +0 -0
  31. {gsmap-1.65 → gsmap-1.67}/docs/source/api/quick_mode.rst +0 -0
  32. {gsmap-1.65 → gsmap-1.67}/docs/source/api/report.rst +0 -0
  33. {gsmap-1.65 → gsmap-1.67}/docs/source/api/spatial_ldsc.rst +0 -0
  34. {gsmap-1.65 → gsmap-1.67}/docs/source/api.rst +0 -0
  35. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_Height.json +0 -0
  36. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_IQ.json +0 -0
  37. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_MCHC.json +0 -0
  38. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/cortex/Cortex_151507_SCZ.json +0 -0
  39. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_Height.json +0 -0
  40. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_IQ.json +0 -0
  41. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_MCHC.json +0 -0
  42. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/mouse_embryo/E16.5_E1S1_SCZ.json +0 -0
  43. {gsmap-1.65 → gsmap-1.67}/docs/source/charts/test.json +0 -0
  44. {gsmap-1.65 → gsmap-1.67}/docs/source/conf.py +0 -0
  45. {gsmap-1.65 → gsmap-1.67}/docs/source/data.rst +0 -0
  46. {gsmap-1.65 → gsmap-1.67}/docs/source/data_format.md +0 -0
  47. {gsmap-1.65 → gsmap-1.67}/docs/source/index.rst +0 -0
  48. {gsmap-1.65 → gsmap-1.67}/docs/source/install.rst +0 -0
  49. {gsmap-1.65 → gsmap-1.67}/docs/source/mouse.rst +0 -0
  50. {gsmap-1.65 → gsmap-1.67}/docs/source/quick_mode.md +0 -0
  51. {gsmap-1.65 → gsmap-1.67}/docs/source/release.rst +0 -0
  52. {gsmap-1.65 → gsmap-1.67}/docs/source/tutorials.rst +0 -0
  53. {gsmap-1.65 → gsmap-1.67}/schematic.png +0 -0
  54. {gsmap-1.65 → gsmap-1.67}/src/gsMap/GNN_VAE/__init__.py +0 -0
  55. {gsmap-1.65 → gsmap-1.67}/src/gsMap/__main__.py +0 -0
  56. {gsmap-1.65 → gsmap-1.67}/src/gsMap/cauchy_combination_test.py +0 -0
  57. {gsmap-1.65 → gsmap-1.67}/src/gsMap/diagnosis.py +0 -0
  58. {gsmap-1.65 → gsmap-1.67}/src/gsMap/generate_ldscore.py +0 -0
  59. {gsmap-1.65 → gsmap-1.67}/src/gsMap/main.py +0 -0
  60. {gsmap-1.65 → gsmap-1.67}/src/gsMap/report.py +0 -0
  61. {gsmap-1.65 → gsmap-1.67}/src/gsMap/run_all_mode.py +0 -0
  62. {gsmap-1.65 → gsmap-1.67}/src/gsMap/setup.py +0 -0
  63. {gsmap-1.65 → gsmap-1.67}/src/gsMap/templates/report_template.html +0 -0
  64. {gsmap-1.65 → gsmap-1.67}/src/gsMap/utils/__init__.py +0 -0
  65. {gsmap-1.65 → gsmap-1.67}/src/gsMap/utils/generate_r2_matrix.py +0 -0
  66. {gsmap-1.65 → gsmap-1.67}/src/gsMap/utils/jackknife.py +0 -0
  67. {gsmap-1.65 → gsmap-1.67}/src/gsMap/utils/make_annotations.py +0 -0
  68. {gsmap-1.65 → gsmap-1.67}/src/gsMap/utils/manhattan_plot.py +0 -0
  69. {gsmap-1.65 → gsmap-1.67}/src/gsMap/utils/regression_read.py +0 -0
  70. {gsmap-1.65 → gsmap-1.67}/src/gsMap/visualize.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gsMap
3
- Version: 1.65
3
+ Version: 1.67
4
4
  Summary: Genetics-informed pathogenic spatial mapping
5
5
  Author-email: liyang <songliyang@westlake.edu.cn>, wenhao <chenwenhao@westlake.edu.cn>
6
6
  Requires-Python: >=3.8
@@ -27,7 +27,7 @@ Requires-Dist: pyfiglet
27
27
  Requires-Dist: plotly
28
28
  Requires-Dist: kaleido
29
29
  Requires-Dist: jinja2
30
- Requires-Dist: scanpy
30
+ Requires-Dist: scanpy >=1.8.0
31
31
  Requires-Dist: zarr
32
32
  Requires-Dist: bitarray
33
33
  Requires-Dist: pyarrow
@@ -209,7 +209,7 @@ gsmap run_cauchy_combination \
209
209
  --annotation 'annotation'
210
210
  ```
211
211
 
212
- ### 6. report generation
212
+ ### 6. report generation (optional)
213
213
 
214
214
  **Objective**: Generate gsMap reports, including visualizations of mapping results and diagnostic plots.
215
215
 
@@ -35,7 +35,7 @@ dependencies = [
35
35
  'plotly',
36
36
  'kaleido',
37
37
  'jinja2',
38
- 'scanpy',
38
+ 'scanpy >=1.8.0',
39
39
  'zarr',
40
40
  'bitarray',
41
41
  'pyarrow',
@@ -0,0 +1,75 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import scipy.sparse as sp
4
+ from sklearn.neighbors import NearestNeighbors
5
+ import torch
6
+
7
+ def cal_spatial_net(adata, n_neighbors=5, verbose=True):
8
+ """Construct the spatial neighbor network."""
9
+ if verbose:
10
+ print('------Calculating spatial graph...')
11
+ coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
12
+ nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
13
+ distances, indices = nbrs.kneighbors(coor)
14
+ n_cells, n_neighbors = indices.shape
15
+ cell_indices = np.arange(n_cells)
16
+ cell1 = np.repeat(cell_indices, n_neighbors)
17
+ cell2 = indices.flatten()
18
+ distance = distances.flatten()
19
+ knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
20
+ knn_df = knn_df[knn_df['Distance'] > 0].copy()
21
+ cell_id_map = dict(zip(cell_indices, coor.index))
22
+ knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
23
+ knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
24
+ return knn_df
25
+
26
+ def sparse_mx_to_torch_sparse_tensor(sparse_mx):
27
+ """Convert a scipy sparse matrix to a torch sparse tensor."""
28
+ sparse_mx = sparse_mx.tocoo().astype(np.float32)
29
+ indices = torch.from_numpy(
30
+ np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
31
+ )
32
+ values = torch.from_numpy(sparse_mx.data)
33
+ shape = torch.Size(sparse_mx.shape)
34
+ return torch.sparse.FloatTensor(indices, values, shape)
35
+
36
+ def preprocess_graph(adj):
37
+ """Symmetrically normalize the adjacency matrix."""
38
+ adj = sp.coo_matrix(adj)
39
+ adj_ = adj + sp.eye(adj.shape[0])
40
+ rowsum = np.array(adj_.sum(1)).flatten()
41
+ degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
42
+ adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
43
+ return sparse_mx_to_torch_sparse_tensor(adj_normalized)
44
+
45
+ def construct_adjacency_matrix(adata, params, verbose=True):
46
+ """Construct the adjacency matrix from spatial data."""
47
+ spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
48
+ if verbose:
49
+ num_edges = spatial_net.shape[0]
50
+ num_cells = adata.n_obs
51
+ print(f'The graph contains {num_edges} edges, {num_cells} cells.')
52
+ print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
53
+ cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
54
+ spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
55
+ spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
56
+ if params.weighted_adj:
57
+ distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
58
+ weights = np.exp(-0.5 * distance_normalized ** 2)
59
+ adj_org = sp.coo_matrix(
60
+ (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
61
+ shape=(adata.n_obs, adata.n_obs)
62
+ )
63
+ else:
64
+ adj_org = sp.coo_matrix(
65
+ (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
66
+ shape=(adata.n_obs, adata.n_obs)
67
+ )
68
+ adj_norm = preprocess_graph(adj_org)
69
+ norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
70
+ graph_dict = {
71
+ "adj_org": adj_org,
72
+ "adj_norm": adj_norm,
73
+ "norm_value": norm_value
74
+ }
75
+ return graph_dict
@@ -0,0 +1,89 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch_geometric.nn import GATConv
5
+
6
+ def full_block(in_features, out_features, p_drop):
7
+ return nn.Sequential(
8
+ nn.Linear(in_features, out_features),
9
+ nn.BatchNorm1d(out_features),
10
+ nn.ELU(),
11
+ nn.Dropout(p=p_drop)
12
+ )
13
+
14
+ class GATModel(nn.Module):
15
+ def __init__(self, input_dim, params, num_classes=1):
16
+ super().__init__()
17
+ self.var = params.var
18
+ self.num_classes = num_classes
19
+ self.params = params
20
+
21
+ # Encoder
22
+ self.encoder = nn.Sequential(
23
+ full_block(input_dim, params.feat_hidden1, params.p_drop),
24
+ full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
25
+ )
26
+
27
+ # GAT Layers
28
+ self.gat1 = GATConv(
29
+ in_channels=params.feat_hidden2,
30
+ out_channels=params.gat_hidden1,
31
+ heads=params.nheads,
32
+ dropout=params.p_drop
33
+ )
34
+ self.gat2 = GATConv(
35
+ in_channels=params.gat_hidden1 * params.nheads,
36
+ out_channels=params.gat_hidden2,
37
+ heads=1,
38
+ concat=False,
39
+ dropout=params.p_drop
40
+ )
41
+ if self.var:
42
+ self.gat3 = GATConv(
43
+ in_channels=params.gat_hidden1 * params.nheads,
44
+ out_channels=params.gat_hidden2,
45
+ heads=1,
46
+ concat=False,
47
+ dropout=params.p_drop
48
+ )
49
+
50
+ # Decoder
51
+ self.decoder = nn.Sequential(
52
+ full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
53
+ full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
54
+ nn.Linear(params.feat_hidden1, input_dim)
55
+ )
56
+
57
+ # Clustering Layer
58
+ self.cluster = nn.Sequential(
59
+ full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
60
+ nn.Linear(params.feat_hidden2, self.num_classes)
61
+ )
62
+
63
+ def encode(self, x, edge_index):
64
+ x = self.encoder(x)
65
+ x = self.gat1(x, edge_index)
66
+ x = F.relu(x)
67
+ x = F.dropout(x, p=self.params.p_drop, training=self.training)
68
+
69
+ mu = self.gat2(x, edge_index)
70
+ if self.var:
71
+ logvar = self.gat3(x, edge_index)
72
+ return mu, logvar
73
+ else:
74
+ return mu, None
75
+
76
+ def reparameterize(self, mu, logvar):
77
+ if self.training and logvar is not None:
78
+ std = torch.exp(0.5 * logvar)
79
+ eps = torch.randn_like(std)
80
+ return eps * std + mu
81
+ else:
82
+ return mu
83
+
84
+ def forward(self, x, edge_index):
85
+ mu, logvar = self.encode(x, edge_index)
86
+ z = self.reparameterize(mu, logvar)
87
+ x_reconstructed = self.decoder(z)
88
+ pred_label = F.softmax(self.cluster(z), dim=1)
89
+ return pred_label, x_reconstructed, z, mu, logvar
@@ -0,0 +1,86 @@
1
+ import logging
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from progress.bar import Bar
7
+
8
+ from gsMap.GNN_VAE.model import GATModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def reconstruction_loss(decoded, x):
14
+ """Compute the mean squared error loss."""
15
+ return F.mse_loss(decoded, x)
16
+
17
+
18
+ def label_loss(pred_label, true_label):
19
+ """Compute the cross-entropy loss."""
20
+ return F.cross_entropy(pred_label, true_label)
21
+
22
+ class ModelTrainer:
23
+ def __init__(self, node_x, graph_dict, params, label=None):
24
+ """Initialize the ModelTrainer with data and hyperparameters."""
25
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ self.params = params
27
+ self.epochs = params.epochs
28
+ self.node_x = torch.FloatTensor(node_x).to(self.device)
29
+ self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
30
+ self.label = label
31
+ self.num_classes = 1
32
+
33
+ if self.label is not None:
34
+ self.label = torch.tensor(self.label).to(self.device)
35
+ self.num_classes = len(torch.unique(self.label))
36
+
37
+ # Set up the model
38
+ self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
39
+ self.optimizer = torch.optim.Adam(
40
+ self.model.parameters(),
41
+ lr=self.params.gat_lr,
42
+ weight_decay=self.params.gcn_decay
43
+ )
44
+
45
+ def run_train(self):
46
+ """Train the model."""
47
+ self.model.train()
48
+ prev_loss = float('inf')
49
+ bar = Bar('GAT-AE model train:', max=self.epochs)
50
+ bar.check_tty = False
51
+
52
+ logger.info('Start training...')
53
+ for epoch in range(self.epochs):
54
+ start_time = time.time()
55
+ self.optimizer.zero_grad()
56
+ pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
57
+ loss_rec = reconstruction_loss(de_feat, self.node_x)
58
+
59
+ if self.label is not None:
60
+ loss_pre = label_loss(pred_label, self.label)
61
+ loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
62
+ else:
63
+ loss = loss_rec
64
+
65
+ loss.backward()
66
+ self.optimizer.step()
67
+
68
+ batch_time = time.time() - start_time
69
+ left_time = batch_time * (self.epochs - epoch - 1) / 60 # in minutes
70
+
71
+ bar.suffix = f'{epoch + 1} / {self.epochs} | Left time: {left_time:.2f} mins | Loss: {loss.item():.4f}'
72
+ bar.next()
73
+
74
+ if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
75
+ logger.info('\nConvergence reached. Training stopped.')
76
+ break
77
+
78
+ prev_loss = loss.item()
79
+ bar.finish()
80
+
81
+ def get_latent(self):
82
+ """Retrieve the latent representation from the model."""
83
+ self.model.eval()
84
+ with torch.no_grad():
85
+ _, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
86
+ return latent_z.cpu().numpy()
@@ -2,4 +2,4 @@
2
2
  Genetics-informed pathogenic spatial mapping
3
3
  '''
4
4
 
5
- __version__ = '1.65'
5
+ __version__ = '1.67'
@@ -55,7 +55,8 @@ def add_find_latent_representations_args(parser):
55
55
  add_shared_args(parser)
56
56
  parser.add_argument('--input_hdf5_path', required=True, type=str, help='Path to the input HDF5 file.')
57
57
  parser.add_argument('--annotation', required=True, type=str, help='Name of the annotation in adata.obs to use.')
58
- parser.add_argument('--data_layer', required=True, type=str, help='Data layer for gene expression (e.g., "counts", "log1p").')
58
+ parser.add_argument('--data_layer', type=str, default='counts', required=True,
59
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
59
60
  parser.add_argument('--epochs', type=int, default=300, help='Number of training epochs.')
60
61
  parser.add_argument('--feat_hidden1', type=int, default=256, help='Neurons in the first hidden layer.')
61
62
  parser.add_argument('--feat_hidden2', type=int, default=128, help='Neurons in the second hidden layer.')
@@ -66,7 +67,6 @@ def add_find_latent_representations_args(parser):
66
67
  parser.add_argument('--n_neighbors', type=int, default=11, help='Number of neighbors for GAT.')
67
68
  parser.add_argument('--n_comps', type=int, default=300, help='Number of principal components for PCA.')
68
69
  parser.add_argument('--weighted_adj', action='store_true', help='Use weighted adjacency in GAT.')
69
- parser.add_argument('--var', action='store_true', help='Enable variance calculations.')
70
70
  parser.add_argument('--convergence_threshold', type=float, default=1e-4, help='Threshold for convergence.')
71
71
  parser.add_argument('--hierarchically', action='store_true', help='Enable hierarchical latent representation finding.')
72
72
 
@@ -236,8 +236,8 @@ def add_run_all_mode_args(parser):
236
236
  help='Path to the input spatial transcriptomics data (H5AD format).')
237
237
  parser.add_argument('--annotation', type=str, required=True,
238
238
  help='Name of the annotation in adata.obs to use.')
239
- parser.add_argument('--data_layer', type=str, default='X',
240
- help='Data layer of h5ad for gene expression (e.g., "counts", "log1p", "X").')
239
+ parser.add_argument('--data_layer', type=str, default='counts', required=True,
240
+ help='Data layer for gene expression (e.g., "count", "counts", "log1p").')
241
241
 
242
242
  # GWAS Data Parameters
243
243
  parser.add_argument('--trait_name', type=str, help='Name of the trait for GWAS analysis (required if sumstats_file is provided).')
@@ -0,0 +1,145 @@
1
+ import logging
2
+ import random
3
+ import numpy as np
4
+ import scanpy as sc
5
+ import torch
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from gsMap.GNN_VAE.adjacency_matrix import construct_adjacency_matrix
9
+ from gsMap.GNN_VAE.train import ModelTrainer
10
+ from gsMap.config import FindLatentRepresentationsConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def set_seed(seed_value):
16
+ """
17
+ Set seed for reproducibility in PyTorch and other libraries.
18
+ """
19
+ torch.manual_seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ random.seed(seed_value)
22
+ if torch.cuda.is_available():
23
+ logger.info('Using GPU for computations.')
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
26
+ else:
27
+ logger.info('Using CPU for computations.')
28
+
29
+ def preprocess_data(adata, params):
30
+ """
31
+ Preprocess the AnnData
32
+ """
33
+ logger.info('Preprocessing data...')
34
+ adata.var_names_make_unique()
35
+
36
+ sc.pp.filter_genes(adata, min_cells=30)
37
+ if params.data_layer in adata.layers.keys():
38
+ adata.X = adata.layers[params.data_layer]
39
+ else:
40
+ raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
41
+
42
+ if params.data_layer in ['count', 'counts']:
43
+
44
+ sc.pp.normalize_total(adata, target_sum=1e4)
45
+ sc.pp.log1p(adata)
46
+
47
+ # Identify highly variable genes
48
+ sc.pp.highly_variable_genes(
49
+ adata,
50
+ flavor="seurat_v3",
51
+ n_top_genes=params.feat_cell,
52
+ )
53
+
54
+ elif params.data_layer in adata.layers.keys():
55
+ logger.info(f'Using {params.data_layer} data...')
56
+ sc.pp.highly_variable_genes(
57
+ adata,
58
+ flavor="seurat",
59
+ n_top_genes=params.feat_cell,
60
+ )
61
+
62
+ return adata
63
+
64
+
65
+ class LatentRepresentationFinder:
66
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
67
+ self.params = args
68
+
69
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
70
+
71
+ if self.params.data_layer in ['count', 'counts']:
72
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
73
+
74
+ # Construct the neighboring graph
75
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
76
+
77
+ def compute_pca(self):
78
+ self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
79
+ return self.latent_pca
80
+
81
+ def run_gnn_vae(self, label, verbose='whole ST data'):
82
+
83
+ # Use PCA if specified
84
+ if self.params.input_pca:
85
+ node_X = self.compute_pca()
86
+ else:
87
+ node_X = self.expression_array
88
+
89
+ # Update the input shape
90
+ self.params.n_nodes = node_X.shape[0]
91
+ self.params.feat_cell = node_X.shape[1]
92
+
93
+ # Run GNN
94
+ logger.info(f'Finding latent representations for {verbose}...')
95
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
96
+ gvae.run_train()
97
+
98
+ del self.graph_dict
99
+
100
+ return gvae.get_latent()
101
+
102
+
103
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
104
+ set_seed(2024)
105
+
106
+ # Load the ST data
107
+ logger.info(f'Loading ST data of {args.sample_name}...')
108
+ adata = sc.read_h5ad(args.input_hdf5_path)
109
+ logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
110
+
111
+ # Load the cell type annotation
112
+ if args.annotation is not None:
113
+ # Remove cells without enough annotations
114
+ adata = adata[~adata.obs[args.annotation].isnull()]
115
+ num = adata.obs[args.annotation].value_counts()
116
+ valid_annotations = num[num >= 30].index.to_list()
117
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
118
+
119
+ le = LabelEncoder()
120
+ adata.obs['categorical_label'] = le.fit_transform(adata.obs[args.annotation])
121
+ label = adata.obs['categorical_label'].to_numpy()
122
+ else:
123
+ label = None
124
+
125
+ # Preprocess data
126
+ adata = preprocess_data(adata, args)
127
+
128
+ latent_rep = LatentRepresentationFinder(adata, args)
129
+ latent_gvae = latent_rep.run_gnn_vae(label)
130
+ latent_pca = latent_rep.compute_pca()
131
+
132
+ # Add latent representations to the AnnData object
133
+ logger.info('Adding latent representations...')
134
+ adata.obsm["latent_GVAE"] = latent_gvae
135
+ adata.obsm["latent_PCA"] = latent_pca
136
+
137
+ # Run UMAP based on latent representations
138
+ for name in ['latent_GVAE', 'latent_PCA']:
139
+ sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
140
+ sc.tl.umap(adata)
141
+ adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
142
+
143
+ # Save the AnnData object
144
+ logger.info('Saving ST data...')
145
+ adata.write(args.hdf5_with_latent_path)
@@ -150,10 +150,10 @@ def gwas_checkname(gwas, config):
150
150
  'Pos': 'SNP positions.'
151
151
  }
152
152
 
153
- print(f'\nIterpreting column names as follows:')
153
+ logger.info(f'\nIterpreting column names as follows:')
154
154
  for key, value in interpreting.items():
155
155
  if key in new_name:
156
- print(f'{name_dict[key]}: {interpreting[key]}')
156
+ logger.info(f'{name_dict[key]}: {interpreting[key]}')
157
157
 
158
158
  return gwas
159
159
 
@@ -242,7 +242,7 @@ def gwas_qc(gwas, config):
242
242
  Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
243
243
  '''
244
244
  old = len(gwas)
245
- print(f'\nFiltering SNPs as follows:')
245
+ logger.info(f'\nFiltering SNPs as follows:')
246
246
  # filter: SNPs with missing values
247
247
  drops = {'NA': 0, 'P': 0, 'INFO': 0, 'FRQ': 0, 'A': 0, 'SNP': 0, 'Dup': 0, 'N': 0}
248
248
 
@@ -250,28 +250,28 @@ def gwas_qc(gwas, config):
250
250
  lambda x: x != 'INFO', gwas.columns)).reset_index(drop=True)
251
251
 
252
252
  drops['NA'] = old - len(gwas)
253
- print(f'Removed {drops["NA"]} SNPs with missing values.')
253
+ logger.info(f'Removed {drops["NA"]} SNPs with missing values.')
254
254
 
255
255
  # filter: SNPs with Info < 0.9
256
256
  if 'INFO' in gwas.columns:
257
257
  old = len(gwas)
258
258
  gwas = gwas.loc[filter_info(gwas['INFO'], config)]
259
259
  drops['INFO'] = old - len(gwas)
260
- print(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
260
+ logger.info(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
261
261
 
262
262
  # filter: SNPs with MAF <= 0.01
263
263
  if 'FRQ' in gwas.columns:
264
264
  old = len(gwas)
265
265
  gwas = gwas.loc[filter_frq(gwas['FRQ'], config)]
266
266
  drops['FRQ'] += old - len(gwas)
267
- print(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
267
+ logger.info(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
268
268
 
269
269
  # filter: P-value that out-of-bounds [0,1]
270
270
  if 'P' in gwas.columns:
271
271
  old = len(gwas)
272
272
  gwas = gwas.loc[filter_pvals(gwas['P'], config)]
273
273
  drops['P'] += old - len(gwas)
274
- print(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
274
+ logger.info(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
275
275
 
276
276
  # filter: Variants that are strand-ambiguous
277
277
  if 'A1' in gwas.columns and 'A2' in gwas.columns:
@@ -279,21 +279,21 @@ def gwas_qc(gwas, config):
279
279
  gwas.A2 = gwas.A2.str.upper()
280
280
  gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
281
281
  drops['A'] += old - len(gwas)
282
- print(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
282
+ logger.info(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
283
283
 
284
284
  # filter: Duplicated rs numbers
285
285
  if 'SNP' in gwas.columns:
286
286
  old = len(gwas)
287
287
  gwas = gwas.drop_duplicates(subset='SNP').reset_index(drop=True)
288
288
  drops['Dup'] += old - len(gwas)
289
- print(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
289
+ logger.info(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
290
290
 
291
291
  # filter:Sample size
292
292
  n_min = gwas.N.quantile(0.9) / 1.5
293
293
  old = len(gwas)
294
294
  gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
295
295
  drops['N'] += old - len(gwas)
296
- print(f'Removed {drops["N"]} SNPs with N < {n_min}.')
296
+ logger.info(f'Removed {drops["N"]} SNPs with N < {n_min}.')
297
297
 
298
298
  return gwas
299
299
 
@@ -302,7 +302,7 @@ def variant_to_rsid(gwas, config):
302
302
  '''
303
303
  Convert variant id (Chr, Pos) to rsid
304
304
  '''
305
- print("\nConverting the SNP position to rsid. This process may take some time.")
305
+ logger.info("\nConverting the SNP position to rsid. This process may take some time.")
306
306
  unique_ids = set(gwas['id'])
307
307
  chr_format = gwas['Chr'].unique().astype(str)
308
308
  chr_format = [re.sub(r'\d+', '', value) for value in chr_format][1]
@@ -347,7 +347,7 @@ def clean_SNP_id(gwas, config):
347
347
  gwas = gwas.loc[matching_id.id]
348
348
  gwas['SNP'] = matching_id.dbsnp
349
349
  num_fail = old - len(gwas)
350
- print(f'Removed {num_fail} SNPs that did not convert to rsid.')
350
+ logger.info(f'Removed {num_fail} SNPs that did not convert to rsid.')
351
351
 
352
352
  return gwas
353
353
 
@@ -356,27 +356,27 @@ def gwas_metadata(gwas, config):
356
356
  '''
357
357
  Report key features of GWAS data
358
358
  '''
359
- print('\nMetadata:')
359
+ logger.info('\nSummary of GWAS data:')
360
360
  CHISQ = (gwas.Z ** 2)
361
361
  mean_chisq = CHISQ.mean()
362
- print('Mean chi^2 = ' + str(round(mean_chisq, 3)))
362
+ logger.info('Mean chi^2 = ' + str(round(mean_chisq, 3)))
363
363
  if mean_chisq < 1.02:
364
364
  logger.warning("Mean chi^2 may be too small.")
365
365
 
366
- print('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
367
- print('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
368
- print('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
366
+ logger.info('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
367
+ logger.info('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
368
+ logger.info('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
369
369
 
370
370
 
371
371
  def gwas_format(config: FormatSumstatsConfig):
372
372
  '''
373
373
  Format GWAS data
374
374
  '''
375
- print(f'------Formating gwas data for {config.sumstats}...')
375
+ logger.info(f'------Formating gwas data for {config.sumstats}...')
376
376
  compression_type = get_compression(config.sumstats)
377
377
  gwas = pd.read_csv(config.sumstats, delim_whitespace=True, header=0, compression=compression_type,
378
378
  na_values=['.', 'NA'])
379
- print(f'Read {len(gwas)} SNPs from {config.sumstats}.')
379
+ logger.info(f'Read {len(gwas)} SNPs from {config.sumstats}.')
380
380
 
381
381
  # Check name and format
382
382
  gwas = gwas_checkname(gwas, config)
@@ -402,6 +402,6 @@ def gwas_format(config: FormatSumstatsConfig):
402
402
  gwas = gwas[keep]
403
403
  out_name = config.out + appendix + '.gz'
404
404
 
405
- print(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
405
+ logger.info(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
406
406
  gwas.to_csv(out_name, sep="\t", index=False,
407
407
  float_format='%.3f', compression='gzip')