gsMap 1.67__py3-none-any.whl → 1.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/{GNN_VAE → GNN}/__init__.py +0 -0
 - gsMap/{GNN_VAE → GNN}/adjacency_matrix.py +75 -75
 - gsMap/{GNN_VAE → GNN}/model.py +89 -89
 - gsMap/{GNN_VAE → GNN}/train.py +88 -86
 - gsMap/__init__.py +5 -5
 - gsMap/__main__.py +2 -2
 - gsMap/cauchy_combination_test.py +141 -141
 - gsMap/config.py +805 -803
 - gsMap/diagnosis.py +273 -273
 - gsMap/find_latent_representation.py +133 -145
 - gsMap/format_sumstats.py +407 -407
 - gsMap/generate_ldscore.py +618 -618
 - gsMap/latent_to_gene.py +234 -234
 - gsMap/main.py +31 -31
 - gsMap/report.py +160 -160
 - gsMap/run_all_mode.py +194 -194
 - gsMap/setup.py +0 -0
 - gsMap/spatial_ldsc_multiple_sumstats.py +380 -380
 - gsMap/templates/report_template.html +198 -198
 - gsMap/utils/__init__.py +0 -0
 - gsMap/utils/generate_r2_matrix.py +735 -735
 - gsMap/utils/jackknife.py +514 -514
 - gsMap/utils/make_annotations.py +518 -518
 - gsMap/utils/manhattan_plot.py +639 -639
 - gsMap/utils/regression_read.py +294 -294
 - gsMap/visualize.py +198 -198
 - {gsmap-1.67.dist-info → gsmap-1.71.dist-info}/LICENSE +21 -21
 - {gsmap-1.67.dist-info → gsmap-1.71.dist-info}/METADATA +28 -22
 - gsmap-1.71.dist-info/RECORD +31 -0
 - gsmap-1.67.dist-info/RECORD +0 -31
 - {gsmap-1.67.dist-info → gsmap-1.71.dist-info}/WHEEL +0 -0
 - {gsmap-1.67.dist-info → gsmap-1.71.dist-info}/entry_points.txt +0 -0
 
| 
         
            File without changes
         
     | 
| 
         @@ -1,75 +1,75 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       2 
     | 
    
         
            -
            import pandas as pd
         
     | 
| 
       3 
     | 
    
         
            -
            import scipy.sparse as sp
         
     | 
| 
       4 
     | 
    
         
            -
            from sklearn.neighbors import NearestNeighbors
         
     | 
| 
       5 
     | 
    
         
            -
            import torch
         
     | 
| 
       6 
     | 
    
         
            -
             
     | 
| 
       7 
     | 
    
         
            -
            def cal_spatial_net(adata, n_neighbors=5, verbose=True):
         
     | 
| 
       8 
     | 
    
         
            -
                """Construct the spatial neighbor network."""
         
     | 
| 
       9 
     | 
    
         
            -
                if verbose:
         
     | 
| 
       10 
     | 
    
         
            -
                    print('------Calculating spatial graph...')
         
     | 
| 
       11 
     | 
    
         
            -
                coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
         
     | 
| 
       12 
     | 
    
         
            -
                nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
         
     | 
| 
       13 
     | 
    
         
            -
                distances, indices = nbrs.kneighbors(coor)
         
     | 
| 
       14 
     | 
    
         
            -
                n_cells, n_neighbors = indices.shape
         
     | 
| 
       15 
     | 
    
         
            -
                cell_indices = np.arange(n_cells)
         
     | 
| 
       16 
     | 
    
         
            -
                cell1 = np.repeat(cell_indices, n_neighbors)
         
     | 
| 
       17 
     | 
    
         
            -
                cell2 = indices.flatten()
         
     | 
| 
       18 
     | 
    
         
            -
                distance = distances.flatten()
         
     | 
| 
       19 
     | 
    
         
            -
                knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
         
     | 
| 
       20 
     | 
    
         
            -
                knn_df = knn_df[knn_df['Distance'] > 0].copy()
         
     | 
| 
       21 
     | 
    
         
            -
                cell_id_map = dict(zip(cell_indices, coor.index))
         
     | 
| 
       22 
     | 
    
         
            -
                knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
         
     | 
| 
       23 
     | 
    
         
            -
                knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
         
     | 
| 
       24 
     | 
    
         
            -
                return knn_df
         
     | 
| 
       25 
     | 
    
         
            -
             
     | 
| 
       26 
     | 
    
         
            -
            def sparse_mx_to_torch_sparse_tensor(sparse_mx):
         
     | 
| 
       27 
     | 
    
         
            -
                """Convert a scipy sparse matrix to a torch sparse tensor."""
         
     | 
| 
       28 
     | 
    
         
            -
                sparse_mx = sparse_mx.tocoo().astype(np.float32)
         
     | 
| 
       29 
     | 
    
         
            -
                indices = torch.from_numpy(
         
     | 
| 
       30 
     | 
    
         
            -
                    np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
         
     | 
| 
       31 
     | 
    
         
            -
                )
         
     | 
| 
       32 
     | 
    
         
            -
                values = torch.from_numpy(sparse_mx.data)
         
     | 
| 
       33 
     | 
    
         
            -
                shape = torch.Size(sparse_mx.shape)
         
     | 
| 
       34 
     | 
    
         
            -
                return 
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
       36 
     | 
    
         
            -
            def preprocess_graph(adj):
         
     | 
| 
       37 
     | 
    
         
            -
                """Symmetrically normalize the adjacency matrix."""
         
     | 
| 
       38 
     | 
    
         
            -
                adj = sp.coo_matrix(adj)
         
     | 
| 
       39 
     | 
    
         
            -
                adj_ = adj + sp.eye(adj.shape[0])
         
     | 
| 
       40 
     | 
    
         
            -
                rowsum = np.array(adj_.sum(1)).flatten()
         
     | 
| 
       41 
     | 
    
         
            -
                degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
         
     | 
| 
       42 
     | 
    
         
            -
                adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
         
     | 
| 
       43 
     | 
    
         
            -
                return sparse_mx_to_torch_sparse_tensor(adj_normalized)
         
     | 
| 
       44 
     | 
    
         
            -
             
     | 
| 
       45 
     | 
    
         
            -
            def construct_adjacency_matrix(adata, params, verbose=True):
         
     | 
| 
       46 
     | 
    
         
            -
                """Construct the adjacency matrix from spatial data."""
         
     | 
| 
       47 
     | 
    
         
            -
                spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
         
     | 
| 
       48 
     | 
    
         
            -
                if verbose:
         
     | 
| 
       49 
     | 
    
         
            -
                    num_edges = spatial_net.shape[0]
         
     | 
| 
       50 
     | 
    
         
            -
                    num_cells = adata.n_obs
         
     | 
| 
       51 
     | 
    
         
            -
                    print(f'The graph contains {num_edges} edges, {num_cells} cells.')
         
     | 
| 
       52 
     | 
    
         
            -
                    print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
         
     | 
| 
       53 
     | 
    
         
            -
                cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
         
     | 
| 
       54 
     | 
    
         
            -
                spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
         
     | 
| 
       55 
     | 
    
         
            -
                spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
         
     | 
| 
       56 
     | 
    
         
            -
                if params.weighted_adj:
         
     | 
| 
       57 
     | 
    
         
            -
                    distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
         
     | 
| 
       58 
     | 
    
         
            -
                    weights = np.exp(-0.5 * distance_normalized ** 2)
         
     | 
| 
       59 
     | 
    
         
            -
                    adj_org = sp.coo_matrix(
         
     | 
| 
       60 
     | 
    
         
            -
                        (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
         
     | 
| 
       61 
     | 
    
         
            -
                        shape=(adata.n_obs, adata.n_obs)
         
     | 
| 
       62 
     | 
    
         
            -
                    )
         
     | 
| 
       63 
     | 
    
         
            -
                else:
         
     | 
| 
       64 
     | 
    
         
            -
                    adj_org = sp.coo_matrix(
         
     | 
| 
       65 
     | 
    
         
            -
                        (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
         
     | 
| 
       66 
     | 
    
         
            -
                        shape=(adata.n_obs, adata.n_obs)
         
     | 
| 
       67 
     | 
    
         
            -
                    )
         
     | 
| 
       68 
     | 
    
         
            -
                adj_norm = preprocess_graph(adj_org)
         
     | 
| 
       69 
     | 
    
         
            -
                norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
         
     | 
| 
       70 
     | 
    
         
            -
                graph_dict = {
         
     | 
| 
       71 
     | 
    
         
            -
                    "adj_org": adj_org,
         
     | 
| 
       72 
     | 
    
         
            -
                    "adj_norm": adj_norm,
         
     | 
| 
       73 
     | 
    
         
            -
                    "norm_value": norm_value
         
     | 
| 
       74 
     | 
    
         
            -
                }
         
     | 
| 
       75 
     | 
    
         
            -
                return graph_dict
         
     | 
| 
      
 1 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 2 
     | 
    
         
            +
            import pandas as pd
         
     | 
| 
      
 3 
     | 
    
         
            +
            import scipy.sparse as sp
         
     | 
| 
      
 4 
     | 
    
         
            +
            from sklearn.neighbors import NearestNeighbors
         
     | 
| 
      
 5 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            def cal_spatial_net(adata, n_neighbors=5, verbose=True):
         
     | 
| 
      
 8 
     | 
    
         
            +
                """Construct the spatial neighbor network."""
         
     | 
| 
      
 9 
     | 
    
         
            +
                if verbose:
         
     | 
| 
      
 10 
     | 
    
         
            +
                    print('------Calculating spatial graph...')
         
     | 
| 
      
 11 
     | 
    
         
            +
                coor = pd.DataFrame(adata.obsm['spatial'], index=adata.obs.index)
         
     | 
| 
      
 12 
     | 
    
         
            +
                nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
         
     | 
| 
      
 13 
     | 
    
         
            +
                distances, indices = nbrs.kneighbors(coor)
         
     | 
| 
      
 14 
     | 
    
         
            +
                n_cells, n_neighbors = indices.shape
         
     | 
| 
      
 15 
     | 
    
         
            +
                cell_indices = np.arange(n_cells)
         
     | 
| 
      
 16 
     | 
    
         
            +
                cell1 = np.repeat(cell_indices, n_neighbors)
         
     | 
| 
      
 17 
     | 
    
         
            +
                cell2 = indices.flatten()
         
     | 
| 
      
 18 
     | 
    
         
            +
                distance = distances.flatten()
         
     | 
| 
      
 19 
     | 
    
         
            +
                knn_df = pd.DataFrame({'Cell1': cell1, 'Cell2': cell2, 'Distance': distance})
         
     | 
| 
      
 20 
     | 
    
         
            +
                knn_df = knn_df[knn_df['Distance'] > 0].copy()
         
     | 
| 
      
 21 
     | 
    
         
            +
                cell_id_map = dict(zip(cell_indices, coor.index))
         
     | 
| 
      
 22 
     | 
    
         
            +
                knn_df['Cell1'] = knn_df['Cell1'].map(cell_id_map)
         
     | 
| 
      
 23 
     | 
    
         
            +
                knn_df['Cell2'] = knn_df['Cell2'].map(cell_id_map)
         
     | 
| 
      
 24 
     | 
    
         
            +
                return knn_df
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
            def sparse_mx_to_torch_sparse_tensor(sparse_mx):
         
     | 
| 
      
 27 
     | 
    
         
            +
                """Convert a scipy sparse matrix to a torch sparse tensor."""
         
     | 
| 
      
 28 
     | 
    
         
            +
                sparse_mx = sparse_mx.tocoo().astype(np.float32)
         
     | 
| 
      
 29 
     | 
    
         
            +
                indices = torch.from_numpy(
         
     | 
| 
      
 30 
     | 
    
         
            +
                    np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
         
     | 
| 
      
 31 
     | 
    
         
            +
                )
         
     | 
| 
      
 32 
     | 
    
         
            +
                values = torch.from_numpy(sparse_mx.data)
         
     | 
| 
      
 33 
     | 
    
         
            +
                shape = torch.Size(sparse_mx.shape)
         
     | 
| 
      
 34 
     | 
    
         
            +
                return  torch.sparse_coo_tensor(indices, values, shape)
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
            def preprocess_graph(adj):
         
     | 
| 
      
 37 
     | 
    
         
            +
                """Symmetrically normalize the adjacency matrix."""
         
     | 
| 
      
 38 
     | 
    
         
            +
                adj = sp.coo_matrix(adj)
         
     | 
| 
      
 39 
     | 
    
         
            +
                adj_ = adj + sp.eye(adj.shape[0])
         
     | 
| 
      
 40 
     | 
    
         
            +
                rowsum = np.array(adj_.sum(1)).flatten()
         
     | 
| 
      
 41 
     | 
    
         
            +
                degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5))
         
     | 
| 
      
 42 
     | 
    
         
            +
                adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
         
     | 
| 
      
 43 
     | 
    
         
            +
                return sparse_mx_to_torch_sparse_tensor(adj_normalized)
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
            def construct_adjacency_matrix(adata, params, verbose=True):
         
     | 
| 
      
 46 
     | 
    
         
            +
                """Construct the adjacency matrix from spatial data."""
         
     | 
| 
      
 47 
     | 
    
         
            +
                spatial_net = cal_spatial_net(adata, n_neighbors=params.n_neighbors, verbose=verbose)
         
     | 
| 
      
 48 
     | 
    
         
            +
                if verbose:
         
     | 
| 
      
 49 
     | 
    
         
            +
                    num_edges = spatial_net.shape[0]
         
     | 
| 
      
 50 
     | 
    
         
            +
                    num_cells = adata.n_obs
         
     | 
| 
      
 51 
     | 
    
         
            +
                    print(f'The graph contains {num_edges} edges, {num_cells} cells.')
         
     | 
| 
      
 52 
     | 
    
         
            +
                    print(f'{num_edges / num_cells:.2f} neighbors per cell on average.')
         
     | 
| 
      
 53 
     | 
    
         
            +
                cell_ids = {cell: idx for idx, cell in enumerate(adata.obs.index)}
         
     | 
| 
      
 54 
     | 
    
         
            +
                spatial_net['Cell1'] = spatial_net['Cell1'].map(cell_ids)
         
     | 
| 
      
 55 
     | 
    
         
            +
                spatial_net['Cell2'] = spatial_net['Cell2'].map(cell_ids)
         
     | 
| 
      
 56 
     | 
    
         
            +
                if params.weighted_adj:
         
     | 
| 
      
 57 
     | 
    
         
            +
                    distance_normalized = spatial_net['Distance'] / (spatial_net['Distance'].max() + 1)
         
     | 
| 
      
 58 
     | 
    
         
            +
                    weights = np.exp(-0.5 * distance_normalized ** 2)
         
     | 
| 
      
 59 
     | 
    
         
            +
                    adj_org = sp.coo_matrix(
         
     | 
| 
      
 60 
     | 
    
         
            +
                        (weights, (spatial_net['Cell1'], spatial_net['Cell2'])),
         
     | 
| 
      
 61 
     | 
    
         
            +
                        shape=(adata.n_obs, adata.n_obs)
         
     | 
| 
      
 62 
     | 
    
         
            +
                    )
         
     | 
| 
      
 63 
     | 
    
         
            +
                else:
         
     | 
| 
      
 64 
     | 
    
         
            +
                    adj_org = sp.coo_matrix(
         
     | 
| 
      
 65 
     | 
    
         
            +
                        (np.ones(spatial_net.shape[0]), (spatial_net['Cell1'], spatial_net['Cell2'])),
         
     | 
| 
      
 66 
     | 
    
         
            +
                        shape=(adata.n_obs, adata.n_obs)
         
     | 
| 
      
 67 
     | 
    
         
            +
                    )
         
     | 
| 
      
 68 
     | 
    
         
            +
                adj_norm = preprocess_graph(adj_org)
         
     | 
| 
      
 69 
     | 
    
         
            +
                norm_value = adj_org.shape[0] ** 2 / ((adj_org.shape[0] ** 2 - adj_org.sum()) * 2)
         
     | 
| 
      
 70 
     | 
    
         
            +
                graph_dict = {
         
     | 
| 
      
 71 
     | 
    
         
            +
                    "adj_org": adj_org,
         
     | 
| 
      
 72 
     | 
    
         
            +
                    "adj_norm": adj_norm,
         
     | 
| 
      
 73 
     | 
    
         
            +
                    "norm_value": norm_value
         
     | 
| 
      
 74 
     | 
    
         
            +
                }
         
     | 
| 
      
 75 
     | 
    
         
            +
                return graph_dict
         
     | 
    
        gsMap/{GNN_VAE → GNN}/model.py
    RENAMED
    
    | 
         @@ -1,89 +1,89 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import torch
         
     | 
| 
       2 
     | 
    
         
            -
            import torch.nn as nn
         
     | 
| 
       3 
     | 
    
         
            -
            import torch.nn.functional as F
         
     | 
| 
       4 
     | 
    
         
            -
            from torch_geometric.nn import GATConv
         
     | 
| 
       5 
     | 
    
         
            -
             
     | 
| 
       6 
     | 
    
         
            -
            def full_block(in_features, out_features, p_drop):
         
     | 
| 
       7 
     | 
    
         
            -
                return nn.Sequential(
         
     | 
| 
       8 
     | 
    
         
            -
                    nn.Linear(in_features, out_features),
         
     | 
| 
       9 
     | 
    
         
            -
                    nn.BatchNorm1d(out_features),
         
     | 
| 
       10 
     | 
    
         
            -
                    nn.ELU(),
         
     | 
| 
       11 
     | 
    
         
            -
                    nn.Dropout(p=p_drop)
         
     | 
| 
       12 
     | 
    
         
            -
                )
         
     | 
| 
       13 
     | 
    
         
            -
             
     | 
| 
       14 
     | 
    
         
            -
            class GATModel(nn.Module):
         
     | 
| 
       15 
     | 
    
         
            -
                def __init__(self, input_dim, params, num_classes=1):
         
     | 
| 
       16 
     | 
    
         
            -
                    super().__init__()
         
     | 
| 
       17 
     | 
    
         
            -
                    self.var = params.var
         
     | 
| 
       18 
     | 
    
         
            -
                    self.num_classes = num_classes
         
     | 
| 
       19 
     | 
    
         
            -
                    self.params = params
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
                    # Encoder
         
     | 
| 
       22 
     | 
    
         
            -
                    self.encoder = nn.Sequential(
         
     | 
| 
       23 
     | 
    
         
            -
                        full_block(input_dim, params.feat_hidden1, params.p_drop),
         
     | 
| 
       24 
     | 
    
         
            -
                        full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
         
     | 
| 
       25 
     | 
    
         
            -
                    )
         
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
                    # GAT Layers
         
     | 
| 
       28 
     | 
    
         
            -
                    self.gat1 = GATConv(
         
     | 
| 
       29 
     | 
    
         
            -
                        in_channels=params.feat_hidden2,
         
     | 
| 
       30 
     | 
    
         
            -
                        out_channels=params.gat_hidden1,
         
     | 
| 
       31 
     | 
    
         
            -
                        heads=params.nheads,
         
     | 
| 
       32 
     | 
    
         
            -
                        dropout=params.p_drop
         
     | 
| 
       33 
     | 
    
         
            -
                    )
         
     | 
| 
       34 
     | 
    
         
            -
                    self.gat2 = GATConv(
         
     | 
| 
       35 
     | 
    
         
            -
                        in_channels=params.gat_hidden1 * params.nheads,
         
     | 
| 
       36 
     | 
    
         
            -
                        out_channels=params.gat_hidden2,
         
     | 
| 
       37 
     | 
    
         
            -
                        heads=1,
         
     | 
| 
       38 
     | 
    
         
            -
                        concat=False,
         
     | 
| 
       39 
     | 
    
         
            -
                        dropout=params.p_drop
         
     | 
| 
       40 
     | 
    
         
            -
                    )
         
     | 
| 
       41 
     | 
    
         
            -
                    if self.var:
         
     | 
| 
       42 
     | 
    
         
            -
                        self.gat3 = GATConv(
         
     | 
| 
       43 
     | 
    
         
            -
                            in_channels=params.gat_hidden1 * params.nheads,
         
     | 
| 
       44 
     | 
    
         
            -
                            out_channels=params.gat_hidden2,
         
     | 
| 
       45 
     | 
    
         
            -
                            heads=1,
         
     | 
| 
       46 
     | 
    
         
            -
                            concat=False,
         
     | 
| 
       47 
     | 
    
         
            -
                            dropout=params.p_drop
         
     | 
| 
       48 
     | 
    
         
            -
                        )
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
                    # Decoder
         
     | 
| 
       51 
     | 
    
         
            -
                    self.decoder = nn.Sequential(
         
     | 
| 
       52 
     | 
    
         
            -
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         
     | 
| 
       53 
     | 
    
         
            -
                        full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
         
     | 
| 
       54 
     | 
    
         
            -
                        nn.Linear(params.feat_hidden1, input_dim)
         
     | 
| 
       55 
     | 
    
         
            -
                    )
         
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
       57 
     | 
    
         
            -
                    # Clustering Layer
         
     | 
| 
       58 
     | 
    
         
            -
                    self.cluster = nn.Sequential(
         
     | 
| 
       59 
     | 
    
         
            -
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         
     | 
| 
       60 
     | 
    
         
            -
                        nn.Linear(params.feat_hidden2, self.num_classes)
         
     | 
| 
       61 
     | 
    
         
            -
                    )
         
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
       63 
     | 
    
         
            -
                def encode(self, x, edge_index):
         
     | 
| 
       64 
     | 
    
         
            -
                    x = self.encoder(x)
         
     | 
| 
       65 
     | 
    
         
            -
                    x = self.gat1(x, edge_index)
         
     | 
| 
       66 
     | 
    
         
            -
                    x = F.relu(x)
         
     | 
| 
       67 
     | 
    
         
            -
                    x = F.dropout(x, p=self.params.p_drop, training=self.training)
         
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
                    mu = self.gat2(x, edge_index)
         
     | 
| 
       70 
     | 
    
         
            -
                    if self.var:
         
     | 
| 
       71 
     | 
    
         
            -
                        logvar = self.gat3(x, edge_index)
         
     | 
| 
       72 
     | 
    
         
            -
                        return mu, logvar
         
     | 
| 
       73 
     | 
    
         
            -
                    else:
         
     | 
| 
       74 
     | 
    
         
            -
                        return mu, None
         
     | 
| 
       75 
     | 
    
         
            -
             
     | 
| 
       76 
     | 
    
         
            -
                def reparameterize(self, mu, logvar):
         
     | 
| 
       77 
     | 
    
         
            -
                    if self.training and logvar is not None:
         
     | 
| 
       78 
     | 
    
         
            -
                        std = torch.exp(0.5 * logvar)
         
     | 
| 
       79 
     | 
    
         
            -
                        eps = torch.randn_like(std)
         
     | 
| 
       80 
     | 
    
         
            -
                        return eps * std + mu
         
     | 
| 
       81 
     | 
    
         
            -
                    else:
         
     | 
| 
       82 
     | 
    
         
            -
                        return mu
         
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
                def forward(self, x, edge_index):
         
     | 
| 
       85 
     | 
    
         
            -
                    mu, logvar = self.encode(x, edge_index)
         
     | 
| 
       86 
     | 
    
         
            -
                    z = self.reparameterize(mu, logvar)
         
     | 
| 
       87 
     | 
    
         
            -
                    x_reconstructed = self.decoder(z)
         
     | 
| 
       88 
     | 
    
         
            -
                    pred_label = F.softmax(self.cluster(z), dim=1)
         
     | 
| 
       89 
     | 
    
         
            -
                    return pred_label, x_reconstructed, z, mu, logvar
         
     | 
| 
      
 1 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 2 
     | 
    
         
            +
            import torch.nn as nn
         
     | 
| 
      
 3 
     | 
    
         
            +
            import torch.nn.functional as F
         
     | 
| 
      
 4 
     | 
    
         
            +
            from torch_geometric.nn import GATConv
         
     | 
| 
      
 5 
     | 
    
         
            +
             
     | 
| 
      
 6 
     | 
    
         
            +
            def full_block(in_features, out_features, p_drop):
         
     | 
| 
      
 7 
     | 
    
         
            +
                return nn.Sequential(
         
     | 
| 
      
 8 
     | 
    
         
            +
                    nn.Linear(in_features, out_features),
         
     | 
| 
      
 9 
     | 
    
         
            +
                    nn.BatchNorm1d(out_features),
         
     | 
| 
      
 10 
     | 
    
         
            +
                    nn.ELU(),
         
     | 
| 
      
 11 
     | 
    
         
            +
                    nn.Dropout(p=p_drop)
         
     | 
| 
      
 12 
     | 
    
         
            +
                )
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            class GATModel(nn.Module):
         
     | 
| 
      
 15 
     | 
    
         
            +
                def __init__(self, input_dim, params, num_classes=1):
         
     | 
| 
      
 16 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 17 
     | 
    
         
            +
                    self.var = params.var
         
     | 
| 
      
 18 
     | 
    
         
            +
                    self.num_classes = num_classes
         
     | 
| 
      
 19 
     | 
    
         
            +
                    self.params = params
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
                    # Encoder
         
     | 
| 
      
 22 
     | 
    
         
            +
                    self.encoder = nn.Sequential(
         
     | 
| 
      
 23 
     | 
    
         
            +
                        full_block(input_dim, params.feat_hidden1, params.p_drop),
         
     | 
| 
      
 24 
     | 
    
         
            +
                        full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    )
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                    # GAT Layers
         
     | 
| 
      
 28 
     | 
    
         
            +
                    self.gat1 = GATConv(
         
     | 
| 
      
 29 
     | 
    
         
            +
                        in_channels=params.feat_hidden2,
         
     | 
| 
      
 30 
     | 
    
         
            +
                        out_channels=params.gat_hidden1,
         
     | 
| 
      
 31 
     | 
    
         
            +
                        heads=params.nheads,
         
     | 
| 
      
 32 
     | 
    
         
            +
                        dropout=params.p_drop
         
     | 
| 
      
 33 
     | 
    
         
            +
                    )
         
     | 
| 
      
 34 
     | 
    
         
            +
                    self.gat2 = GATConv(
         
     | 
| 
      
 35 
     | 
    
         
            +
                        in_channels=params.gat_hidden1 * params.nheads,
         
     | 
| 
      
 36 
     | 
    
         
            +
                        out_channels=params.gat_hidden2,
         
     | 
| 
      
 37 
     | 
    
         
            +
                        heads=1,
         
     | 
| 
      
 38 
     | 
    
         
            +
                        concat=False,
         
     | 
| 
      
 39 
     | 
    
         
            +
                        dropout=params.p_drop
         
     | 
| 
      
 40 
     | 
    
         
            +
                    )
         
     | 
| 
      
 41 
     | 
    
         
            +
                    if self.var:
         
     | 
| 
      
 42 
     | 
    
         
            +
                        self.gat3 = GATConv(
         
     | 
| 
      
 43 
     | 
    
         
            +
                            in_channels=params.gat_hidden1 * params.nheads,
         
     | 
| 
      
 44 
     | 
    
         
            +
                            out_channels=params.gat_hidden2,
         
     | 
| 
      
 45 
     | 
    
         
            +
                            heads=1,
         
     | 
| 
      
 46 
     | 
    
         
            +
                            concat=False,
         
     | 
| 
      
 47 
     | 
    
         
            +
                            dropout=params.p_drop
         
     | 
| 
      
 48 
     | 
    
         
            +
                        )
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                    # Decoder
         
     | 
| 
      
 51 
     | 
    
         
            +
                    self.decoder = nn.Sequential(
         
     | 
| 
      
 52 
     | 
    
         
            +
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         
     | 
| 
      
 53 
     | 
    
         
            +
                        full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop),
         
     | 
| 
      
 54 
     | 
    
         
            +
                        nn.Linear(params.feat_hidden1, input_dim)
         
     | 
| 
      
 55 
     | 
    
         
            +
                    )
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
                    # Clustering Layer
         
     | 
| 
      
 58 
     | 
    
         
            +
                    self.cluster = nn.Sequential(
         
     | 
| 
      
 59 
     | 
    
         
            +
                        full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop),
         
     | 
| 
      
 60 
     | 
    
         
            +
                        nn.Linear(params.feat_hidden2, self.num_classes)
         
     | 
| 
      
 61 
     | 
    
         
            +
                    )
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                def encode(self, x, edge_index):
         
     | 
| 
      
 64 
     | 
    
         
            +
                    x = self.encoder(x)
         
     | 
| 
      
 65 
     | 
    
         
            +
                    x = self.gat1(x, edge_index)
         
     | 
| 
      
 66 
     | 
    
         
            +
                    x = F.relu(x)
         
     | 
| 
      
 67 
     | 
    
         
            +
                    x = F.dropout(x, p=self.params.p_drop, training=self.training)
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                    mu = self.gat2(x, edge_index)
         
     | 
| 
      
 70 
     | 
    
         
            +
                    if self.var:
         
     | 
| 
      
 71 
     | 
    
         
            +
                        logvar = self.gat3(x, edge_index)
         
     | 
| 
      
 72 
     | 
    
         
            +
                        return mu, logvar
         
     | 
| 
      
 73 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 74 
     | 
    
         
            +
                        return mu, None
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
                def reparameterize(self, mu, logvar):
         
     | 
| 
      
 77 
     | 
    
         
            +
                    if self.training and logvar is not None:
         
     | 
| 
      
 78 
     | 
    
         
            +
                        std = torch.exp(0.5 * logvar)
         
     | 
| 
      
 79 
     | 
    
         
            +
                        eps = torch.randn_like(std)
         
     | 
| 
      
 80 
     | 
    
         
            +
                        return eps * std + mu
         
     | 
| 
      
 81 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 82 
     | 
    
         
            +
                        return mu
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                def forward(self, x, edge_index):
         
     | 
| 
      
 85 
     | 
    
         
            +
                    mu, logvar = self.encode(x, edge_index)
         
     | 
| 
      
 86 
     | 
    
         
            +
                    z = self.reparameterize(mu, logvar)
         
     | 
| 
      
 87 
     | 
    
         
            +
                    x_reconstructed = self.decoder(z)
         
     | 
| 
      
 88 
     | 
    
         
            +
                    pred_label = F.softmax(self.cluster(z), dim=1)
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return pred_label, x_reconstructed, z, mu, logvar
         
     | 
    
        gsMap/{GNN_VAE → GNN}/train.py
    RENAMED
    
    | 
         @@ -1,86 +1,88 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            import logging
         
     | 
| 
       2 
     | 
    
         
            -
            import time
         
     | 
| 
       3 
     | 
    
         
            -
             
     | 
| 
       4 
     | 
    
         
            -
            import torch
         
     | 
| 
       5 
     | 
    
         
            -
            import torch.nn.functional as F
         
     | 
| 
       6 
     | 
    
         
            -
            from  
     | 
| 
       7 
     | 
    
         
            -
             
     | 
| 
       8 
     | 
    
         
            -
            from gsMap. 
     | 
| 
       9 
     | 
    
         
            -
             
     | 
| 
       10 
     | 
    
         
            -
            logger = logging.getLogger(__name__)
         
     | 
| 
       11 
     | 
    
         
            -
             
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
     | 
    
         
            -
            def reconstruction_loss(decoded, x):
         
     | 
| 
       14 
     | 
    
         
            -
                """Compute the mean squared error loss."""
         
     | 
| 
       15 
     | 
    
         
            -
                return F.mse_loss(decoded, x)
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
            def label_loss(pred_label, true_label):
         
     | 
| 
       19 
     | 
    
         
            -
                """Compute the cross-entropy loss."""
         
     | 
| 
       20 
     | 
    
         
            -
                return F.cross_entropy(pred_label, true_label)
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
                     
     | 
| 
       26 
     | 
    
         
            -
                    self. 
     | 
| 
       27 
     | 
    
         
            -
                    self. 
     | 
| 
       28 
     | 
    
         
            -
                    self. 
     | 
| 
       29 
     | 
    
         
            -
                    self. 
     | 
| 
       30 
     | 
    
         
            -
                    self. 
     | 
| 
       31 
     | 
    
         
            -
                    self. 
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
     | 
    
         
            -
                        self. 
     | 
| 
       36 
     | 
    
         
            -
             
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
                     
     | 
| 
       39 
     | 
    
         
            -
                    self. 
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
       41 
     | 
    
         
            -
                         
     | 
| 
       42 
     | 
    
         
            -
                         
     | 
| 
       43 
     | 
    
         
            -
             
     | 
| 
       44 
     | 
    
         
            -
             
     | 
| 
       45 
     | 
    
         
            -
             
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
                     
     | 
| 
       48 
     | 
    
         
            -
                     
     | 
| 
       49 
     | 
    
         
            -
                     
     | 
| 
       50 
     | 
    
         
            -
                     
     | 
| 
       51 
     | 
    
         
            -
             
     | 
| 
       52 
     | 
    
         
            -
                     
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
                         
     | 
| 
       55 
     | 
    
         
            -
                        self. 
     | 
| 
       56 
     | 
    
         
            -
                         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
             
     | 
| 
       59 
     | 
    
         
            -
             
     | 
| 
       60 
     | 
    
         
            -
                             
     | 
| 
       61 
     | 
    
         
            -
             
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
             
     | 
| 
       65 
     | 
    
         
            -
                         
     | 
| 
       66 
     | 
    
         
            -
             
     | 
| 
       67 
     | 
    
         
            -
             
     | 
| 
       68 
     | 
    
         
            -
                         
     | 
| 
       69 
     | 
    
         
            -
             
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
       71 
     | 
    
         
            -
                         
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
                            logger.info(' 
     | 
| 
       76 
     | 
    
         
            -
                            break
         
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
             
     | 
| 
       79 
     | 
    
         
            -
             
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
             
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
                     
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
                     
     | 
| 
      
 1 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 2 
     | 
    
         
            +
            import time
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 5 
     | 
    
         
            +
            import torch.nn.functional as F
         
     | 
| 
      
 6 
     | 
    
         
            +
            from tqdm import tqdm
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            from gsMap.GNN.model import GATModel
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            def reconstruction_loss(decoded, x):
         
     | 
| 
      
 14 
     | 
    
         
            +
                """Compute the mean squared error loss."""
         
     | 
| 
      
 15 
     | 
    
         
            +
                return F.mse_loss(decoded, x)
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            def label_loss(pred_label, true_label):
         
     | 
| 
      
 19 
     | 
    
         
            +
                """Compute the cross-entropy loss."""
         
     | 
| 
      
 20 
     | 
    
         
            +
                return F.cross_entropy(pred_label, true_label)
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            class ModelTrainer:
         
     | 
| 
      
 24 
     | 
    
         
            +
                def __init__(self, node_x, graph_dict, params, label=None):
         
     | 
| 
      
 25 
     | 
    
         
            +
                    """Initialize the ModelTrainer with data and hyperparameters."""
         
     | 
| 
      
 26 
     | 
    
         
            +
                    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 
      
 27 
     | 
    
         
            +
                    self.params = params
         
     | 
| 
      
 28 
     | 
    
         
            +
                    self.epochs = params.epochs
         
     | 
| 
      
 29 
     | 
    
         
            +
                    self.node_x = torch.FloatTensor(node_x).to(self.device)
         
     | 
| 
      
 30 
     | 
    
         
            +
                    self.adj_norm = graph_dict["adj_norm"].to(self.device).coalesce()
         
     | 
| 
      
 31 
     | 
    
         
            +
                    self.label = label
         
     | 
| 
      
 32 
     | 
    
         
            +
                    self.num_classes = 1
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                    if self.label is not None:
         
     | 
| 
      
 35 
     | 
    
         
            +
                        self.label = torch.tensor(self.label).to(self.device)
         
     | 
| 
      
 36 
     | 
    
         
            +
                        self.num_classes = len(torch.unique(self.label))
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                    # Set up the model
         
     | 
| 
      
 39 
     | 
    
         
            +
                    self.model = GATModel(self.params.feat_cell, self.params, self.num_classes).to(self.device)
         
     | 
| 
      
 40 
     | 
    
         
            +
                    self.optimizer = torch.optim.Adam(
         
     | 
| 
      
 41 
     | 
    
         
            +
                        self.model.parameters(),
         
     | 
| 
      
 42 
     | 
    
         
            +
                        lr=self.params.gat_lr,
         
     | 
| 
      
 43 
     | 
    
         
            +
                        weight_decay=self.params.gcn_decay
         
     | 
| 
      
 44 
     | 
    
         
            +
                    )
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
                def run_train(self):
         
     | 
| 
      
 47 
     | 
    
         
            +
                    """Train the model."""
         
     | 
| 
      
 48 
     | 
    
         
            +
                    self.model.train()
         
     | 
| 
      
 49 
     | 
    
         
            +
                    prev_loss = float('inf')
         
     | 
| 
      
 50 
     | 
    
         
            +
                    logger.info('Start training...')
         
     | 
| 
      
 51 
     | 
    
         
            +
                    pbar = tqdm(range(self.epochs), desc='GAT-AE model train:', total=self.epochs)
         
     | 
| 
      
 52 
     | 
    
         
            +
                    for epoch in range(self.epochs):
         
     | 
| 
      
 53 
     | 
    
         
            +
                        start_time = time.time()
         
     | 
| 
      
 54 
     | 
    
         
            +
                        self.optimizer.zero_grad()
         
     | 
| 
      
 55 
     | 
    
         
            +
                        pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_x, self.adj_norm)
         
     | 
| 
      
 56 
     | 
    
         
            +
                        loss_rec = reconstruction_loss(de_feat, self.node_x)
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                        if self.label is not None:
         
     | 
| 
      
 59 
     | 
    
         
            +
                            loss_pre = label_loss(pred_label, self.label)
         
     | 
| 
      
 60 
     | 
    
         
            +
                            loss = self.params.rec_w * loss_rec + self.params.label_w * loss_pre
         
     | 
| 
      
 61 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 62 
     | 
    
         
            +
                            loss = loss_rec
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                        loss.backward()
         
     | 
| 
      
 65 
     | 
    
         
            +
                        self.optimizer.step()
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                        batch_time = time.time() - start_time
         
     | 
| 
      
 68 
     | 
    
         
            +
                        left_time = batch_time * (self.epochs - epoch - 1) / 60  # in minutes
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
                        pbar.set_postfix({'Left time': f'{left_time:.2f} mins', 'Loss': f'{loss.item():.4f}'})
         
     | 
| 
      
 71 
     | 
    
         
            +
                        pbar.update(1)
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                        if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
         
     | 
| 
      
 74 
     | 
    
         
            +
                            pbar.close()
         
     | 
| 
      
 75 
     | 
    
         
            +
                            logger.info('Convergence reached. Training stopped.')
         
     | 
| 
      
 76 
     | 
    
         
            +
                            break
         
     | 
| 
      
 77 
     | 
    
         
            +
                        prev_loss = loss.item()
         
     | 
| 
      
 78 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 79 
     | 
    
         
            +
                        pbar.close()
         
     | 
| 
      
 80 
     | 
    
         
            +
                        logger.info('Max epochs reached. Training stopped.')
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
                def get_latent(self):
         
     | 
| 
      
 84 
     | 
    
         
            +
                    """Retrieve the latent representation from the model."""
         
     | 
| 
      
 85 
     | 
    
         
            +
                    self.model.eval()
         
     | 
| 
      
 86 
     | 
    
         
            +
                    with torch.no_grad():
         
     | 
| 
      
 87 
     | 
    
         
            +
                        _, _, latent_z, _, _ = self.model(self.node_x, self.adj_norm)
         
     | 
| 
      
 88 
     | 
    
         
            +
                    return latent_z.cpu().numpy()
         
     | 
    
        gsMap/__init__.py
    CHANGED
    
    | 
         @@ -1,5 +1,5 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            '''
         
     | 
| 
       2 
     | 
    
         
            -
            Genetics-informed pathogenic spatial mapping
         
     | 
| 
       3 
     | 
    
         
            -
            '''
         
     | 
| 
       4 
     | 
    
         
            -
             
     | 
| 
       5 
     | 
    
         
            -
            __version__ = '1. 
     | 
| 
      
 1 
     | 
    
         
            +
            '''
         
     | 
| 
      
 2 
     | 
    
         
            +
            Genetics-informed pathogenic spatial mapping
         
     | 
| 
      
 3 
     | 
    
         
            +
            '''
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            __version__ = '1.71'
         
     | 
    
        gsMap/__main__.py
    CHANGED
    
    | 
         @@ -1,3 +1,3 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            from .main import main
         
     | 
| 
       2 
     | 
    
         
            -
            if __name__ == '__main__':
         
     | 
| 
      
 1 
     | 
    
         
            +
            from .main import main
         
     | 
| 
      
 2 
     | 
    
         
            +
            if __name__ == '__main__':
         
     | 
| 
       3 
3 
     | 
    
         
             
                main()
         
     |