spatialformer 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,765 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import torch.nn.functional as F
7
+ # import pdb; pdb.set_trace()
8
+ from torch_geometric.loader import NeighborLoader, DataLoader
9
+ from torch_geometric.nn import SAGEConv
10
+ from torch_geometric.data import Data
11
+ from torch_geometric.utils import k_hop_subgraph, negative_sampling, add_self_loops
12
+ from sklearn.preprocessing import OneHotEncoder
13
+ from torch_geometric.utils import to_networkx, from_networkx
14
+ from scipy.spatial import KDTree
15
+ from torch_geometric.nn import global_mean_pool
16
+ from torchmetrics.classification import BinaryAccuracy
17
+ from torchmetrics.classification import MulticlassAccuracy
18
+ from pytorch_lightning.loggers import WandbLogger, CSVLogger
19
+ import networkx as nx
20
+ import logging
21
+ from tqdm import tqdm
22
+ import json
23
+ import random
24
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
25
+ import argparse
26
+ from pathlib import Path
27
+ from datetime import datetime
28
+ import pytorch_lightning as pl
29
+ import pickle
30
+ import os
31
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
32
+
33
+ path = Path(os.getcwd())
34
+ parent_dir = path.parent
35
+ # import pdb; pdb.set_trace()
36
+ # data_dir = os.path.join(parent_dir, "david_data")
37
+
38
+ model_path = os.path.join(parent_dir, "output", "GraphSAGE_model")
39
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
+
41
+
42
+ # Step 1: Load and preprocess data
43
+ def load_and_preprocess_data(filepath):
44
+ print(f"loading {filepath}")
45
+ try:
46
+ dataset = pd.read_csv(filepath)
47
+ except:
48
+ dataset = pd.read_csv(filepath + ".gz")
49
+ print("after loading the data")
50
+ if "qv" in dataset.columns:
51
+ dataset = dataset[dataset['qv'] >= 20]
52
+ dataset = dataset[~(dataset['feature_name'].str.startswith('Neg') | dataset['feature_name'].str.startswith('BLANK') | dataset['feature_name'].str.startswith('Unassigned'))]
53
+ #only use partial data
54
+ # Ensure that we work with distinct cell IDs
55
+ unique_cell_ids = dataset['cell_id'].unique()
56
+ # Randomly sample 20,000 unique cell_ids
57
+ # Make sure that there are at least 20,000 unique cell IDs before sampling
58
+ if len(unique_cell_ids) >= 20000:
59
+ sampled_cell_ids = pd.Series(unique_cell_ids).sample(n=20000, random_state=42)
60
+ else:
61
+ sampled_cell_ids = pd.Series(unique_cell_ids)
62
+ # import pdb; pdb.set_trace()
63
+ # Select only the rows where `cell_id` is in the sampled set
64
+ sampled_dataset = dataset[dataset['cell_id'].isin(sampled_cell_ids)]
65
+ # dataset = dataset.iloc[:10000]
66
+ return sampled_dataset
67
+
68
+ def index2gene(filepath):
69
+ #get the reference gene vocabulary
70
+ with open(filepath, "r") as f:
71
+ vocab = json.load(f)
72
+ vocab_list = list(vocab.keys())[27:]
73
+ return vocab_list
74
+
75
+
76
+ def build_graph_for_sample(data, threshold=3.0, batch_size=100, sample_id = None):
77
+ # import pdb; pdb.set_trace()
78
+ r_c = np.array(data[['x_location', 'y_location', 'z_location']])
79
+ gene_labels = data['feature_name'].values
80
+ # Convert vocab dictionary keys to a sorted list
81
+ # Initialize the OneHotEncoder with the specific categories
82
+ encoder = OneHotEncoder(categories=[vocab_list], sparse_output=False)
83
+ # Transform the gene_labels into one-hot encoded format
84
+ one_hot_labels = encoder.fit_transform(gene_labels.reshape(-1, 1))
85
+ # import pdb; pdb.set_trace()
86
+
87
+ # import pdb; pdb.set_trace()
88
+ kdtree = KDTree(r_c)
89
+ G = nx.Graph()
90
+ # import pdb; pdb.set_trace()
91
+ for i in range(len(r_c)):
92
+ G.add_node(i, feature=one_hot_labels[i])
93
+ # import pdb; pdb.set_trace()
94
+ num_nodes = len(r_c)
95
+
96
+ print("building graph in batch")
97
+ for start_idx in tqdm(range(0, num_nodes, batch_size)):
98
+ end_idx = min(start_idx + batch_size, num_nodes)
99
+ batch_r_c = r_c[start_idx:end_idx]
100
+ edges_to_add = []
101
+ for i, x in enumerate(batch_r_c, start=start_idx):
102
+ # import pdb; pdb.set_trace()
103
+ neighbors_idx = kdtree.query_ball_point(x, threshold)
104
+ for j in neighbors_idx:
105
+ if i < j:
106
+ edges_to_add.append((i, j))
107
+ G.add_edges_from(edges_to_add)
108
+
109
+ # Batch add edges
110
+ # import pdb; pdb.set_trace()
111
+ # import pdb; pdb.set_trace()
112
+ edge_index = torch.tensor(list(G.edges)).t().contiguous()
113
+ x = torch.tensor(one_hot_labels, dtype=torch.float)
114
+ # import pdb; pdb.set_trace()
115
+ num_nodes = x.size(0)
116
+
117
+ root_nodes = torch.tensor(random.sample(range(num_nodes), min(5000, num_nodes)))
118
+
119
+ # import pdb; pdb.set_trace()
120
+ print("creating the subgraph...")
121
+ subgraph_nodes, subgraph_edge_index, _, _ = k_hop_subgraph(
122
+ node_idx=root_nodes,
123
+ num_hops=3,
124
+ edge_index=edge_index,
125
+ relabel_nodes=True
126
+ )
127
+ # import pdb; pdb.set_trace()
128
+ x_subgraph = x[subgraph_nodes]
129
+ # import pdb; pdb.set_trace()
130
+ data = Data(x=x_subgraph, edge_index=subgraph_edge_index)
131
+
132
+
133
+ print("saving the graph")
134
+ torch.save(data, f'{data_dir}/subgraph_data_{sample_id}.pt')
135
+
136
+
137
+
138
+ # Step 2: Create subgraphs
139
+ def create_subgraph(data, num_root_nodes=5000, num_neighbors=[20, 10, 10]):
140
+ num_nodes = data.x.size(0)
141
+ root_nodes = torch.tensor(random.sample(range(num_nodes), min(num_root_nodes, num_nodes)))
142
+
143
+ subgraph_nodes, subgraph_edge_index, _, _ = k_hop_subgraph(
144
+ node_idx=root_nodes,
145
+ num_hops=len(num_neighbors),
146
+ edge_index=data.edge_index,
147
+ relabel_nodes=True
148
+ )
149
+
150
+ x_subgraph = data.x[subgraph_nodes]
151
+
152
+ components = [c for c in nx.connected_components(G) if len(c) >= 10]
153
+ G = G.subgraph(set.union(*map(set, components)))
154
+
155
+
156
+ return Data(x=x_subgraph, edge_index=subgraph_edge_index)
157
+
158
+
159
+ # Step 4: Define and train the 2-hop GraphSAGE Model
160
+ class TwoHopGraphSAGE(nn.Module):
161
+ def __init__(self, in_channels, hidden_channels, out_channels):
162
+ super(TwoHopGraphSAGE, self).__init__()
163
+ self.conv1 = SAGEConv(in_channels, hidden_channels)
164
+ self.conv2 = SAGEConv(hidden_channels, out_channels)
165
+
166
+ def forward(self, x, edge_index):
167
+ x = self.conv1(x, edge_index)
168
+ x = F.relu(x)
169
+ x = self.conv2(x, edge_index)
170
+ return x
171
+ def filter_component(data):
172
+ # import pdb; pdb.set_trace()
173
+ print("subgraph to networkx...")
174
+ sub_G = to_networkx(data, to_undirected=True, node_attrs=['x'])
175
+ print("filtering by components...")
176
+ components = [c for c in nx.connected_components(sub_G) if len(c) >= 10]
177
+ sub_G_f = sub_G.subgraph(set.union(*map(set, components)))
178
+
179
+ data_filtered = from_networkx(sub_G_f, group_node_attrs=['x'])
180
+ return data_filtered
181
+
182
+
183
+
184
+ # Define your LightningModule
185
+ class GraphSAGEModel(pl.LightningModule):
186
+ def __init__(self, input_dim, hidden_dim, output_dim, lr, train_dataset, batch_size):
187
+ super(GraphSAGEModel, self).__init__()
188
+ self.conv1 = SAGEConv(input_dim, hidden_dim)
189
+ self.conv2 = SAGEConv(hidden_dim, output_dim)
190
+ self.accuracy = BinaryAccuracy()
191
+ self.lr = lr
192
+ self.train_dataset = train_dataset
193
+ self.batch_size = batch_size
194
+
195
+ def forward(self, x, edge_index):
196
+ x = self.conv1(x, edge_index)
197
+ x = F.relu(x)
198
+ x = self.conv2(x, edge_index)
199
+ return x
200
+
201
+ def training_step(self, batch, batch_idx):
202
+ z = self(batch.x, batch.edge_index) # Pass through the model
203
+
204
+ # Use positive samples from existing edges
205
+ pos_edge_index = batch.edge_index
206
+
207
+ # Generate negative samples
208
+ neg_edge_index = negative_sampling(
209
+ edge_index=pos_edge_index,
210
+ num_nodes=batch.x.size(0),
211
+ num_neg_samples=pos_edge_index.size(1)
212
+ )
213
+
214
+ # Compute dot product of embeddings for positive and negative samples
215
+ pos_out = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=-1)
216
+ neg_out = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=-1)
217
+
218
+ # Concatenate all outputs and create labels
219
+ all_out = torch.cat([pos_out, neg_out])
220
+ all_labels = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
221
+
222
+ # Define the binary classification loss
223
+ loss = F.binary_cross_entropy_with_logits(all_out, all_labels)
224
+
225
+ # Evaluate accuracy
226
+ preds = torch.sigmoid(all_out) > 0.5
227
+ acc = self.accuracy(preds, all_labels.int())
228
+
229
+ # Log loss and accuracy
230
+ self.log('train_loss', loss, sync_dist=True, reduce_fx='mean', prog_bar=True, batch_size=batch.x.size(0))
231
+ self.log('train_acc', acc, sync_dist=True, reduce_fx='mean', prog_bar=True, batch_size=batch.x.size(0))
232
+
233
+
234
+ return loss
235
+
236
+ def configure_optimizers(self):
237
+ # return optim.Adam(self.parameters(), lr=0.001)
238
+
239
+ optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.1)
240
+ return optimizer
241
+
242
+ def train_dataloader(self):
243
+
244
+ loader = NeighborLoader(
245
+ data=self.train_dataset,
246
+ num_neighbors=[20, 10], # 20 first-hop, 10 second-hop neighbors
247
+ batch_size=self.batch_size, # Example batch size for root nodes
248
+ shuffle=True,
249
+ num_workers=4 # Set based on your CPU availability
250
+ )
251
+ return loader
252
+
253
+
254
+
255
+ class PatternGraphSAGEModel(pl.LightningModule):
256
+ def __init__(self, model, output_dim, nn_hidden_layer_dim, lr):
257
+ super(PatternGraphSAGEModel, self).__init__()
258
+ self.model = model
259
+ self.fc = nn.Linear(512, 2) # 5 classes
260
+ self.hidden_layer = nn.Linear(output_dim, nn_hidden_layer_dim)
261
+
262
+ # self.label_str = ["random", "extranuclear", "perinuclear", "pericellular", "intranuclear"]
263
+ self.label_str = ["random", "non-random"]
264
+ self.accuracy = MulticlassAccuracy(num_classes=len(self.label_str))
265
+
266
+ self.lr = lr
267
+
268
+ def forward(self, x, edge_index, batch_indice):
269
+ x = self.model(x, edge_index)
270
+ # x = F.relu(x)
271
+ # Activation is not required before the hidden layer
272
+ # x = self.hidden_layer(x)
273
+ x = F.relu(x) # Activation after the hidden layer output
274
+ # Use global mean pooling
275
+ x = global_mean_pool(x, batch_indice)
276
+ # Finally apply the classification layer
277
+ x = self.fc(x)
278
+
279
+ return x # Return raw logits for CrossEntropyLoss
280
+
281
+ def training_step(self, batch, batch_idx):
282
+ # import pdb; pdb.set_trace()
283
+ preds = self(batch[0].x, batch[0].edge_index, batch[0].batch) # Pass through the model
284
+ # labels = torch.tensor([self.label_str.index(label_str) for label_str in batch[1]], dtype=torch.long).to(preds.device)
285
+ labels = torch.tensor([0 if label_str == "random" else 1 for label_str in batch[1]], dtype=torch.long).to(preds.device)
286
+
287
+ # labels = torch.tensor([0 for label_str in batch[1] if label_str != "intranuclear"], dtype=torch.long).to(preds.device)
288
+ # Define the multi-class classification loss
289
+ loss = F.cross_entropy(preds, labels) # For multi-class
290
+
291
+ # Evaluate accuracy
292
+ # acc = self.accuracy(preds, labels)
293
+ acc = self.accuracy(preds.argmax(dim=1), labels)
294
+ # Log loss and accuracy
295
+ self.log('train_loss', loss, sync_dist=True, prog_bar=True)
296
+ self.log('train_acc', acc, sync_dist=True, prog_bar=True)
297
+
298
+ return loss
299
+
300
+ def validation_step(self, batch, batch_idx):
301
+
302
+ preds = self(batch[0].x, batch[0].edge_index, batch[0].batch) # Pass through the model
303
+ # labels = torch.tensor([self.label_str.index(label_str) for label_str in batch[1]], dtype=torch.long).to(preds.device)
304
+ # Convert patterns from strings to binary labels
305
+ labels = torch.tensor([0 if label_str == "random" else 1 for label_str in batch[1]], dtype=torch.long).to(preds.device)
306
+
307
+ # labels = torch.tensor([0 if label_str != "intranuclear" for label_str in batch[1] else 1], dtype=torch.long).to(preds.device)
308
+ # import pdb; pdb.set_trace()
309
+ # Define the multi-class classification loss
310
+ # import pdb; pdb.set_trace()
311
+ # criterion = torch.nn.CrossEntropyLoss()
312
+ # criterion(preds, labels)
313
+ loss = F.cross_entropy(preds, labels) # For multi-class
314
+
315
+ # Evaluate accuracy
316
+ acc = self.accuracy(preds.argmax(dim=1), labels)
317
+ # acc = self.accuracy(preds, labels)
318
+
319
+ # Log loss and accuracy
320
+ self.log('val_loss', loss, sync_dist=True, prog_bar=True)
321
+ self.log('val_acc', acc, sync_dist=True, prog_bar=True)
322
+
323
+ return loss
324
+
325
+ def configure_optimizers(self):
326
+ optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.1)
327
+ return optimizer
328
+
329
+ def evaluation(self, ckp_path, test_dataloader):
330
+ ckp = torch.load(ckp_path)
331
+ params = ckp["state_dict"]
332
+ self.model.load_state_dict(params)
333
+ self.model.eval() # Set the model to evaluation mode
334
+ all_preds = []
335
+ all_labels = []
336
+
337
+ with torch.no_grad(): # Disable gradient calculation for efficiency
338
+ for batch in test_dataloader:
339
+ preds = self(batch[0].x, batch[0].edge_index) # Forward pass
340
+ labels = torch.tensor([self.label_str.index(label_str) for label_str in batch[1]], dtype=torch.long)
341
+
342
+ all_preds.append(preds.softmax(dim=1)) # Get predicted probabilities
343
+ all_labels.append(labels) # Collect true labels
344
+
345
+ # Concatenate all predictions and labels
346
+ all_preds = torch.cat(all_preds) # Shape: [num_samples, num_classes]
347
+ all_labels = torch.cat(all_labels) # Shape: [num_samples]
348
+
349
+ # Calculate metrics
350
+ acc = self.accuracy(all_preds.argmax(dim=1), all_labels) # Accuracy
351
+ f1 = self.f1(all_preds, all_labels) # F1 Score
352
+ mcc = self.mcc(all_preds, all_labels) # Matthews Correlation Coefficient
353
+
354
+ # Calculate AUC (average over one-vs-rest for multi-class)
355
+ AUCs = []
356
+ for i in range(len(self.label_str)):
357
+ AUCs.append(self.auc(all_preds[:, i], (all_labels == i).float()))
358
+
359
+ average_auc = sum(AUCs) / len(AUCs) # Average AUC over all classes
360
+
361
+ # Log the metrics
362
+ self.log('test_acc', acc, sync_dist=True)
363
+ self.log('test_f1', f1, sync_dist=True)
364
+ self.log('test_mcc', mcc, sync_dist=True)
365
+ self.log('test_auc', average_auc, sync_dist=True)
366
+
367
+ # Print final metrics
368
+ print(f"Test Accuracy: {acc.item()}")
369
+ print(f"Test F1 Score: {f1.item()}")
370
+ print(f"Test MCC: {mcc.item()}")
371
+ print(f"Average Test AUC: {average_auc.item()}")
372
+ return acc.item(), f1.item(), mcc.item(), average_auc.item()
373
+
374
+
375
+
376
+
377
+
378
+
379
+ class MyTrainer:
380
+ def __init__(self, config, input_dim, train_dataset):
381
+ self.config = config
382
+ self.plmodel = GraphSAGEModel(
383
+ input_dim = input_dim,
384
+ hidden_dim=config["hidden_dim"],
385
+ output_dim=config["output_dim"],
386
+ lr = config["lr"],
387
+ train_dataset = train_dataset,
388
+ batch_size = config["batch_size"]
389
+ )
390
+ self.output_dir = "/home/sxr280/Spatialformer/output/GraphSAGE_model"
391
+ self.train_dataset = train_dataset
392
+ self.gpus = torch.cuda.device_count()
393
+ self.trainer = None
394
+
395
+ def make_callback(self):
396
+ # Callbacks
397
+ callbacks = [
398
+ ModelCheckpoint(
399
+ dirpath=os.path.join(self.output_dir, "GraphSAGE_model", "checkpoints"),
400
+ filename=f"{{step:07d}}-{{train_loss:.4f}}-{{val_loss:.4f}}-{{train_acc:.4f}}",
401
+ every_n_train_steps=10000,
402
+ save_top_k=-1,
403
+ # every_n_epochs=1,
404
+ monitor='train_loss',
405
+ save_on_train_epoch_end=False
406
+ ), LearningRateMonitor(logging_interval="step"),
407
+ # EarlyStopping(monitor = "val_loss", min_delta = 0.00, verbose = True, mode = "min")
408
+ ]
409
+
410
+ return callbacks
411
+ def set_trainer(self):
412
+ self.logger = WandbLogger(project = "Spaformer",
413
+ name = "GraphSAGE",
414
+ log_model = "all",
415
+ save_dir = self.output_dir)
416
+ # self.logger = CSVLogger("/home/sxr280/Spatialformer/output", name="my_experiment")
417
+
418
+ self.trainer = pl.Trainer(
419
+ accelerator="auto",
420
+ devices=self.gpus,
421
+ max_steps=self.config["total_step"],
422
+ val_check_interval = 0.1,
423
+ default_root_dir=self.output_dir,
424
+ callbacks=self.make_callback(),
425
+ log_every_n_steps=50,
426
+ logger=self.logger,
427
+ precision='bf16',
428
+ strategy = self.config['strategy'],
429
+ num_nodes = 1
430
+ )
431
+ def resume_train(self, ckp, train_loader, val_loader):
432
+ self.logger = WandbLogger(project = "Spaformer",
433
+ name = "GraphSAGE",
434
+ log_model = "all",
435
+ save_dir = self.output_dir)
436
+ # import pdb; pdb.set_trace()
437
+ logging.info("resuming the training ...")
438
+ self.trainer = pl.Trainer(
439
+ accelerator="auto",
440
+ devices=self.gpus,
441
+ strategy = self.config['strategy'],
442
+ num_nodes = 1,
443
+ val_check_interval = 0.1,
444
+ gradient_clip_val = 1,
445
+ logger=self.logger,
446
+ default_root_dir=self.output_dir,
447
+ log_every_n_steps=50,
448
+ check_val_every_n_epoch=1,
449
+ precision='bf16',
450
+ callbacks=self.make_callback(),
451
+ max_steps=self.config["total_step"],
452
+ resume_from_checkpoint=ckp,
453
+ accumulate_grad_batches = self.config['accumulate_grad_batches'])
454
+ self.trainer.fit(self.plmodel, train_loader)
455
+
456
+
457
+ def train(self):
458
+ # import pdb; pdb.set_trace()
459
+ self.set_trainer()
460
+ self.trainer.fit(self.plmodel)
461
+ def get_embedding(self, ckp_path, batch_size, token_path, output_dim):
462
+ '''
463
+ Getting the gene embeddings that merge from the transcripts
464
+ '''
465
+ model = self.plmodel
466
+ ckp = torch.load(ckp_path)
467
+ params = ckp["state_dict"]
468
+ model.load_state_dict(params)
469
+
470
+ model.eval()
471
+
472
+ gene_embeds = {}
473
+
474
+ #loading the token path
475
+ with open(os.path.join(token_path), 'r') as json_file:
476
+ token_config = json.load(json_file)
477
+ token_num = np.max([j for i,j in token_config.items()]) + 1
478
+ pretrained_embeddings = torch.rand(token_num, output_dim)
479
+
480
+
481
+
482
+ # import pdb;pdb.set_trace()
483
+ # Ensure no gradient tracking during evaluation
484
+ with torch.no_grad():
485
+ indices = torch.argmax(self.train_dataset.x, axis=1)
486
+ # genes = [index_to_gene[indice.item()] for indice in indices]
487
+ genes = [vocab_list[indice.item()] for indice in indices]
488
+ # Generate embeddings for all nodes
489
+ embeddings = model(self.train_dataset.x, self.train_dataset.edge_index)
490
+ # Group embeddings by gene
491
+ for i, gene in enumerate(genes):
492
+ if gene not in gene_embeds:
493
+ gene_embeds[gene] = []
494
+ gene_embeds[gene].append(embeddings[i])
495
+
496
+ # gene_embed = {gene: torch.mean(torch.stack(embeds), dim=0) for gene, embeds in gene_embeds.items()}
497
+ #transfer gene to embedding by token ids
498
+ for gene, embeds in gene_embeds.items():
499
+ # import pdb; pdb.set_trace()
500
+ pretrained_embeddings[token_config[gene]] = torch.mean(torch.stack(embeds), dim=0)
501
+
502
+ #settign the padding as 0
503
+ pretrained_embeddings[0] = 0
504
+ # import pdb; pdb.set_trace()
505
+ return pretrained_embeddings
506
+
507
+ class PatternTrainer:
508
+ def __init__(self, model, lr, strategy, output_dim, output_dir, train_dataloader, val_dataloader, test_dataloader):
509
+ self.plmodel = PatternGraphSAGEModel(
510
+ model,
511
+ output_dim=output_dim,
512
+ nn_hidden_layer_dim = 8,
513
+ lr = lr
514
+ )
515
+ self.output_dir = output_dir
516
+ self.train_dataloader = train_dataloader
517
+ self.gpus = torch.cuda.device_count()
518
+ self.trainer = None
519
+ self.strategy = strategy
520
+ self.train_dataloader = train_dataloader
521
+ self.val_dataloader = val_dataloader
522
+ self.test_dataloader = test_dataloader
523
+
524
+ def make_callback(self):
525
+ # Callbacks
526
+ callbacks = [
527
+ ModelCheckpoint(
528
+ dirpath=os.path.join(self.output_dir, "GraphSAGE_model", "checkpoints"),
529
+ filename=f"{{step:07d}}-{{train_loss:.4f}}-{{val_loss:.4f}}-{{train_acc:.4f}}",
530
+ every_n_train_steps=10000,
531
+ save_top_k=-1,
532
+ # every_n_epochs=1,
533
+ monitor='train_loss',
534
+ save_on_train_epoch_end=False
535
+ ), LearningRateMonitor(logging_interval="step"),
536
+ # EarlyStopping(monitor = "val_loss", min_delta = 0.00, verbose = True, mode = "min")
537
+ ]
538
+
539
+ return callbacks
540
+ def set_trainer(self):
541
+ # self.logger = WandbLogger(project = "Spaformer",
542
+ # name = "GraphSAGE_pattern",
543
+ # log_model = "all",
544
+ # save_dir = self.output_dir)
545
+ self.logger = CSVLogger("/scratch/project_465001027/Spatialformer/output/GraphSAGE_model", name="my_experiment")
546
+
547
+ self.trainer = pl.Trainer(
548
+ accelerator="auto",
549
+ devices=self.gpus,
550
+ max_steps=10000,
551
+ val_check_interval = 0.1,
552
+ default_root_dir=self.output_dir,
553
+ callbacks=self.make_callback(),
554
+ log_every_n_steps=50,
555
+ logger=self.logger,
556
+ precision='bf16',
557
+ strategy = self.strategy,
558
+ num_nodes = 1,
559
+ )
560
+
561
+ def resume_train(self, ckp, train_loader, val_loader):
562
+ self.logger = WandbLogger(project = "Spaformer",
563
+ name = "GraphSAGE",
564
+ log_model = "all",
565
+ save_dir = self.output_dir)
566
+ # import pdb; pdb.set_trace()
567
+ logging.info("resuming the training ...")
568
+ self.trainer = pl.Trainer(
569
+ accelerator="auto",
570
+ devices=self.gpus,
571
+ strategy = self.config['strategy'],
572
+ num_nodes = 1,
573
+ val_check_interval = 0.1,
574
+ gradient_clip_val = 1,
575
+ logger=self.logger,
576
+ default_root_dir=self.output_dir,
577
+ log_every_n_steps=50,
578
+ check_val_every_n_epoch=1,
579
+ precision='bf16',
580
+ callbacks=self.make_callback(),
581
+ max_steps=self.config["total_step"],
582
+ resume_from_checkpoint=ckp,
583
+ accumulate_grad_batches = self.config['accumulate_grad_batches'])
584
+ self.trainer.fit(self.plmodel, train_loader)
585
+
586
+
587
+ def train(self):
588
+ # import pdb; pdb.set_trace()
589
+ self.set_trainer()
590
+ self.trainer.fit(self.plmodel, self.train_dataloader, self.val_dataloader)
591
+ def evaluation(self, ckp_path):
592
+ acc, f1, mcc, average_auc = self.plmodel.evaluation(ckp_path, self.test_dataloader)
593
+ return acc, f1, mcc, average_auc
594
+
595
+
596
+ def get_embedding(self, ckp_path, batch_size, token_path, output_dim):
597
+ '''
598
+ Getting the gene embeddings that merge from the transcripts
599
+ '''
600
+ model = self.plmodel
601
+ ckp = torch.load(ckp_path)
602
+ params = ckp["state_dict"]
603
+ model.load_state_dict(params)
604
+
605
+ model.eval()
606
+
607
+ gene_embeds = {}
608
+
609
+ #loading the token path
610
+ with open(os.path.join(token_path), 'r') as json_file:
611
+ token_config = json.load(json_file)
612
+ token_num = np.max([j for i,j in token_config.items()]) + 1
613
+ pretrained_embeddings = torch.rand(token_num, output_dim)
614
+
615
+
616
+
617
+ # import pdb;pdb.set_trace()
618
+ # Ensure no gradient tracking during evaluation
619
+ with torch.no_grad():
620
+ indices = torch.argmax(self.train_dataset.x, axis=1)
621
+ # genes = [index_to_gene[indice.item()] for indice in indices]
622
+ genes = [vocab_list[indice.item()] for indice in indices]
623
+ # Generate embeddings for all nodes
624
+ embeddings = model(self.train_dataset.x, self.train_dataset.edge_index)
625
+ # Group embeddings by gene
626
+ for i, gene in enumerate(genes):
627
+ if gene not in gene_embeds:
628
+ gene_embeds[gene] = []
629
+ gene_embeds[gene].append(embeddings[i])
630
+
631
+ # gene_embed = {gene: torch.mean(torch.stack(embeds), dim=0) for gene, embeds in gene_embeds.items()}
632
+ #transfer gene to embedding by token ids
633
+ for gene, embeds in gene_embeds.items():
634
+ # import pdb; pdb.set_trace()
635
+ pretrained_embeddings[token_config[gene]] = torch.mean(torch.stack(embeds), dim=0)
636
+
637
+ #settign the padding as 0
638
+ pretrained_embeddings[0] = 0
639
+ # import pdb; pdb.set_trace()
640
+ return pretrained_embeddings
641
+
642
+
643
+
644
+
645
+
646
+
647
+
648
+ if __name__ == "__main__":
649
+
650
+ parser = argparse.ArgumentParser(description='calculate the gene graph')
651
+ parser.add_argument('--save_graph', action = 'store_true', help='only save the graph for each sample')
652
+ parser.add_argument('--data_dir', type=str, default = "/tmp/erda/Spatialformer/downloaded_data/raw/", help='the parent path of the data')
653
+ parser.add_argument('--mode', type=str, default = "train", help='the mode to run the code')
654
+
655
+ args = parser.parse_args()
656
+
657
+
658
+ data_dir = args.data_dir
659
+ mode = args.mode
660
+
661
+ mouse_names = ["Xenium_V1_FFPE_TgCRND8_17_9_months_outs",
662
+ "Xenium_V1_FFPE_TgCRND8_2_5_months_outs",
663
+ "Xenium_V1_FFPE_TgCRND8_5_7_months_outs",
664
+ "Xenium_V1_FFPE_wildtype_13_4_months_outs",
665
+ "Xenium_V1_FFPE_wildtype_2_5_months_outs",
666
+ "Xenium_V1_FFPE_wildtype_5_7_months_outs",
667
+ "Xenium_V1_mouse_pup_outs",
668
+ "Xenium_V1_mouse_Colon_FF_outs",
669
+ "Xenium_V1_FF_Mouse_Brain_Coronal_outs",
670
+ "Xenium_V1_FF_Mouse_Brain_Coronal_Subset_CTX_HP_outs",
671
+ "Xenium_Prime_Mouse_Brain_Coronal_FF_outs",
672
+ "Xenium_V1_mFemur_formic_acid_24hrdecal_section_outs",
673
+ "Xenium_V1_mFemur_EDTA_3daydecal_section_outs",
674
+ "Xenium_V1_mFemur_EDTA_PFA_3daydecal_section_outs",
675
+ "Xenium_V1_FF_Mouse_Brain_MultiSection_1_outs",
676
+ "Xenium_V1_FF_Mouse_Brain_MultiSection_2_outs",
677
+ "Xenium_V1_FF_Mouse_Brain_MultiSection_3_outs"]
678
+ large_human_file = [
679
+ "Xenium_Prime_Human_Ovary_FF_outs",
680
+ "Xenium_Prime_Ovarian_Cancer_FFPE_outs",
681
+ "Xenium_Prime_Cervical_Cancer_FFPE_outs",
682
+ "Xenium_Prime_Human_Skin_FFPE_outs",
683
+ "Xenium_Prime_Human_Prostate_FFPE_outs",
684
+ "Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
685
+ "Xenium_V1_hBoneMarrow_nondiseased_section_outs",
686
+ "Xenium_V1_hBone_nondiseased_section_outs"
687
+ ]
688
+ failed_human_file = [
689
+ "Xenium_V1_hBoneMarrow_nondiseased_section_outs",
690
+ "Xenium_V1_hBone_nondiseased_section_outs"
691
+ ]
692
+
693
+ #only select the human Xenium datasets
694
+ # import pdb;pdb.set_trace()
695
+ root_path = "/tmp/erda/Spatialformer/downloaded_data/raw"
696
+
697
+ sample_files = [
698
+ os.path.join(root_path, file, "transcripts.csv")
699
+ for file in os.listdir("/tmp/erda/Spatialformer/downloaded_data/raw")
700
+ if (
701
+ ".zip" not in file and
702
+ file not in mouse_names and
703
+ file not in large_human_file and
704
+ file not in failed_human_file and (
705
+ os.path.exists(os.path.join(root_path, file, "transcripts.csv")) or
706
+ os.path.exists(os.path.join(root_path, file, "transcripts.csv.gz"))
707
+ )
708
+ )
709
+ ]
710
+ global vocab_list
711
+ vocab_list = index2gene("/home/sxr280/Spatialformer/tokenizer/tokenv3.json")
712
+
713
+ if args.save_graph:
714
+ #saving all the intermediate data
715
+ # import pdb; pdb.set_trace()
716
+ print("getting the subgraph")
717
+ file = "/tmp/erda/Spatialformer/downloaded_data/raw/Xenium_V1_hHeart_nondiseased_section_FFPE_outs/transcripts.csv"
718
+ build_graph_for_sample(load_and_preprocess_data(file), sample_id = file.split("/")[-2])
719
+ # import pdb; pdb.set_trace()
720
+ # build_graph_for_sample(load_and_preprocess_data(sample_files[0]), sample_id = sample_files[0].split("__")[-3])
721
+ # all_samples = [build_graph_for_sample(load_and_preprocess_data(sample_file), sample_id = sample_file.split("__")[-3]) for sample_file in sample_files if f'subgraph_data_{sample_file.split("__")[-3]}.pt' not in os.listdir(data_dir)]
722
+ all_samples = [build_graph_for_sample(load_and_preprocess_data(sample_file), sample_id = sample_file.split("/")[-2]) for sample_file in sample_files if f'subgraph_data_{sample_file.split("/")[-2]}.pt' not in os.listdir(data_dir)]
723
+
724
+ else:
725
+ print("WARNNING: please make sure you have already save the graph for each sample")
726
+ print(f"loading the {str(len(sample_files))} saved subgraphs")
727
+ subgraphs = [
728
+ torch.load(os.path.join(data_dir, f"subgraph_data_{os.path.basename(os.path.dirname(sample_file))}.pt"))
729
+ for sample_file in sample_files
730
+ ]
731
+
732
+ print("building the full graph")
733
+ # Combine subgraphs into a joint large graph
734
+ joint_x = torch.cat([subgraph.x for subgraph in subgraphs], dim=0)
735
+ offset = 0
736
+ edge_lists = []
737
+ for subgraph in subgraphs:
738
+ edge_lists.append(subgraph.edge_index + offset)
739
+ offset += subgraph.x.size(0)
740
+ joint_edge_index = torch.cat(edge_lists, dim=1)
741
+ joint_graph = Data(x=joint_x, edge_index=joint_edge_index)
742
+ print("building the dataloader")
743
+
744
+ print("training the model")
745
+ # Example execution
746
+ with open(os.path.join("/home/sxr280/Spatialformer/config/_config_graphsave.json"), 'r') as json_file:
747
+ config = json.load(json_file)
748
+ trainer = MyTrainer(config, len(vocab_list),joint_graph)
749
+ #training the model
750
+ if mode == "trian":
751
+ trainer.train()
752
+ elif mode == "test":
753
+
754
+ #getting the embeddings
755
+ embeddings = trainer.get_embedding("/home/sxr280/Spatialformer/output/GraphSAGE_model/GraphSAGE_model/checkpoints/step=0010000-train_loss=0.3983-val_loss=0.0000-train_acc=0.8256.ckpt", 32,
756
+ "/home/sxr280/Spatialformer/tokenizer/tokenv3.json", config["output_dim"])
757
+ pickle.dump(embeddings, open("/home/sxr280/Spatialformer/data/gene_embeddings_GraphSAGE_pandavid.pkl", "wb"))
758
+
759
+
760
+
761
+
762
+
763
+
764
+
765
+