RGAST 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
RGAST-0.0.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Yuqiqo Gong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
RGAST-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.1
2
+ Name: RGAST
3
+ Version: 0.0.1
4
+ Summary: Relational Graph Attention Network for Spatial Transcriptome Analysis
5
+ Home-page: https://github.com/GYQ-form/RGAST
6
+ Author: Yuqiao Gong
7
+ Author-email: gyq123@sjtu.edu.cn
8
+ License: MIT
9
+ Keywords: spatial transcriptomic,RGAT,representation learning,spatial domain identification
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+
16
+ # RGAST
17
+
18
+ RGAST: Relational Graph Attention Network for Spatial Transcriptome Analysis
19
+
20
+ This document will help you easily go through the scBC model.
21
+
22
+ ![fig1_00](https://github.com/GYQ-form/RGAST/assets/79566479/fe0655dc-2318-44e0-92bf-0aea3aad7163)
23
+
24
+
25
+ ## Installation
26
+
27
+ To install our package, run
28
+
29
+ ```bash
30
+ pip install RGAST
31
+ ```
32
+
33
+
34
+
35
+ ## Usage
36
+
37
+ RGAST (Relational Graph Attention network for Spatial Transcriptome analysis) constructs a relational graph attention network to learn the representation of each spot in the spatial transcriptome data. Plus the attention mechanism, RGAST considers both gene expression similarity and spatial neighbor relationships in constructing the graph network, enabling a more comprehensive and flexible representation of the spatial transcriptome data. RGAST can be used in many ST analysis:
38
+
39
+ - spatial domain identification
40
+ - cell trajectory inference
41
+ - spatially variable gene (SVG) detection
42
+ - uncover spatially resolved cell-cell interactions
43
+ - reveal intricate 3D spatial patterns across multiple sections of ST data
44
+
45
+
46
+
47
+ ## Tutorial
48
+
49
+ We have prepared several basic tutorials in https://github.com/GYQ-form/RGAST/tree/main/tutorial. You can quickly hands on RGAST by going through these tutorials. Model parameters trained in our study are also released in https://github.com/GYQ-form/RGAST/tree/main/model_path.
RGAST-0.0.1/README.md ADDED
@@ -0,0 +1,34 @@
1
+ # RGAST
2
+
3
+ RGAST: Relational Graph Attention Network for Spatial Transcriptome Analysis
4
+
5
+ This document will help you easily go through the scBC model.
6
+
7
+ ![fig1_00](https://github.com/GYQ-form/RGAST/assets/79566479/fe0655dc-2318-44e0-92bf-0aea3aad7163)
8
+
9
+
10
+ ## Installation
11
+
12
+ To install our package, run
13
+
14
+ ```bash
15
+ pip install RGAST
16
+ ```
17
+
18
+
19
+
20
+ ## Usage
21
+
22
+ RGAST (Relational Graph Attention network for Spatial Transcriptome analysis) constructs a relational graph attention network to learn the representation of each spot in the spatial transcriptome data. Plus the attention mechanism, RGAST considers both gene expression similarity and spatial neighbor relationships in constructing the graph network, enabling a more comprehensive and flexible representation of the spatial transcriptome data. RGAST can be used in many ST analysis:
23
+
24
+ - spatial domain identification
25
+ - cell trajectory inference
26
+ - spatially variable gene (SVG) detection
27
+ - uncover spatially resolved cell-cell interactions
28
+ - reveal intricate 3D spatial patterns across multiple sections of ST data
29
+
30
+
31
+
32
+ ## Tutorial
33
+
34
+ We have prepared several basic tutorials in https://github.com/GYQ-form/RGAST/tree/main/tutorial. You can quickly hands on RGAST by going through these tutorials. Model parameters trained in our study are also released in https://github.com/GYQ-form/RGAST/tree/main/model_path.
@@ -0,0 +1,26 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.backends.cudnn as cudnn
5
+ cudnn.deterministic = True
6
+ cudnn.benchmark = True
7
+ import torch.nn.functional as F
8
+ from torch_geometric.nn.conv.rgat_conv import RGATConv
9
+
10
+ class RGAST(torch.nn.Module):
11
+ def __init__(self, hidden_dims):
12
+ super(RGAST, self).__init__()
13
+ [in_dim, num_hidden, out_dim] = hidden_dims
14
+ self.conv1 = RGATConv(in_dim, num_hidden, num_relations=2, heads=1, concat=False,
15
+ dropout=0.3, add_self_loops=False, bias=False)
16
+ self.conv2 = RGATConv(num_hidden, out_dim, num_relations=2, heads=1, concat=False,
17
+ dropout=0.3, add_self_loops=False, bias=False)
18
+ self.decoder = nn.Sequential(
19
+ nn.Linear(out_dim, in_dim),
20
+ )
21
+
22
+ def forward(self, features, edge_index, edge_type):
23
+ h1 = F.elu(self.conv1(features, edge_index, edge_type))
24
+ h2 = F.elu(self.conv2(h1, edge_index, edge_type))
25
+ h3 = self.decoder(h2)
26
+ return h2, h3
@@ -0,0 +1,387 @@
1
+ import numpy as np
2
+ import os
3
+ import scanpy as sc
4
+ import anndata
5
+ from sklearn.metrics.cluster import adjusted_rand_score
6
+ from sklearn.metrics import silhouette_score
7
+ from tqdm import tqdm
8
+
9
+ from .RGAST import RGAST
10
+ from .utils import Transfer_pytorch_Data, res_search_fixed_clus, Batch_Data, Cal_Spatial_Net, Cal_Expression_Net
11
+
12
+ import torch
13
+ import torch.backends.cudnn as cudnn
14
+ cudnn.deterministic = True
15
+ cudnn.benchmark = True
16
+ import torch.nn.functional as F
17
+ from torch_geometric.loader import DataLoader
18
+
19
+ def target_distribution(batch):
20
+ weight = (batch ** 2) / torch.sum(batch, 0)
21
+ return (weight.t() / torch.sum(weight, 1)).t()
22
+
23
+ class Train_RGAST:
24
+
25
+ def __init__(self, adata, batch_data = False, num_batch_x_y = None, spatial_net_arg = {}, exp_net_arg = {}, verbose=True):
26
+
27
+ """\
28
+ Initialization of a RGAST trainer.
29
+
30
+ Parameters
31
+ ----------
32
+ adata
33
+ AnnData object of scanpy package.
34
+ num_batch_x_y
35
+ A tuple specifying the number of points at which to segment the spatially transcribed image on the x and y axes.
36
+ Each split is then trained as a batch. This is useful for large scale cases.
37
+ spatial_net_arg
38
+ A dict passing key-word arguments to calculating spatial network in each batch data. See `Cal_Spatial_Net`.
39
+ exp_net_arg
40
+ A dict passing key-word arguments to calculating expression network in each batch data. See `Cal_Expression_Net`
41
+ """
42
+
43
+ if 'X_pca' not in adata.obsm.keys():
44
+ raise ValueError("PCA has not been done! Run sc.pp.pca first!")
45
+ if verbose:
46
+ print('Size of Input: ', adata.obsm['X_pca'].shape)
47
+
48
+ self.batch_data = batch_data
49
+ self.adata = adata
50
+ if 'Spatial_Net' not in adata.uns.keys():
51
+ raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!")
52
+ if 'Exp_Net' not in adata.uns.keys():
53
+ raise ValueError("Exp_Net is not existed! Run Cal_Expression_Net first!")
54
+ self.data = Transfer_pytorch_Data(adata)
55
+
56
+ if batch_data:
57
+ self.num_batch_x, self.num_batch_y = num_batch_x_y
58
+ adata.obs['X'] = adata.obsm['spatial'][:,0]
59
+ adata.obs['Y'] = adata.obsm['spatial'][:,1]
60
+ Batch_list = Batch_Data(adata, num_batch_x=self.num_batch_x, num_batch_y=self.num_batch_y,
61
+ spatial_key=['X', 'Y'])
62
+ for temp_adata in Batch_list:
63
+ Cal_Spatial_Net(temp_adata, **spatial_net_arg)
64
+ Cal_Expression_Net(temp_adata, **exp_net_arg)
65
+ data_list = [Transfer_pytorch_Data(adata) for adata in Batch_list]
66
+ self.loader = DataLoader(data_list, batch_size=1, shuffle=True)
67
+
68
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
69
+ self.model = None
70
+
71
+
72
+ def train_RGAST(self, early_stopping = True, label_key = None, save_path = '.', n_clusters = 7,
73
+ hidden_dims=[100, 32], n_epochs=1000, lr=0.001, key_added='RGAST',
74
+ gradient_clipping=5., weight_decay=0.0001, verbose=True,
75
+ random_seed=0, save_loss=False, save_reconstrction=False):
76
+
77
+ """\
78
+ Training graph attention auto-encoder.
79
+
80
+ Parameters
81
+ ----------
82
+ early_stopping
83
+ Using early stopping strategy or not. Default = True.
84
+ lable_key
85
+ A key specify the specific column in adata.obs to be treated as reference label.
86
+ save_path
87
+ directory to save the trained RGAST model.
88
+ n_clusters
89
+ number of clusters to set when calculating early stopping criterion.
90
+ hidden_dims
91
+ The dimension of the encoder.
92
+ n_epochs
93
+ Number of total epochs in training.
94
+ lr
95
+ Learning rate for AdamOptimizer.
96
+ key_added
97
+ The latent embeddings are saved in adata.obsm[key_added].
98
+ gradient_clipping
99
+ Gradient Clipping.
100
+ weight_decay
101
+ Weight decay for AdamOptimizer.
102
+ save_loss
103
+ If True, the training loss is saved in adata.uns['RGAST_loss'].
104
+ save_reconstrction
105
+ If True, the reconstructed expression profiles are saved in adata.layers['RGAST_ReX'].
106
+ device
107
+ See torch.device.
108
+
109
+ Returns
110
+ -------
111
+ AnnData
112
+ """
113
+ self.save_path = save_path
114
+ self.label_key = label_key
115
+ self.n_clusters = n_clusters
116
+
117
+ # seed_everything()
118
+ seed=random_seed
119
+ import random
120
+ random.seed(seed)
121
+ torch.manual_seed(seed)
122
+ torch.cuda.manual_seed_all(seed)
123
+ np.random.seed(seed)
124
+
125
+ if self.model is None:
126
+ model = RGAST(hidden_dims = [self.data.x.shape[1]] + hidden_dims).to(self.device)
127
+ else:
128
+ model = self.model.to(self.device)
129
+
130
+ data = self.data.to(self.device)
131
+
132
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
133
+
134
+ loss_list = []
135
+ score_list = [0]
136
+ num_fail = 0
137
+ for epoch in tqdm(range(1, n_epochs+1)):
138
+
139
+ if early_stopping:
140
+
141
+ if label_key is not None:
142
+ if epoch % 50 == 0:
143
+ if self.batch_data:
144
+ model.to('cpu')
145
+ model.eval()
146
+ z, _ = model(data.x.cpu(), data.edge_index.cpu(), data.edge_type.cpu())
147
+ model.to(self.device)
148
+ else:
149
+ model.eval()
150
+ z, _ = model(data.x, data.edge_index, data.edge_type)
151
+ z = z.to('cpu').detach().numpy()
152
+ adata_RGAST = anndata.AnnData(z)
153
+ adata_RGAST.obs_names=self.adata.obs_names
154
+ sc.pp.neighbors(adata_RGAST)
155
+ sc.tl.umap(adata_RGAST)
156
+ _ = res_search_fixed_clus(adata_RGAST, n_clusters)
157
+ obs_df = adata_RGAST.obs.join(self.adata.obs[label_key]).dropna(subset=label_key)
158
+ ARI = adjusted_rand_score(obs_df['leiden'], obs_df[label_key])
159
+ if verbose:
160
+ print(f'epoch:{epoch},ARI:{ARI}')
161
+ if ARI <= max(score_list):
162
+ num_fail += 1
163
+ if num_fail>3 and epoch>=300:
164
+ break
165
+ else:
166
+ num_fail = 0
167
+ torch.save(model,f'{save_path}/model.pth')
168
+ self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
169
+ score_list.append(ARI)
170
+
171
+ else:
172
+ if epoch % 50 == 0:
173
+ if self.batch_data:
174
+ model.to('cpu')
175
+ model.eval()
176
+ z, _ = model(data.x.cpu(), data.edge_index.cpu(), data.edge_type.cpu())
177
+ model.to(self.device)
178
+ else:
179
+ model.eval()
180
+ z, _ = model(data.x, data.edge_index, data.edge_type)
181
+ z = z.to('cpu').detach().numpy()
182
+ adata_RGAST = anndata.AnnData(z)
183
+ adata_RGAST.obs_names=self.adata.obs_names
184
+ sc.pp.neighbors(adata_RGAST)
185
+ sc.tl.umap(adata_RGAST)
186
+ _ = res_search_fixed_clus(adata_RGAST, n_clusters)
187
+ SC = silhouette_score(z, adata_RGAST.obs['leiden'])
188
+ if verbose:
189
+ print(f'epoch:{epoch},SC:{SC}')
190
+ if SC <= max(score_list):
191
+ num_fail += 1
192
+ if num_fail>3 and epoch>=300:
193
+ break
194
+ else:
195
+ num_fail = 0
196
+ torch.save(model,f'{save_path}/model.pth')
197
+ self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
198
+ score_list.append(SC)
199
+
200
+ if self.batch_data:
201
+ for batch in self.loader:
202
+ batch = batch.to(self.device)
203
+ model.train()
204
+ optimizer.zero_grad()
205
+ z, out = model(batch.x, batch.edge_index, batch.edge_type)
206
+ loss = F.mse_loss(batch.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
207
+ loss_list.append(loss)
208
+ loss.backward()
209
+ torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
210
+ optimizer.step()
211
+
212
+ else:
213
+ model.train()
214
+ optimizer.zero_grad()
215
+ z, out = model(data.x, data.edge_index, data.edge_type)
216
+ loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
217
+ loss_list.append(loss)
218
+ loss.backward()
219
+ torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
220
+ optimizer.step()
221
+
222
+ if os.path.exists(f'{save_path}/model.pth'):
223
+ model = torch.load(f'{save_path}/model.pth').to(self.device)
224
+
225
+ if self.batch_data:
226
+ model.to('cpu')
227
+ model.eval()
228
+ z, out = model(data.x.cpu(), data.edge_index.cpu(), data.edge_type.cpu())
229
+ model.to(self.device)
230
+ else:
231
+ model.eval()
232
+ z, out = model(data.x, data.edge_index, data.edge_type)
233
+
234
+ RGAST_rep = z.to('cpu').detach().numpy()
235
+ np.save(f'{save_path}/RGAST_embedding.npy', RGAST_rep)
236
+ self.adata.obsm[key_added] = RGAST_rep
237
+
238
+ if save_loss:
239
+ self.adata.uns['RGAST_loss'] = loss
240
+ if save_reconstrction:
241
+ ReX = out.to('cpu').detach().numpy()
242
+ self.adata.layers['RGAST_ReX'] = ReX
243
+
244
+ self.model = model
245
+
246
+
247
+ def train_with_dec(self, verbose = True, early_stopping = True, key_added='RGAST', num_epochs=1000, dec_interval=50, dec_tol=0.01):
248
+
249
+ """\
250
+ Training graph attention auto-encoder with deep embedding clustering.
251
+ Only call this after call Train_RGAST.train_RGAST() and make sure batch_data = False.
252
+
253
+ Parameters
254
+ ----------
255
+ early_stopping
256
+ Using early stopping strategy or not. Default = True.
257
+ key_added
258
+ The latent embeddings are saved in adata.obsm[key_added].
259
+ num_epochs
260
+ Number of total epochs in training.
261
+ dec_interval
262
+ Evaluate after how many epochs (for early stopping).
263
+ dec_tol
264
+ DEC tol.
265
+
266
+ Returns
267
+ -------
268
+ AnnData with updated .obsm[key_added]
269
+ """
270
+
271
+ # initialize cluster parameter
272
+ model = self.model.to(self.device)
273
+ model.eval()
274
+ test_z = self.adata.obsm['RGAST']
275
+ y_pred_last = np.array(self.adata.obs['leiden'],dtype=np.int32).copy()
276
+ counts = len(np.bincount(y_pred_last))
277
+ cluster_layer = []
278
+ for i in range(counts):
279
+ cluster_layer.append(np.mean(test_z[y_pred_last==i,],axis=0))
280
+ cluster_layer = torch.tensor(cluster_layer).to(self.device)
281
+ data = self.data.to(self.device)
282
+
283
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
284
+
285
+ score_list = [0]
286
+ num_fail = 0
287
+ for epoch_id in tqdm(range(num_epochs)):
288
+
289
+ if early_stopping:
290
+
291
+ if epoch_id % dec_interval == 0:
292
+ #early stopping
293
+ if self.label_key is not None:
294
+ model.eval()
295
+ z, _ = model(data.x, data.edge_index, data.edge_type)
296
+ z = z.to('cpu').detach().numpy()
297
+ adata_RGAST = anndata.AnnData(z)
298
+ adata_RGAST.obs_names=self.adata.obs_names
299
+ sc.pp.neighbors(adata_RGAST)
300
+ sc.tl.umap(adata_RGAST)
301
+ _ = res_search_fixed_clus(adata_RGAST, self.n_clusters)
302
+ obs_df = adata_RGAST.obs.join(self.adata.obs[self.label_key]).dropna(subset=self.label_key)
303
+ ARI = adjusted_rand_score(obs_df['leiden'], obs_df[self.label_key])
304
+ if verbose:
305
+ print(f'epoch:{epoch_id},ARI:{ARI}')
306
+ if ARI <= max(score_list):
307
+ num_fail += 1
308
+ if num_fail>3 and epoch_id>=300:
309
+ break
310
+ else:
311
+ num_fail = 0
312
+ torch.save(model,f'{self.save_path}/model.pth')
313
+ self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
314
+ score_list.append(ARI)
315
+
316
+ else:
317
+ model.eval()
318
+ z, _ = model(data.x, data.edge_index, data.edge_type)
319
+ z = z.to('cpu').detach().numpy()
320
+ adata_RGAST = anndata.AnnData(z)
321
+ adata_RGAST.obs_names=self.adata.obs_names
322
+ sc.pp.neighbors(adata_RGAST)
323
+ sc.tl.umap(adata_RGAST)
324
+ _ = res_search_fixed_clus(adata_RGAST, self.n_clusters)
325
+ SC = silhouette_score(z, adata_RGAST.obs['leiden'])
326
+ if verbose:
327
+ print(f'epoch:{epoch_id},SC:{SC}')
328
+ if SC <= max(score_list):
329
+ num_fail += 1
330
+ if num_fail>3 and epoch_id>=300:
331
+ break
332
+ else:
333
+ num_fail = 0
334
+ torch.save(model,f'{self.save_path}/model.pth')
335
+ self.adata.obs['leiden'] = adata_RGAST.obs['leiden']
336
+ score_list.append(SC)
337
+
338
+ #DEC update
339
+ z, reconst = model(data.x, data.edge_index, data.edge_type)
340
+ q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - cluster_layer, 2), 2))
341
+ q = (q.t() / torch.sum(q, 1)).t()
342
+ tmp_p = target_distribution(torch.Tensor(q))
343
+ y_pred = tmp_p.cpu().detach().numpy().argmax(1)
344
+ delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
345
+ y_pred_last = np.copy(y_pred)
346
+ if epoch_id > 0 and delta_label < dec_tol:
347
+ print('delta_label {:.4}'.format(delta_label), '< tol', dec_tol)
348
+ print('Reached tolerance threshold. Stopping training.')
349
+ break
350
+
351
+
352
+ # training model
353
+ model.train()
354
+ optimizer.zero_grad()
355
+ z, reconst = model(data.x, data.edge_index, data.edge_type)
356
+ q = 1.0 / (1.0 + torch.sum(torch.pow(z.unsqueeze(1) - cluster_layer, 2), 2) / 1.0)
357
+ q = (q.t() / torch.sum(q, 1)).t()
358
+ loss_rec = F.mse_loss(data.x, reconst)
359
+ # clustering KL loss
360
+ loss_kl = F.kl_div(q.log(), torch.tensor(tmp_p).to(self.device)).to(self.device)
361
+ loss = loss_kl + loss_rec
362
+ loss.backward()
363
+ optimizer.step()
364
+
365
+ model = torch.load(f'{self.save_path}/model.pth').to(self.device)
366
+ model.eval()
367
+ z, _ = model(data.x, data.edge_index, data.edge_type)
368
+
369
+ RGAST_rep = z.to('cpu').detach().numpy()
370
+ np.save(f'{self.save_path}/RGAST_embedding.npy', RGAST_rep)
371
+ self.adata.obsm[key_added] = RGAST_rep
372
+ self.model = model
373
+
374
+ def load_model(self, path):
375
+ self.model = torch.load(path)
376
+
377
+ def save_model(self, path):
378
+ torch.save(self.model,f'{path}/model.pth')
379
+
380
+ def process(self, gdata = None):
381
+ if gdata is None:
382
+ gdata = self.data
383
+ self.model.to(self.device)
384
+ self.model.eval()
385
+ gdata = gdata.to(self.device)
386
+ return self.model(gdata.x, gdata.edge_index, gdata.edge_type)
387
+
@@ -0,0 +1,13 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ # Author: Yuqiao Gong
4
+ # File Name: __init__.py
5
+ # Description:
6
+ """
7
+
8
+ __author__ = "Yuqiao Gong"
9
+ __email__ = "gyq123@sjtu.edu.cn"
10
+
11
+ from .RGAST import RGAST
12
+ from .Train_RGAST import Train_RGAST
13
+ from .utils import Transfer_pytorch_Data, Cal_Spatial_Net, Cal_Expression_Net, Stats_Spatial_Net, Cal_Spatial_Net_3D, Batch_Data, plot_clustering, refine_spatial_cluster, Cal_Expression_3D
@@ -0,0 +1,252 @@
1
+ # code is modified from https://github.com/jianhuupenn/SpaGCN
2
+ import scanpy as sc
3
+ import pandas as pd
4
+ import numpy as np
5
+ import scipy
6
+ from scipy.sparse import issparse
7
+ import numba
8
+ from sklearn.neighbors import NearestNeighbors
9
+
10
+
11
+ @numba.njit("f4(f4[:], f4[:])")
12
+ def euclid_dist(t1,t2):
13
+ sum=0
14
+ for i in range(t1.shape[0]):
15
+ sum+=(t1[i]-t2[i])**2
16
+ return np.sqrt(sum)
17
+
18
+
19
+ @numba.njit("f4[:,:](f4[:,:])", parallel=True, nogil=True)
20
+ def pairwise_distance(X):
21
+ n=X.shape[0]
22
+ adj=np.empty((n, n), dtype=np.float32)
23
+ for i in numba.prange(n):
24
+ for j in numba.prange(n):
25
+ adj[i][j]=euclid_dist(X[i], X[j])
26
+ return adj
27
+
28
+
29
+ def calculate_adj_matrix(x, y):
30
+
31
+ X=np.array([x, y]).T.astype(np.float32)
32
+ return pairwise_distance(X)
33
+
34
+
35
+ def count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, radius):
36
+ # adj_2d=calculate_adj_matrix(x=x,y=y, histology=False)
37
+ cluster_num = dict()
38
+ df = {'cell_id': cell_id, 'x': x, "y":y, "pred":pred}
39
+ df = pd.DataFrame(data=df)
40
+ df.index=df['cell_id']
41
+ target_df=df[df["pred"]==target_cluster]
42
+ row_index=0
43
+ num_nbr=[]
44
+ for index, row in target_df.iterrows():
45
+ x=row["x"]
46
+ y=row["y"]
47
+ tmp_nbr=df[((df["x"]-x)**2+(df["y"]-y)**2)<=(radius**2)]
48
+ num_nbr.append(tmp_nbr.shape[0])
49
+ return np.mean(num_nbr)
50
+
51
+
52
+ def search_radius(target_cluster,cell_id, x, y, pred, adj_2d, start, end, num_min=8, num_max=15, max_run=100):
53
+ run=0
54
+ num_low=count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, start)
55
+ num_high=count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, end)
56
+ if num_min<=num_low<=num_max:
57
+ print("recommended radius = ", str(start))
58
+ return start
59
+ elif num_min<=num_high<=num_max:
60
+ print("recommended radius = ", str(end))
61
+ return end
62
+ elif num_low>num_max:
63
+ print("Try smaller start.")
64
+ return None
65
+ elif num_high<num_min:
66
+ print("Try bigger end.")
67
+ return None
68
+ while (num_low<num_min) and (num_high>num_min):
69
+ run+=1
70
+ print("Run "+str(run)+": radius ["+str(start)+", "+str(end)+"], num_nbr ["+str(num_low)+", "+str(num_high)+"]")
71
+ if run >max_run:
72
+ print("Exact radius not found, closest values are:\n"+"radius="+str(start)+": "+"num_nbr="+str(num_low)+"\nradius="+str(end)+": "+"num_nbr="+str(num_high))
73
+ return mid
74
+ mid=(start+end)/2
75
+ num_mid=count_nbr(target_cluster,cell_id, x, y, pred, adj_2d, mid)
76
+ if num_min<=num_mid<=num_max:
77
+ print("recommended radius = ", str(mid), "num_nbr="+str(num_mid))
78
+ return mid
79
+ if num_mid<num_min:
80
+ start=mid
81
+ num_low=num_mid
82
+ elif num_mid>num_max:
83
+ end=mid
84
+ num_high=num_mid
85
+
86
+
87
+ def rank_genes_groups(input_adata, target_cluster,nbr_list, label_col, adj_nbr=True, log=False):
88
+ if adj_nbr:
89
+ nbr_list=nbr_list+[target_cluster]
90
+ adata=input_adata[input_adata.obs[label_col].isin(nbr_list)]
91
+ else:
92
+ adata=input_adata.copy()
93
+ adata.var_names_make_unique()
94
+ adata.obs["target"]=((adata.obs[label_col]==target_cluster)*1).astype('category')
95
+ sc.tl.rank_genes_groups(adata, use_raw=False, groupby="target",reference="rest", n_genes=adata.shape[1],method='wilcoxon')
96
+ pvals_adj=[i[0] for i in adata.uns['rank_genes_groups']["pvals_adj"]]
97
+ genes=[i[1] for i in adata.uns['rank_genes_groups']["names"]]
98
+ if issparse(adata.X):
99
+ obs_tidy=pd.DataFrame(adata.X.A)
100
+ else:
101
+ obs_tidy=pd.DataFrame(adata.X)
102
+ obs_tidy.index=adata.obs["target"].tolist()
103
+ obs_tidy.columns=adata.var.index.tolist()
104
+ obs_tidy=obs_tidy.loc[:,genes]
105
+ # 1. compute mean value
106
+ mean_obs = obs_tidy.groupby(level=0).mean()
107
+ # 2. compute fraction of cells having value >0
108
+ obs_bool = obs_tidy.astype(bool)
109
+ fraction_obs = obs_bool.groupby(level=0).sum() / obs_bool.groupby(level=0).count()
110
+ # compute fold change.
111
+ if log: #The adata already logged
112
+ fold_change=np.exp((mean_obs.loc[1] - mean_obs.loc[0]).values)
113
+ else:
114
+ fold_change = (mean_obs.loc[1] / (mean_obs.loc[0]+ 1e-9)).values
115
+ df = {'genes': genes, 'in_group_fraction': fraction_obs.loc[1].tolist(), "out_group_fraction":fraction_obs.loc[0].tolist(),"in_out_group_ratio":(fraction_obs.loc[1]/fraction_obs.loc[0]).tolist(),"in_group_mean_exp": mean_obs.loc[1].tolist(), "out_group_mean_exp": mean_obs.loc[0].tolist(),"fold_change":fold_change.tolist(), "pvals_adj":pvals_adj}
116
+ df = pd.DataFrame(data=df)
117
+ return df
118
+
119
+
120
+ def find_neighbor_clusters(target_cluster,cell_id, x, y, pred,radius, ratio=1/2):
121
+ cluster_num = dict()
122
+ for i in pred:
123
+ cluster_num[i] = cluster_num.get(i, 0) + 1
124
+ df = {'cell_id': cell_id, 'x': x, "y":y, "pred":pred}
125
+ df = pd.DataFrame(data=df)
126
+ df.index=df['cell_id']
127
+ target_df=df[df["pred"]==target_cluster]
128
+ nbr_num={}
129
+ row_index=0
130
+ num_nbr=[]
131
+ for index, row in target_df.iterrows():
132
+ x=row["x"]
133
+ y=row["y"]
134
+ tmp_nbr=df[((df["x"]-x)**2+(df["y"]-y)**2)<=(radius**2)]
135
+ #tmp_nbr=df[(df["x"]<x+radius) & (df["x"]>x-radius) & (df["y"]<y+radius) & (df["y"]>y-radius)]
136
+ num_nbr.append(tmp_nbr.shape[0])
137
+ for p in tmp_nbr["pred"]:
138
+ nbr_num[p]=nbr_num.get(p,0)+1
139
+ del nbr_num[target_cluster]
140
+ nbr_num=[(k, v) for k, v in nbr_num.items() if v>(ratio*cluster_num[k])]
141
+ nbr_num.sort(key=lambda x: -x[1])
142
+ print("radius=", radius, "average number of neighbors for each spot is", np.mean(num_nbr))
143
+ print(" Cluster",target_cluster, "has neighbors:")
144
+ for t in nbr_num:
145
+ print("Dmain ", t[0], ": ",t[1])
146
+ ret=[t[0] for t in nbr_num]
147
+ if len(ret)==0:
148
+ print("No neighbor domain found, try bigger radius or smaller ratio.")
149
+ else:
150
+ return ret
151
+
152
+
153
+ def find_meta_gene(input_adata,
154
+ pred,
155
+ target_domain,
156
+ start_gene,
157
+ mean_diff=0,
158
+ early_stop=True,
159
+ max_iter=5):
160
+ meta_name=start_gene
161
+ adata=input_adata.copy()
162
+ adata.obs["meta"]=adata.X.A[:,adata.var.index==start_gene]
163
+ adata.obs["pred"]=pred
164
+ num_non_target=adata.shape[0]
165
+ for i in range(max_iter):
166
+ #Select cells
167
+ tmp=adata[((adata.obs["meta"]>np.mean(adata.obs[adata.obs["pred"]==target_domain]["meta"]))|(adata.obs["pred"]==target_domain))]
168
+ tmp.obs["target"]=((tmp.obs["pred"]==target_domain)*1).astype('category').copy()
169
+ if (len(set(tmp.obs["target"]))<2) or (np.min(tmp.obs["target"].value_counts().values)<5):
170
+ print("Meta gene is: ", meta_name)
171
+ return meta_name, adata.obs["meta"].tolist()
172
+ #DE
173
+ sc.tl.rank_genes_groups(tmp, groupby="target",reference="rest", n_genes=1,method='wilcoxon')
174
+ adj_g=tmp.uns['rank_genes_groups']["names"][0][0]
175
+ add_g=tmp.uns['rank_genes_groups']["names"][0][1]
176
+ meta_name_cur=meta_name+"+"+add_g+"-"+adj_g
177
+ print("Add gene: ", add_g)
178
+ print("Minus gene: ", adj_g)
179
+ #Meta gene
180
+ adata.obs[add_g]=adata.X[:,adata.var.index==add_g]
181
+ adata.obs[adj_g]=adata.X[:,adata.var.index==adj_g]
182
+ adata.obs["meta_cur"]=(adata.obs["meta"]+adata.obs[add_g]-adata.obs[adj_g])
183
+ adata.obs["meta_cur"]=adata.obs["meta_cur"]-np.min(adata.obs["meta_cur"])
184
+ mean_diff_cur=np.mean(adata.obs["meta_cur"][adata.obs["pred"]==target_domain])-np.mean(adata.obs["meta_cur"][adata.obs["pred"]!=target_domain])
185
+ num_non_target_cur=np.sum(tmp.obs["target"]==0)
186
+ if (early_stop==False) | ((num_non_target>=num_non_target_cur) & (mean_diff<=mean_diff_cur)):
187
+ num_non_target=num_non_target_cur
188
+ mean_diff=mean_diff_cur
189
+ print("Absolute mean change:", mean_diff)
190
+ print("Number of non-target spots reduced to:",num_non_target)
191
+ else:
192
+ print("Stopped!", "Previous Number of non-target spots",num_non_target, num_non_target_cur, mean_diff,mean_diff_cur)
193
+ print("Previous Number of non-target spots",num_non_target, num_non_target_cur, mean_diff,mean_diff_cur)
194
+ print("Previous Number of non-target spots",num_non_target)
195
+ print("Current Number of non-target spots",num_non_target_cur)
196
+ print("Absolute mean change", mean_diff)
197
+ print("===========================================================================")
198
+ print("Meta gene: ", meta_name)
199
+ print("===========================================================================")
200
+ return meta_name, adata.obs["meta"].tolist()
201
+ meta_name=meta_name_cur
202
+ adata.obs["meta"]=adata.obs["meta_cur"]
203
+ print("===========================================================================")
204
+ print("Meta gene is: ", meta_name)
205
+ print("===========================================================================")
206
+ return meta_name, adata.obs["meta"].tolist()
207
+
208
+
209
+ def Moran_I(genes_exp,x, y, k=5, knn=True):
210
+ XYmap=pd.DataFrame({"x": x, "y":y})
211
+ if knn:
212
+ XYnbrs = NearestNeighbors(n_neighbors=k, algorithm='auto',metric = 'euclidean').fit(XYmap)
213
+ XYdistances, XYindices = XYnbrs.kneighbors(XYmap)
214
+ W = np.zeros((genes_exp.shape[0],genes_exp.shape[0]))
215
+ for i in range(0,genes_exp.shape[0]):
216
+ W[i,XYindices[i,:]]=1
217
+ for i in range(0,genes_exp.shape[0]):
218
+ W[i,i]=0
219
+ else:
220
+ W=calculate_adj_matrix(x=x,y=y, histology=False)
221
+ I = pd.Series(index=genes_exp.columns, dtype="float64")
222
+ for k in genes_exp.columns:
223
+ X_minus_mean = np.array(genes_exp[k] - np.mean(genes_exp[k]))
224
+ X_minus_mean = np.reshape(X_minus_mean,(len(X_minus_mean),1))
225
+ Nom = np.sum(np.multiply(W,np.matmul(X_minus_mean,X_minus_mean.T)))
226
+ Den = np.sum(np.multiply(X_minus_mean,X_minus_mean))
227
+ I[k] = (len(genes_exp[k])/np.sum(W))*(Nom/Den)
228
+ return I
229
+
230
+
231
+ def Geary_C(genes_exp,x, y, k=5, knn=True):
232
+ XYmap=pd.DataFrame({"x": x, "y":y})
233
+ if knn:
234
+ XYnbrs = NearestNeighbors(n_neighbors=k, algorithm='auto',metric = 'euclidean').fit(XYmap)
235
+ XYdistances, XYindices = XYnbrs.kneighbors(XYmap)
236
+ W = np.zeros((genes_exp.shape[0],genes_exp.shape[0]))
237
+ for i in range(0,genes_exp.shape[0]):
238
+ W[i,XYindices[i,:]]=1
239
+ for i in range(0,genes_exp.shape[0]):
240
+ W[i,i]=0
241
+ else:
242
+ W=calculate_adj_matrix(x=x,y=y, histology=False)
243
+ C = pd.Series(index=genes_exp.columns, dtype="float64")
244
+ for k in genes_exp.columns:
245
+ X=np.array(genes_exp[k])
246
+ X_minus_mean = X - np.mean(X)
247
+ X_minus_mean = np.reshape(X_minus_mean,(len(X_minus_mean),1))
248
+ Xij=np.array([X,]*X.shape[0]).transpose()-np.array([X,]*X.shape[0])
249
+ Nom = np.sum(np.multiply(W,np.multiply(Xij,Xij)))
250
+ Den = np.sum(np.multiply(X_minus_mean,X_minus_mean))
251
+ C[k] = (len(genes_exp[k])/(2*np.sum(W)))*(Nom/Den)
252
+ return C
@@ -0,0 +1,315 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import sklearn.neighbors
4
+ import scipy.sparse as sp
5
+ import seaborn as sns
6
+ import matplotlib.pyplot as plt
7
+ import scanpy as sc
8
+
9
+ import torch
10
+ from torch_geometric.data import Data
11
+
12
+ def refine_spatial_cluster(adata, pred, shape="hexagon"):
13
+ G_df = adata.uns['Spatial_Net'].copy()
14
+ cells = np.array(adata.obs_names)
15
+ cells_id_tran = dict(zip(cells, range(cells.shape[0])))
16
+ G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
17
+ G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
18
+ G = sp.coo_matrix((G_df['Distance'], (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
19
+ refined_pred=[]
20
+ pred=pd.DataFrame({"pred": pred})
21
+ pred.reset_index(inplace=True)
22
+ dis_df=pd.DataFrame(G.todense())
23
+ if shape=="hexagon":
24
+ num_nbs=6
25
+ elif shape=="square":
26
+ num_nbs=4
27
+ else:
28
+ print("Shape not recongized, shape='hexagon' for Visium data, 'square' for ST data.")
29
+ for i in range(pred.shape[0]):
30
+ dis_tmp=dis_df.iloc[i, :]
31
+ dis_tmp = dis_tmp[dis_tmp>0]
32
+ dis_tmp = dis_tmp.sort_values(ascending=True)
33
+ nbs=dis_tmp[0:num_nbs]
34
+ nbs_pred=pred.pred.iloc[nbs.index]
35
+ self_pred=pred.pred.iloc[i]
36
+ v_c=nbs_pred.value_counts()
37
+ if (v_c.loc[self_pred]<num_nbs/2) and (np.max(v_c)>num_nbs/2):
38
+ refined_pred.append(v_c.idxmax())
39
+ else:
40
+ refined_pred.append(self_pred)
41
+ return refined_pred
42
+
43
+ def plot_clustering(adata, colors, title = None, savepath = None):
44
+ adata.obs['x_pixel'] = adata.obsm['spatial'][:, 0]
45
+ adata.obs['y_pixel'] = adata.obsm['spatial'][:, 1]
46
+
47
+ fig = plt.figure(figsize=(5, 5))
48
+ ax1 = fig.add_subplot(111)
49
+ sc.pl.scatter(adata, alpha=1, x="x_pixel", y="y_pixel", color=colors, title=title,
50
+ palette=sns.color_palette('plasma', 7), show=False, ax=ax1)
51
+
52
+ ax1.set_aspect('equal', 'box')
53
+ ax1.axis('off')
54
+ ax1.axes.invert_yaxis()
55
+ if savepath is not None:
56
+ fig.savefig(savepath, bbox_inches='tight')
57
+
58
+ def Transfer_pytorch_Data(adata):
59
+ #Expression edge
60
+ G_df = adata.uns['Exp_Net'].copy()
61
+ cells = np.array(adata.obs_names)
62
+ cells_id_tran = dict(zip(cells, range(cells.shape[0])))
63
+ G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
64
+ G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
65
+ G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
66
+ G = G + sp.eye(G.shape[0])
67
+ exp_edge = np.nonzero(G)
68
+
69
+ #Spatial edge
70
+ G_df = adata.uns['Spatial_Net'].copy()
71
+ cells = np.array(adata.obs_names)
72
+ cells_id_tran = dict(zip(cells, range(cells.shape[0])))
73
+ G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
74
+ G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
75
+ G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
76
+ G = G + sp.eye(G.shape[0])
77
+ spatial_edge = np.nonzero(G)
78
+
79
+ data = Data(edge_index=torch.LongTensor(np.array(
80
+ [np.concatenate((exp_edge[0],spatial_edge[0])),
81
+ np.concatenate((exp_edge[1],spatial_edge[1]))])).contiguous(),
82
+ x=torch.FloatTensor(adata.obsm['X_pca'].copy())) # .todense()
83
+ edge_type = torch.zeros(exp_edge[0].shape[0]+spatial_edge[0].shape[0],dtype=torch.int64)
84
+ edge_type[exp_edge[0].shape[0]:] = 1
85
+ data.edge_type = edge_type
86
+
87
+ return data
88
+
89
+ def Batch_Data(adata, num_batch_x, num_batch_y, spatial_key=['X', 'Y'], plot_Stats=False):
90
+ Sp_df = adata.obs.loc[:, spatial_key].copy()
91
+ Sp_df = np.array(Sp_df)
92
+ batch_x_coor = [np.percentile(Sp_df[:, 0], (1/num_batch_x)*x*100) for x in range(num_batch_x+1)]
93
+ batch_y_coor = [np.percentile(Sp_df[:, 1], (1/num_batch_y)*x*100) for x in range(num_batch_y+1)]
94
+
95
+ Batch_list = []
96
+ for it_x in range(num_batch_x):
97
+ for it_y in range(num_batch_y):
98
+ min_x = batch_x_coor[it_x]
99
+ max_x = batch_x_coor[it_x+1]
100
+ min_y = batch_y_coor[it_y]
101
+ max_y = batch_y_coor[it_y+1]
102
+ temp_adata = adata.copy()
103
+ temp_adata = temp_adata[temp_adata.obs[spatial_key[0]].map(lambda x: min_x <= x <= max_x)]
104
+ temp_adata = temp_adata[temp_adata.obs[spatial_key[1]].map(lambda y: min_y <= y <= max_y)]
105
+ Batch_list.append(temp_adata)
106
+ if plot_Stats:
107
+ f, ax = plt.subplots(figsize=(1, 3))
108
+ plot_df = pd.DataFrame([x.shape[0] for x in Batch_list], columns=['#spot/batch'])
109
+ sns.boxplot(y='#spot/batch', data=plot_df, ax=ax)
110
+ sns.stripplot(y='#spot/batch', data=plot_df, ax=ax, color='red', size=5)
111
+ return Batch_list
112
+
113
+ def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=6, model='KNN', verbose=True):
114
+ """\
115
+ Construct the spatial neighbor networks.
116
+
117
+ Parameters
118
+ ----------
119
+ adata
120
+ AnnData object of scanpy package.
121
+ rad_cutoff
122
+ radius cutoff when model='Radius'
123
+ k_cutoff
124
+ The number of nearest neighbors when model='KNN'
125
+ model
126
+ The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.
127
+
128
+ Returns
129
+ -------
130
+ The spatial networks are saved in adata.uns['Spatial_Net']
131
+ """
132
+
133
+ assert(model in ['Radius', 'KNN'])
134
+ if verbose:
135
+ print('------Calculating spatial graph...')
136
+ coor = pd.DataFrame(adata.obsm['spatial'])
137
+ coor.index = adata.obs.index
138
+ coor.columns = ['imagerow', 'imagecol']
139
+
140
+ if model == 'Radius':
141
+ nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
142
+ distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
143
+ KNN_list = []
144
+ for it in range(indices.shape[0]):
145
+ KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
146
+
147
+ if model == 'KNN':
148
+ nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
149
+ distances, indices = nbrs.kneighbors(coor)
150
+ KNN_list = []
151
+ for it in range(indices.shape[0]):
152
+ KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
153
+
154
+ KNN_df = pd.concat(KNN_list)
155
+ KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
156
+
157
+ Spatial_Net = KNN_df.copy()
158
+ Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
159
+ id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
160
+ Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
161
+ Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
162
+ if verbose:
163
+ print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
164
+ print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
165
+
166
+ adata.uns['Spatial_Net'] = Spatial_Net
167
+
168
+ def Cal_Expression_Net(adata, k_cutoff=6):
169
+
170
+ coor = pd.DataFrame(adata.obsm['X_pca'])
171
+ coor.index = adata.obs.index
172
+ nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
173
+ distances, indices = nbrs.kneighbors(coor)
174
+ KNN_list = []
175
+ for it in range(indices.shape[0]):
176
+ KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
177
+ KNN_df = pd.concat(KNN_list)
178
+ KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
179
+
180
+ exp_Net = KNN_df.copy()
181
+ exp_Net = exp_Net.loc[exp_Net['Distance']>0,]
182
+
183
+ id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
184
+ exp_Net['Cell1'] = exp_Net['Cell1'].map(id_cell_trans)
185
+ exp_Net['Cell2'] = exp_Net['Cell2'].map(id_cell_trans)
186
+
187
+ adata.uns['Exp_Net'] = exp_Net
188
+
189
+
190
+ def res_search_fixed_clus(adata, fixed_clus_count, increment=0.02):
191
+ '''
192
+ arg1(adata)[AnnData matrix]
193
+ arg2(fixed_clus_count)[int]
194
+
195
+ return:
196
+ resolution[int]
197
+ '''
198
+ for res in np.arange(2.5, 0.0, -increment):
199
+ sc.tl.leiden(adata, random_state=0, resolution=res)
200
+ count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
201
+ if count_unique_leiden <= fixed_clus_count:
202
+ break
203
+ return res
204
+
205
+
206
+ def Cal_Spatial_Net_3D(adata, rad_cutoff_2D, rad_cutoff_Zaxis,
207
+ key_section='Section_id', section_order=None, verbose=True):
208
+ """\
209
+ Construct the spatial neighbor networks.
210
+
211
+ Parameters
212
+ ----------
213
+ adata
214
+ AnnData object of scanpy package.
215
+ rad_cutoff_2D
216
+ radius cutoff for 2D SNN construction.
217
+ rad_cutoff_Zaxis
218
+ radius cutoff for 2D SNN construction for consturcting SNNs between adjacent sections.
219
+ key_section
220
+ The columns names of section_ID in adata.obs.
221
+ section_order
222
+ The order of sections. The SNNs between adjacent sections are constructed according to this order.
223
+
224
+ Returns
225
+ -------
226
+ The 3D spatial networks are saved in adata.uns['Spatial_Net'].
227
+ """
228
+ adata.uns['Spatial_Net_2D'] = pd.DataFrame()
229
+ adata.uns['Spatial_Net_Zaxis'] = pd.DataFrame()
230
+ num_section = np.unique(adata.obs[key_section]).shape[0]
231
+ if verbose:
232
+ print('Radius used for 2D SNN:', rad_cutoff_2D)
233
+ print('Radius used for SNN between sections:', rad_cutoff_Zaxis)
234
+ for temp_section in np.unique(adata.obs[key_section]):
235
+ if verbose:
236
+ print('------Calculating 2D SNN of section ', temp_section)
237
+ temp_adata = adata[adata.obs[key_section] == temp_section, ]
238
+ Cal_Spatial_Net(
239
+ temp_adata, rad_cutoff=rad_cutoff_2D, verbose=False)
240
+ temp_adata.uns['Spatial_Net']['SNN'] = temp_section
241
+ if verbose:
242
+ print('This graph contains %d edges, %d cells.' %
243
+ (temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs))
244
+ print('%.4f neighbors per cell on average.' %
245
+ (temp_adata.uns['Spatial_Net'].shape[0]/temp_adata.n_obs))
246
+ adata.uns['Spatial_Net_2D'] = pd.concat(
247
+ [adata.uns['Spatial_Net_2D'], temp_adata.uns['Spatial_Net']])
248
+ for it in range(num_section-1):
249
+ section_1 = section_order[it]
250
+ section_2 = section_order[it+1]
251
+ if verbose:
252
+ print('------Calculating SNN between adjacent section %s and %s.' %
253
+ (section_1, section_2))
254
+ Z_Net_ID = section_1+'-'+section_2
255
+ temp_adata = adata[adata.obs[key_section].isin(
256
+ [section_1, section_2]), ]
257
+ Cal_Spatial_Net(
258
+ temp_adata, rad_cutoff=rad_cutoff_Zaxis, verbose=False)
259
+ spot_section_trans = dict(
260
+ zip(temp_adata.obs.index, temp_adata.obs[key_section]))
261
+ temp_adata.uns['Spatial_Net']['Section_id_1'] = temp_adata.uns['Spatial_Net']['Cell1'].map(
262
+ spot_section_trans)
263
+ temp_adata.uns['Spatial_Net']['Section_id_2'] = temp_adata.uns['Spatial_Net']['Cell2'].map(
264
+ spot_section_trans)
265
+ used_edge = temp_adata.uns['Spatial_Net'].apply(
266
+ lambda x: x['Section_id_1'] != x['Section_id_2'], axis=1)
267
+ temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[used_edge, ]
268
+ temp_adata.uns['Spatial_Net'] = temp_adata.uns['Spatial_Net'].loc[:, [
269
+ 'Cell1', 'Cell2', 'Distance']]
270
+ temp_adata.uns['Spatial_Net']['SNN'] = Z_Net_ID
271
+ if verbose:
272
+ print('This graph contains %d edges, %d cells.' %
273
+ (temp_adata.uns['Spatial_Net'].shape[0], temp_adata.n_obs))
274
+ print('%.4f neighbors per cell on average.' %
275
+ (temp_adata.uns['Spatial_Net'].shape[0]/temp_adata.n_obs))
276
+ adata.uns['Spatial_Net_Zaxis'] = pd.concat(
277
+ [adata.uns['Spatial_Net_Zaxis'], temp_adata.uns['Spatial_Net']])
278
+ adata.uns['Spatial_Net'] = pd.concat(
279
+ [adata.uns['Spatial_Net_2D'], adata.uns['Spatial_Net_Zaxis']])
280
+ if verbose:
281
+ print('3D SNN contains %d edges, %d cells.' %
282
+ (adata.uns['Spatial_Net'].shape[0], adata.n_obs))
283
+ print('%.4f neighbors per cell on average.' %
284
+ (adata.uns['Spatial_Net'].shape[0]/adata.n_obs))
285
+
286
+
287
+ def Cal_Expression_3D(adata, k_cutoff=6, key_section='Section_id', verbose=True):
288
+
289
+ adata.uns['Exp_Net'] = pd.DataFrame()
290
+ for temp_section in np.unique(adata.obs[key_section]):
291
+ if verbose:
292
+ print('------Calculating Expression Network of section ', temp_section)
293
+ temp_adata = adata[adata.obs[key_section] == temp_section, ].copy()
294
+ sc.pp.filter_genes(temp_adata, min_cells=5)
295
+ sc.pp.normalize_total(temp_adata, target_sum=1, exclude_highly_expressed=True)
296
+ sc.pp.scale(temp_adata)
297
+ sc.pp.pca(temp_adata, n_comps=100)
298
+ Cal_Expression_Net(
299
+ temp_adata, k_cutoff=k_cutoff)
300
+ temp_adata.uns['Exp_Net']['SNN'] = temp_section
301
+ adata.uns['Exp_Net'] = pd.concat(
302
+ [adata.uns['Exp_Net'], temp_adata.uns['Exp_Net']])
303
+
304
+
305
+ def Stats_Spatial_Net(adata):
306
+ import matplotlib.pyplot as plt
307
+ Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
308
+ Mean_edge = Num_edge/adata.shape[0]
309
+ plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
310
+ plot_df = plot_df/adata.shape[0]
311
+ fig, ax = plt.subplots(figsize=[3,2])
312
+ plt.ylabel('Percentage')
313
+ plt.xlabel('')
314
+ plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)
315
+ ax.bar(plot_df.index, plot_df)
@@ -0,0 +1,49 @@
1
+ Metadata-Version: 2.1
2
+ Name: RGAST
3
+ Version: 0.0.1
4
+ Summary: Relational Graph Attention Network for Spatial Transcriptome Analysis
5
+ Home-page: https://github.com/GYQ-form/RGAST
6
+ Author: Yuqiao Gong
7
+ Author-email: gyq123@sjtu.edu.cn
8
+ License: MIT
9
+ Keywords: spatial transcriptomic,RGAT,representation learning,spatial domain identification
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+
16
+ # RGAST
17
+
18
+ RGAST: Relational Graph Attention Network for Spatial Transcriptome Analysis
19
+
20
+ This document will help you easily go through the scBC model.
21
+
22
+ ![fig1_00](https://github.com/GYQ-form/RGAST/assets/79566479/fe0655dc-2318-44e0-92bf-0aea3aad7163)
23
+
24
+
25
+ ## Installation
26
+
27
+ To install our package, run
28
+
29
+ ```bash
30
+ pip install RGAST
31
+ ```
32
+
33
+
34
+
35
+ ## Usage
36
+
37
+ RGAST (Relational Graph Attention network for Spatial Transcriptome analysis) constructs a relational graph attention network to learn the representation of each spot in the spatial transcriptome data. Plus the attention mechanism, RGAST considers both gene expression similarity and spatial neighbor relationships in constructing the graph network, enabling a more comprehensive and flexible representation of the spatial transcriptome data. RGAST can be used in many ST analysis:
38
+
39
+ - spatial domain identification
40
+ - cell trajectory inference
41
+ - spatially variable gene (SVG) detection
42
+ - uncover spatially resolved cell-cell interactions
43
+ - reveal intricate 3D spatial patterns across multiple sections of ST data
44
+
45
+
46
+
47
+ ## Tutorial
48
+
49
+ We have prepared several basic tutorials in https://github.com/GYQ-form/RGAST/tree/main/tutorial. You can quickly hands on RGAST by going through these tutorials. Model parameters trained in our study are also released in https://github.com/GYQ-form/RGAST/tree/main/model_path.
@@ -0,0 +1,13 @@
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ RGAST/RGAST.py
5
+ RGAST/Train_RGAST.py
6
+ RGAST/__init__.py
7
+ RGAST/svg.py
8
+ RGAST/utils.py
9
+ RGAST.egg-info/PKG-INFO
10
+ RGAST.egg-info/SOURCES.txt
11
+ RGAST.egg-info/dependency_links.txt
12
+ RGAST.egg-info/requires.txt
13
+ RGAST.egg-info/top_level.txt
@@ -0,0 +1,6 @@
1
+ torch
2
+ scanpy
3
+ sklearn
4
+ torch_geometric
5
+ scipy
6
+ numba
@@ -0,0 +1 @@
1
+ RGAST
RGAST-0.0.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
RGAST-0.0.1/setup.py ADDED
@@ -0,0 +1,34 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ __version__ = "0.0.1"
4
+
5
+ with open("README.md", "r", encoding='utf-8') as fh:
6
+ long_description = fh.read()
7
+
8
+ setup(
9
+ name="RGAST",
10
+ version=__version__,
11
+ packages=find_packages(),
12
+ classifiers=[
13
+ "Programming Language :: Python :: 3",
14
+ "License :: OSI Approved :: MIT License",
15
+ "Operating System :: OS Independent",
16
+ ],
17
+ include_package_data=True,
18
+ install_requires=[
19
+ 'torch',
20
+ 'scanpy',
21
+ 'sklearn',
22
+ 'torch_geometric',
23
+ 'scipy',
24
+ 'numba',
25
+ ],
26
+ author="Yuqiao Gong",
27
+ author_email="gyq123@sjtu.edu.cn",
28
+ keywords=["spatial transcriptomic", "RGAT", "representation learning", "spatial domain identification"],
29
+ description="Relational Graph Attention Network for Spatial Transcriptome Analysis",
30
+ license="MIT",
31
+ url='https://github.com/GYQ-form/RGAST',
32
+ long_description_content_type='text/markdown',
33
+ long_description=long_description
34
+ )