RGAST 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- RGAST-0.0.1/LICENSE +21 -0
- RGAST-0.0.1/PKG-INFO +49 -0
- RGAST-0.0.1/README.md +34 -0
- RGAST-0.0.1/RGAST/RGAST.py +26 -0
- RGAST-0.0.1/RGAST/Train_RGAST.py +387 -0
- RGAST-0.0.1/RGAST/__init__.py +13 -0
- RGAST-0.0.1/RGAST/svg.py +252 -0
- RGAST-0.0.1/RGAST/utils.py +315 -0
- RGAST-0.0.1/RGAST.egg-info/PKG-INFO +49 -0
- RGAST-0.0.1/RGAST.egg-info/SOURCES.txt +13 -0
- RGAST-0.0.1/RGAST.egg-info/dependency_links.txt +1 -0
- RGAST-0.0.1/RGAST.egg-info/requires.txt +6 -0
- RGAST-0.0.1/RGAST.egg-info/top_level.txt +1 -0
- RGAST-0.0.1/setup.cfg +4 -0
- RGAST-0.0.1/setup.py +34 -0
RGAST-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Yuqiqo Gong
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
RGAST-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: RGAST
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Relational Graph Attention Network for Spatial Transcriptome Analysis
|
|
5
|
+
Home-page: https://github.com/GYQ-form/RGAST
|
|
6
|
+
Author: Yuqiao Gong
|
|
7
|
+
Author-email: gyq123@sjtu.edu.cn
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: spatial transcriptomic,RGAT,representation learning,spatial domain identification
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
|
|
16
|
+
# RGAST
|
|
17
|
+
|
|
18
|
+
RGAST: Relational Graph Attention Network for Spatial Transcriptome Analysis
|
|
19
|
+
|
|
20
|
+
This document will help you easily go through the scBC model.
|
|
21
|
+
|
|
22
|
+

|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
To install our package, run
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install RGAST
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## Usage
|
|
36
|
+
|
|
37
|
+
RGAST (Relational Graph Attention network for Spatial Transcriptome analysis) constructs a relational graph attention network to learn the representation of each spot in the spatial transcriptome data. Plus the attention mechanism, RGAST considers both gene expression similarity and spatial neighbor relationships in constructing the graph network, enabling a more comprehensive and flexible representation of the spatial transcriptome data. RGAST can be used in many ST analysis:
|
|
38
|
+
|
|
39
|
+
- spatial domain identification
|
|
40
|
+
- cell trajectory inference
|
|
41
|
+
- spatially variable gene (SVG) detection
|
|
42
|
+
- uncover spatially resolved cell-cell interactions
|
|
43
|
+
- reveal intricate 3D spatial patterns across multiple sections of ST data
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
## Tutorial
|
|
48
|
+
|
|
49
|
+
We have prepared several basic tutorials in https://github.com/GYQ-form/RGAST/tree/main/tutorial. You can quickly hands on RGAST by going through these tutorials. Model parameters trained in our study are also released in https://github.com/GYQ-form/RGAST/tree/main/model_path.
|
RGAST-0.0.1/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# RGAST
|
|
2
|
+
|
|
3
|
+
RGAST: Relational Graph Attention Network for Spatial Transcriptome Analysis
|
|
4
|
+
|
|
5
|
+
This document will help you easily go through the scBC model.
|
|
6
|
+
|
|
7
|
+

|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
To install our package, run
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
pip install RGAST
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
## Usage
|
|
21
|
+
|
|
22
|
+
RGAST (Relational Graph Attention network for Spatial Transcriptome analysis) constructs a relational graph attention network to learn the representation of each spot in the spatial transcriptome data. Plus the attention mechanism, RGAST considers both gene expression similarity and spatial neighbor relationships in constructing the graph network, enabling a more comprehensive and flexible representation of the spatial transcriptome data. RGAST can be used in many ST analysis:
|
|
23
|
+
|
|
24
|
+
- spatial domain identification
|
|
25
|
+
- cell trajectory inference
|
|
26
|
+
- spatially variable gene (SVG) detection
|
|
27
|
+
- uncover spatially resolved cell-cell interactions
|
|
28
|
+
- reveal intricate 3D spatial patterns across multiple sections of ST data
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
## Tutorial
|
|
33
|
+
|
|
34
|
+
We have prepared several basic tutorials in https://github.com/GYQ-form/RGAST/tree/main/tutorial. You can quickly hands on RGAST by going through these tutorials. Model parameters trained in our study are also released in https://github.com/GYQ-form/RGAST/tree/main/model_path.
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.backends.cudnn as cudnn
|
|
5
|
+
cudnn.deterministic = True
|
|
6
|
+
cudnn.benchmark = True
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from torch_geometric.nn.conv.rgat_conv import RGATConv
|
|
9
|
+
|
|
10
|
+
class RGAST(torch.nn.Module):
|
|
11
|
+
def __init__(self, hidden_dims):
|
|
12
|
+
super(RGAST, self).__init__()
|
|
13
|
+
[in_dim, num_hidden, out_dim] = hidden_dims
|
|
14
|
+
self.conv1 = RGATConv(in_dim, num_hidden, num_relations=2, heads=1, concat=False,
|
|
15
|
+
dropout=0.3, add_self_loops=False, bias=False)
|
|
16
|
+
self.conv2 = RGATConv(num_hidden, out_dim, num_relations=2, heads=1, concat=False,
|
|
17
|
+
dropout=0.3, add_self_loops=False, bias=False)
|
|
18
|
+
self.decoder = nn.Sequential(
|
|
19
|
+
nn.Linear(out_dim, in_dim),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, features, edge_index, edge_type):
|
|
23
|
+
h1 = F.elu(self.conv1(features, edge_index, edge_type))
|
|
24
|
+
h2 = F.elu(self.conv2(h1, edge_index, edge_type))
|
|
25
|
+
h3 = self.decoder(h2)
|
|
26
|
+
return h2, h3
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import os
|
|
3
|
+
import scanpy as sc
|
|
4
|
+
import anndata
|
|
5
|
+
from sklearn.metrics.cluster import adjusted_rand_score
|
|
6
|
+
from sklearn.metrics import silhouette_score
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from .RGAST import RGAST
|
|
10
|
+
from .utils import Transfer_pytorch_Data, res_search_fixed_clus, Batch_Data, Cal_Spatial_Net, Cal_Expression_Net
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.backends.cudnn as cudnn
|
|
14
|
+
cudnn.deterministic = True
|
|
15
|
+
cudnn.benchmark = True
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
from torch_geometric.loader import DataLoader
|
|
18
|
+
|
|
19
|
+
def target_distribution(batch):
|
|
20
|
+
weight = (batch ** 2) / torch.sum(batch, 0)
|
|
21
|
+
return (weight.t() / torch.sum(weight, 1)).t()
|
|
22
|
+
|
|
23
|
+
class Train_RGAST:
|
|
24
|
+
|
|
25
|
+
def __init__(self, adata, batch_data = False, num_batch_x_y = None, spatial_net_arg = {}, exp_net_arg = {}, verbose=True):
|
|
26
|
+
|
|
27
|
+
"""\
|
|
28
|
+
Initialization of a RGAST trainer.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
adata
|
|
33
|
+
AnnData object of scanpy package.
|
|
34
|
+
num_batch_x_y
|
|
35
|
+
A tuple specifying the number of points at which to segment the spatially transcribed image on the x and y axes.
|
|
36
|
+
Each split is then trained as a batch. This is useful for large scale cases.
|
|
37
|
+
spatial_net_arg
|
|
38
|
+
A dict passing key-word arguments to calculating spatial network in each batch data. See `Cal_Spatial_Net`.
|
|
39
|
+
exp_net_arg
|
|
40
|
+
A dict passing key-word arguments to calculating expression network in each batch data. See `Cal_Expression_Net`
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
if 'X_pca' not in adata.obsm.keys():
|
|
44
|
+
raise ValueError("PCA has not been done! Run sc.pp.pca first!")
|
|
45
|
+
if verbose:
|
|
46
|
+
print('Size of Input: ', adata.obsm['X_pca'].shape)
|
|
47
|
+
|
|
48
|
+
self.batch_data = batch_data
|
|
49
|
+
self.adata = adata
|
|
50
|
+
if 'Spatial_Net' not in adata.uns.keys():
|
|
51
|
+
raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!")
|
|
52
|
+
if 'Exp_Net' not in adata.uns.keys():
|
|
53
|
+
raise ValueError("Exp_Net is not existed! Run Cal_Expression_Net first!")
|
|
54
|
+
self.data = Transfer_pytorch_Data(adata)
|
|
55
|
+
|
|
56
|
+
if batch_data:
|
|
57
|
+
self.num_batch_x, self.num_batch_y = num_batch_x_y
|
|
58
|
+
adata.obs['X'] = adata.obsm['spatial'][:,0]
|
|
59
|
+
adata.obs['Y'] = adata.obsm['spatial'][:,1]
|
|
60
|
+
Batch_list = Batch_Data(adata, num_batch_x=self.num_batch_x, num_batch_y=self.num_batch_y,
|
|
61
|
+
spatial_key=['X', 'Y'])
|
|
62
|
+
for temp_adata in Batch_list:
|
|
63
|
+
Cal_Spatial_Net(temp_adata, **spatial_net_arg)
|
|
64
|
+
Cal_Expression_Net(temp_adata, **exp_net_arg)
|
|
65
|
+
data_list = [Transfer_pytorch_Data(adata) for adata in Batch_list]
|
|
66
|
+
self.loader = DataLoader(data_list, batch_size=1, shuffle=True)
|
|
67
|
+
|
|
68
|
+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
69
|
+
self.model = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def train_RGAST(self, early_stopping = True, label_key = None, save_path = '.', n_clusters = 7,
|
|
73
|
+
hidden_dims=[100, 32], n_epochs=1000, lr=0.001, key_added='RGAST',
|
|
74
|
+
gradient_clipping=5., weight_decay=0.0001, verbose=True,
|
|
75
|
+
random_seed=0, save_loss=False, save_reconstrction=False):
|
|
76
|
+
|
|
77
|
+
"""\
|
|
78
|
+
Training graph attention auto-encoder.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
early_stopping
|
|
83
|
+
Using early stopping strategy or not. Default = True.
|
|
84
|
+
lable_key
|
|
85
|
+
A key specify the specific column in adata.obs to be treated as reference label.
|
|
86
|
+
save_path
|
|
87
|
+
directory to save the trained RGAST model.
|
|
88
|
+
n_clusters
|
|
89
|
+
number of clusters to set when calculating early stopping criterion.
|
|
90
|
+
hidden_dims
|
|
91
|
+
The dimension of the encoder.
|
|
92
|
+
n_epochs
|
|
93
|
+
Number of total epochs in training.
|
|
94
|
+
lr
|
|
95
|
+
Learning rate for AdamOptimizer.
|
|
96
|
+
key_added
|
|
97
|
+
The latent embeddings are saved in adata.obsm[key_added].
|
|
98
|
+
gradient_clipping
|
|
99
|
+
Gradient Clipping.
|
|
100
|
+
weight_decay
|
|
101
|
+
Weight decay for AdamOptimizer.
|
|
102
|
+
save_loss
|
|
103
|
+
If True, the training loss is saved in adata.uns['RGAST_loss'].
|
|
104
|
+
save_reconstrction
|
|
105
|
+
If True, the reconstructed expression profiles are saved in adata.layers['RGAST_ReX'].
|
|
106
|
+
device
|
|
107
|
+
See torch.device.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
AnnData
|
|
112
|
+
"""
|
|
113
|
+
self.save_path = save_path
|
|
114
|
+
self.label_key = label_key
|
|
115
|
+
self.n_clusters = n_clusters
|
|
116
|
+
|
|
117
|
+
# seed_everything()
|
|
118
|
+
seed=random_seed
|
|
119
|
+
import random
|
|
120
|
+
random.seed(seed)
|
|
121
|
+
torch.manual_seed(seed)
|
|
122
|
+
torch.cuda.manual_seed_all(seed)
|
|
123
|
+
np.random.seed(seed)
|
|
124
|
+
|
|
125
|
+
if self.model is None:
|
|
126
|
+
model = RGAST(hidden_dims = [self.data.x.shape[1]] + hidden_dims).to(self.device)
|
|
127
|
+
else:
|
|
128
|
+
model = self.model.to(self.device)
|
|
129
|
+
|
|
130
|
+
data = self.data.to(self.device)
|
|
131
|
+
|
|
132
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
133
|
+
|
|
134
|
+
loss_list = []
|
|
135
|
+
score_list = [0]
|
|
136
|
+
num_fail = 0
|
|
137
|
+
for epoch in tqdm(range(1, n_epochs+1)):
|
|
138
|
+
|
|
139
|
+
if early_stopping:
|
|
140
|
+
|
|
141
|
+
if label_key is not None:
|
|
142
|
+
if epoch % 50 == 0:
|
|
143
|
+
if self.batch_data:
|
|
144
|
+
model.to('cpu')
|
|
145
|
+
model.eval()
|
|
146
|
+
z, _ = model(data.x.cpu(), data.edge_index.cpu(), data.edge_type.cpu())
|
|
147
|
+
model.to(self.device)
|
|
148
|
+
else:
|
|
149
|
+
model.eval()
|
|
150
|
+
z, _ = model(data.x, data.edge_index, data.edge_type)
|
|
151
|
+
z = z.to('cpu').detach().numpy()
|
|
152
|
+
adata_RGAST = anndata.AnnData(z)
|
|
153
|
+
adata_RGAST.obs_names=self.adata.obs_names
|
|
154
|
+
sc.pp.neighbors(adata_RGAST)
|
|
155
|
+
sc.tl.umap(adata_RGAST)
|
|
156
|
+
_ = res_search_fixed_clus(adata_RGAST, n_clusters)
|
|
157
|
+
obs_df = adata_RGAST.obs.join(self.adata.obs[label_key]).dropna(subset=label_key)
|
|
158
|
+
ARI = adjusted_rand_score(obs_df['leiden'], obs_df[label_key])
|
|
159
|
+
if verbose:
|
|
160
|
+
print(f'epoch:{epoch},ARI:{ARI}')
|
|
161
|
+
if ARI <= max(score_list):
|
|
162
|
+
num_fail += 1
|
|
163
|
+
if num_fail>3 and epoch>=300:
|
|
164
|
+
break
|
|
165
|
+
else:
|
|
166
|
+
num_fail = 0
|
|
167
|
+
torch.save(model,f'{save_path}/model.pth')
|
|
168
|
+
self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
|
|
169
|
+
score_list.append(ARI)
|
|
170
|
+
|
|
171
|
+
else:
|
|
172
|
+
if epoch % 50 == 0:
|
|
173
|
+
if self.batch_data:
|
|
174
|
+
model.to('cpu')
|
|
175
|
+
model.eval()
|
|
176
|
+
z, _ = model(data.x.cpu(), data.edge_index.cpu(), data.edge_type.cpu())
|
|
177
|
+
model.to(self.device)
|
|
178
|
+
else:
|
|
179
|
+
model.eval()
|
|
180
|
+
z, _ = model(data.x, data.edge_index, data.edge_type)
|
|
181
|
+
z = z.to('cpu').detach().numpy()
|
|
182
|
+
adata_RGAST = anndata.AnnData(z)
|
|
183
|
+
adata_RGAST.obs_names=self.adata.obs_names
|
|
184
|
+
sc.pp.neighbors(adata_RGAST)
|
|
185
|
+
sc.tl.umap(adata_RGAST)
|
|
186
|
+
_ = res_search_fixed_clus(adata_RGAST, n_clusters)
|
|
187
|
+
SC = silhouette_score(z, adata_RGAST.obs['leiden'])
|
|
188
|
+
if verbose:
|
|
189
|
+
print(f'epoch:{epoch},SC:{SC}')
|
|
190
|
+
if SC <= max(score_list):
|
|
191
|
+
num_fail += 1
|
|
192
|
+
if num_fail>3 and epoch>=300:
|
|
193
|
+
break
|
|
194
|
+
else:
|
|
195
|
+
num_fail = 0
|
|
196
|
+
torch.save(model,f'{save_path}/model.pth')
|
|
197
|
+
self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
|
|
198
|
+
score_list.append(SC)
|
|
199
|
+
|
|
200
|
+
if self.batch_data:
|
|
201
|
+
for batch in self.loader:
|
|
202
|
+
batch = batch.to(self.device)
|
|
203
|
+
model.train()
|
|
204
|
+
optimizer.zero_grad()
|
|
205
|
+
z, out = model(batch.x, batch.edge_index, batch.edge_type)
|
|
206
|
+
loss = F.mse_loss(batch.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
|
|
207
|
+
loss_list.append(loss)
|
|
208
|
+
loss.backward()
|
|
209
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
|
|
210
|
+
optimizer.step()
|
|
211
|
+
|
|
212
|
+
else:
|
|
213
|
+
model.train()
|
|
214
|
+
optimizer.zero_grad()
|
|
215
|
+
z, out = model(data.x, data.edge_index, data.edge_type)
|
|
216
|
+
loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
|
|
217
|
+
loss_list.append(loss)
|
|
218
|
+
loss.backward()
|
|
219
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
|
|
220
|
+
optimizer.step()
|
|
221
|
+
|
|
222
|
+
if os.path.exists(f'{save_path}/model.pth'):
|
|
223
|
+
model = torch.load(f'{save_path}/model.pth').to(self.device)
|
|
224
|
+
|
|
225
|
+
if self.batch_data:
|
|
226
|
+
model.to('cpu')
|
|
227
|
+
model.eval()
|
|
228
|
+
z, out = model(data.x.cpu(), data.edge_index.cpu(), data.edge_type.cpu())
|
|
229
|
+
model.to(self.device)
|
|
230
|
+
else:
|
|
231
|
+
model.eval()
|
|
232
|
+
z, out = model(data.x, data.edge_index, data.edge_type)
|
|
233
|
+
|
|
234
|
+
RGAST_rep = z.to('cpu').detach().numpy()
|
|
235
|
+
np.save(f'{save_path}/RGAST_embedding.npy', RGAST_rep)
|
|
236
|
+
self.adata.obsm[key_added] = RGAST_rep
|
|
237
|
+
|
|
238
|
+
if save_loss:
|
|
239
|
+
self.adata.uns['RGAST_loss'] = loss
|
|
240
|
+
if save_reconstrction:
|
|
241
|
+
ReX = out.to('cpu').detach().numpy()
|
|
242
|
+
self.adata.layers['RGAST_ReX'] = ReX
|
|
243
|
+
|
|
244
|
+
self.model = model
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def train_with_dec(self, verbose = True, early_stopping = True, key_added='RGAST', num_epochs=1000, dec_interval=50, dec_tol=0.01):
|
|
248
|
+
|
|
249
|
+
"""\
|
|
250
|
+
Training graph attention auto-encoder with deep embedding clustering.
|
|
251
|
+
Only call this after call Train_RGAST.train_RGAST() and make sure batch_data = False.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
early_stopping
|
|
256
|
+
Using early stopping strategy or not. Default = True.
|
|
257
|
+
key_added
|
|
258
|
+
The latent embeddings are saved in adata.obsm[key_added].
|
|
259
|
+
num_epochs
|
|
260
|
+
Number of total epochs in training.
|
|
261
|
+
dec_interval
|
|
262
|
+
Evaluate after how many epochs (for early stopping).
|
|
263
|
+
dec_tol
|
|
264
|
+
DEC tol.
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
-------
|
|
268
|
+
AnnData with updated .obsm[key_added]
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
# initialize cluster parameter
|
|
272
|
+
model = self.model.to(self.device)
|
|
273
|
+
model.eval()
|
|
274
|
+
test_z = self.adata.obsm['RGAST']
|
|
275
|
+
y_pred_last = np.array(self.adata.obs['leiden'],dtype=np.int32).copy()
|
|
276
|
+
counts = len(np.bincount(y_pred_last))
|
|
277
|
+
cluster_layer = []
|
|
278
|
+
for i in range(counts):
|
|
279
|
+
cluster_layer.append(np.mean(test_z[y_pred_last==i,],axis=0))
|
|
280
|
+
cluster_layer = torch.tensor(cluster_layer).to(self.device)
|
|
281
|
+
data = self.data.to(self.device)
|
|
282
|
+
|
|
283
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
|
|
284
|
+
|
|
285
|
+
score_list = [0]
|
|
286
|
+
num_fail = 0
|
|
287
|
+
for epoch_id in tqdm(range(num_epochs)):
|
|
288
|
+
|
|
289
|
+
if early_stopping:
|
|
290
|
+
|
|
291
|
+
if epoch_id % dec_interval == 0:
|
|
292
|
+
#early stopping
|
|
293
|
+
if self.label_key is not None:
|
|
294
|
+
model.eval()
|
|
295
|
+
z, _ = model(data.x, data.edge_index, data.edge_type)
|
|
296
|
+
z = z.to('cpu').detach().numpy()
|
|
297
|
+
adata_RGAST = anndata.AnnData(z)
|
|
298
|
+
adata_RGAST.obs_names=self.adata.obs_names
|
|
299
|
+
sc.pp.neighbors(adata_RGAST)
|
|
300
|
+
sc.tl.umap(adata_RGAST)
|
|
301
|
+
_ = res_search_fixed_clus(adata_RGAST, self.n_clusters)
|
|
302
|
+
obs_df = adata_RGAST.obs.join(self.adata.obs[self.label_key]).dropna(subset=self.label_key)
|
|
303
|
+
ARI = adjusted_rand_score(obs_df['leiden'], obs_df[self.label_key])
|
|
304
|
+
if verbose:
|
|
305
|
+
print(f'epoch:{epoch_id},ARI:{ARI}')
|
|
306
|
+
if ARI <= max(score_list):
|
|
307
|
+
num_fail += 1
|
|
308
|
+
if num_fail>3 and epoch_id>=300:
|
|
309
|
+
break
|
|
310
|
+
else:
|
|
311
|
+
num_fail = 0
|
|
312
|
+
torch.save(model,f'{self.save_path}/model.pth')
|
|
313
|
+
self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
|
|
314
|
+
score_list.append(ARI)
|
|
315
|
+
|
|
316
|
+
else:
|
|
317
|
+
model.eval()
|
|
318
|
+
z, _ = model(data.x, data.edge_index, data.edge_type)
|
|
319
|
+
z = z.to('cpu').detach().numpy()
|
|
320
|
+
adata_RGAST = anndata.AnnData(z)
|
|
321
|
+
adata_RGAST.obs_names=self.adata.obs_names
|
|
322
|
+
sc.pp.neighbors(adata_RGAST)
|
|
323
|
+
sc.tl.umap(adata_RGAST)
|
|
324
|
+
_ = res_search_fixed_clus(adata_RGAST, self.n_clusters)
|
|
325
|
+
SC = silhouette_score(z, adata_RGAST.obs['leiden'])
|
|
326
|
+
if verbose:
|
|
327
|
+
print(f'epoch:{epoch_id},SC:{SC}')
|
|
328
|
+
if SC <= max(score_list):
|
|
329
|
+
num_fail += 1
|
|
330
|
+
if num_fail>3 and epoch_id>=300:
|
|
331
|
+
break
|
|
332
|
+
else:
|
|
333
|
+
num_fail = 0
|
|
334
|
+
torch.save(model,f'{self.save_path}/model.pth')
|
|
335
|
+
self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
|
|
336
|
+
score_list.append(SC)
|
|
337
|
+
|
|
338
|
+
#DEC update
|
|
339
|
+
z, reconst = model(data.x, data.edge_index, data.edge_type)
|
|
340
|
+
q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - cluster_layer, 2), 2))
|
|
341
|
+
q = (q.t() / torch.sum(q, 1)).t()
|
|
342
|
+
tmp_p = target_distribution(torch.Tensor(q))
|
|
343
|
+
y_pred = tmp_p.cpu().detach().numpy().argmax(1)
|
|
344
|
+
delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
|
|
345
|
+
y_pred_last = np.copy(y_pred)
|
|
346
|
+
if epoch_id > 0 and delta_label < dec_tol:
|
|
347
|
+
print('delta_label {:.4}'.format(delta_label), '< tol', dec_tol)
|
|
348
|
+
print('Reached tolerance threshold. Stopping training.')
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
# training model
|
|
353
|
+
model.train()
|
|
354
|
+
optimizer.zero_grad()
|
|
355
|
+
z, reconst = model(data.x, data.edge_index, data.edge_type)
|
|
356
|
+
q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - cluster_layer, 2), 2) / 1.0)
|
|
357
|
+
q = (q.t() / torch.sum(q, 1)).t()
|
|
358
|
+
loss_rec = F.mse_loss(data.x, reconst)
|
|
359
|
+
# clustering KL loss
|
|
360
|
+
loss_kl = F.kl_div(q.log(), torch.tensor(tmp_p).to(self.device)).to(self.device)
|
|
361
|
+
loss = loss_kl + loss_rec
|
|
362
|
+
loss.backward()
|
|
363
|
+
optimizer.step()
|
|
364
|
+
|
|
365
|
+
model = torch.load(f'{self.save_path}/model.pth').to(self.device)
|
|
366
|
+
model.eval()
|
|
367
|
+
z, _ = model(data.x, data.edge_index, data.edge_type)
|
|
368
|
+
|
|
369
|
+
RGAST_rep = z.to('cpu').detach().numpy()
|
|
370
|
+
np.save(f'{self.save_path}/RGAST_embedding.npy', RGAST_rep)
|
|
371
|
+
self.adata.obsm[key_added] = RGAST_rep
|
|
372
|
+
self.model = model
|
|
373
|
+
|
|
374
|
+
def load_model(self, path):
|
|
375
|
+
self.model = torch.load(path)
|
|
376
|
+
|
|
377
|
+
def save_model(self, path):
|
|
378
|
+
torch.save(self.model,f'{path}/model.pth')
|
|
379
|
+
|
|
380
|
+
def process(self, gdata = None):
|
|
381
|
+
if gdata is None:
|
|
382
|
+
gdata = self.data
|
|
383
|
+
self.model.to(self.device)
|
|
384
|
+
self.model.eval()
|
|
385
|
+
gdata = gdata.to(self.device)
|
|
386
|
+
return self.model(gdata.x, gdata.edge_index, gdata.edge_type)
|
|
387
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""
|
|
3
|
+
# Author: Yuqiao Gong
|
|
4
|
+
# File Name: __init__.py
|
|
5
|
+
# Description:
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__author__ = "Yuqiao Gong"
|
|
9
|
+
__email__ = "gyq123@sjtu.edu.cn"
|
|
10
|
+
|
|
11
|
+
from .RGAST import RGAST
|
|
12
|
+
from .Train_RGAST import Train_RGAST
|
|
13
|
+
from .utils import Transfer_pytorch_Data, Cal_Spatial_Net, Cal_Expression_Net, Stats_Spatial_Net, Cal_Spatial_Net_3D, Batch_Data, plot_clustering, refine_spatial_cluster, Cal_Expression_3D
|
RGAST-0.0.1/RGAST/svg.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# code is modified from https://github.com/jianhuupenn/SpaGCN
|
|
2
|
+
import scanpy as sc
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import numpy as np
|
|
5
|
+
import scipy
|
|
6
|
+
from scipy.sparse import issparse
|
|
7
|
+
import numba
|
|
8
|
+
from sklearn.neighbors import NearestNeighbors
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@numba.njit("f4(f4[:], f4[:])")
|
|
12
|
+
def euclid_dist(t1,t2):
|
|
13
|
+
sum=0
|
|
14
|
+
for i in range(t1.shape[0]):
|
|
15
|
+
sum+=(t1[i]-t2[i])**2
|
|
16
|
+
return np.sqrt(sum)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@numba.njit("f4[:,:](f4[:,:])", parallel=True, nogil=True)
|
|
20
|
+
def pairwise_distance(X):
|
|
21
|
+
n=X.shape[0]
|
|
22
|
+
adj=np.empty((n, n), dtype=np.float32)
|
|
23
|
+
for i in numba.prange(n):
|
|
24
|
+
for j in numba.prange(n):
|
|
25
|
+
adj[i][j]=euclid_dist(X[i], X[j])
|
|
26
|
+
return adj
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def calculate_adj_matrix(x, y):
|
|
30
|
+
|
|
31
|
+
X=np.array([x, y]).T.astype(np.float32)
|
|
32
|
+
return pairwise_distance(X)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, radius):
|
|
36
|
+
# adj_2d=calculate_adj_matrix(x=x,y=y, histology=False)
|
|
37
|
+
cluster_num = dict()
|
|
38
|
+
df = {'cell_id': cell_id, 'x': x, "y":y, "pred":pred}
|
|
39
|
+
df = pd.DataFrame(data=df)
|
|
40
|
+
df.index=df['cell_id']
|
|
41
|
+
target_df=df[df["pred"]==target_cluster]
|
|
42
|
+
row_index=0
|
|
43
|
+
num_nbr=[]
|
|
44
|
+
for index, row in target_df.iterrows():
|
|
45
|
+
x=row["x"]
|
|
46
|
+
y=row["y"]
|
|
47
|
+
tmp_nbr=df[((df["x"]-x)**2+(df["y"]-y)**2)<=(radius**2)]
|
|
48
|
+
num_nbr.append(tmp_nbr.shape[0])
|
|
49
|
+
return np.mean(num_nbr)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def search_radius(target_cluster,cell_id, x, y, pred, adj_2d, start, end, num_min=8, num_max=15, max_run=100):
|
|
53
|
+
run=0
|
|
54
|
+
num_low=count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, start)
|
|
55
|
+
num_high=count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, end)
|
|
56
|
+
if num_min<=num_low<=num_max:
|
|
57
|
+
print("recommended radius = ", str(start))
|
|
58
|
+
return start
|
|
59
|
+
elif num_min<=num_high<=num_max:
|
|
60
|
+
print("recommended radius = ", str(end))
|
|
61
|
+
return end
|
|
62
|
+
elif num_low>num_max:
|
|
63
|
+
print("Try smaller start.")
|
|
64
|
+
return None
|
|
65
|
+
elif num_high<num_min:
|
|
66
|
+
print("Try bigger end.")
|
|
67
|
+
return None
|
|
68
|
+
while (num_low<num_min) and (num_high>num_min):
|
|
69
|
+
run+=1
|
|
70
|
+
print("Run "+str(run)+": radius ["+str(start)+", "+str(end)+"], num_nbr ["+str(num_low)+", "+str(num_high)+"]")
|
|
71
|
+
if run >max_run:
|
|
72
|
+
print("Exact radius not found, closest values are:\n"+"radius="+str(start)+": "+"num_nbr="+str(num_low)+"\nradius="+str(end)+": "+"num_nbr="+str(num_high))
|
|
73
|
+
return mid
|
|
74
|
+
mid=(start+end)/2
|
|
75
|
+
num_mid=count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, mid)
|
|
76
|
+
if num_min<=num_mid<=num_max:
|
|
77
|
+
print("recommended radius = ", str(mid), "num_nbr="+str(num_mid))
|
|
78
|
+
return mid
|
|
79
|
+
if num_mid<num_min:
|
|
80
|
+
start=mid
|
|
81
|
+
num_low=num_mid
|
|
82
|
+
elif num_mid>num_max:
|
|
83
|
+
end=mid
|
|
84
|
+
num_high=num_mid
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def rank_genes_groups(input_adata, target_cluster,nbr_list, label_col, adj_nbr=True, log=False):
|
|
88
|
+
if adj_nbr:
|
|
89
|
+
nbr_list=nbr_list+[target_cluster]
|
|
90
|
+
adata=input_adata[input_adata.obs[label_col].isin(nbr_list)]
|
|
91
|
+
else:
|
|
92
|
+
adata=input_adata.copy()
|
|
93
|
+
adata.var_names_make_unique()
|
|
94
|
+
adata.obs["target"]=((adata.obs[label_col]==target_cluster)*1).astype('category')
|
|
95
|
+
sc.tl.rank_genes_groups(adata, use_raw=False, groupby="target",reference="rest", n_genes=adata.shape[1],method='wilcoxon')
|
|
96
|
+
pvals_adj=[i[0] for i in adata.uns['rank_genes_groups']["pvals_adj"]]
|
|
97
|
+
genes=[i[1] for i in adata.uns['rank_genes_groups']["names"]]
|
|
98
|
+
if issparse(adata.X):
|
|
99
|
+
obs_tidy=pd.DataFrame(adata.X.A)
|
|
100
|
+
else:
|
|
101
|
+
obs_tidy=pd.DataFrame(adata.X)
|
|
102
|
+
obs_tidy.index=adata.obs["target"].tolist()
|
|
103
|
+
obs_tidy.columns=adata.var.index.tolist()
|
|
104
|
+
obs_tidy=obs_tidy.loc[:,genes]
|
|
105
|
+
# 1. compute mean value
|
|
106
|
+
mean_obs = obs_tidy.groupby(level=0).mean()
|
|
107
|
+
# 2. compute fraction of cells having value >0
|
|
108
|
+
obs_bool = obs_tidy.astype(bool)
|
|
109
|
+
fraction_obs = obs_bool.groupby(level=0).sum() / obs_bool.groupby(level=0).count()
|
|
110
|
+
# compute fold change.
|
|
111
|
+
if log: #The adata already logged
|
|
112
|
+
fold_change=np.exp((mean_obs.loc[1] - mean_obs.loc[0]).values)
|
|
113
|
+
else:
|
|
114
|
+
fold_change = (mean_obs.loc[1] / (mean_obs.loc[0]+ 1e-9)).values
|
|
115
|
+
df = {'genes': genes, 'in_group_fraction': fraction_obs.loc[1].tolist(), "out_group_fraction":fraction_obs.loc[0].tolist(),"in_out_group_ratio":(fraction_obs.loc[1]/fraction_obs.loc[0]).tolist(),"in_group_mean_exp": mean_obs.loc[1].tolist(), "out_group_mean_exp": mean_obs.loc[0].tolist(),"fold_change":fold_change.tolist(), "pvals_adj":pvals_adj}
|
|
116
|
+
df = pd.DataFrame(data=df)
|
|
117
|
+
return df
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def find_neighbor_clusters(target_cluster,cell_id, x, y, pred,radius, ratio=1/2):
|
|
121
|
+
cluster_num = dict()
|
|
122
|
+
for i in pred:
|
|
123
|
+
cluster_num[i] = cluster_num.get(i, 0) + 1
|
|
124
|
+
df = {'cell_id': cell_id, 'x': x, "y":y, "pred":pred}
|
|
125
|
+
df = pd.DataFrame(data=df)
|
|
126
|
+
df.index=df['cell_id']
|
|
127
|
+
target_df=df[df["pred"]==target_cluster]
|
|
128
|
+
nbr_num={}
|
|
129
|
+
row_index=0
|
|
130
|
+
num_nbr=[]
|
|
131
|
+
for index, row in target_df.iterrows():
|
|
132
|
+
x=row["x"]
|
|
133
|
+
y=row["y"]
|
|
134
|
+
tmp_nbr=df[((df["x"]-x)**2+(df["y"]-y)**2)<=(radius**2)]
|
|
135
|
+
#tmp_nbr=df[(df["x"]<x+radius) & (df["x"]>x-radius) & (df["y"]<y+radius) & (df["y"]>y-radius)]
|
|
136
|
+
num_nbr.append(tmp_nbr.shape[0])
|
|
137
|
+
for p in tmp_nbr["pred"]:
|
|
138
|
+
nbr_num[p]=nbr_num.get(p,0)+1
|
|
139
|
+
del nbr_num[target_cluster]
|
|
140
|
+
nbr_num=[(k, v) for k, v in nbr_num.items() if v>(ratio*cluster_num[k])]
|
|
141
|
+
nbr_num.sort(key=lambda x: -x[1])
|
|
142
|
+
print("radius=", radius, "average number of neighbors for each spot is", np.mean(num_nbr))
|
|
143
|
+
print(" Cluster",target_cluster, "has neighbors:")
|
|
144
|
+
for t in nbr_num:
|
|
145
|
+
print("Dmain ", t[0], ": ",t[1])
|
|
146
|
+
ret=[t[0] for t in nbr_num]
|
|
147
|
+
if len(ret)==0:
|
|
148
|
+
print("No neighbor domain found, try bigger radius or smaller ratio.")
|
|
149
|
+
else:
|
|
150
|
+
return ret
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def find_meta_gene(input_adata,
|
|
154
|
+
pred,
|
|
155
|
+
target_domain,
|
|
156
|
+
start_gene,
|
|
157
|
+
mean_diff=0,
|
|
158
|
+
early_stop=True,
|
|
159
|
+
max_iter=5):
|
|
160
|
+
meta_name=start_gene
|
|
161
|
+
adata=input_adata.copy()
|
|
162
|
+
adata.obs["meta"]=adata.X.A[:,adata.var.index==start_gene]
|
|
163
|
+
adata.obs["pred"]=pred
|
|
164
|
+
num_non_target=adata.shape[0]
|
|
165
|
+
for i in range(max_iter):
|
|
166
|
+
#Select cells
|
|
167
|
+
tmp=adata[((adata.obs["meta"]>np.mean(adata.obs[adata.obs["pred"]==target_domain]["meta"]))|(adata.obs["pred"]==target_domain))]
|
|
168
|
+
tmp.obs["target"]=((tmp.obs["pred"]==target_domain)*1).astype('category').copy()
|
|
169
|
+
if (len(set(tmp.obs["target"]))<2) or (np.min(tmp.obs["target"].value_counts().values)<5):
|
|
170
|
+
print("Meta gene is: ", meta_name)
|
|
171
|
+
return meta_name, adata.obs["meta"].tolist()
|
|
172
|
+
#DE
|
|
173
|
+
sc.tl.rank_genes_groups(tmp, groupby="target",reference="rest", n_genes=1,method='wilcoxon')
|
|
174
|
+
adj_g=tmp.uns['rank_genes_groups']["names"][0][0]
|
|
175
|
+
add_g=tmp.uns['rank_genes_groups']["names"][0][1]
|
|
176
|
+
meta_name_cur=meta_name+"+"+add_g+"-"+adj_g
|
|
177
|
+
print("Add gene: ", add_g)
|
|
178
|
+
print("Minus gene: ", adj_g)
|
|
179
|
+
#Meta gene
|
|
180
|
+
adata.obs[add_g]=adata.X[:,adata.var.index==add_g]
|
|
181
|
+
adata.obs[adj_g]=adata.X[:,adata.var.index==adj_g]
|
|
182
|
+
adata.obs["meta_cur"]=(adata.obs["meta"]+adata.obs[add_g]-adata.obs[adj_g])
|
|
183
|
+
adata.obs["meta_cur"]=adata.obs["meta_cur"]-np.min(adata.obs["meta_cur"])
|
|
184
|
+
mean_diff_cur=np.mean(adata.obs["meta_cur"][adata.obs["pred"]==target_domain])-np.mean(adata.obs["meta_cur"][adata.obs["pred"]!=target_domain])
|
|
185
|
+
num_non_target_cur=np.sum(tmp.obs["target"]==0)
|
|
186
|
+
if (early_stop==False) | ((num_non_target>=num_non_target_cur) & (mean_diff<=mean_diff_cur)):
|
|
187
|
+
num_non_target=num_non_target_cur
|
|
188
|
+
mean_diff=mean_diff_cur
|
|
189
|
+
print("Absolute mean change:", mean_diff)
|
|
190
|
+
print("Number of non-target spots reduced to:",num_non_target)
|
|
191
|
+
else:
|
|
192
|
+
print("Stopped!", "Previous Number of non-target spots",num_non_target, num_non_target_cur, mean_diff,mean_diff_cur)
|
|
193
|
+
print("Previous Number of non-target spots",num_non_target, num_non_target_cur, mean_diff,mean_diff_cur)
|
|
194
|
+
print("Previous Number of non-target spots",num_non_target)
|
|
195
|
+
print("Current Number of non-target spots",num_non_target_cur)
|
|
196
|
+
print("Absolute mean change", mean_diff)
|
|
197
|
+
print("===========================================================================")
|
|
198
|
+
print("Meta gene: ", meta_name)
|
|
199
|
+
print("===========================================================================")
|
|
200
|
+
return meta_name, adata.obs["meta"].tolist()
|
|
201
|
+
meta_name=meta_name_cur
|
|
202
|
+
adata.obs["meta"]=adata.obs["meta_cur"]
|
|
203
|
+
print("===========================================================================")
|
|
204
|
+
print("Meta gene is: ", meta_name)
|
|
205
|
+
print("===========================================================================")
|
|
206
|
+
return meta_name, adata.obs["meta"].tolist()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def Moran_I(genes_exp,x, y, k=5, knn=True):
|
|
210
|
+
XYmap=pd.DataFrame({"x": x, "y":y})
|
|
211
|
+
if knn:
|
|
212
|
+
XYnbrs = NearestNeighbors(n_neighbors=k, algorithm='auto',metric = 'euclidean').fit(XYmap)
|
|
213
|
+
XYdistances, XYindices = XYnbrs.kneighbors(XYmap)
|
|
214
|
+
W = np.zeros((genes_exp.shape[0],genes_exp.shape[0]))
|
|
215
|
+
for i in range(0,genes_exp.shape[0]):
|
|
216
|
+
W[i,XYindices[i,:]]=1
|
|
217
|
+
for i in range(0,genes_exp.shape[0]):
|
|
218
|
+
W[i,i]=0
|
|
219
|
+
else:
|
|
220
|
+
W=calculate_adj_matrix(x=x,y=y, histology=False)
|
|
221
|
+
I = pd.Series(index=genes_exp.columns, dtype="float64")
|
|
222
|
+
for k in genes_exp.columns:
|
|
223
|
+
X_minus_mean = np.array(genes_exp[k] - np.mean(genes_exp[k]))
|
|
224
|
+
X_minus_mean = np.reshape(X_minus_mean,(len(X_minus_mean),1))
|
|
225
|
+
Nom = np.sum(np.multiply(W,np.matmul(X_minus_mean,X_minus_mean.T)))
|
|
226
|
+
Den = np.sum(np.multiply(X_minus_mean,X_minus_mean))
|
|
227
|
+
I[k] = (len(genes_exp[k])/np.sum(W))*(Nom/Den)
|
|
228
|
+
return I
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def Geary_C(genes_exp,x, y, k=5, knn=True):
|
|
232
|
+
XYmap=pd.DataFrame({"x": x, "y":y})
|
|
233
|
+
if knn:
|
|
234
|
+
XYnbrs = NearestNeighbors(n_neighbors=k, algorithm='auto',metric = 'euclidean').fit(XYmap)
|
|
235
|
+
XYdistances, XYindices = XYnbrs.kneighbors(XYmap)
|
|
236
|
+
W = np.zeros((genes_exp.shape[0],genes_exp.shape[0]))
|
|
237
|
+
for i in range(0,genes_exp.shape[0]):
|
|
238
|
+
W[i,XYindices[i,:]]=1
|
|
239
|
+
for i in range(0,genes_exp.shape[0]):
|
|
240
|
+
W[i,i]=0
|
|
241
|
+
else:
|
|
242
|
+
W=calculate_adj_matrix(x=x,y=y, histology=False)
|
|
243
|
+
C = pd.Series(index=genes_exp.columns, dtype="float64")
|
|
244
|
+
for k in genes_exp.columns:
|
|
245
|
+
X=np.array(genes_exp[k])
|
|
246
|
+
X_minus_mean = X - np.mean(X)
|
|
247
|
+
X_minus_mean = np.reshape(X_minus_mean,(len(X_minus_mean),1))
|
|
248
|
+
Xij=np.array([X,]*X.shape[0]).transpose()-np.array([X,]*X.shape[0])
|
|
249
|
+
Nom = np.sum(np.multiply(W,np.multiply(Xij,Xij)))
|
|
250
|
+
Den = np.sum(np.multiply(X_minus_mean,X_minus_mean))
|
|
251
|
+
C[k] = (len(genes_exp[k])/(2*np.sum(W)))*(Nom/Den)
|
|
252
|
+
return C
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import sklearn.neighbors
|
|
4
|
+
import scipy.sparse as sp
|
|
5
|
+
import seaborn as sns
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import scanpy as sc
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch_geometric.data import Data
|
|
11
|
+
|
|
12
|
+
def refine_spatial_cluster(adata, pred, shape="hexagon"):
|
|
13
|
+
G_df = adata.uns['Spatial_Net'].copy()
|
|
14
|
+
cells = np.array(adata.obs_names)
|
|
15
|
+
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
|
|
16
|
+
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
|
|
17
|
+
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
|
|
18
|
+
G = sp.coo_matrix((G_df['Distance'], (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
|
|
19
|
+
refined_pred=[]
|
|
20
|
+
pred=pd.DataFrame({"pred": pred})
|
|
21
|
+
pred.reset_index(inplace=True)
|
|
22
|
+
dis_df=pd.DataFrame(G.todense())
|
|
23
|
+
if shape=="hexagon":
|
|
24
|
+
num_nbs=6
|
|
25
|
+
elif shape=="square":
|
|
26
|
+
num_nbs=4
|
|
27
|
+
else:
|
|
28
|
+
print("Shape not recongized, shape='hexagon' for Visium data, 'square' for ST data.")
|
|
29
|
+
for i in range(pred.shape[0]):
|
|
30
|
+
dis_tmp=dis_df.iloc[i, :]
|
|
31
|
+
dis_tmp = dis_tmp[dis_tmp>0]
|
|
32
|
+
dis_tmp = dis_tmp.sort_values(ascending=True)
|
|
33
|
+
nbs=dis_tmp[0:num_nbs]
|
|
34
|
+
nbs_pred=pred.pred.iloc[nbs.index]
|
|
35
|
+
self_pred=pred.pred.iloc[i]
|
|
36
|
+
v_c=nbs_pred.value_counts()
|
|
37
|
+
if (v_c.loc[self_pred]<num_nbs/2) and (np.max(v_c)>num_nbs/2):
|
|
38
|
+
refined_pred.append(v_c.idxmax())
|
|
39
|
+
else:
|
|
40
|
+
refined_pred.append(self_pred)
|
|
41
|
+
return refined_pred
|
|
42
|
+
|
|
43
|
+
def plot_clustering(adata, colors, title = None, savepath = None):
|
|
44
|
+
adata.obs['x_pixel'] = adata.obsm['spatial'][:, 0]
|
|
45
|
+
adata.obs['y_pixel'] = adata.obsm['spatial'][:, 1]
|
|
46
|
+
|
|
47
|
+
fig = plt.figure(figsize=(5, 5))
|
|
48
|
+
ax1 = fig.add_subplot(111)
|
|
49
|
+
sc.pl.scatter(adata, alpha=1, x="x_pixel", y="y_pixel", color=colors, title=title,
|
|
50
|
+
palette=sns.color_palette('plasma', 7), show=False, ax=ax1)
|
|
51
|
+
|
|
52
|
+
ax1.set_aspect('equal', 'box')
|
|
53
|
+
ax1.axis('off')
|
|
54
|
+
ax1.axes.invert_yaxis()
|
|
55
|
+
if savepath is not None:
|
|
56
|
+
fig.savefig(savepath, bbox_inches='tight')
|
|
57
|
+
|
|
58
|
+
def Transfer_pytorch_Data(adata):
|
|
59
|
+
#Expression edge
|
|
60
|
+
G_df = adata.uns['Exp_Net'].copy()
|
|
61
|
+
cells = np.array(adata.obs_names)
|
|
62
|
+
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
|
|
63
|
+
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
|
|
64
|
+
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
|
|
65
|
+
G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
|
|
66
|
+
G = G + sp.eye(G.shape[0])
|
|
67
|
+
exp_edge = np.nonzero(G)
|
|
68
|
+
|
|
69
|
+
#Spatial edge
|
|
70
|
+
G_df = adata.uns['Spatial_Net'].copy()
|
|
71
|
+
cells = np.array(adata.obs_names)
|
|
72
|
+
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
|
|
73
|
+
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
|
|
74
|
+
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
|
|
75
|
+
G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
|
|
76
|
+
G = G + sp.eye(G.shape[0])
|
|
77
|
+
spatial_edge = np.nonzero(G)
|
|
78
|
+
|
|
79
|
+
data = Data(edge_index=torch.LongTensor(np.array(
|
|
80
|
+
[np.concatenate((exp_edge[0],spatial_edge[0])),
|
|
81
|
+
np.concatenate((exp_edge[1],spatial_edge[1]))])).contiguous(),
|
|
82
|
+
x=torch.FloatTensor(adata.obsm['X_pca'].copy())) # .todense()
|
|
83
|
+
edge_type = torch.zeros(exp_edge[0].shape[0]+spatial_edge[0].shape[0],dtype=torch.int64)
|
|
84
|
+
edge_type[exp_edge[0].shape[0]:] = 1
|
|
85
|
+
data.edge_type = edge_type
|
|
86
|
+
|
|
87
|
+
return data
|
|
88
|
+
|
|
89
|
+
def Batch_Data(adata, num_batch_x, num_batch_y, spatial_key=['X', 'Y'], plot_Stats=False):
|
|
90
|
+
Sp_df = adata.obs.loc[:, spatial_key].copy()
|
|
91
|
+
Sp_df = np.array(Sp_df)
|
|
92
|
+
batch_x_coor = [np.percentile(Sp_df[:, 0], (1/num_batch_x)*x*100) for x in range(num_batch_x+1)]
|
|
93
|
+
batch_y_coor = [np.percentile(Sp_df[:, 1], (1/num_batch_y)*x*100) for x in range(num_batch_y+1)]
|
|
94
|
+
|
|
95
|
+
Batch_list = []
|
|
96
|
+
for it_x in range(num_batch_x):
|
|
97
|
+
for it_y in range(num_batch_y):
|
|
98
|
+
min_x = batch_x_coor[it_x]
|
|
99
|
+
max_x = batch_x_coor[it_x+1]
|
|
100
|
+
min_y = batch_y_coor[it_y]
|
|
101
|
+
max_y = batch_y_coor[it_y+1]
|
|
102
|
+
temp_adata = adata.copy()
|
|
103
|
+
temp_adata = temp_adata[temp_adata.obs[spatial_key[0]].map(lambda x: min_x <= x <= max_x)]
|
|
104
|
+
temp_adata = temp_adata[temp_adata.obs[spatial_key[1]].map(lambda y: min_y <= y <= max_y)]
|
|
105
|
+
Batch_list.append(temp_adata)
|
|
106
|
+
if plot_Stats:
|
|
107
|
+
f, ax = plt.subplots(figsize=(1, 3))
|
|
108
|
+
plot_df = pd.DataFrame([x.shape[0] for x in Batch_list], columns=['#spot/batch'])
|
|
109
|
+
sns.boxplot(y='#spot/batch', data=plot_df, ax=ax)
|
|
110
|
+
sns.stripplot(y='#spot/batch', data=plot_df, ax=ax, color='red', size=5)
|
|
111
|
+
return Batch_list
|
|
112
|
+
|
|
113
|
+
def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=6, model='KNN', verbose=True):
|
|
114
|
+
"""\
|
|
115
|
+
Construct the spatial neighbor networks.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
adata
|
|
120
|
+
AnnData object of scanpy package.
|
|
121
|
+
rad_cutoff
|
|
122
|
+
radius cutoff when model='Radius'
|
|
123
|
+
k_cutoff
|
|
124
|
+
The number of nearest neighbors when model='KNN'
|
|
125
|
+
model
|
|
126
|
+
The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
The spatial networks are saved in adata.uns['Spatial_Net']
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
assert(model in ['Radius', 'KNN'])
|
|
134
|
+
if verbose:
|
|
135
|
+
print('------Calculating spatial graph...')
|
|
136
|
+
coor = pd.DataFrame(adata.obsm['spatial'])
|
|
137
|
+
coor.index = adata.obs.index
|
|
138
|
+
coor.columns = ['imagerow', 'imagecol']
|
|
139
|
+
|
|
140
|
+
if model == 'Radius':
|
|
141
|
+
nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
|
|
142
|
+
distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
|
|
143
|
+
KNN_list = []
|
|
144
|
+
for it in range(indices.shape[0]):
|
|
145
|
+
KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
|
|
146
|
+
|
|
147
|
+
if model == 'KNN':
|
|
148
|
+
nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
|
|
149
|
+
distances, indices = nbrs.kneighbors(coor)
|
|
150
|
+
KNN_list = []
|
|
151
|
+
for it in range(indices.shape[0]):
|
|
152
|
+
KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
|
|
153
|
+
|
|
154
|
+
KNN_df = pd.concat(KNN_list)
|
|
155
|
+
KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
|
|
156
|
+
|
|
157
|
+
Spatial_Net = KNN_df.copy()
|
|
158
|
+
Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
|
|
159
|
+
id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
|
|
160
|
+
Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
|
|
161
|
+
Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
|
|
162
|
+
if verbose:
|
|
163
|
+
print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
|
|
164
|
+
print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
|
|
165
|
+
|
|
166
|
+
adata.uns['Spatial_Net'] = Spatial_Net
|
|
167
|
+
|
|
168
|
+
def Cal_Expression_Net(adata, k_cutoff=6):
|
|
169
|
+
|
|
170
|
+
coor = pd.DataFrame(adata.obsm['X_pca'])
|
|
171
|
+
coor.index = adata.obs.index
|
|
172
|
+
nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
|
|
173
|
+
distances, indices = nbrs.kneighbors(coor)
|
|
174
|
+
KNN_list = []
|
|
175
|
+
for it in range(indices.shape[0]):
|
|
176
|
+
KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
|
|
177
|
+
KNN_df = pd.concat(KNN_list)
|
|
178
|
+
KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
|
|
179
|
+
|
|
180
|
+
exp_Net = KNN_df.copy()
|
|
181
|
+
exp_Net = exp_Net.loc[exp_Net['Distance']>0,]
|
|
182
|
+
|
|
183
|
+
id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
|
|
184
|
+
exp_Net['Cell1'] = exp_Net['Cell1'].map(id_cell_trans)
|
|
185
|
+
exp_Net['Cell2'] = exp_Net['Cell2'].map(id_cell_trans)
|
|
186
|
+
|
|
187
|
+
adata.uns['Exp_Net'] = exp_Net
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def res_search_fixed_clus(adata, fixed_clus_count, increment=0.02):
|
|
191
|
+
'''
|
|
192
|
+
arg1(adata)[AnnData matrix]
|
|
193
|
+
arg2(fixed_clus_count)[int]
|
|
194
|
+
|
|
195
|
+
return:
|
|
196
|
+
resolution[int]
|
|
197
|
+
'''
|
|
198
|
+
for res in np.arange(2.5, 0.0, -increment):
|
|
199
|
+
sc.tl.leiden(adata, random_state=0, resolution=res)
|
|
200
|
+
count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
|
|
201
|
+
if count_unique_leiden <= fixed_clus_count:
|
|
202
|
+
break
|
|
203
|
+
return res
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def Cal_Spatial_Net_3D(adata, rad_cutoff_2D, rad_cutoff_Zaxis,
|
|
207
|
+
key_section='Section_id', section_order=None, verbose=True):
|
|
208
|
+
"""\
|
|
209
|
+
Construct the spatial neighbor networks.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
adata
|
|
214
|
+
AnnData object of scanpy package.
|
|
215
|
+
rad_cutoff_2D
|
|
216
|
+
radius cutoff for 2D SNN construction.
|
|
217
|
+
rad_cutoff_Zaxis
|
|
218
|
+
radius cutoff for 2D SNN construction for consturcting SNNs between adjacent sections.
|
|
219
|
+
key_section
|
|
220
|
+
The columns names of section_ID in adata.obs.
|
|
221
|
+
section_order
|
|
222
|
+
The order of sections. The SNNs between adjacent sections are constructed according to this order.
|
|
223
|
+
|
|
224
|
+
Returns
|
|
225
|
+
-------
|
|
226
|
+
The 3D spatial networks are saved in adata.uns['Spatial_Net'].
|
|
227
|
+
"""
|
|
228
|
+
adata.uns['Spatial_Net_2D'] = pd.DataFrame()
|
|
229
|
+
adata.uns['Spatial_Net_Zaxis'] = pd.DataFrame()
|
|
230
|
+
num_section = np.unique(adata.obs[key_section]).shape[0]
|
|
231
|
+
if verbose:
|
|
232
|
+
print('Radius used for 2D SNN:', rad_cutoff_2D)
|
|
233
|
+
print('Radius used for SNN between sections:', rad_cutoff_Zaxis)
|
|
234
|
+
for temp_section in np.unique(adata.obs[key_section]):
|
|
235
|
+
if verbose:
|
|
236
|
+
print('------Calculating 2D SNN of section ', temp_section)
|
|
237
|
+
temp_adata = adata[adata.obs[key_section] == temp_section, ]
|
|
238
|
+
Cal_Spatial_Net(
|
|
239
|
+
temp_adata, rad_cutoff=rad_cutoff_2D, verbose=False)
|
|
240
|
+
temp_adata.uns['Spatial_Net']['SNN'] = temp_section
|
|
241
|
+
if verbose:
|
|
242
|
+
print('This graph contains %d edges, %d cells.' %
|
|
243
|
+
(temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs))
|
|
244
|
+
print('%.4f neighbors per cell on average.' %
|
|
245
|
+
(temp_adata.uns['Spatial_Net'].shape[0]/temp_adata.n_obs))
|
|
246
|
+
adata.uns['Spatial_Net_2D'] = pd.concat(
|
|
247
|
+
[adata.uns['Spatial_Net_2D'], temp_adata.uns['Spatial_Net']])
|
|
248
|
+
for it in range(num_section-1):
|
|
249
|
+
section_1 = section_order[it]
|
|
250
|
+
section_2 = section_order[it+1]
|
|
251
|
+
if verbose:
|
|
252
|
+
print('------Calculating SNN between adjacent section %s and %s.' %
|
|
253
|
+
(section_1, section_2))
|
|
254
|
+
Z_Net_ID = section_1+'-'+section_2
|
|
255
|
+
temp_adata = adata[adata.obs[key_section].isin(
|
|
256
|
+
[section_1, section_2]), ]
|
|
257
|
+
Cal_Spatial_Net(
|
|
258
|
+
temp_adata, rad_cutoff=rad_cutoff_Zaxis, verbose=False)
|
|
259
|
+
spot_section_trans = dict(
|
|
260
|
+
zip(temp_adata.obs.index, temp_adata.obs[key_section]))
|
|
261
|
+
temp_adata.uns['Spatial_Net']['Section_id_1'] = temp_adata.uns['Spatial_Net']['Cell1'].map(
|
|
262
|
+
spot_section_trans)
|
|
263
|
+
temp_adata.uns['Spatial_Net']['Section_id_2'] = temp_adata.uns['Spatial_Net']['Cell2'].map(
|
|
264
|
+
spot_section_trans)
|
|
265
|
+
used_edge = temp_adata.uns['Spatial_Net'].apply(
|
|
266
|
+
lambda x: x['Section_id_1'] != x['Section_id_2'], axis=1)
|
|
267
|
+
temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[used_edge, ]
|
|
268
|
+
temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[:, [
|
|
269
|
+
'Cell1', 'Cell2', 'Distance']]
|
|
270
|
+
temp_adata.uns['Spatial_Net']['SNN'] = Z_Net_ID
|
|
271
|
+
if verbose:
|
|
272
|
+
print('This graph contains %d edges, %d cells.' %
|
|
273
|
+
(temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs))
|
|
274
|
+
print('%.4f neighbors per cell on average.' %
|
|
275
|
+
(temp_adata.uns['Spatial_Net'].shape[0]/temp_adata.n_obs))
|
|
276
|
+
adata.uns['Spatial_Net_Zaxis'] = pd.concat(
|
|
277
|
+
[adata.uns['Spatial_Net_Zaxis'], temp_adata.uns['Spatial_Net']])
|
|
278
|
+
adata.uns['Spatial_Net'] = pd.concat(
|
|
279
|
+
[adata.uns['Spatial_Net_2D'], adata.uns['Spatial_Net_Zaxis']])
|
|
280
|
+
if verbose:
|
|
281
|
+
print('3D SNN contains %d edges, %d cells.' %
|
|
282
|
+
(adata.uns['Spatial_Net'].shape[0], adata.n_obs))
|
|
283
|
+
print('%.4f neighbors per cell on average.' %
|
|
284
|
+
(adata.uns['Spatial_Net'].shape[0]/adata.n_obs))
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def Cal_Expression_3D(adata, k_cutoff=6, key_section='Section_id', verbose=True):
|
|
288
|
+
|
|
289
|
+
adata.uns['Exp_Net'] = pd.DataFrame()
|
|
290
|
+
for temp_section in np.unique(adata.obs[key_section]):
|
|
291
|
+
if verbose:
|
|
292
|
+
print('------Calculating Expression Network of section ', temp_section)
|
|
293
|
+
temp_adata = adata[adata.obs[key_section] == temp_section, ].copy()
|
|
294
|
+
sc.pp.filter_genes(temp_adata, min_cells=5)
|
|
295
|
+
sc.pp.normalize_total(temp_adata, target_sum=1, exclude_highly_expressed=True)
|
|
296
|
+
sc.pp.scale(temp_adata)
|
|
297
|
+
sc.pp.pca(temp_adata, n_comps=100)
|
|
298
|
+
Cal_Expression_Net(
|
|
299
|
+
temp_adata, k_cutoff=k_cutoff)
|
|
300
|
+
temp_adata.uns['Exp_Net']['SNN'] = temp_section
|
|
301
|
+
adata.uns['Exp_Net'] = pd.concat(
|
|
302
|
+
[adata.uns['Exp_Net'], temp_adata.uns['Exp_Net']])
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def Stats_Spatial_Net(adata):
|
|
306
|
+
import matplotlib.pyplot as plt
|
|
307
|
+
Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
|
|
308
|
+
Mean_edge = Num_edge/adata.shape[0]
|
|
309
|
+
plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
|
|
310
|
+
plot_df = plot_df/adata.shape[0]
|
|
311
|
+
fig, ax = plt.subplots(figsize=[3,2])
|
|
312
|
+
plt.ylabel('Percentage')
|
|
313
|
+
plt.xlabel('')
|
|
314
|
+
plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)
|
|
315
|
+
ax.bar(plot_df.index, plot_df)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: RGAST
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Relational Graph Attention Network for Spatial Transcriptome Analysis
|
|
5
|
+
Home-page: https://github.com/GYQ-form/RGAST
|
|
6
|
+
Author: Yuqiao Gong
|
|
7
|
+
Author-email: gyq123@sjtu.edu.cn
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: spatial transcriptomic,RGAT,representation learning,spatial domain identification
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
|
|
16
|
+
# RGAST
|
|
17
|
+
|
|
18
|
+
RGAST: Relational Graph Attention Network for Spatial Transcriptome Analysis
|
|
19
|
+
|
|
20
|
+
This document will help you easily go through the scBC model.
|
|
21
|
+
|
|
22
|
+

|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
To install our package, run
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install RGAST
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## Usage
|
|
36
|
+
|
|
37
|
+
RGAST (Relational Graph Attention network for Spatial Transcriptome analysis) constructs a relational graph attention network to learn the representation of each spot in the spatial transcriptome data. Plus the attention mechanism, RGAST considers both gene expression similarity and spatial neighbor relationships in constructing the graph network, enabling a more comprehensive and flexible representation of the spatial transcriptome data. RGAST can be used in many ST analysis:
|
|
38
|
+
|
|
39
|
+
- spatial domain identification
|
|
40
|
+
- cell trajectory inference
|
|
41
|
+
- spatially variable gene (SVG) detection
|
|
42
|
+
- uncover spatially resolved cell-cell interactions
|
|
43
|
+
- reveal intricate 3D spatial patterns across multiple sections of ST data
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
## Tutorial
|
|
48
|
+
|
|
49
|
+
We have prepared several basic tutorials in https://github.com/GYQ-form/RGAST/tree/main/tutorial. You can quickly hands on RGAST by going through these tutorials. Model parameters trained in our study are also released in https://github.com/GYQ-form/RGAST/tree/main/model_path.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
setup.py
|
|
4
|
+
RGAST/RGAST.py
|
|
5
|
+
RGAST/Train_RGAST.py
|
|
6
|
+
RGAST/__init__.py
|
|
7
|
+
RGAST/svg.py
|
|
8
|
+
RGAST/utils.py
|
|
9
|
+
RGAST.egg-info/PKG-INFO
|
|
10
|
+
RGAST.egg-info/SOURCES.txt
|
|
11
|
+
RGAST.egg-info/dependency_links.txt
|
|
12
|
+
RGAST.egg-info/requires.txt
|
|
13
|
+
RGAST.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
RGAST
|
RGAST-0.0.1/setup.cfg
ADDED
RGAST-0.0.1/setup.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
__version__ = "0.0.1"
|
|
4
|
+
|
|
5
|
+
with open("README.md", "r", encoding='utf-8') as fh:
|
|
6
|
+
long_description = fh.read()
|
|
7
|
+
|
|
8
|
+
setup(
|
|
9
|
+
name="RGAST",
|
|
10
|
+
version=__version__,
|
|
11
|
+
packages=find_packages(),
|
|
12
|
+
classifiers=[
|
|
13
|
+
"Programming Language :: Python :: 3",
|
|
14
|
+
"License :: OSI Approved :: MIT License",
|
|
15
|
+
"Operating System :: OS Independent",
|
|
16
|
+
],
|
|
17
|
+
include_package_data=True,
|
|
18
|
+
install_requires=[
|
|
19
|
+
'torch',
|
|
20
|
+
'scanpy',
|
|
21
|
+
'sklearn',
|
|
22
|
+
'torch_geometric',
|
|
23
|
+
'scipy',
|
|
24
|
+
'numba',
|
|
25
|
+
],
|
|
26
|
+
author="Yuqiao Gong",
|
|
27
|
+
author_email="gyq123@sjtu.edu.cn",
|
|
28
|
+
keywords=["spatial transcriptomic", "RGAT", "representation learning", "spatial domain identification"],
|
|
29
|
+
description="Relational Graph Attention Network for Spatial Transcriptome Analysis",
|
|
30
|
+
license="MIT",
|
|
31
|
+
url='https://github.com/GYQ-form/RGAST',
|
|
32
|
+
long_description_content_type='text/markdown',
|
|
33
|
+
long_description=long_description
|
|
34
|
+
)
|