spatialformer 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spatialformer/GraphSAGE.py +765 -0
- spatialformer/__init__.py +23 -0
- spatialformer/data_loader.py +499 -0
- spatialformer/graphsage.py +362 -0
- spatialformer/processor.py +107 -0
- spatialformer/tools/__init__.py +2 -0
- spatialformer/tools/get_embeddings.py +404 -0
- spatialformer/train.py +234 -0
- spatialformer/utils.py +947 -0
- spatialformer-0.0.6.dist-info/LICENSE +21 -0
- spatialformer-0.0.6.dist-info/METADATA +39 -0
- spatialformer-0.0.6.dist-info/RECORD +14 -0
- spatialformer-0.0.6.dist-info/WHEEL +5 -0
- spatialformer-0.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,765 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.optim as optim
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
# import pdb; pdb.set_trace()
|
|
8
|
+
from torch_geometric.loader import NeighborLoader, DataLoader
|
|
9
|
+
from torch_geometric.nn import SAGEConv
|
|
10
|
+
from torch_geometric.data import Data
|
|
11
|
+
from torch_geometric.utils import k_hop_subgraph, negative_sampling, add_self_loops
|
|
12
|
+
from sklearn.preprocessing import OneHotEncoder
|
|
13
|
+
from torch_geometric.utils import to_networkx, from_networkx
|
|
14
|
+
from scipy.spatial import KDTree
|
|
15
|
+
from torch_geometric.nn import global_mean_pool
|
|
16
|
+
from torchmetrics.classification import BinaryAccuracy
|
|
17
|
+
from torchmetrics.classification import MulticlassAccuracy
|
|
18
|
+
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
|
19
|
+
import networkx as nx
|
|
20
|
+
import logging
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
import json
|
|
23
|
+
import random
|
|
24
|
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
|
25
|
+
import argparse
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from datetime import datetime
|
|
28
|
+
import pytorch_lightning as pl
|
|
29
|
+
import pickle
|
|
30
|
+
import os
|
|
31
|
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
32
|
+
|
|
33
|
+
path = Path(os.getcwd())
|
|
34
|
+
parent_dir = path.parent
|
|
35
|
+
# import pdb; pdb.set_trace()
|
|
36
|
+
# data_dir = os.path.join(parent_dir, "david_data")
|
|
37
|
+
|
|
38
|
+
model_path = os.path.join(parent_dir, "output", "GraphSAGE_model")
|
|
39
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Step 1: Load and preprocess data
|
|
43
|
+
def load_and_preprocess_data(filepath):
|
|
44
|
+
print(f"loading {filepath}")
|
|
45
|
+
try:
|
|
46
|
+
dataset = pd.read_csv(filepath)
|
|
47
|
+
except:
|
|
48
|
+
dataset = pd.read_csv(filepath + ".gz")
|
|
49
|
+
print("after loading the data")
|
|
50
|
+
if "qv" in dataset.columns:
|
|
51
|
+
dataset = dataset[dataset['qv'] >= 20]
|
|
52
|
+
dataset = dataset[~(dataset['feature_name'].str.startswith('Neg') | dataset['feature_name'].str.startswith('BLANK') | dataset['feature_name'].str.startswith('Unassigned'))]
|
|
53
|
+
#only use partial data
|
|
54
|
+
# Ensure that we work with distinct cell IDs
|
|
55
|
+
unique_cell_ids = dataset['cell_id'].unique()
|
|
56
|
+
# Randomly sample 20,000 unique cell_ids
|
|
57
|
+
# Make sure that there are at least 20,000 unique cell IDs before sampling
|
|
58
|
+
if len(unique_cell_ids) >= 20000:
|
|
59
|
+
sampled_cell_ids = pd.Series(unique_cell_ids).sample(n=20000, random_state=42)
|
|
60
|
+
else:
|
|
61
|
+
sampled_cell_ids = pd.Series(unique_cell_ids)
|
|
62
|
+
# import pdb; pdb.set_trace()
|
|
63
|
+
# Select only the rows where `cell_id` is in the sampled set
|
|
64
|
+
sampled_dataset = dataset[dataset['cell_id'].isin(sampled_cell_ids)]
|
|
65
|
+
# dataset = dataset.iloc[:10000]
|
|
66
|
+
return sampled_dataset
|
|
67
|
+
|
|
68
|
+
def index2gene(filepath):
|
|
69
|
+
#get the reference gene vocabulary
|
|
70
|
+
with open(filepath, "r") as f:
|
|
71
|
+
vocab = json.load(f)
|
|
72
|
+
vocab_list = list(vocab.keys())[27:]
|
|
73
|
+
return vocab_list
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def build_graph_for_sample(data, threshold=3.0, batch_size=100, sample_id = None):
|
|
77
|
+
# import pdb; pdb.set_trace()
|
|
78
|
+
r_c = np.array(data[['x_location', 'y_location', 'z_location']])
|
|
79
|
+
gene_labels = data['feature_name'].values
|
|
80
|
+
# Convert vocab dictionary keys to a sorted list
|
|
81
|
+
# Initialize the OneHotEncoder with the specific categories
|
|
82
|
+
encoder = OneHotEncoder(categories=[vocab_list], sparse_output=False)
|
|
83
|
+
# Transform the gene_labels into one-hot encoded format
|
|
84
|
+
one_hot_labels = encoder.fit_transform(gene_labels.reshape(-1, 1))
|
|
85
|
+
# import pdb; pdb.set_trace()
|
|
86
|
+
|
|
87
|
+
# import pdb; pdb.set_trace()
|
|
88
|
+
kdtree = KDTree(r_c)
|
|
89
|
+
G = nx.Graph()
|
|
90
|
+
# import pdb; pdb.set_trace()
|
|
91
|
+
for i in range(len(r_c)):
|
|
92
|
+
G.add_node(i, feature=one_hot_labels[i])
|
|
93
|
+
# import pdb; pdb.set_trace()
|
|
94
|
+
num_nodes = len(r_c)
|
|
95
|
+
|
|
96
|
+
print("building graph in batch")
|
|
97
|
+
for start_idx in tqdm(range(0, num_nodes, batch_size)):
|
|
98
|
+
end_idx = min(start_idx + batch_size, num_nodes)
|
|
99
|
+
batch_r_c = r_c[start_idx:end_idx]
|
|
100
|
+
edges_to_add = []
|
|
101
|
+
for i, x in enumerate(batch_r_c, start=start_idx):
|
|
102
|
+
# import pdb; pdb.set_trace()
|
|
103
|
+
neighbors_idx = kdtree.query_ball_point(x, threshold)
|
|
104
|
+
for j in neighbors_idx:
|
|
105
|
+
if i < j:
|
|
106
|
+
edges_to_add.append((i, j))
|
|
107
|
+
G.add_edges_from(edges_to_add)
|
|
108
|
+
|
|
109
|
+
# Batch add edges
|
|
110
|
+
# import pdb; pdb.set_trace()
|
|
111
|
+
# import pdb; pdb.set_trace()
|
|
112
|
+
edge_index = torch.tensor(list(G.edges)).t().contiguous()
|
|
113
|
+
x = torch.tensor(one_hot_labels, dtype=torch.float)
|
|
114
|
+
# import pdb; pdb.set_trace()
|
|
115
|
+
num_nodes = x.size(0)
|
|
116
|
+
|
|
117
|
+
root_nodes = torch.tensor(random.sample(range(num_nodes), min(5000, num_nodes)))
|
|
118
|
+
|
|
119
|
+
# import pdb; pdb.set_trace()
|
|
120
|
+
print("creating the subgraph...")
|
|
121
|
+
subgraph_nodes, subgraph_edge_index, _, _ = k_hop_subgraph(
|
|
122
|
+
node_idx=root_nodes,
|
|
123
|
+
num_hops=3,
|
|
124
|
+
edge_index=edge_index,
|
|
125
|
+
relabel_nodes=True
|
|
126
|
+
)
|
|
127
|
+
# import pdb; pdb.set_trace()
|
|
128
|
+
x_subgraph = x[subgraph_nodes]
|
|
129
|
+
# import pdb; pdb.set_trace()
|
|
130
|
+
data = Data(x=x_subgraph, edge_index=subgraph_edge_index)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
print("saving the graph")
|
|
134
|
+
torch.save(data, f'{data_dir}/subgraph_data_{sample_id}.pt')
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Step 2: Create subgraphs
|
|
139
|
+
def create_subgraph(data, num_root_nodes=5000, num_neighbors=[20, 10, 10]):
|
|
140
|
+
num_nodes = data.x.size(0)
|
|
141
|
+
root_nodes = torch.tensor(random.sample(range(num_nodes), min(num_root_nodes, num_nodes)))
|
|
142
|
+
|
|
143
|
+
subgraph_nodes, subgraph_edge_index, _, _ = k_hop_subgraph(
|
|
144
|
+
node_idx=root_nodes,
|
|
145
|
+
num_hops=len(num_neighbors),
|
|
146
|
+
edge_index=data.edge_index,
|
|
147
|
+
relabel_nodes=True
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
x_subgraph = data.x[subgraph_nodes]
|
|
151
|
+
|
|
152
|
+
components = [c for c in nx.connected_components(G) if len(c) >= 10]
|
|
153
|
+
G = G.subgraph(set.union(*map(set, components)))
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
return Data(x=x_subgraph, edge_index=subgraph_edge_index)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# Step 4: Define and train the 2-hop GraphSAGE Model
|
|
160
|
+
class TwoHopGraphSAGE(nn.Module):
|
|
161
|
+
def __init__(self, in_channels, hidden_channels, out_channels):
|
|
162
|
+
super(TwoHopGraphSAGE, self).__init__()
|
|
163
|
+
self.conv1 = SAGEConv(in_channels, hidden_channels)
|
|
164
|
+
self.conv2 = SAGEConv(hidden_channels, out_channels)
|
|
165
|
+
|
|
166
|
+
def forward(self, x, edge_index):
|
|
167
|
+
x = self.conv1(x, edge_index)
|
|
168
|
+
x = F.relu(x)
|
|
169
|
+
x = self.conv2(x, edge_index)
|
|
170
|
+
return x
|
|
171
|
+
def filter_component(data):
|
|
172
|
+
# import pdb; pdb.set_trace()
|
|
173
|
+
print("subgraph to networkx...")
|
|
174
|
+
sub_G = to_networkx(data, to_undirected=True, node_attrs=['x'])
|
|
175
|
+
print("filtering by components...")
|
|
176
|
+
components = [c for c in nx.connected_components(sub_G) if len(c) >= 10]
|
|
177
|
+
sub_G_f = sub_G.subgraph(set.union(*map(set, components)))
|
|
178
|
+
|
|
179
|
+
data_filtered = from_networkx(sub_G_f, group_node_attrs=['x'])
|
|
180
|
+
return data_filtered
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# Define your LightningModule
|
|
185
|
+
class GraphSAGEModel(pl.LightningModule):
|
|
186
|
+
def __init__(self, input_dim, hidden_dim, output_dim, lr, train_dataset, batch_size):
|
|
187
|
+
super(GraphSAGEModel, self).__init__()
|
|
188
|
+
self.conv1 = SAGEConv(input_dim, hidden_dim)
|
|
189
|
+
self.conv2 = SAGEConv(hidden_dim, output_dim)
|
|
190
|
+
self.accuracy = BinaryAccuracy()
|
|
191
|
+
self.lr = lr
|
|
192
|
+
self.train_dataset = train_dataset
|
|
193
|
+
self.batch_size = batch_size
|
|
194
|
+
|
|
195
|
+
def forward(self, x, edge_index):
|
|
196
|
+
x = self.conv1(x, edge_index)
|
|
197
|
+
x = F.relu(x)
|
|
198
|
+
x = self.conv2(x, edge_index)
|
|
199
|
+
return x
|
|
200
|
+
|
|
201
|
+
def training_step(self, batch, batch_idx):
|
|
202
|
+
z = self(batch.x, batch.edge_index) # Pass through the model
|
|
203
|
+
|
|
204
|
+
# Use positive samples from existing edges
|
|
205
|
+
pos_edge_index = batch.edge_index
|
|
206
|
+
|
|
207
|
+
# Generate negative samples
|
|
208
|
+
neg_edge_index = negative_sampling(
|
|
209
|
+
edge_index=pos_edge_index,
|
|
210
|
+
num_nodes=batch.x.size(0),
|
|
211
|
+
num_neg_samples=pos_edge_index.size(1)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Compute dot product of embeddings for positive and negative samples
|
|
215
|
+
pos_out = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)
|
|
216
|
+
neg_out = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)
|
|
217
|
+
|
|
218
|
+
# Concatenate all outputs and create labels
|
|
219
|
+
all_out = torch.cat([pos_out, neg_out])
|
|
220
|
+
all_labels = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
|
|
221
|
+
|
|
222
|
+
# Define the binary classification loss
|
|
223
|
+
loss = F.binary_cross_entropy_with_logits(all_out, all_labels)
|
|
224
|
+
|
|
225
|
+
# Evaluate accuracy
|
|
226
|
+
preds = torch.sigmoid(all_out) > 0.5
|
|
227
|
+
acc = self.accuracy(preds, all_labels.int())
|
|
228
|
+
|
|
229
|
+
# Log loss and accuracy
|
|
230
|
+
self.log('train_loss', loss, sync_dist=True, reduce_fx='mean', prog_bar=True, batch_size=batch.x.size(0))
|
|
231
|
+
self.log('train_acc', acc, sync_dist=True, reduce_fx='mean', prog_bar=True, batch_size=batch.x.size(0))
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
return loss
|
|
235
|
+
|
|
236
|
+
def configure_optimizers(self):
|
|
237
|
+
# return optim.Adam(self.parameters(), lr=0.001)
|
|
238
|
+
|
|
239
|
+
optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.1)
|
|
240
|
+
return optimizer
|
|
241
|
+
|
|
242
|
+
def train_dataloader(self):
|
|
243
|
+
|
|
244
|
+
loader = NeighborLoader(
|
|
245
|
+
data=self.train_dataset,
|
|
246
|
+
num_neighbors=[20, 10], # 20 first-hop, 10 second-hop neighbors
|
|
247
|
+
batch_size=self.batch_size, # Example batch size for root nodes
|
|
248
|
+
shuffle=True,
|
|
249
|
+
num_workers=4 # Set based on your CPU availability
|
|
250
|
+
)
|
|
251
|
+
return loader
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class PatternGraphSAGEModel(pl.LightningModule):
|
|
256
|
+
def __init__(self, model, output_dim, nn_hidden_layer_dim, lr):
|
|
257
|
+
super(PatternGraphSAGEModel, self).__init__()
|
|
258
|
+
self.model = model
|
|
259
|
+
self.fc = nn.Linear(512, 2) # 5 classes
|
|
260
|
+
self.hidden_layer = nn.Linear(output_dim, nn_hidden_layer_dim)
|
|
261
|
+
|
|
262
|
+
# self.label_str = ["random", "extranuclear", "perinuclear", "pericellular", "intranuclear"]
|
|
263
|
+
self.label_str = ["random", "non-random"]
|
|
264
|
+
self.accuracy = MulticlassAccuracy(num_classes=len(self.label_str))
|
|
265
|
+
|
|
266
|
+
self.lr = lr
|
|
267
|
+
|
|
268
|
+
def forward(self, x, edge_index, batch_indice):
|
|
269
|
+
x = self.model(x, edge_index)
|
|
270
|
+
# x = F.relu(x)
|
|
271
|
+
# Activation is not required before the hidden layer
|
|
272
|
+
# x = self.hidden_layer(x)
|
|
273
|
+
x = F.relu(x) # Activation after the hidden layer output
|
|
274
|
+
# Use global mean pooling
|
|
275
|
+
x = global_mean_pool(x, batch_indice)
|
|
276
|
+
# Finally apply the classification layer
|
|
277
|
+
x = self.fc(x)
|
|
278
|
+
|
|
279
|
+
return x # Return raw logits for CrossEntropyLoss
|
|
280
|
+
|
|
281
|
+
def training_step(self, batch, batch_idx):
|
|
282
|
+
# import pdb; pdb.set_trace()
|
|
283
|
+
preds = self(batch[0].x, batch[0].edge_index, batch[0].batch) # Pass through the model
|
|
284
|
+
# labels = torch.tensor([self.label_str.index(label_str) for label_str in batch[1]], dtype=torch.long).to(preds.device)
|
|
285
|
+
labels = torch.tensor([0 if label_str == "random" else 1 for label_str in batch[1]], dtype=torch.long).to(preds.device)
|
|
286
|
+
|
|
287
|
+
# labels = torch.tensor([0 for label_str in batch[1] if label_str != "intranuclear"], dtype=torch.long).to(preds.device)
|
|
288
|
+
# Define the multi-class classification loss
|
|
289
|
+
loss = F.cross_entropy(preds, labels) # For multi-class
|
|
290
|
+
|
|
291
|
+
# Evaluate accuracy
|
|
292
|
+
# acc = self.accuracy(preds, labels)
|
|
293
|
+
acc = self.accuracy(preds.argmax(dim=1), labels)
|
|
294
|
+
# Log loss and accuracy
|
|
295
|
+
self.log('train_loss', loss, sync_dist=True, prog_bar=True)
|
|
296
|
+
self.log('train_acc', acc, sync_dist=True, prog_bar=True)
|
|
297
|
+
|
|
298
|
+
return loss
|
|
299
|
+
|
|
300
|
+
def validation_step(self, batch, batch_idx):
|
|
301
|
+
|
|
302
|
+
preds = self(batch[0].x, batch[0].edge_index, batch[0].batch) # Pass through the model
|
|
303
|
+
# labels = torch.tensor([self.label_str.index(label_str) for label_str in batch[1]], dtype=torch.long).to(preds.device)
|
|
304
|
+
# Convert patterns from strings to binary labels
|
|
305
|
+
labels = torch.tensor([0 if label_str == "random" else 1 for label_str in batch[1]], dtype=torch.long).to(preds.device)
|
|
306
|
+
|
|
307
|
+
# labels = torch.tensor([0 if label_str != "intranuclear" for label_str in batch[1] else 1], dtype=torch.long).to(preds.device)
|
|
308
|
+
# import pdb; pdb.set_trace()
|
|
309
|
+
# Define the multi-class classification loss
|
|
310
|
+
# import pdb; pdb.set_trace()
|
|
311
|
+
# criterion = torch.nn.CrossEntropyLoss()
|
|
312
|
+
# criterion(preds, labels)
|
|
313
|
+
loss = F.cross_entropy(preds, labels) # For multi-class
|
|
314
|
+
|
|
315
|
+
# Evaluate accuracy
|
|
316
|
+
acc = self.accuracy(preds.argmax(dim=1), labels)
|
|
317
|
+
# acc = self.accuracy(preds, labels)
|
|
318
|
+
|
|
319
|
+
# Log loss and accuracy
|
|
320
|
+
self.log('val_loss', loss, sync_dist=True, prog_bar=True)
|
|
321
|
+
self.log('val_acc', acc, sync_dist=True, prog_bar=True)
|
|
322
|
+
|
|
323
|
+
return loss
|
|
324
|
+
|
|
325
|
+
def configure_optimizers(self):
|
|
326
|
+
optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.1)
|
|
327
|
+
return optimizer
|
|
328
|
+
|
|
329
|
+
def evaluation(self, ckp_path, test_dataloader):
|
|
330
|
+
ckp = torch.load(ckp_path)
|
|
331
|
+
params = ckp["state_dict"]
|
|
332
|
+
self.model.load_state_dict(params)
|
|
333
|
+
self.model.eval() # Set the model to evaluation mode
|
|
334
|
+
all_preds = []
|
|
335
|
+
all_labels = []
|
|
336
|
+
|
|
337
|
+
with torch.no_grad(): # Disable gradient calculation for efficiency
|
|
338
|
+
for batch in test_dataloader:
|
|
339
|
+
preds = self(batch[0].x, batch[0].edge_index) # Forward pass
|
|
340
|
+
labels = torch.tensor([self.label_str.index(label_str) for label_str in batch[1]], dtype=torch.long)
|
|
341
|
+
|
|
342
|
+
all_preds.append(preds.softmax(dim=1)) # Get predicted probabilities
|
|
343
|
+
all_labels.append(labels) # Collect true labels
|
|
344
|
+
|
|
345
|
+
# Concatenate all predictions and labels
|
|
346
|
+
all_preds = torch.cat(all_preds) # Shape: [num_samples, num_classes]
|
|
347
|
+
all_labels = torch.cat(all_labels) # Shape: [num_samples]
|
|
348
|
+
|
|
349
|
+
# Calculate metrics
|
|
350
|
+
acc = self.accuracy(all_preds.argmax(dim=1), all_labels) # Accuracy
|
|
351
|
+
f1 = self.f1(all_preds, all_labels) # F1 Score
|
|
352
|
+
mcc = self.mcc(all_preds, all_labels) # Matthews Correlation Coefficient
|
|
353
|
+
|
|
354
|
+
# Calculate AUC (average over one-vs-rest for multi-class)
|
|
355
|
+
AUCs = []
|
|
356
|
+
for i in range(len(self.label_str)):
|
|
357
|
+
AUCs.append(self.auc(all_preds[:, i], (all_labels == i).float()))
|
|
358
|
+
|
|
359
|
+
average_auc = sum(AUCs) / len(AUCs) # Average AUC over all classes
|
|
360
|
+
|
|
361
|
+
# Log the metrics
|
|
362
|
+
self.log('test_acc', acc, sync_dist=True)
|
|
363
|
+
self.log('test_f1', f1, sync_dist=True)
|
|
364
|
+
self.log('test_mcc', mcc, sync_dist=True)
|
|
365
|
+
self.log('test_auc', average_auc, sync_dist=True)
|
|
366
|
+
|
|
367
|
+
# Print final metrics
|
|
368
|
+
print(f"Test Accuracy: {acc.item()}")
|
|
369
|
+
print(f"Test F1 Score: {f1.item()}")
|
|
370
|
+
print(f"Test MCC: {mcc.item()}")
|
|
371
|
+
print(f"Average Test AUC: {average_auc.item()}")
|
|
372
|
+
return acc.item(), f1.item(), mcc.item(), average_auc.item()
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class MyTrainer:
|
|
380
|
+
def __init__(self, config, input_dim, train_dataset):
|
|
381
|
+
self.config = config
|
|
382
|
+
self.plmodel = GraphSAGEModel(
|
|
383
|
+
input_dim = input_dim,
|
|
384
|
+
hidden_dim=config["hidden_dim"],
|
|
385
|
+
output_dim=config["output_dim"],
|
|
386
|
+
lr = config["lr"],
|
|
387
|
+
train_dataset = train_dataset,
|
|
388
|
+
batch_size = config["batch_size"]
|
|
389
|
+
)
|
|
390
|
+
self.output_dir = "/home/sxr280/Spatialformer/output/GraphSAGE_model"
|
|
391
|
+
self.train_dataset = train_dataset
|
|
392
|
+
self.gpus = torch.cuda.device_count()
|
|
393
|
+
self.trainer = None
|
|
394
|
+
|
|
395
|
+
def make_callback(self):
|
|
396
|
+
# Callbacks
|
|
397
|
+
callbacks = [
|
|
398
|
+
ModelCheckpoint(
|
|
399
|
+
dirpath=os.path.join(self.output_dir, "GraphSAGE_model", "checkpoints"),
|
|
400
|
+
filename=f"{{step:07d}}-{{train_loss:.4f}}-{{val_loss:.4f}}-{{train_acc:.4f}}",
|
|
401
|
+
every_n_train_steps=10000,
|
|
402
|
+
save_top_k=-1,
|
|
403
|
+
# every_n_epochs=1,
|
|
404
|
+
monitor='train_loss',
|
|
405
|
+
save_on_train_epoch_end=False
|
|
406
|
+
), LearningRateMonitor(logging_interval="step"),
|
|
407
|
+
# EarlyStopping(monitor = "val_loss", min_delta = 0.00, verbose = True, mode = "min")
|
|
408
|
+
]
|
|
409
|
+
|
|
410
|
+
return callbacks
|
|
411
|
+
def set_trainer(self):
|
|
412
|
+
self.logger = WandbLogger(project = "Spaformer",
|
|
413
|
+
name = "GraphSAGE",
|
|
414
|
+
log_model = "all",
|
|
415
|
+
save_dir = self.output_dir)
|
|
416
|
+
# self.logger = CSVLogger("/home/sxr280/Spatialformer/output", name="my_experiment")
|
|
417
|
+
|
|
418
|
+
self.trainer = pl.Trainer(
|
|
419
|
+
accelerator="auto",
|
|
420
|
+
devices=self.gpus,
|
|
421
|
+
max_steps=self.config["total_step"],
|
|
422
|
+
val_check_interval = 0.1,
|
|
423
|
+
default_root_dir=self.output_dir,
|
|
424
|
+
callbacks=self.make_callback(),
|
|
425
|
+
log_every_n_steps=50,
|
|
426
|
+
logger=self.logger,
|
|
427
|
+
precision='bf16',
|
|
428
|
+
strategy = self.config['strategy'],
|
|
429
|
+
num_nodes = 1
|
|
430
|
+
)
|
|
431
|
+
def resume_train(self, ckp, train_loader, val_loader):
|
|
432
|
+
self.logger = WandbLogger(project = "Spaformer",
|
|
433
|
+
name = "GraphSAGE",
|
|
434
|
+
log_model = "all",
|
|
435
|
+
save_dir = self.output_dir)
|
|
436
|
+
# import pdb; pdb.set_trace()
|
|
437
|
+
logging.info("resuming the training ...")
|
|
438
|
+
self.trainer = pl.Trainer(
|
|
439
|
+
accelerator="auto",
|
|
440
|
+
devices=self.gpus,
|
|
441
|
+
strategy = self.config['strategy'],
|
|
442
|
+
num_nodes = 1,
|
|
443
|
+
val_check_interval = 0.1,
|
|
444
|
+
gradient_clip_val = 1,
|
|
445
|
+
logger=self.logger,
|
|
446
|
+
default_root_dir=self.output_dir,
|
|
447
|
+
log_every_n_steps=50,
|
|
448
|
+
check_val_every_n_epoch=1,
|
|
449
|
+
precision='bf16',
|
|
450
|
+
callbacks=self.make_callback(),
|
|
451
|
+
max_steps=self.config["total_step"],
|
|
452
|
+
resume_from_checkpoint=ckp,
|
|
453
|
+
accumulate_grad_batches = self.config['accumulate_grad_batches'])
|
|
454
|
+
self.trainer.fit(self.plmodel, train_loader)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def train(self):
|
|
458
|
+
# import pdb; pdb.set_trace()
|
|
459
|
+
self.set_trainer()
|
|
460
|
+
self.trainer.fit(self.plmodel)
|
|
461
|
+
def get_embedding(self, ckp_path, batch_size, token_path, output_dim):
|
|
462
|
+
'''
|
|
463
|
+
Getting the gene embeddings that merge from the transcripts
|
|
464
|
+
'''
|
|
465
|
+
model = self.plmodel
|
|
466
|
+
ckp = torch.load(ckp_path)
|
|
467
|
+
params = ckp["state_dict"]
|
|
468
|
+
model.load_state_dict(params)
|
|
469
|
+
|
|
470
|
+
model.eval()
|
|
471
|
+
|
|
472
|
+
gene_embeds = {}
|
|
473
|
+
|
|
474
|
+
#loading the token path
|
|
475
|
+
with open(os.path.join(token_path), 'r') as json_file:
|
|
476
|
+
token_config = json.load(json_file)
|
|
477
|
+
token_num = np.max([j for i,j in token_config.items()]) + 1
|
|
478
|
+
pretrained_embeddings = torch.rand(token_num, output_dim)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
# import pdb;pdb.set_trace()
|
|
483
|
+
# Ensure no gradient tracking during evaluation
|
|
484
|
+
with torch.no_grad():
|
|
485
|
+
indices = torch.argmax(self.train_dataset.x, axis=1)
|
|
486
|
+
# genes = [index_to_gene[indice.item()] for indice in indices]
|
|
487
|
+
genes = [vocab_list[indice.item()] for indice in indices]
|
|
488
|
+
# Generate embeddings for all nodes
|
|
489
|
+
embeddings = model(self.train_dataset.x, self.train_dataset.edge_index)
|
|
490
|
+
# Group embeddings by gene
|
|
491
|
+
for i, gene in enumerate(genes):
|
|
492
|
+
if gene not in gene_embeds:
|
|
493
|
+
gene_embeds[gene] = []
|
|
494
|
+
gene_embeds[gene].append(embeddings[i])
|
|
495
|
+
|
|
496
|
+
# gene_embed = {gene: torch.mean(torch.stack(embeds), dim=0) for gene, embeds in gene_embeds.items()}
|
|
497
|
+
#transfer gene to embedding by token ids
|
|
498
|
+
for gene, embeds in gene_embeds.items():
|
|
499
|
+
# import pdb; pdb.set_trace()
|
|
500
|
+
pretrained_embeddings[token_config[gene]] = torch.mean(torch.stack(embeds), dim=0)
|
|
501
|
+
|
|
502
|
+
#settign the padding as 0
|
|
503
|
+
pretrained_embeddings[0] = 0
|
|
504
|
+
# import pdb; pdb.set_trace()
|
|
505
|
+
return pretrained_embeddings
|
|
506
|
+
|
|
507
|
+
class PatternTrainer:
|
|
508
|
+
def __init__(self, model, lr, strategy, output_dim, output_dir, train_dataloader, val_dataloader, test_dataloader):
|
|
509
|
+
self.plmodel = PatternGraphSAGEModel(
|
|
510
|
+
model,
|
|
511
|
+
output_dim=output_dim,
|
|
512
|
+
nn_hidden_layer_dim = 8,
|
|
513
|
+
lr = lr
|
|
514
|
+
)
|
|
515
|
+
self.output_dir = output_dir
|
|
516
|
+
self.train_dataloader = train_dataloader
|
|
517
|
+
self.gpus = torch.cuda.device_count()
|
|
518
|
+
self.trainer = None
|
|
519
|
+
self.strategy = strategy
|
|
520
|
+
self.train_dataloader = train_dataloader
|
|
521
|
+
self.val_dataloader = val_dataloader
|
|
522
|
+
self.test_dataloader = test_dataloader
|
|
523
|
+
|
|
524
|
+
def make_callback(self):
|
|
525
|
+
# Callbacks
|
|
526
|
+
callbacks = [
|
|
527
|
+
ModelCheckpoint(
|
|
528
|
+
dirpath=os.path.join(self.output_dir, "GraphSAGE_model", "checkpoints"),
|
|
529
|
+
filename=f"{{step:07d}}-{{train_loss:.4f}}-{{val_loss:.4f}}-{{train_acc:.4f}}",
|
|
530
|
+
every_n_train_steps=10000,
|
|
531
|
+
save_top_k=-1,
|
|
532
|
+
# every_n_epochs=1,
|
|
533
|
+
monitor='train_loss',
|
|
534
|
+
save_on_train_epoch_end=False
|
|
535
|
+
), LearningRateMonitor(logging_interval="step"),
|
|
536
|
+
# EarlyStopping(monitor = "val_loss", min_delta = 0.00, verbose = True, mode = "min")
|
|
537
|
+
]
|
|
538
|
+
|
|
539
|
+
return callbacks
|
|
540
|
+
def set_trainer(self):
|
|
541
|
+
# self.logger = WandbLogger(project = "Spaformer",
|
|
542
|
+
# name = "GraphSAGE_pattern",
|
|
543
|
+
# log_model = "all",
|
|
544
|
+
# save_dir = self.output_dir)
|
|
545
|
+
self.logger = CSVLogger("/scratch/project_465001027/Spatialformer/output/GraphSAGE_model", name="my_experiment")
|
|
546
|
+
|
|
547
|
+
self.trainer = pl.Trainer(
|
|
548
|
+
accelerator="auto",
|
|
549
|
+
devices=self.gpus,
|
|
550
|
+
max_steps=10000,
|
|
551
|
+
val_check_interval = 0.1,
|
|
552
|
+
default_root_dir=self.output_dir,
|
|
553
|
+
callbacks=self.make_callback(),
|
|
554
|
+
log_every_n_steps=50,
|
|
555
|
+
logger=self.logger,
|
|
556
|
+
precision='bf16',
|
|
557
|
+
strategy = self.strategy,
|
|
558
|
+
num_nodes = 1,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
def resume_train(self, ckp, train_loader, val_loader):
|
|
562
|
+
self.logger = WandbLogger(project = "Spaformer",
|
|
563
|
+
name = "GraphSAGE",
|
|
564
|
+
log_model = "all",
|
|
565
|
+
save_dir = self.output_dir)
|
|
566
|
+
# import pdb; pdb.set_trace()
|
|
567
|
+
logging.info("resuming the training ...")
|
|
568
|
+
self.trainer = pl.Trainer(
|
|
569
|
+
accelerator="auto",
|
|
570
|
+
devices=self.gpus,
|
|
571
|
+
strategy = self.config['strategy'],
|
|
572
|
+
num_nodes = 1,
|
|
573
|
+
val_check_interval = 0.1,
|
|
574
|
+
gradient_clip_val = 1,
|
|
575
|
+
logger=self.logger,
|
|
576
|
+
default_root_dir=self.output_dir,
|
|
577
|
+
log_every_n_steps=50,
|
|
578
|
+
check_val_every_n_epoch=1,
|
|
579
|
+
precision='bf16',
|
|
580
|
+
callbacks=self.make_callback(),
|
|
581
|
+
max_steps=self.config["total_step"],
|
|
582
|
+
resume_from_checkpoint=ckp,
|
|
583
|
+
accumulate_grad_batches = self.config['accumulate_grad_batches'])
|
|
584
|
+
self.trainer.fit(self.plmodel, train_loader)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def train(self):
|
|
588
|
+
# import pdb; pdb.set_trace()
|
|
589
|
+
self.set_trainer()
|
|
590
|
+
self.trainer.fit(self.plmodel, self.train_dataloader, self.val_dataloader)
|
|
591
|
+
def evaluation(self, ckp_path):
|
|
592
|
+
acc, f1, mcc, average_auc = self.plmodel.evaluation(ckp_path, self.test_dataloader)
|
|
593
|
+
return acc, f1, mcc, average_auc
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def get_embedding(self, ckp_path, batch_size, token_path, output_dim):
|
|
597
|
+
'''
|
|
598
|
+
Getting the gene embeddings that merge from the transcripts
|
|
599
|
+
'''
|
|
600
|
+
model = self.plmodel
|
|
601
|
+
ckp = torch.load(ckp_path)
|
|
602
|
+
params = ckp["state_dict"]
|
|
603
|
+
model.load_state_dict(params)
|
|
604
|
+
|
|
605
|
+
model.eval()
|
|
606
|
+
|
|
607
|
+
gene_embeds = {}
|
|
608
|
+
|
|
609
|
+
#loading the token path
|
|
610
|
+
with open(os.path.join(token_path), 'r') as json_file:
|
|
611
|
+
token_config = json.load(json_file)
|
|
612
|
+
token_num = np.max([j for i,j in token_config.items()]) + 1
|
|
613
|
+
pretrained_embeddings = torch.rand(token_num, output_dim)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
# import pdb;pdb.set_trace()
|
|
618
|
+
# Ensure no gradient tracking during evaluation
|
|
619
|
+
with torch.no_grad():
|
|
620
|
+
indices = torch.argmax(self.train_dataset.x, axis=1)
|
|
621
|
+
# genes = [index_to_gene[indice.item()] for indice in indices]
|
|
622
|
+
genes = [vocab_list[indice.item()] for indice in indices]
|
|
623
|
+
# Generate embeddings for all nodes
|
|
624
|
+
embeddings = model(self.train_dataset.x, self.train_dataset.edge_index)
|
|
625
|
+
# Group embeddings by gene
|
|
626
|
+
for i, gene in enumerate(genes):
|
|
627
|
+
if gene not in gene_embeds:
|
|
628
|
+
gene_embeds[gene] = []
|
|
629
|
+
gene_embeds[gene].append(embeddings[i])
|
|
630
|
+
|
|
631
|
+
# gene_embed = {gene: torch.mean(torch.stack(embeds), dim=0) for gene, embeds in gene_embeds.items()}
|
|
632
|
+
#transfer gene to embedding by token ids
|
|
633
|
+
for gene, embeds in gene_embeds.items():
|
|
634
|
+
# import pdb; pdb.set_trace()
|
|
635
|
+
pretrained_embeddings[token_config[gene]] = torch.mean(torch.stack(embeds), dim=0)
|
|
636
|
+
|
|
637
|
+
#settign the padding as 0
|
|
638
|
+
pretrained_embeddings[0] = 0
|
|
639
|
+
# import pdb; pdb.set_trace()
|
|
640
|
+
return pretrained_embeddings
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
if __name__ == "__main__":
|
|
649
|
+
|
|
650
|
+
parser = argparse.ArgumentParser(description='calculate the gene graph')
|
|
651
|
+
parser.add_argument('--save_graph', action = 'store_true', help='only save the graph for each sample')
|
|
652
|
+
parser.add_argument('--data_dir', type=str, default = "/tmp/erda/Spatialformer/downloaded_data/raw/", help='the parent path of the data')
|
|
653
|
+
parser.add_argument('--mode', type=str, default = "train", help='the mode to run the code')
|
|
654
|
+
|
|
655
|
+
args = parser.parse_args()
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
data_dir = args.data_dir
|
|
659
|
+
mode = args.mode
|
|
660
|
+
|
|
661
|
+
mouse_names = ["Xenium_V1_FFPE_TgCRND8_17_9_months_outs",
|
|
662
|
+
"Xenium_V1_FFPE_TgCRND8_2_5_months_outs",
|
|
663
|
+
"Xenium_V1_FFPE_TgCRND8_5_7_months_outs",
|
|
664
|
+
"Xenium_V1_FFPE_wildtype_13_4_months_outs",
|
|
665
|
+
"Xenium_V1_FFPE_wildtype_2_5_months_outs",
|
|
666
|
+
"Xenium_V1_FFPE_wildtype_5_7_months_outs",
|
|
667
|
+
"Xenium_V1_mouse_pup_outs",
|
|
668
|
+
"Xenium_V1_mouse_Colon_FF_outs",
|
|
669
|
+
"Xenium_V1_FF_Mouse_Brain_Coronal_outs",
|
|
670
|
+
"Xenium_V1_FF_Mouse_Brain_Coronal_Subset_CTX_HP_outs",
|
|
671
|
+
"Xenium_Prime_Mouse_Brain_Coronal_FF_outs",
|
|
672
|
+
"Xenium_V1_mFemur_formic_acid_24hrdecal_section_outs",
|
|
673
|
+
"Xenium_V1_mFemur_EDTA_3daydecal_section_outs",
|
|
674
|
+
"Xenium_V1_mFemur_EDTA_PFA_3daydecal_section_outs",
|
|
675
|
+
"Xenium_V1_FF_Mouse_Brain_MultiSection_1_outs",
|
|
676
|
+
"Xenium_V1_FF_Mouse_Brain_MultiSection_2_outs",
|
|
677
|
+
"Xenium_V1_FF_Mouse_Brain_MultiSection_3_outs"]
|
|
678
|
+
large_human_file = [
|
|
679
|
+
"Xenium_Prime_Human_Ovary_FF_outs",
|
|
680
|
+
"Xenium_Prime_Ovarian_Cancer_FFPE_outs",
|
|
681
|
+
"Xenium_Prime_Cervical_Cancer_FFPE_outs",
|
|
682
|
+
"Xenium_Prime_Human_Skin_FFPE_outs",
|
|
683
|
+
"Xenium_Prime_Human_Prostate_FFPE_outs",
|
|
684
|
+
"Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
|
|
685
|
+
"Xenium_V1_hBoneMarrow_nondiseased_section_outs",
|
|
686
|
+
"Xenium_V1_hBone_nondiseased_section_outs"
|
|
687
|
+
]
|
|
688
|
+
failed_human_file = [
|
|
689
|
+
"Xenium_V1_hBoneMarrow_nondiseased_section_outs",
|
|
690
|
+
"Xenium_V1_hBone_nondiseased_section_outs"
|
|
691
|
+
]
|
|
692
|
+
|
|
693
|
+
#only select the human Xenium datasets
|
|
694
|
+
# import pdb;pdb.set_trace()
|
|
695
|
+
root_path = "/tmp/erda/Spatialformer/downloaded_data/raw"
|
|
696
|
+
|
|
697
|
+
sample_files = [
|
|
698
|
+
os.path.join(root_path, file, "transcripts.csv")
|
|
699
|
+
for file in os.listdir("/tmp/erda/Spatialformer/downloaded_data/raw")
|
|
700
|
+
if (
|
|
701
|
+
".zip" not in file and
|
|
702
|
+
file not in mouse_names and
|
|
703
|
+
file not in large_human_file and
|
|
704
|
+
file not in failed_human_file and (
|
|
705
|
+
os.path.exists(os.path.join(root_path, file, "transcripts.csv")) or
|
|
706
|
+
os.path.exists(os.path.join(root_path, file, "transcripts.csv.gz"))
|
|
707
|
+
)
|
|
708
|
+
)
|
|
709
|
+
]
|
|
710
|
+
global vocab_list
|
|
711
|
+
vocab_list = index2gene("/home/sxr280/Spatialformer/tokenizer/tokenv3.json")
|
|
712
|
+
|
|
713
|
+
if args.save_graph:
|
|
714
|
+
#saving all the intermediate data
|
|
715
|
+
# import pdb; pdb.set_trace()
|
|
716
|
+
print("getting the subgraph")
|
|
717
|
+
file = "/tmp/erda/Spatialformer/downloaded_data/raw/Xenium_V1_hHeart_nondiseased_section_FFPE_outs/transcripts.csv"
|
|
718
|
+
build_graph_for_sample(load_and_preprocess_data(file), sample_id = file.split("/")[-2])
|
|
719
|
+
# import pdb; pdb.set_trace()
|
|
720
|
+
# build_graph_for_sample(load_and_preprocess_data(sample_files[0]), sample_id = sample_files[0].split("__")[-3])
|
|
721
|
+
# all_samples = [build_graph_for_sample(load_and_preprocess_data(sample_file), sample_id = sample_file.split("__")[-3]) for sample_file in sample_files if f'subgraph_data_{sample_file.split("__")[-3]}.pt' not in os.listdir(data_dir)]
|
|
722
|
+
all_samples = [build_graph_for_sample(load_and_preprocess_data(sample_file), sample_id = sample_file.split("/")[-2]) for sample_file in sample_files if f'subgraph_data_{sample_file.split("/")[-2]}.pt' not in os.listdir(data_dir)]
|
|
723
|
+
|
|
724
|
+
else:
|
|
725
|
+
print("WARNNING: please make sure you have already save the graph for each sample")
|
|
726
|
+
print(f"loading the {str(len(sample_files))} saved subgraphs")
|
|
727
|
+
subgraphs = [
|
|
728
|
+
torch.load(os.path.join(data_dir, f"subgraph_data_{os.path.basename(os.path.dirname(sample_file))}.pt"))
|
|
729
|
+
for sample_file in sample_files
|
|
730
|
+
]
|
|
731
|
+
|
|
732
|
+
print("building the full graph")
|
|
733
|
+
# Combine subgraphs into a joint large graph
|
|
734
|
+
joint_x = torch.cat([subgraph.x for subgraph in subgraphs], dim=0)
|
|
735
|
+
offset = 0
|
|
736
|
+
edge_lists = []
|
|
737
|
+
for subgraph in subgraphs:
|
|
738
|
+
edge_lists.append(subgraph.edge_index + offset)
|
|
739
|
+
offset += subgraph.x.size(0)
|
|
740
|
+
joint_edge_index = torch.cat(edge_lists, dim=1)
|
|
741
|
+
joint_graph = Data(x=joint_x, edge_index=joint_edge_index)
|
|
742
|
+
print("building the dataloader")
|
|
743
|
+
|
|
744
|
+
print("training the model")
|
|
745
|
+
# Example execution
|
|
746
|
+
with open(os.path.join("/home/sxr280/Spatialformer/config/_config_graphsave.json"), 'r') as json_file:
|
|
747
|
+
config = json.load(json_file)
|
|
748
|
+
trainer = MyTrainer(config, len(vocab_list),joint_graph)
|
|
749
|
+
#training the model
|
|
750
|
+
if mode == "trian":
|
|
751
|
+
trainer.train()
|
|
752
|
+
elif mode == "test":
|
|
753
|
+
|
|
754
|
+
#getting the embeddings
|
|
755
|
+
embeddings = trainer.get_embedding("/home/sxr280/Spatialformer/output/GraphSAGE_model/GraphSAGE_model/checkpoints/step=0010000-train_loss=0.3983-val_loss=0.0000-train_acc=0.8256.ckpt", 32,
|
|
756
|
+
"/home/sxr280/Spatialformer/tokenizer/tokenv3.json", config["output_dim"])
|
|
757
|
+
pickle.dump(embeddings, open("/home/sxr280/Spatialformer/data/gene_embeddings_GraphSAGE_pandavid.pkl", "wb"))
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
|