gsMap 1.67__py3-none-any.whl → 1.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/{GNN_VAE → GNN}/__init__.py +0 -0
- gsMap/{GNN_VAE → GNN}/adjacency_matrix.py +75 -75
- gsMap/{GNN_VAE → GNN}/model.py +89 -89
- gsMap/{GNN_VAE → GNN}/train.py +88 -86
- gsMap/__init__.py +5 -5
- gsMap/__main__.py +2 -2
- gsMap/cauchy_combination_test.py +141 -141
- gsMap/config.py +805 -803
- gsMap/diagnosis.py +273 -273
- gsMap/find_latent_representation.py +133 -145
- gsMap/format_sumstats.py +407 -407
- gsMap/generate_ldscore.py +618 -618
- gsMap/latent_to_gene.py +234 -234
- gsMap/main.py +31 -31
- gsMap/report.py +160 -160
- gsMap/run_all_mode.py +194 -194
- gsMap/setup.py +0 -0
- gsMap/spatial_ldsc_multiple_sumstats.py +380 -380
- gsMap/templates/report_template.html +198 -198
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +735 -735
- gsMap/utils/jackknife.py +514 -514
- gsMap/utils/make_annotations.py +518 -518
- gsMap/utils/manhattan_plot.py +639 -639
- gsMap/utils/regression_read.py +294 -294
- gsMap/visualize.py +198 -198
- {gsmap-1.67.dist-info → gsmap-1.70.dist-info}/LICENSE +21 -21
- {gsmap-1.67.dist-info → gsmap-1.70.dist-info}/METADATA +28 -22
- gsmap-1.70.dist-info/RECORD +31 -0
- gsmap-1.67.dist-info/RECORD +0 -31
- {gsmap-1.67.dist-info → gsmap-1.70.dist-info}/WHEEL +0 -0
- {gsmap-1.67.dist-info → gsmap-1.70.dist-info}/entry_points.txt +0 -0
| 
            File without changes
         | 
| @@ -1,75 +1,75 @@ | |
| 1 | 
            -
            import numpy as np
         | 
| 2 | 
            -
            import pandas as pd
         | 
| 3 | 
            -
            import scipy.sparse as sp
         | 
| 4 | 
            -
            from sklearn.neighbors import NearestNeighbors
         | 
| 5 | 
            -
            import torch
         | 
| 6 | 
            -
             | 
| 7 | 
            -
            def cal_spatial_net(adata, n_neighbors=5, verbose=True):
         | 
| 8 | 
            -
                """Construct the spatial neighbor network."""
         | 
| 9 | 
            -
                if verbose:
         | 
| 10 | 
            -
                    print('------Calculating spatial graph...')
         | 
| 11 | 
            -
                coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
         | 
| 12 | 
            -
                nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
         | 
| 13 | 
            -
                distances, indices = nbrs.kneighbors(coor)
         | 
| 14 | 
            -
                n_cells, n_neighbors = indices.shape
         | 
| 15 | 
            -
                cell_indices = np.arange(n_cells)
         | 
| 16 | 
            -
                cell1 = np.repeat(cell_indices, n_neighbors)
         | 
| 17 | 
            -
                cell2 = indices.flatten()
         | 
| 18 | 
            -
                distance = distances.flatten()
         | 
| 19 | 
            -
                knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
         | 
| 20 | 
            -
                knn_df = knn_df[knn_df['Distance'] > 0].copy()
         | 
| 21 | 
            -
                cell_id_map = dict(zip(cell_indices, coor.index))
         | 
| 22 | 
            -
                knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
         | 
| 23 | 
            -
                knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
         | 
| 24 | 
            -
                return knn_df
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            def sparse_mx_to_torch_sparse_tensor(sparse_mx):
         | 
| 27 | 
            -
                """Convert a scipy sparse matrix to a torch sparse tensor."""
         | 
| 28 | 
            -
                sparse_mx = sparse_mx.tocoo().astype(np.float32)
         | 
| 29 | 
            -
                indices = torch.from_numpy(
         | 
| 30 | 
            -
                    np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
         | 
| 31 | 
            -
                )
         | 
| 32 | 
            -
                values = torch.from_numpy(sparse_mx.data)
         | 
| 33 | 
            -
                shape = torch.Size(sparse_mx.shape)
         | 
| 34 | 
            -
                return | 
| 35 | 
            -
             | 
| 36 | 
            -
            def preprocess_graph(adj):
         | 
| 37 | 
            -
                """Symmetrically normalize the adjacency matrix."""
         | 
| 38 | 
            -
                adj = sp.coo_matrix(adj)
         | 
| 39 | 
            -
                adj_ = adj + sp.eye(adj.shape[0])
         | 
| 40 | 
            -
                rowsum = np.array(adj_.sum(1)).flatten()
         | 
| 41 | 
            -
                degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
         | 
| 42 | 
            -
                adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
         | 
| 43 | 
            -
                return sparse_mx_to_torch_sparse_tensor(adj_normalized)
         | 
| 44 | 
            -
             | 
| 45 | 
            -
            def construct_adjacency_matrix(adata, params, verbose=True):
         | 
| 46 | 
            -
                """Construct the adjacency matrix from spatial data."""
         | 
| 47 | 
            -
                spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
         | 
| 48 | 
            -
                if verbose:
         | 
| 49 | 
            -
                    num_edges = spatial_net.shape[0]
         | 
| 50 | 
            -
                    num_cells = adata.n_obs
         | 
| 51 | 
            -
                    print(f'The graph contains {num_edges} edges, {num_cells} cells.')
         | 
| 52 | 
            -
                    print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
         | 
| 53 | 
            -
                cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
         | 
| 54 | 
            -
                spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
         | 
| 55 | 
            -
                spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
         | 
| 56 | 
            -
                if params.weighted_adj:
         | 
| 57 | 
            -
                    distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
         | 
| 58 | 
            -
                    weights = np.exp(-0.5 * distance_normalized ** 2)
         | 
| 59 | 
            -
                    adj_org = sp.coo_matrix(
         | 
| 60 | 
            -
                        (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
         | 
| 61 | 
            -
                        shape=(adata.n_obs, adata.n_obs)
         | 
| 62 | 
            -
                    )
         | 
| 63 | 
            -
                else:
         | 
| 64 | 
            -
                    adj_org = sp.coo_matrix(
         | 
| 65 | 
            -
                        (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
         | 
| 66 | 
            -
                        shape=(adata.n_obs, adata.n_obs)
         | 
| 67 | 
            -
                    )
         | 
| 68 | 
            -
                adj_norm = preprocess_graph(adj_org)
         | 
| 69 | 
            -
                norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
         | 
| 70 | 
            -
                graph_dict = {
         | 
| 71 | 
            -
                    "adj_org": adj_org,
         | 
| 72 | 
            -
                    "adj_norm": adj_norm,
         | 
| 73 | 
            -
                    "norm_value": norm_value
         | 
| 74 | 
            -
                }
         | 
| 75 | 
            -
                return graph_dict
         | 
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            import scipy.sparse as sp
         | 
| 4 | 
            +
            from sklearn.neighbors import NearestNeighbors
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def cal_spatial_net(adata, n_neighbors=5, verbose=True):
         | 
| 8 | 
            +
                """Construct the spatial neighbor network."""
         | 
| 9 | 
            +
                if verbose:
         | 
| 10 | 
            +
                    print('------Calculating spatial graph...')
         | 
| 11 | 
            +
                coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
         | 
| 12 | 
            +
                nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
         | 
| 13 | 
            +
                distances, indices = nbrs.kneighbors(coor)
         | 
| 14 | 
            +
                n_cells, n_neighbors = indices.shape
         | 
| 15 | 
            +
                cell_indices = np.arange(n_cells)
         | 
| 16 | 
            +
                cell1 = np.repeat(cell_indices, n_neighbors)
         | 
| 17 | 
            +
                cell2 = indices.flatten()
         | 
| 18 | 
            +
                distance = distances.flatten()
         | 
| 19 | 
            +
                knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
         | 
| 20 | 
            +
                knn_df = knn_df[knn_df['Distance'] > 0].copy()
         | 
| 21 | 
            +
                cell_id_map = dict(zip(cell_indices, coor.index))
         | 
| 22 | 
            +
                knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
         | 
| 23 | 
            +
                knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
         | 
| 24 | 
            +
                return knn_df
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            def sparse_mx_to_torch_sparse_tensor(sparse_mx):
         | 
| 27 | 
            +
                """Convert a scipy sparse matrix to a torch sparse tensor."""
         | 
| 28 | 
            +
                sparse_mx = sparse_mx.tocoo().astype(np.float32)
         | 
| 29 | 
            +
                indices = torch.from_numpy(
         | 
| 30 | 
            +
                    np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                values = torch.from_numpy(sparse_mx.data)
         | 
| 33 | 
            +
                shape = torch.Size(sparse_mx.shape)
         | 
| 34 | 
            +
                return  torch.sparse_coo_tensor(indices, values, shape)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def preprocess_graph(adj):
         | 
| 37 | 
            +
                """Symmetrically normalize the adjacency matrix."""
         | 
| 38 | 
            +
                adj = sp.coo_matrix(adj)
         | 
| 39 | 
            +
                adj_ = adj + sp.eye(adj.shape[0])
         | 
| 40 | 
            +
                rowsum = np.array(adj_.sum(1)).flatten()
         | 
| 41 | 
            +
                degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
         | 
| 42 | 
            +
                adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
         | 
| 43 | 
            +
                return sparse_mx_to_torch_sparse_tensor(adj_normalized)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            def construct_adjacency_matrix(adata, params, verbose=True):
         | 
| 46 | 
            +
                """Construct the adjacency matrix from spatial data."""
         | 
| 47 | 
            +
                spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
         | 
| 48 | 
            +
                if verbose:
         | 
| 49 | 
            +
                    num_edges = spatial_net.shape[0]
         | 
| 50 | 
            +
                    num_cells = adata.n_obs
         | 
| 51 | 
            +
                    print(f'The graph contains {num_edges} edges, {num_cells} cells.')
         | 
| 52 | 
            +
                    print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
         | 
| 53 | 
            +
                cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
         | 
| 54 | 
            +
                spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
         | 
| 55 | 
            +
                spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
         | 
| 56 | 
            +
                if params.weighted_adj:
         | 
| 57 | 
            +
                    distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
         | 
| 58 | 
            +
                    weights = np.exp(-0.5 * distance_normalized ** 2)
         | 
| 59 | 
            +
                    adj_org = sp.coo_matrix(
         | 
| 60 | 
            +
                        (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
         | 
| 61 | 
            +
                        shape=(adata.n_obs, adata.n_obs)
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
                else:
         | 
| 64 | 
            +
                    adj_org = sp.coo_matrix(
         | 
| 65 | 
            +
                        (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
         | 
| 66 | 
            +
                        shape=(adata.n_obs, adata.n_obs)
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                adj_norm = preprocess_graph(adj_org)
         | 
| 69 | 
            +
                norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
         | 
| 70 | 
            +
                graph_dict = {
         | 
| 71 | 
            +
                    "adj_org": adj_org,
         | 
| 72 | 
            +
                    "adj_norm": adj_norm,
         | 
| 73 | 
            +
                    "norm_value": norm_value
         | 
| 74 | 
            +
                }
         | 
| 75 | 
            +
                return graph_dict
         | 
    
        gsMap/{GNN_VAE → GNN}/model.py
    RENAMED
    
    | @@ -1,89 +1,89 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import torch.nn as nn
         | 
| 3 | 
            -
            import torch.nn.functional as F
         | 
| 4 | 
            -
            from torch_geometric.nn import GATConv
         | 
| 5 | 
            -
             | 
| 6 | 
            -
            def full_block(in_features, out_features, p_drop):
         | 
| 7 | 
            -
                return nn.Sequential(
         | 
| 8 | 
            -
                    nn.Linear(in_features, out_features),
         | 
| 9 | 
            -
                    nn.BatchNorm1d(out_features),
         | 
| 10 | 
            -
                    nn.ELU(),
         | 
| 11 | 
            -
                    nn.Dropout(p=p_drop)
         | 
| 12 | 
            -
                )
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            class GATModel(nn.Module):
         | 
| 15 | 
            -
                def __init__(self, input_dim, params, num_classes=1):
         | 
| 16 | 
            -
                    super().__init__()
         | 
| 17 | 
            -
                    self.var = params.var
         | 
| 18 | 
            -
                    self.num_classes = num_classes
         | 
| 19 | 
            -
                    self.params = params
         | 
| 20 | 
            -
             | 
| 21 | 
            -
                    # Encoder
         | 
| 22 | 
            -
                    self.encoder = nn.Sequential(
         | 
| 23 | 
            -
                        full_block(input_dim, params.feat_hidden1, params.p_drop),
         | 
| 24 | 
            -
                        full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
         | 
| 25 | 
            -
                    )
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                    # GAT Layers
         | 
| 28 | 
            -
                    self.gat1 = GATConv(
         | 
| 29 | 
            -
                        in_channels=params.feat_hidden2,
         | 
| 30 | 
            -
                        out_channels=params.gat_hidden1,
         | 
| 31 | 
            -
                        heads=params.nheads,
         | 
| 32 | 
            -
                        dropout=params.p_drop
         | 
| 33 | 
            -
                    )
         | 
| 34 | 
            -
                    self.gat2 = GATConv(
         | 
| 35 | 
            -
                        in_channels=params.gat_hidden1 * params.nheads,
         | 
| 36 | 
            -
                        out_channels=params.gat_hidden2,
         | 
| 37 | 
            -
                        heads=1,
         | 
| 38 | 
            -
                        concat=False,
         | 
| 39 | 
            -
                        dropout=params.p_drop
         | 
| 40 | 
            -
                    )
         | 
| 41 | 
            -
                    if self.var:
         | 
| 42 | 
            -
                        self.gat3 = GATConv(
         | 
| 43 | 
            -
                            in_channels=params.gat_hidden1 * params.nheads,
         | 
| 44 | 
            -
                            out_channels=params.gat_hidden2,
         | 
| 45 | 
            -
                            heads=1,
         | 
| 46 | 
            -
                            concat=False,
         | 
| 47 | 
            -
                            dropout=params.p_drop
         | 
| 48 | 
            -
                        )
         | 
| 49 | 
            -
             | 
| 50 | 
            -
                    # Decoder
         | 
| 51 | 
            -
                    self.decoder = nn.Sequential(
         | 
| 52 | 
            -
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         | 
| 53 | 
            -
                        full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
         | 
| 54 | 
            -
                        nn.Linear(params.feat_hidden1, input_dim)
         | 
| 55 | 
            -
                    )
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                    # Clustering Layer
         | 
| 58 | 
            -
                    self.cluster = nn.Sequential(
         | 
| 59 | 
            -
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         | 
| 60 | 
            -
                        nn.Linear(params.feat_hidden2, self.num_classes)
         | 
| 61 | 
            -
                    )
         | 
| 62 | 
            -
             | 
| 63 | 
            -
                def encode(self, x, edge_index):
         | 
| 64 | 
            -
                    x = self.encoder(x)
         | 
| 65 | 
            -
                    x = self.gat1(x, edge_index)
         | 
| 66 | 
            -
                    x = F.relu(x)
         | 
| 67 | 
            -
                    x = F.dropout(x, p=self.params.p_drop, training=self.training)
         | 
| 68 | 
            -
             | 
| 69 | 
            -
                    mu = self.gat2(x, edge_index)
         | 
| 70 | 
            -
                    if self.var:
         | 
| 71 | 
            -
                        logvar = self.gat3(x, edge_index)
         | 
| 72 | 
            -
                        return mu, logvar
         | 
| 73 | 
            -
                    else:
         | 
| 74 | 
            -
                        return mu, None
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                def reparameterize(self, mu, logvar):
         | 
| 77 | 
            -
                    if self.training and logvar is not None:
         | 
| 78 | 
            -
                        std = torch.exp(0.5 * logvar)
         | 
| 79 | 
            -
                        eps = torch.randn_like(std)
         | 
| 80 | 
            -
                        return eps * std + mu
         | 
| 81 | 
            -
                    else:
         | 
| 82 | 
            -
                        return mu
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                def forward(self, x, edge_index):
         | 
| 85 | 
            -
                    mu, logvar = self.encode(x, edge_index)
         | 
| 86 | 
            -
                    z = self.reparameterize(mu, logvar)
         | 
| 87 | 
            -
                    x_reconstructed = self.decoder(z)
         | 
| 88 | 
            -
                    pred_label = F.softmax(self.cluster(z), dim=1)
         | 
| 89 | 
            -
                    return pred_label, x_reconstructed, z, mu, logvar
         | 
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from torch_geometric.nn import GATConv
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            def full_block(in_features, out_features, p_drop):
         | 
| 7 | 
            +
                return nn.Sequential(
         | 
| 8 | 
            +
                    nn.Linear(in_features, out_features),
         | 
| 9 | 
            +
                    nn.BatchNorm1d(out_features),
         | 
| 10 | 
            +
                    nn.ELU(),
         | 
| 11 | 
            +
                    nn.Dropout(p=p_drop)
         | 
| 12 | 
            +
                )
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            class GATModel(nn.Module):
         | 
| 15 | 
            +
                def __init__(self, input_dim, params, num_classes=1):
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    self.var = params.var
         | 
| 18 | 
            +
                    self.num_classes = num_classes
         | 
| 19 | 
            +
                    self.params = params
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    # Encoder
         | 
| 22 | 
            +
                    self.encoder = nn.Sequential(
         | 
| 23 | 
            +
                        full_block(input_dim, params.feat_hidden1, params.p_drop),
         | 
| 24 | 
            +
                        full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # GAT Layers
         | 
| 28 | 
            +
                    self.gat1 = GATConv(
         | 
| 29 | 
            +
                        in_channels=params.feat_hidden2,
         | 
| 30 | 
            +
                        out_channels=params.gat_hidden1,
         | 
| 31 | 
            +
                        heads=params.nheads,
         | 
| 32 | 
            +
                        dropout=params.p_drop
         | 
| 33 | 
            +
                    )
         | 
| 34 | 
            +
                    self.gat2 = GATConv(
         | 
| 35 | 
            +
                        in_channels=params.gat_hidden1 * params.nheads,
         | 
| 36 | 
            +
                        out_channels=params.gat_hidden2,
         | 
| 37 | 
            +
                        heads=1,
         | 
| 38 | 
            +
                        concat=False,
         | 
| 39 | 
            +
                        dropout=params.p_drop
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    if self.var:
         | 
| 42 | 
            +
                        self.gat3 = GATConv(
         | 
| 43 | 
            +
                            in_channels=params.gat_hidden1 * params.nheads,
         | 
| 44 | 
            +
                            out_channels=params.gat_hidden2,
         | 
| 45 | 
            +
                            heads=1,
         | 
| 46 | 
            +
                            concat=False,
         | 
| 47 | 
            +
                            dropout=params.p_drop
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    # Decoder
         | 
| 51 | 
            +
                    self.decoder = nn.Sequential(
         | 
| 52 | 
            +
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         | 
| 53 | 
            +
                        full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
         | 
| 54 | 
            +
                        nn.Linear(params.feat_hidden1, input_dim)
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    # Clustering Layer
         | 
| 58 | 
            +
                    self.cluster = nn.Sequential(
         | 
| 59 | 
            +
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         | 
| 60 | 
            +
                        nn.Linear(params.feat_hidden2, self.num_classes)
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def encode(self, x, edge_index):
         | 
| 64 | 
            +
                    x = self.encoder(x)
         | 
| 65 | 
            +
                    x = self.gat1(x, edge_index)
         | 
| 66 | 
            +
                    x = F.relu(x)
         | 
| 67 | 
            +
                    x = F.dropout(x, p=self.params.p_drop, training=self.training)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    mu = self.gat2(x, edge_index)
         | 
| 70 | 
            +
                    if self.var:
         | 
| 71 | 
            +
                        logvar = self.gat3(x, edge_index)
         | 
| 72 | 
            +
                        return mu, logvar
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        return mu, None
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def reparameterize(self, mu, logvar):
         | 
| 77 | 
            +
                    if self.training and logvar is not None:
         | 
| 78 | 
            +
                        std = torch.exp(0.5 * logvar)
         | 
| 79 | 
            +
                        eps = torch.randn_like(std)
         | 
| 80 | 
            +
                        return eps * std + mu
         | 
| 81 | 
            +
                    else:
         | 
| 82 | 
            +
                        return mu
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                def forward(self, x, edge_index):
         | 
| 85 | 
            +
                    mu, logvar = self.encode(x, edge_index)
         | 
| 86 | 
            +
                    z = self.reparameterize(mu, logvar)
         | 
| 87 | 
            +
                    x_reconstructed = self.decoder(z)
         | 
| 88 | 
            +
                    pred_label = F.softmax(self.cluster(z), dim=1)
         | 
| 89 | 
            +
                    return pred_label, x_reconstructed, z, mu, logvar
         | 
    
        gsMap/{GNN_VAE → GNN}/train.py
    RENAMED
    
    | @@ -1,86 +1,88 @@ | |
| 1 | 
            -
            import logging
         | 
| 2 | 
            -
            import time
         | 
| 3 | 
            -
             | 
| 4 | 
            -
            import torch
         | 
| 5 | 
            -
            import torch.nn.functional as F
         | 
| 6 | 
            -
            from  | 
| 7 | 
            -
             | 
| 8 | 
            -
            from gsMap. | 
| 9 | 
            -
             | 
| 10 | 
            -
            logger = logging.getLogger(__name__)
         | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
            def reconstruction_loss(decoded, x):
         | 
| 14 | 
            -
                """Compute the mean squared error loss."""
         | 
| 15 | 
            -
                return F.mse_loss(decoded, x)
         | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
            def label_loss(pred_label, true_label):
         | 
| 19 | 
            -
                """Compute the cross-entropy loss."""
         | 
| 20 | 
            -
                return F.cross_entropy(pred_label, true_label)
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
                     | 
| 26 | 
            -
                    self. | 
| 27 | 
            -
                    self. | 
| 28 | 
            -
                    self. | 
| 29 | 
            -
                    self. | 
| 30 | 
            -
                    self. | 
| 31 | 
            -
                    self. | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
                        self. | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
                     | 
| 39 | 
            -
                    self. | 
| 40 | 
            -
             | 
| 41 | 
            -
                         | 
| 42 | 
            -
                         | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
                     | 
| 48 | 
            -
                     | 
| 49 | 
            -
                     | 
| 50 | 
            -
                     | 
| 51 | 
            -
             | 
| 52 | 
            -
                     | 
| 53 | 
            -
             | 
| 54 | 
            -
                         | 
| 55 | 
            -
                        self. | 
| 56 | 
            -
                         | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                         | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
                         | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
                         | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
                            logger.info(' | 
| 76 | 
            -
                            break
         | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
                     | 
| 85 | 
            -
             | 
| 86 | 
            -
                     | 
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from tqdm import tqdm
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from gsMap.GNN.model import GATModel
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def reconstruction_loss(decoded, x):
         | 
| 14 | 
            +
                """Compute the mean squared error loss."""
         | 
| 15 | 
            +
                return F.mse_loss(decoded, x)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def label_loss(pred_label, true_label):
         | 
| 19 | 
            +
                """Compute the cross-entropy loss."""
         | 
| 20 | 
            +
                return F.cross_entropy(pred_label, true_label)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class ModelTrainer:
         | 
| 24 | 
            +
                def __init__(self, node_x, graph_dict, params, label=None):
         | 
| 25 | 
            +
                    """Initialize the ModelTrainer with data and hyperparameters."""
         | 
| 26 | 
            +
                    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 27 | 
            +
                    self.params = params
         | 
| 28 | 
            +
                    self.epochs = params.epochs
         | 
| 29 | 
            +
                    self.node_x = torch.FloatTensor(node_x).to(self.device)
         | 
| 30 | 
            +
                    self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
         | 
| 31 | 
            +
                    self.label = label
         | 
| 32 | 
            +
                    self.num_classes = 1
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if self.label is not None:
         | 
| 35 | 
            +
                        self.label = torch.tensor(self.label).to(self.device)
         | 
| 36 | 
            +
                        self.num_classes = len(torch.unique(self.label))
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # Set up the model
         | 
| 39 | 
            +
                    self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
         | 
| 40 | 
            +
                    self.optimizer = torch.optim.Adam(
         | 
| 41 | 
            +
                        self.model.parameters(),
         | 
| 42 | 
            +
                        lr=self.params.gat_lr,
         | 
| 43 | 
            +
                        weight_decay=self.params.gcn_decay
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def run_train(self):
         | 
| 47 | 
            +
                    """Train the model."""
         | 
| 48 | 
            +
                    self.model.train()
         | 
| 49 | 
            +
                    prev_loss = float('inf')
         | 
| 50 | 
            +
                    logger.info('Start training...')
         | 
| 51 | 
            +
                    pbar = tqdm(range(self.epochs), desc='GAT-AE model train:', total=self.epochs)
         | 
| 52 | 
            +
                    for epoch in range(self.epochs):
         | 
| 53 | 
            +
                        start_time = time.time()
         | 
| 54 | 
            +
                        self.optimizer.zero_grad()
         | 
| 55 | 
            +
                        pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
         | 
| 56 | 
            +
                        loss_rec = reconstruction_loss(de_feat, self.node_x)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                        if self.label is not None:
         | 
| 59 | 
            +
                            loss_pre = label_loss(pred_label, self.label)
         | 
| 60 | 
            +
                            loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
         | 
| 61 | 
            +
                        else:
         | 
| 62 | 
            +
                            loss = loss_rec
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                        loss.backward()
         | 
| 65 | 
            +
                        self.optimizer.step()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                        batch_time = time.time() - start_time
         | 
| 68 | 
            +
                        left_time = batch_time * (self.epochs - epoch - 1) / 60  # in minutes
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        pbar.set_postfix({'Left time': f'{left_time:.2f} mins', 'Loss': f'{loss.item():.4f}'})
         | 
| 71 | 
            +
                        pbar.update(1)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                        if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
         | 
| 74 | 
            +
                            pbar.close()
         | 
| 75 | 
            +
                            logger.info('Convergence reached. Training stopped.')
         | 
| 76 | 
            +
                            break
         | 
| 77 | 
            +
                        prev_loss = loss.item()
         | 
| 78 | 
            +
                    else:
         | 
| 79 | 
            +
                        pbar.close()
         | 
| 80 | 
            +
                        logger.info('Max epochs reached. Training stopped.')
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
                def get_latent(self):
         | 
| 84 | 
            +
                    """Retrieve the latent representation from the model."""
         | 
| 85 | 
            +
                    self.model.eval()
         | 
| 86 | 
            +
                    with torch.no_grad():
         | 
| 87 | 
            +
                        _, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
         | 
| 88 | 
            +
                    return latent_z.cpu().numpy()
         | 
    
        gsMap/__init__.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
            -
            '''
         | 
| 2 | 
            -
            Genetics-informed pathogenic spatial mapping
         | 
| 3 | 
            -
            '''
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            __version__ = '1. | 
| 1 | 
            +
            '''
         | 
| 2 | 
            +
            Genetics-informed pathogenic spatial mapping
         | 
| 3 | 
            +
            '''
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            __version__ = '1.70'
         | 
    
        gsMap/__main__.py
    CHANGED
    
    | @@ -1,3 +1,3 @@ | |
| 1 | 
            -
            from .main import main
         | 
| 2 | 
            -
            if __name__ == '__main__':
         | 
| 1 | 
            +
            from .main import main
         | 
| 2 | 
            +
            if __name__ == '__main__':
         | 
| 3 3 | 
             
                main()
         |