gsMap 1.66__py3-none-any.whl → 1.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,95 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Tue Jul 4 21:31:27 2023
5
-
6
- @author: songliyang
7
- """
8
- import numpy as np
9
- import pandas as pd
10
- import scipy.sparse as sp
11
- import sklearn.neighbors
12
- import torch
13
-
14
-
15
- def Cal_Spatial_Net(adata, n_neighbors=5, verbose=True):
16
- """\
17
- Construct the spatial neighbor networks.
18
- """
19
- #-
20
- if verbose:
21
- print('------Calculating spatial graph...')
22
- coor = pd.DataFrame(adata.obsm['spatial'])
23
- coor.index = adata.obs.index
24
- #-
25
- nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(coor)
26
- #-
27
- distances, indices = nbrs.kneighbors(coor, return_distance=True)
28
- KNN_list = []
29
- for it in range(indices.shape[0]):
30
- KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
31
- #-
32
- KNN_df = pd.concat(KNN_list)
33
- KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
34
- #-
35
- Spatial_Net = KNN_df.copy()
36
- Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
37
- id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
38
- Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
39
- Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
40
- #-
41
- return Spatial_Net
42
-
43
-
44
- def sparse_mx_to_torch_sparse_tensor(sparse_mx):
45
- """Convert a scipy sparse matrix to a torch sparse tensor."""
46
- sparse_mx = sparse_mx.tocoo().astype(np.float32)
47
- indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
48
- values = torch.from_numpy(sparse_mx.data)
49
- shape = torch.Size(sparse_mx.shape)
50
- return torch.sparse.FloatTensor(indices, values, shape)
51
-
52
-
53
- def preprocess_graph(adj):
54
- adj = sp.coo_matrix(adj)
55
- adj_ = adj + sp.eye(adj.shape[0])
56
- rowsum = np.array(adj_.sum(1))
57
- degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
58
- adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
59
- return sparse_mx_to_torch_sparse_tensor(adj_normalized)
60
-
61
-
62
-
63
- def Construct_Adjacency_Matrix(adata,Params, verbose=True):
64
- # Construct the neighbor graph
65
- Spatial_Net = Cal_Spatial_Net(adata, n_neighbors=Params.n_neighbors)
66
- #-
67
- if verbose:
68
- print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
69
- print('%.2f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
70
- #-
71
- cells = np.array(adata.obs.index)
72
- cells_id_tran = dict(zip(cells, range(cells.shape[0])))
73
- #-
74
- G_df = Spatial_Net.copy()
75
- G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
76
- G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
77
- #-
78
- if Params.weighted_adj:
79
- distance_normalized = G_df.Distance/(max(G_df.Distance)+1)
80
- adj_org = sp.coo_matrix((np.exp(-distance_normalized**2/(2)), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
81
- else:
82
- adj_org = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
83
- #-
84
- adj_m1 = adj_org
85
- adj_norm_m1 = preprocess_graph(adj_m1)
86
- adj_label_m1 = adj_m1 + sp.eye(adj_m1.shape[0])
87
- norm_m1 = adj_m1.shape[0] * adj_m1.shape[0] / float((adj_m1.shape[0] * adj_m1.shape[0] - adj_m1.sum()) * 2)
88
- #-
89
- graph_dict = {
90
- "adj_org": adj_org,
91
- "adj_norm": adj_norm_m1,
92
- "norm_value": norm_m1
93
- }
94
- #-
95
- return graph_dict
gsMap/GNN_VAE/model.py DELETED
@@ -1,87 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Mon Jul 3 11:42:44 2023
5
-
6
- @author: songliyang
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from torch_geometric.nn import GATConv
13
-
14
-
15
- def full_block(in_features, out_features, p_drop):
16
- return nn.Sequential(nn.Linear(in_features, out_features),
17
- nn.BatchNorm1d(out_features),
18
- nn.ELU(),
19
- nn.Dropout(p=p_drop))
20
-
21
-
22
- class GNN(nn.Module):
23
- def __init__(self, in_features, out_features, dr=0, act=F.relu,heads=1):
24
- super().__init__()
25
- self.conv1 = GATConv(in_features, out_features,heads)
26
- self.act = act
27
- self.dr = dr
28
- #-
29
- def forward(self, x, edge_index):
30
- out = self.conv1(x, edge_index)
31
- out = self.act(out)
32
- out = F.dropout(out, self.dr, self.training)
33
- return out
34
-
35
-
36
-
37
- class GNN_VAE_Model(nn.Module):
38
- def __init__(self, input_dim,params,num_classes=1):
39
- super(GNN_VAE_Model, self).__init__()
40
- self.var = params.var
41
- self.num_classes = num_classes
42
-
43
- # Encoder
44
- self.encoder = nn.Sequential()
45
- self.encoder.add_module('encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
46
- self.encoder.add_module('encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))
47
-
48
- # GNN (GAT)
49
- self.gn1 = GNN(params.feat_hidden2, params.gat_hidden1, params.p_drop, act=F.relu,heads = params.nheads)
50
- self.gn2 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
51
- self.gn3 = GNN(params.gat_hidden1*params.nheads, params.gat_hidden2, params.p_drop, act=lambda x: x)
52
-
53
- # Decoder
54
- self.decoder = nn.Sequential()
55
- self.decoder.add_module('decoder_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
56
- self.decoder.add_module('decoder_L2', full_block(params.feat_hidden2, params.feat_hidden1, params.p_drop))
57
- self.decoder.add_module('decoder_output', nn.Sequential(nn.Linear(params.feat_hidden1, input_dim)))
58
-
59
- # Cluster
60
- self.cluster = nn.Sequential()
61
- self.cluster.add_module('cluster_L1', full_block(params.gat_hidden2, params.feat_hidden2, params.p_drop))
62
- self.cluster.add_module('cluster_output', nn.Linear(params.feat_hidden2, self.num_classes))
63
-
64
- def encode(self, x, adj):
65
- feat_x = self.encoder(x)
66
- hidden1 = self.gn1(feat_x, adj)
67
- mu = self.gn2(hidden1, adj)
68
- if self.var:
69
- logvar = self.gn3(hidden1, adj)
70
- return mu, logvar
71
- else:
72
- return mu, None
73
-
74
- def reparameterize(self, mu, logvar):
75
- if self.training and logvar is not None:
76
- std = torch.exp(logvar)
77
- eps = torch.randn_like(std)
78
- return eps.mul(std).add_(mu)
79
- else:
80
- return mu
81
-
82
- def forward(self, x, adj):
83
- mu, logvar = self.encode(x, adj)
84
- gnn_z = self.reparameterize(mu, logvar)
85
- x_reconstructed = self.decoder(gnn_z)
86
- pred_label = F.softmax(self.cluster(gnn_z),dim=1)
87
- return pred_label, x_reconstructed, gnn_z, mu, logvar
gsMap/GNN_VAE/train.py DELETED
@@ -1,97 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- Created on Tue Jul 4 19:58:58 2023
5
-
6
- @author: songliyang
7
- """
8
- import time
9
-
10
- import torch
11
- from progress.bar import Bar
12
-
13
- from gsMap.GNN_VAE.model import GNN_VAE_Model
14
-
15
-
16
- def reconstruction_loss(decoded, x):
17
- loss_fn = torch.nn.MSELoss()
18
- loss = loss_fn(decoded, x)
19
- return loss
20
-
21
-
22
- def label_loss(pred_label, true_label):
23
- loss_fn = torch.nn.CrossEntropyLoss()
24
- loss = loss_fn(pred_label, true_label)
25
- return loss
26
-
27
-
28
- class Model_Train:
29
- def __init__(self, node_X, graph_dict, params, label=None):
30
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
- torch.cuda.empty_cache()
32
-
33
- self.params = params
34
- self.device = device
35
- self.epochs = params.epochs
36
- self.node_X = torch.FloatTensor(node_X.copy()).to(device)
37
- self.adj_norm = graph_dict["adj_norm"].to(device).coalesce()
38
- self.label = label
39
- self.num_classes = 1
40
-
41
- if not self.label is None:
42
- self.label = torch.tensor(self.label).to(self.device)
43
- self.num_classes = len(self.label.unique())
44
-
45
- # Set Model
46
- self.model = GNN_VAE_Model(self.params.feat_cell,self.params,self.num_classes).to(device)
47
- self.optimizer = torch.optim.Adam(params = list(self.model.parameters()),
48
- lr = self.params.gat_lr, weight_decay = self.params.gcn_decay)
49
-
50
- # Train
51
- def run_train(self):
52
- self.model.train()
53
- prev_loss = float('inf')
54
-
55
- bar = Bar('GAT-AE model train:', max = self.epochs)
56
- bar.check_tty = False
57
- for epoch in range(self.epochs):
58
- start_time = time.time()
59
- self.model.train()
60
- self.optimizer.zero_grad()
61
- pred_label, de_feat, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
62
- loss_rec = reconstruction_loss(de_feat, self.node_X)
63
-
64
- # Check whether annotation was provided
65
- if not self.label is None:
66
- loss_pre = label_loss(pred_label, self.label)
67
- loss = (self.params.rec_w * loss_rec) + (self.params.label_w * loss_pre)
68
- else:
69
- loss = loss_rec
70
-
71
- loss.backward()
72
- self.optimizer.step()
73
-
74
- # Update process
75
- end_time = time.time()
76
- batch_time = end_time - start_time
77
-
78
-
79
- bar_str = '{} / {} | Left time: {batch_time:.2f} mins| Loss: {loss:.4f}'
80
- bar.suffix = bar_str.format(epoch + 1,self.epochs,
81
- batch_time = batch_time * (self.epochs - epoch) / 60, loss=loss.item())
82
- bar.next()
83
-
84
- # Check convergence
85
- if abs(loss.item() - prev_loss) <= self.params.convergence_threshold and epoch >= 200:
86
- print('\nConvergence reached. Training stopped.')
87
- break
88
-
89
- prev_loss = loss.item()
90
-
91
- bar.finish()
92
- #-
93
- def get_latent(self):
94
- self.model.eval()
95
- pred, de_fea, latent_z, mu, logvar = self.model(self.node_X, self.adj_norm)
96
- latent_z = latent_z.data.cpu().numpy()
97
- return latent_z
@@ -1,31 +0,0 @@
1
- gsMap/__init__.py,sha256=eQ-mfdcGTJtKS2KIu5PEQMqgx_9j9W5KKTBVr-iI4yo,78
2
- gsMap/__main__.py,sha256=jR-HT42Zzfj2f-7kFJy0bkWjNxcV1MyfQHXFpef2nSE,62
3
- gsMap/cauchy_combination_test.py,sha256=zBPR7DOaNkr7rRoua4tAjRZL7ArjCyMRSQlPSUdHNSE,5694
4
- gsMap/config.py,sha256=hMUvlwlKZXeRdTJZfMINz_8DadVhEIT6X6fyJf11M9E,41134
5
- gsMap/diagnosis.py,sha256=pp3ONVaWCOoNCog1_6eud38yicBFxL-XhH7D8iTBgF4,13220
6
- gsMap/find_latent_representation.py,sha256=BVv4dyTolrlciHG3I-vwNDh2ruPpTf9jiT1hMKZnpto,6044
7
- gsMap/format_sumstats.py,sha256=9OBxuunoOLml3LKZvvRsPEEjQvT1Cuqb0w6lqsRIYPw,13714
8
- gsMap/generate_ldscore.py,sha256=2JfQoMWeQ0-B-zRHakmwq8ovkeewlnWHUCnih6od6ZE,29089
9
- gsMap/latent_to_gene.py,sha256=MwoGQd0EDvDmvpuMoVD83SI1EeGJXXzMW8YZp_6wxI8,10082
10
- gsMap/main.py,sha256=skyBtESdjvuXd9HNq5c83OPxQTNgLVErkYhwuJm8tE4,1285
11
- gsMap/report.py,sha256=H0uYAru2L5-d41_LFHPPdoJbtiTzP4f8kX-mirUNAfc,6963
12
- gsMap/run_all_mode.py,sha256=sPEct9fRw7aAQuU7BNChxk-I8YQcXuq--mtBn-2wTTY,8388
13
- gsMap/setup.py,sha256=eOoPalGAHTY06_7a35nvbOaKQhq0SBE5GK4pc4bp3wc,102
14
- gsMap/spatial_ldsc_multiple_sumstats.py,sha256=09j2zG98JUjQvSHPaRIDQVMZZLYDd1JBFaZTkW7tdvY,18470
15
- gsMap/visualize.py,sha256=FLIRHHj1KntLMFDjAhg1jZnJUdvrncR74pCW2Kj5pus,7453
16
- gsMap/GNN_VAE/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- gsMap/GNN_VAE/adjacency_matrix.py,sha256=pmSfK9TTwdrsWTmHvCqVrbRE0PAiq1lvgmxzrdQgpiU,3500
18
- gsMap/GNN_VAE/model.py,sha256=Fixl8-2zN3T5MmMtpMIvxsBYydM3QVR4uC3Hhsg0DzI,3349
19
- gsMap/GNN_VAE/train.py,sha256=KnvZYHImRzTwJl1H0dkZZqWASdZ5VgYTCifQVW8TavM,3389
20
- gsMap/templates/report_template.html,sha256=pdxHFl_W0W351NUzuJTf_Ay_BfKlEbD_fztNabAGmmg,8214
21
- gsMap/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- gsMap/utils/generate_r2_matrix.py,sha256=A1BrUnlTrYjRwEKxK0I1FbZ5SCQzcviWVM-JzFHHRkw,29352
23
- gsMap/utils/jackknife.py,sha256=nEDPVQJOPQ_uqfUCGX_v5cQwokgCqUmJTT_8rVFuIQo,18244
24
- gsMap/utils/make_annotations.py,sha256=lCbtahT27WFOwLgZrEUE5QcNRuMXmAFYUfsFR-cT-m0,22197
25
- gsMap/utils/manhattan_plot.py,sha256=k3n-NNgMsov9-8UQrirVqG560FUfJ4d6wNG8C0OeCjY,26235
26
- gsMap/utils/regression_read.py,sha256=n_hZZzQXHU-CSLvSofXmQM5Jw4Zpufv3U2HoUW344ko,8768
27
- gsmap-1.66.dist-info/entry_points.txt,sha256=s_P2Za22O077tc1FPLKMinbdRVXaN_HTcDBgWMYpqA4,41
28
- gsmap-1.66.dist-info/LICENSE,sha256=Ni2F-lLSv_H1xaVT3CoSrkiKzMvlgwh-dq8PE1esGyI,1094
29
- gsmap-1.66.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
30
- gsmap-1.66.dist-info/METADATA,sha256=HXeRNmaP_UPzG2Qjn5s-jcLBvrfLgPYl7qVGDAKJG5Y,3376
31
- gsmap-1.66.dist-info/RECORD,,
File without changes