gsMap 1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN_VAE/__init__.py +0 -0
- gsMap/GNN_VAE/adjacency_matrix.py +95 -0
- gsMap/GNN_VAE/model.py +87 -0
- gsMap/GNN_VAE/train.py +97 -0
- gsMap/__init__.py +5 -0
- gsMap/__main__.py +3 -0
- gsMap/cauchy_combination_test.py +163 -0
- gsMap/config.py +734 -0
- gsMap/find_latent_representation.py +209 -0
- gsMap/format_sumstats.py +410 -0
- gsMap/generate_ldscore.py +551 -0
- gsMap/generate_r2_matrix.py +743 -0
- gsMap/jackknife.py +514 -0
- gsMap/latent_to_gene.py +257 -0
- gsMap/main.py +39 -0
- gsMap/make_annotations.py +560 -0
- gsMap/regression_read.py +294 -0
- gsMap/spatial_ldsc_multiple_sumstats.py +307 -0
- gsMap/visualize.py +154 -0
- gsmap-1.60.dist-info/LICENSE +21 -0
- gsmap-1.60.dist-info/METADATA +124 -0
- gsmap-1.60.dist-info/RECORD +24 -0
- gsmap-1.60.dist-info/WHEEL +4 -0
- gsmap-1.60.dist-info/entry_points.txt +3 -0
| 
            File without changes
         | 
| @@ -0,0 +1,95 @@ | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            Created on Tue Jul  4 21:31:27 2023
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            @author: songliyang
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import pandas as pd
         | 
| 10 | 
            +
            import scipy.sparse as sp
         | 
| 11 | 
            +
            import sklearn.neighbors
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def Cal_Spatial_Net(adata, n_neighbors=5, verbose=True):
         | 
| 16 | 
            +
                """\
         | 
| 17 | 
            +
                Construct the spatial neighbor networks.
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                #- 
         | 
| 20 | 
            +
                if verbose:
         | 
| 21 | 
            +
                    print('------Calculating spatial graph...')
         | 
| 22 | 
            +
                coor = pd.DataFrame(adata.obsm['spatial'])
         | 
| 23 | 
            +
                coor.index = adata.obs.index
         | 
| 24 | 
            +
                #- 
         | 
| 25 | 
            +
                nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
         | 
| 26 | 
            +
                #- 
         | 
| 27 | 
            +
                distances, indices = nbrs.kneighbors(coor, return_distance=True)
         | 
| 28 | 
            +
                KNN_list = []
         | 
| 29 | 
            +
                for it in range(indices.shape[0]):
         | 
| 30 | 
            +
                    KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
         | 
| 31 | 
            +
                #- 
         | 
| 32 | 
            +
                KNN_df = pd.concat(KNN_list)
         | 
| 33 | 
            +
                KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
         | 
| 34 | 
            +
                #- 
         | 
| 35 | 
            +
                Spatial_Net = KNN_df.copy()
         | 
| 36 | 
            +
                Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
         | 
| 37 | 
            +
                id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
         | 
| 38 | 
            +
                Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
         | 
| 39 | 
            +
                Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
         | 
| 40 | 
            +
                #- 
         | 
| 41 | 
            +
                return Spatial_Net
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def sparse_mx_to_torch_sparse_tensor(sparse_mx):
         | 
| 45 | 
            +
                """Convert a scipy sparse matrix to a torch sparse tensor."""
         | 
| 46 | 
            +
                sparse_mx = sparse_mx.tocoo().astype(np.float32)
         | 
| 47 | 
            +
                indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
         | 
| 48 | 
            +
                values = torch.from_numpy(sparse_mx.data)
         | 
| 49 | 
            +
                shape = torch.Size(sparse_mx.shape)
         | 
| 50 | 
            +
                return torch.sparse.FloatTensor(indices, values, shape)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def preprocess_graph(adj):
         | 
| 54 | 
            +
                adj = sp.coo_matrix(adj)
         | 
| 55 | 
            +
                adj_ = adj + sp.eye(adj.shape[0])
         | 
| 56 | 
            +
                rowsum = np.array(adj_.sum(1))
         | 
| 57 | 
            +
                degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
         | 
| 58 | 
            +
                adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
         | 
| 59 | 
            +
                return sparse_mx_to_torch_sparse_tensor(adj_normalized)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def Construct_Adjacency_Matrix(adata,Params, verbose=True):
         | 
| 64 | 
            +
                # Construct the neighbor graph 
         | 
| 65 | 
            +
                Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
         | 
| 66 | 
            +
                #- 
         | 
| 67 | 
            +
                if verbose:
         | 
| 68 | 
            +
                    print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
         | 
| 69 | 
            +
                    print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
         | 
| 70 | 
            +
                #-  
         | 
| 71 | 
            +
                cells = np.array(adata.obs.index)
         | 
| 72 | 
            +
                cells_id_tran = dict(zip(cells, range(cells.shape[0])))
         | 
| 73 | 
            +
                #- 
         | 
| 74 | 
            +
                G_df = Spatial_Net.copy()
         | 
| 75 | 
            +
                G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
         | 
| 76 | 
            +
                G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
         | 
| 77 | 
            +
                #- 
         | 
| 78 | 
            +
                if Params.weighted_adj:
         | 
| 79 | 
            +
                    distance_normalized = G_df.Distance/(max(G_df.Distance)+1)
         | 
| 80 | 
            +
                    adj_org = sp.coo_matrix((np.exp(-distance_normalized**2/(2)), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
         | 
| 81 | 
            +
                else:
         | 
| 82 | 
            +
                    adj_org = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
         | 
| 83 | 
            +
                #- 
         | 
| 84 | 
            +
                adj_m1 = adj_org
         | 
| 85 | 
            +
                adj_norm_m1 = preprocess_graph(adj_m1)
         | 
| 86 | 
            +
                adj_label_m1 = adj_m1 + sp.eye(adj_m1.shape[0])
         | 
| 87 | 
            +
                norm_m1 = adj_m1.shape[0] * adj_m1.shape[0] / float((adj_m1.shape[0] * adj_m1.shape[0] - adj_m1.sum()) * 2)
         | 
| 88 | 
            +
                #- 
         | 
| 89 | 
            +
                graph_dict = {
         | 
| 90 | 
            +
                    "adj_org": adj_org,
         | 
| 91 | 
            +
                    "adj_norm": adj_norm_m1,
         | 
| 92 | 
            +
                    "norm_value": norm_m1
         | 
| 93 | 
            +
                }
         | 
| 94 | 
            +
                #- 
         | 
| 95 | 
            +
                return graph_dict
         | 
    
        gsMap/GNN_VAE/model.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            Created on Mon Jul  3 11:42:44 2023
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            @author: songliyang
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
            from torch_geometric.nn import GATConv
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def full_block(in_features, out_features, p_drop):
         | 
| 16 | 
            +
                return nn.Sequential(nn.Linear(in_features, out_features),
         | 
| 17 | 
            +
                                     nn.BatchNorm1d(out_features),
         | 
| 18 | 
            +
                                     nn.ELU(),
         | 
| 19 | 
            +
                                     nn.Dropout(p=p_drop))
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class GNN(nn.Module):
         | 
| 23 | 
            +
                def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
         | 
| 24 | 
            +
                    super().__init__()
         | 
| 25 | 
            +
                    self.conv1 = GATConv(in_features, out_features,heads)
         | 
| 26 | 
            +
                    self.act = act
         | 
| 27 | 
            +
                    self.dr = dr
         | 
| 28 | 
            +
                #-
         | 
| 29 | 
            +
                def forward(self, x, edge_index):
         | 
| 30 | 
            +
                    out = self.conv1(x, edge_index)
         | 
| 31 | 
            +
                    out = self.act(out)
         | 
| 32 | 
            +
                    out = F.dropout(out, self.dr, self.training)
         | 
| 33 | 
            +
                    return out
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
                
         | 
| 37 | 
            +
            class GNN_VAE_Model(nn.Module):
         | 
| 38 | 
            +
                def __init__(self, input_dim,params,num_classes=1):
         | 
| 39 | 
            +
                    super(GNN_VAE_Model, self).__init__()
         | 
| 40 | 
            +
                    self.var = params.var
         | 
| 41 | 
            +
                    self.num_classes = num_classes
         | 
| 42 | 
            +
                        
         | 
| 43 | 
            +
                    # Encoder
         | 
| 44 | 
            +
                    self.encoder = nn.Sequential()
         | 
| 45 | 
            +
                    self.encoder.add_module('encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
         | 
| 46 | 
            +
                    self.encoder.add_module('encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))
         | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                    # GNN (GAT)
         | 
| 49 | 
            +
                    self.gn1 = GNN(params.feat_hidden2, params.gcn_hidden1, params.p_drop, act=F.relu,heads = params.nheads)
         | 
| 50 | 
            +
                    self.gn2 = GNN(params.gcn_hidden1*params.nheads, params.gcn_hidden2, params.p_drop, act=lambda x: x)
         | 
| 51 | 
            +
                    self.gn3 = GNN(params.gcn_hidden1*params.nheads, params.gcn_hidden2, params.p_drop, act=lambda x: x) 
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    # Decoder
         | 
| 54 | 
            +
                    self.decoder = nn.Sequential()
         | 
| 55 | 
            +
                    self.decoder.add_module('decoder_L1', full_block(params.gcn_hidden2, params.feat_hidden2, params.p_drop))
         | 
| 56 | 
            +
                    self.decoder.add_module('decoder_L2', full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop))
         | 
| 57 | 
            +
                    self.decoder.add_module('decoder_output', nn.Sequential(nn.Linear(params.feat_hidden1, input_dim)))
         | 
| 58 | 
            +
                    
         | 
| 59 | 
            +
                    # Cluster
         | 
| 60 | 
            +
                    self.cluster = nn.Sequential()
         | 
| 61 | 
            +
                    self.cluster.add_module('cluster_L1', full_block(params.gcn_hidden2, params.feat_hidden2, params.p_drop))
         | 
| 62 | 
            +
                    self.cluster.add_module('cluster_output', nn.Linear(params.feat_hidden2, self.num_classes))
         | 
| 63 | 
            +
                       
         | 
| 64 | 
            +
                def encode(self, x, adj):
         | 
| 65 | 
            +
                    feat_x = self.encoder(x)
         | 
| 66 | 
            +
                    hidden1 = self.gn1(feat_x, adj)
         | 
| 67 | 
            +
                    mu = self.gn2(hidden1, adj)
         | 
| 68 | 
            +
                    if self.var:
         | 
| 69 | 
            +
                        logvar = self.gn3(hidden1, adj)
         | 
| 70 | 
            +
                        return mu, logvar
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        return mu, None
         | 
| 73 | 
            +
                 
         | 
| 74 | 
            +
                def reparameterize(self, mu, logvar):
         | 
| 75 | 
            +
                    if self.training and logvar is not None:
         | 
| 76 | 
            +
                        std = torch.exp(logvar)
         | 
| 77 | 
            +
                        eps = torch.randn_like(std)
         | 
| 78 | 
            +
                        return eps.mul(std).add_(mu)
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        return mu
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                def forward(self, x, adj):
         | 
| 83 | 
            +
                    mu, logvar = self.encode(x, adj)
         | 
| 84 | 
            +
                    gnn_z = self.reparameterize(mu, logvar)
         | 
| 85 | 
            +
                    x_reconstructed = self.decoder(gnn_z)
         | 
| 86 | 
            +
                    pred_label = F.softmax(self.cluster(gnn_z),dim=1)
         | 
| 87 | 
            +
                    return pred_label, x_reconstructed, gnn_z, mu, logvar
         | 
    
        gsMap/GNN_VAE/train.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            Created on Tue Jul  4 19:58:58 2023
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            @author: songliyang
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            import time
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from progress.bar import Bar
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from gsMap.GNN_VAE.model import GNN_VAE_Model
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def reconstruction_loss(decoded, x):
         | 
| 17 | 
            +
                loss_fn = torch.nn.MSELoss()
         | 
| 18 | 
            +
                loss = loss_fn(decoded, x)
         | 
| 19 | 
            +
                return loss
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def label_loss(pred_label, true_label):
         | 
| 23 | 
            +
                loss_fn = torch.nn.CrossEntropyLoss()
         | 
| 24 | 
            +
                loss = loss_fn(pred_label, true_label)
         | 
| 25 | 
            +
                return loss
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class Model_Train:
         | 
| 29 | 
            +
                def __init__(self, node_X, graph_dict, params, label=None):
         | 
| 30 | 
            +
                    device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 31 | 
            +
                    torch.cuda.empty_cache()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.params = params
         | 
| 34 | 
            +
                    self.device = device
         | 
| 35 | 
            +
                    self.epochs = params.epochs
         | 
| 36 | 
            +
                    self.node_X = torch.FloatTensor(node_X.copy()).to(device)
         | 
| 37 | 
            +
                    self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
         | 
| 38 | 
            +
                    self.label = label
         | 
| 39 | 
            +
                    self.num_classes = 1
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                    if not self.label is None:
         | 
| 42 | 
            +
                        self.label = torch.tensor(self.label).to(self.device)
         | 
| 43 | 
            +
                        self.num_classes = len(self.label.unique())
         | 
| 44 | 
            +
                    
         | 
| 45 | 
            +
                    # Set Model 
         | 
| 46 | 
            +
                    self.model = GNN_VAE_Model(self.params.feat_cell,self.params,self.num_classes).to(device)
         | 
| 47 | 
            +
                    self.optimizer = torch.optim.Adam(params = list(self.model.parameters()),
         | 
| 48 | 
            +
                                                      lr = self.params.gcn_lr, weight_decay = self.params.gcn_decay)               
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                # Train    
         | 
| 51 | 
            +
                def run_train(self):
         | 
| 52 | 
            +
                    self.model.train()
         | 
| 53 | 
            +
                    prev_loss = float('inf')
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                    bar = Bar('GAT-AE model train:', max = self.epochs)
         | 
| 56 | 
            +
                    bar.check_tty = False 
         | 
| 57 | 
            +
                    for epoch in range(self.epochs):
         | 
| 58 | 
            +
                        start_time = time.time()
         | 
| 59 | 
            +
                        self.model.train()
         | 
| 60 | 
            +
                        self.optimizer.zero_grad()
         | 
| 61 | 
            +
                        pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
         | 
| 62 | 
            +
                        loss_rec = reconstruction_loss(de_feat, self.node_X)
         | 
| 63 | 
            +
                        
         | 
| 64 | 
            +
                        # Check whether annotation was provided
         | 
| 65 | 
            +
                        if not self.label is None:
         | 
| 66 | 
            +
                            loss_pre = label_loss(pred_label, self.label)
         | 
| 67 | 
            +
                            loss = (self.params.rec_w * loss_rec) + (self.params.label_w * loss_pre)
         | 
| 68 | 
            +
                        else:
         | 
| 69 | 
            +
                            loss = loss_rec
         | 
| 70 | 
            +
                            
         | 
| 71 | 
            +
                        loss.backward()
         | 
| 72 | 
            +
                        self.optimizer.step()
         | 
| 73 | 
            +
                        
         | 
| 74 | 
            +
                        # Update process
         | 
| 75 | 
            +
                        end_time = time.time()
         | 
| 76 | 
            +
                        batch_time = end_time - start_time
         | 
| 77 | 
            +
                        
         | 
| 78 | 
            +
                        
         | 
| 79 | 
            +
                        bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
         | 
| 80 | 
            +
                        bar.suffix = bar_str.format(epoch + 1,self.epochs,
         | 
| 81 | 
            +
                                                    batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
         | 
| 82 | 
            +
                        bar.next()
         | 
| 83 | 
            +
                        
         | 
| 84 | 
            +
                        # Check convergence
         | 
| 85 | 
            +
                        if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
         | 
| 86 | 
            +
                            print('\nConvergence reached. Training stopped.')
         | 
| 87 | 
            +
                            break
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                        prev_loss = loss.item()
         | 
| 90 | 
            +
                        
         | 
| 91 | 
            +
                    bar.finish()
         | 
| 92 | 
            +
                #-    
         | 
| 93 | 
            +
                def get_latent(self):
         | 
| 94 | 
            +
                    self.model.eval()
         | 
| 95 | 
            +
                    pred, de_fea, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
         | 
| 96 | 
            +
                    latent_z = latent_z.data.cpu().numpy()
         | 
| 97 | 
            +
                    return latent_z
         | 
    
        gsMap/__init__.py
    ADDED
    
    
    
        gsMap/__main__.py
    ADDED
    
    
| @@ -0,0 +1,163 @@ | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import pandas as pd
         | 
| 6 | 
            +
            import scanpy as sc
         | 
| 7 | 
            +
            import scipy as sp
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from gsMap.config import CauchyCombinationConfig, add_Cauchy_combination_args
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # The fun of cauchy combination
         | 
| 12 | 
            +
            def acat_test(pvalues, weights=None):
         | 
| 13 | 
            +
                '''acat_test()
         | 
| 14 | 
            +
                Aggregated Cauchy Assocaition Test
         | 
| 15 | 
            +
                A p-value combination method using the Cauchy distribution.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Inspired by: https://github.com/yaowuliu/ACAT/blob/master/R/ACAT.R
         | 
| 18 | 
            +
                Inputs:
         | 
| 19 | 
            +
                    pvalues: <list or numpy array>
         | 
| 20 | 
            +
                        The p-values you want to combine.
         | 
| 21 | 
            +
                    weights: <list or numpy array>, default=None
         | 
| 22 | 
            +
                        The weights for each of the p-values. If None, equal weights are used.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Returns:
         | 
| 25 | 
            +
                    pval: <float>
         | 
| 26 | 
            +
                        The ACAT combined p-value.
         | 
| 27 | 
            +
                '''
         | 
| 28 | 
            +
                if any(np.isnan(pvalues)):
         | 
| 29 | 
            +
                    raise Exception("Cannot have NAs in the p-values.")
         | 
| 30 | 
            +
                if any([(i > 1) | (i < 0) for i in pvalues]):
         | 
| 31 | 
            +
                    raise Exception("P-values must be between 0 and 1.")
         | 
| 32 | 
            +
                if any([i == 1 for i in pvalues]) & any([i == 0 for i in pvalues]):
         | 
| 33 | 
            +
                    raise Exception("Cannot have both 0 and 1 p-values.")
         | 
| 34 | 
            +
                if any([i == 0 for i in pvalues]):
         | 
| 35 | 
            +
                    print("Warn: p-values are exactly 0.")
         | 
| 36 | 
            +
                    return 0
         | 
| 37 | 
            +
                if any([i == 1 for i in pvalues]):
         | 
| 38 | 
            +
                    print("Warn: p-values are exactly 1.")
         | 
| 39 | 
            +
                    return 1
         | 
| 40 | 
            +
                if weights == None:
         | 
| 41 | 
            +
                    weights = [1 / len(pvalues) for i in pvalues]
         | 
| 42 | 
            +
                elif len(weights) != len(pvalues):
         | 
| 43 | 
            +
                    raise Exception("Length of weights and p-values differs.")
         | 
| 44 | 
            +
                elif any([i < 0 for i in weights]):
         | 
| 45 | 
            +
                    raise Exception("All weights must be positive.")
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    weights = [i / len(weights) for i in weights]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                pvalues = np.array(pvalues)
         | 
| 50 | 
            +
                weights = np.array(weights)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                if any([i < 1e-16 for i in pvalues]) == False:
         | 
| 53 | 
            +
                    cct_stat = sum(weights * np.tan((0.5 - pvalues) * np.pi))
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                    is_small = [i < (1e-16) for i in pvalues]
         | 
| 56 | 
            +
                    is_large = [i >= (1e-16) for i in pvalues]
         | 
| 57 | 
            +
                    cct_stat = sum((weights[is_small] / pvalues[is_small]) / np.pi)
         | 
| 58 | 
            +
                    cct_stat += sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if cct_stat > 1e15:
         | 
| 61 | 
            +
                    pval = (1 / cct_stat) / np.pi
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    pval = 1 - sp.stats.cauchy.cdf(cct_stat)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                return pval
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def run_Cauchy_combination(config:CauchyCombinationConfig):
         | 
| 69 | 
            +
                # Load the ldsc results
         | 
| 70 | 
            +
                print(f'------Loading LDSC results of {config.input_ldsc_dir}...')
         | 
| 71 | 
            +
                ldsc_input_file= Path(config.input_ldsc_dir)/f'{config.sample_name}_{config.trait_name}.csv.gz'
         | 
| 72 | 
            +
                ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
         | 
| 73 | 
            +
                ldsc.spot = ldsc.spot.astype(str).replace('\.0', '', regex=True)
         | 
| 74 | 
            +
                ldsc.index = ldsc.spot
         | 
| 75 | 
            +
                if config.meta is None:
         | 
| 76 | 
            +
                    # Load the spatial data
         | 
| 77 | 
            +
                    print(f'------Loading ST data of {config.input_hdf5_path}...')
         | 
| 78 | 
            +
                    spe = sc.read_h5ad(f'{config.input_hdf5_path}')
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    common_cell = np.intersect1d(ldsc.index, spe.obs_names)
         | 
| 81 | 
            +
                    spe = spe[common_cell,]
         | 
| 82 | 
            +
                    ldsc = ldsc.loc[common_cell]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # Add the annotation
         | 
| 85 | 
            +
                    ldsc['annotation'] = spe.obs.loc[ldsc.spot][config.annotation].to_list()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                elif config.meta is not None:
         | 
| 88 | 
            +
                    # Or Load the additional annotation (just for the macaque data at this stage: 2023Nov25)
         | 
| 89 | 
            +
                    print(f'------Loading additional annotation...')
         | 
| 90 | 
            +
                    meta = pd.read_csv(config.meta, index_col=0)
         | 
| 91 | 
            +
                    meta = meta.loc[meta.slide == config.slide]
         | 
| 92 | 
            +
                    meta.index = meta.cell_id.astype(str).replace('\.0', '', regex=True)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    common_cell = np.intersect1d(ldsc.index, meta.index)
         | 
| 95 | 
            +
                    meta = meta.loc[common_cell]
         | 
| 96 | 
            +
                    ldsc = ldsc.loc[common_cell]
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # Add the annotation
         | 
| 99 | 
            +
                    ldsc['annotation'] = meta.loc[ldsc.spot][config.annotation].to_list()
         | 
| 100 | 
            +
                # Perform the Cauchy combination based on the given annotations
         | 
| 101 | 
            +
                p_cauchy = []
         | 
| 102 | 
            +
                p_median = []
         | 
| 103 | 
            +
                for ct in np.unique(ldsc.annotation):
         | 
| 104 | 
            +
                    p_temp = ldsc.loc[ldsc['annotation'] == ct, 'p']
         | 
| 105 | 
            +
                    
         | 
| 106 | 
            +
                    # The Cauchy test is sensitive to very small p-values, so extreme outliers should be considered for removal...
         | 
| 107 | 
            +
                    # to enhance robustness, particularly in cases where spot annotations may be incorrect. 
         | 
| 108 | 
            +
                    # p_cauchy_temp = acat_test(p_temp[p_temp != np.min(p_temp)])
         | 
| 109 | 
            +
                    p_temp_log = -np.log10(p_temp)
         | 
| 110 | 
            +
                    median_log = np.median(p_temp_log)
         | 
| 111 | 
            +
                    IQR_log = np.percentile(p_temp_log, 75) - np.percentile(p_temp_log, 25)
         | 
| 112 | 
            +
                    
         | 
| 113 | 
            +
                    p_use = p_temp[p_temp_log < median_log + 3*IQR_log]
         | 
| 114 | 
            +
                    n_remove = len(p_temp) - len(p_use)
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    # Outlier: -log10(p) < median + 3IQR && len(outlier set) < 20
         | 
| 117 | 
            +
                    if (0 < n_remove < 20):
         | 
| 118 | 
            +
                        print(f'Remove {n_remove}/{len(p_temp)} outliers (median + 3IQR) for {ct}.')
         | 
| 119 | 
            +
                        p_cauchy_temp = acat_test(p_use)
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                         p_cauchy_temp = acat_test(p_temp)
         | 
| 122 | 
            +
                            
         | 
| 123 | 
            +
                    p_median_temp = np.median(p_temp)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    p_cauchy.append(p_cauchy_temp)
         | 
| 126 | 
            +
                    p_median.append(p_median_temp)
         | 
| 127 | 
            +
                #     p_tissue = pd.DataFrame(p_cauchy,p_median,np.unique(ldsc.annotation))
         | 
| 128 | 
            +
                data = {'p_cauchy': p_cauchy, 'p_median': p_median, 'annotation': np.unique(ldsc.annotation)}
         | 
| 129 | 
            +
                p_tissue = pd.DataFrame(data)
         | 
| 130 | 
            +
                p_tissue.columns = ['p_cauchy', 'p_median', 'annotation']
         | 
| 131 | 
            +
                # Save the results
         | 
| 132 | 
            +
                output_dir = Path(config.output_cauchy_dir)
         | 
| 133 | 
            +
                output_dir.mkdir(parents=True, exist_ok=True, mode=0o755)
         | 
| 134 | 
            +
                output_file = output_dir / f'{config.sample_name}_{config.trait_name}.Cauchy.csv.gz'
         | 
| 135 | 
            +
                p_tissue.to_csv(
         | 
| 136 | 
            +
                    output_file,
         | 
| 137 | 
            +
                    compression='gzip',
         | 
| 138 | 
            +
                    index=False,
         | 
| 139 | 
            +
                )
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            if __name__ == '__main__':
         | 
| 143 | 
            +
                TEST = True
         | 
| 144 | 
            +
                if TEST:
         | 
| 145 | 
            +
                    test_dir = '/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021'
         | 
| 146 | 
            +
                    name = 'Cortex_151507'
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    config = CauchyCombinationConfig(
         | 
| 149 | 
            +
                        input_hdf5_path= f'{test_dir}/{name}/hdf5/{name}_add_latent.h5ad',
         | 
| 150 | 
            +
                        input_ldsc_dir=
         | 
| 151 | 
            +
                        f'/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021/snake_workdir/Cortex_151507/ldsc/',
         | 
| 152 | 
            +
                        sample_name=name,
         | 
| 153 | 
            +
                        annotation='layer_guess',
         | 
| 154 | 
            +
                        output_cauchy_dir='/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021/snake_workdir/Cortex_151507/cauchy/',
         | 
| 155 | 
            +
                        trait_name='adult1_adult2_onset_asthma',
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
                else:
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    parser = argparse.ArgumentParser(description="Run Cauchy Combination Analysis")
         | 
| 160 | 
            +
                    add_Cauchy_combination_args(parser)
         | 
| 161 | 
            +
                    args = parser.parse_args()
         | 
| 162 | 
            +
                    config = CauchyCombinationConfig(**vars(args))
         | 
| 163 | 
            +
                run_Cauchy_combination(config)
         |