grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,778 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ # =========================================================================
6
+ # IMPORTANT: set thread-related env vars before importing numpy/torch.
7
+ # This avoids OpenBLAS/OMP errors such as "too many memory regions".
8
+ # =========================================================================
9
+ os.environ["OPENBLAS_NUM_THREADS"] = "8" # Tune for your CPU (e.g., 1/8/16)
10
+ os.environ["MKL_NUM_THREADS"] = "8"
11
+ os.environ["OMP_NUM_THREADS"] = "8"
12
+ os.environ["VECLIB_MAXIMUM_THREADS"] = "8"
13
+ os.environ["NUMEXPR_NUM_THREADS"] = "8"
14
+
15
+ # NOTE: `grasp_tool.gnn.plot_refined` imports `umap`, which may import TensorFlow
16
+ # and emit noisy INFO/WARN logs even for `--help`. Suppress them by default.
17
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
18
+
19
+ import argparse
20
+ import glob
21
+ import importlib
22
+ import json
23
+ import pickle
24
+ import random
25
+ import shutil
26
+ import time
27
+ import traceback
28
+ import uuid
29
+ import warnings
30
+ from typing import Any, Dict, List, Optional, Tuple
31
+
32
+ import numpy as np
33
+ import pandas as pd
34
+ from matplotlib import pyplot as plt
35
+
36
+ warnings.filterwarnings("ignore", category=FutureWarning)
37
+ warnings.filterwarnings("ignore", category=FutureWarning)
38
+
39
+
40
+ def _lazy_import_training_deps() -> None:
41
+ """Import torch/pyg (and friends) only when training is actually executed.
42
+
43
+ This keeps `grasp-tool train-moco --help` working in a base install where
44
+ torch/pyg are intentionally NOT declared as PyPI dependencies.
45
+ """
46
+
47
+ global torch, ReduceLROnPlateau, gcl, vis
48
+
49
+ try:
50
+ torch = importlib.import_module("torch")
51
+ torch_geometric_loader = importlib.import_module("torch_geometric.loader")
52
+ _ = getattr(torch_geometric_loader, "DataLoader")
53
+ lr_scheduler = importlib.import_module("torch.optim.lr_scheduler")
54
+ ReduceLROnPlateau = getattr(lr_scheduler, "ReduceLROnPlateau")
55
+ except ModuleNotFoundError as e:
56
+ missing = getattr(e, "name", "")
57
+ if missing == "torch" or missing.startswith("torch_geometric"):
58
+ raise ModuleNotFoundError(
59
+ "Missing training dependencies: torch and torch-geometric.\n"
60
+ "Install them first (recommended: conda), then re-run: grasp-tool train-moco ..."
61
+ ) from e
62
+ raise
63
+
64
+ gcl = importlib.import_module("grasp_tool.gnn.gat_moco_final")
65
+ vis = importlib.import_module("grasp_tool.gnn.plot_refined")
66
+
67
+
68
+ def parse_args() -> argparse.Namespace:
69
+ parser = argparse.ArgumentParser(
70
+ description="MoCo Training for Graph Neural Networks"
71
+ )
72
+
73
+ # Dataset inputs
74
+ parser.add_argument(
75
+ "--dataset",
76
+ type=str,
77
+ required=True,
78
+ help="Dataset name (e.g., data1_simulated1)",
79
+ )
80
+ parser.add_argument("--pkl", type=str, required=True, help="Path to training PKL")
81
+ parser.add_argument(
82
+ "--js",
83
+ type=int,
84
+ default=0,
85
+ choices=[0, 1],
86
+ help="Use JS distances: 0=no, 1=yes (default: 0)",
87
+ )
88
+ parser.add_argument(
89
+ "--js_file",
90
+ type=str,
91
+ default=None,
92
+ help="Path to JS distances CSV (required when --js=1)",
93
+ )
94
+ parser.add_argument(
95
+ "--n", type=int, default=20, help="Number of sectors (n_sectors) (default: 20)"
96
+ )
97
+ parser.add_argument(
98
+ "--m", type=int, default=10, help="Number of rings (m_rings) (default: 10)"
99
+ )
100
+ parser.add_argument(
101
+ "--a", type=float, default=0.5, help="Reconstruction loss weight (default: 0.5)"
102
+ )
103
+ parser.add_argument(
104
+ "--b", type=float, default=0.5, help="Contrastive loss weight (default: 0.5)"
105
+ )
106
+ parser.add_argument(
107
+ "--temperature",
108
+ type=float,
109
+ default=0.07,
110
+ help="Contrastive temperature (default: 0.07)",
111
+ )
112
+ parser.add_argument(
113
+ "--batch_size", type=int, default=64, help="Batch size (default: 64)"
114
+ )
115
+ parser.add_argument(
116
+ "--lrs",
117
+ type=float,
118
+ nargs="+",
119
+ default=None,
120
+ help=(
121
+ "Learning rate list (one or more values, e.g. --lrs 0.001 0.002). "
122
+ "If omitted, uses the built-in default list."
123
+ ),
124
+ )
125
+ parser.add_argument(
126
+ "--num_positive", type=int, default=4, help="Number of positives (default: 4)"
127
+ )
128
+ parser.add_argument(
129
+ "--num_epoch", type=int, default=300, help="Number of epochs (default: 300)"
130
+ )
131
+ parser.add_argument(
132
+ "--num_clusters",
133
+ type=int,
134
+ default=None,
135
+ help="Enable clustering eval with this number of clusters (e.g., 5, 8)",
136
+ )
137
+ parser.add_argument(
138
+ "--cuda_device", type=int, default=0, help="CUDA device index (default: 0)"
139
+ )
140
+ parser.add_argument(
141
+ "--seed", type=int, default=2025, help="Random seed (default: 2025)"
142
+ )
143
+ parser.add_argument(
144
+ "--use_gradient_clipping",
145
+ type=int,
146
+ default=1,
147
+ choices=[0, 1],
148
+ help="Use gradient clipping: 0=no, 1=yes (default: 1)",
149
+ )
150
+ parser.add_argument(
151
+ "--gradient_clip_norm",
152
+ type=float,
153
+ default=3.0,
154
+ help="Gradient clipping max_norm (default: 3.0)",
155
+ )
156
+ parser.add_argument("--k", type=int, default=512, help="Queue size (default: 512)")
157
+ parser.add_argument(
158
+ "--label_file",
159
+ type=str,
160
+ default=None,
161
+ help="Optional label file path (auto-discover if omitted)",
162
+ )
163
+ parser.add_argument(
164
+ "--output_dir",
165
+ type=str,
166
+ default=None,
167
+ help="Output root directory (default: ./outputs/<dataset>/step5_embedding)",
168
+ )
169
+
170
+ args = parser.parse_args()
171
+
172
+ if args.js == 1 and not args.js_file:
173
+ parser.error("--js_file must be specified when using --js=1")
174
+
175
+ args.n_sectors = args.n
176
+ args.m_rings = args.m
177
+ args.positive_sample_method = "js" if args.js == 1 else "random_window"
178
+ args.js_distances_file = args.js_file
179
+ args.pkl_file = args.pkl
180
+ args.model = "gat"
181
+ args.layer = "layer2"
182
+ args.dist_type = "uniform"
183
+ if args.lrs is None:
184
+ args.lrs = [0.001, 0.002, 0.005, 0.01]
185
+ args.c = 0.0
186
+ args.use_clustering = False
187
+ args.visualize = True
188
+
189
+ if args.num_clusters is not None:
190
+ args.clustering = True
191
+ print(f"Clustering evaluation enabled. num_clusters={args.num_clusters}")
192
+ else:
193
+ args.clustering = False
194
+ args.num_clusters = 8
195
+
196
+ args.reduce_dims = True
197
+ args.forward_method = "default"
198
+ # use_gradient_clipping / gradient_clip_norm come from CLI
199
+ args.print_freq = 10
200
+ args.checkpoint_freq = 20
201
+ args.save_best_only = False
202
+ args.early_stopping = 0
203
+ args.weighted = False
204
+ args.window_size = 5
205
+ args.optimizer = "adam"
206
+ args.weight_decay = 1e-5
207
+ args.lr_scheduler = "plateau"
208
+ args.lr_patience = 10
209
+ args.clustering_methods = [
210
+ "KMeans",
211
+ "Agglomerative",
212
+ "SpectralClustering",
213
+ "GaussianMixture",
214
+ ]
215
+ args.spectral_loss = False
216
+ args.tsne_perplexity = 30.0
217
+ args.umap_n_neighbors = 15
218
+ args.umap_min_dist = 0.2
219
+ args.size = 20
220
+ args.graphs_number = None
221
+ args.cell_numbers = None
222
+ args.gene_numbers = None
223
+ args.tissue = None
224
+ args.experiment_id = None
225
+ args.no_timestamp = False
226
+ args.vis_methods = None
227
+
228
+ if args.label_file is not None and os.path.exists(args.label_file):
229
+ print(f"Using label file: {args.label_file}")
230
+
231
+ return args
232
+
233
+
234
+ def load_data(
235
+ args: argparse.Namespace,
236
+ ) -> Tuple[List, List, List, List, pd.DataFrame, Optional[pd.DataFrame]]:
237
+ save_file = args.pkl
238
+ if not os.path.exists(save_file):
239
+ raise FileNotFoundError(f"PKL file not found: {save_file}")
240
+
241
+ print(f"Loading data from: {save_file}")
242
+ with open(save_file, "rb") as f:
243
+ data = pickle.load(f)
244
+
245
+ original_graphs = data["original_graphs"]
246
+ augmented_graphs = data["augmented_graphs"]
247
+ gene_labels = data["gene_labels"]
248
+ cell_labels = data["cell_labels"]
249
+
250
+ args.graphs_number = len(original_graphs)
251
+ args.cell_numbers = len(set(cell_labels)) if cell_labels else 0
252
+ args.gene_numbers = len(set(gene_labels)) if gene_labels else 0
253
+
254
+ gw_distances_df = pd.DataFrame(
255
+ columns=pd.Index(
256
+ [
257
+ "target_cell",
258
+ "target_gene",
259
+ "cell",
260
+ "gene",
261
+ "num_real_nodes",
262
+ "gw_distance",
263
+ ]
264
+ )
265
+ )
266
+
267
+ js_distances_df = None
268
+ if args.js == 1:
269
+ js_file = args.js_file
270
+ if js_file and os.path.exists(js_file):
271
+ js_distances_df = pd.read_csv(js_file)
272
+ print(f"Loaded JS distances from: {js_file}")
273
+ else:
274
+ print(f"ERROR: JS distance file not found: {js_file}")
275
+ print("Falling back to random_window method")
276
+ args.positive_sample_method = "random_window"
277
+
278
+ return (
279
+ original_graphs,
280
+ augmented_graphs,
281
+ gene_labels,
282
+ cell_labels,
283
+ gw_distances_df,
284
+ js_distances_df,
285
+ )
286
+
287
+
288
+ def set_seed(seed: int) -> None:
289
+ random.seed(seed)
290
+ np.random.seed(seed)
291
+ torch.manual_seed(seed)
292
+ torch.cuda.manual_seed_all(seed)
293
+ torch.backends.cudnn.deterministic = True
294
+
295
+
296
+ def setup_training(args):
297
+ device = torch.device(
298
+ f"cuda:{args.cuda_device}" if torch.cuda.is_available() else "cpu"
299
+ )
300
+ print("Using device:", device)
301
+ timestamp = time.strftime("%m%d_%H%M")
302
+ js_flag = "js" if args.js == 1 else "nojs"
303
+ folder_name = (
304
+ f"n{args.n}_m{args.m}_{js_flag}_"
305
+ f"a{args.a}_b{args.b}_t{args.temperature}_"
306
+ f"bs{args.batch_size}_neg{args.num_positive}_{timestamp}"
307
+ )
308
+ base_output_dir = args.output_dir or f"./outputs/{args.dataset}/step5_embedding"
309
+ save_path = os.path.join(base_output_dir, folder_name)
310
+ os.makedirs(save_path, exist_ok=True)
311
+ print(f"Results will be saved to: {save_path}")
312
+ return save_path, device
313
+
314
+
315
+ def train_epoch(
316
+ model: Any,
317
+ original_graphs: List,
318
+ augmented_graphs: List,
319
+ positive_samples: List,
320
+ optimizer: Any,
321
+ device: Any,
322
+ args: argparse.Namespace,
323
+ epoch: int,
324
+ ) -> Dict[str, float]:
325
+ model.train()
326
+ (
327
+ total_loss,
328
+ total_reconstruction_loss,
329
+ total_contrastive_loss,
330
+ total_clustering_loss,
331
+ ) = 0.0, 0.0, 0.0, 0.0
332
+ batch_count = 0
333
+
334
+ use_clustering = getattr(args, "use_clustering", True)
335
+ spectral_loss = getattr(args, "spectral_loss", False)
336
+ dist_type = "spectral" if spectral_loss else args.dist_type
337
+ forward_method = getattr(args, "forward_method", "default")
338
+
339
+ batch_generator = gcl.MoCoMultiPositive.prepare_multi_positive_batch(
340
+ original_graphs, augmented_graphs, positive_samples, args.batch_size
341
+ )
342
+
343
+ for query_batch, positive_batches in batch_generator:
344
+ batch_count += 1
345
+ query_batch = query_batch.to(device)
346
+ positive_batches = [batch.to(device) for batch in positive_batches]
347
+ im_q, edge_index_q, batch = (
348
+ query_batch.x,
349
+ query_batch.edge_index,
350
+ query_batch.batch,
351
+ )
352
+ im_k_list = [pos_batch.x for pos_batch in positive_batches]
353
+ edge_index_k_list = [pos_batch.edge_index for pos_batch in positive_batches]
354
+
355
+ if forward_method == "supcon":
356
+ loss, reconstruction_loss, contrastive_loss, clustering_loss, _, _, _ = (
357
+ model.forward_supcon(
358
+ im_q,
359
+ im_k_list,
360
+ edge_index_q,
361
+ edge_index_k_list,
362
+ batch,
363
+ args.num_clusters,
364
+ dist_type,
365
+ args.a,
366
+ args.b,
367
+ args.c,
368
+ use_clustering,
369
+ )
370
+ )
371
+ elif forward_method == "avg":
372
+ loss, reconstruction_loss, contrastive_loss, clustering_loss, _, _, _ = (
373
+ model.forward_avg(
374
+ im_q,
375
+ im_k_list,
376
+ edge_index_q,
377
+ edge_index_k_list,
378
+ batch,
379
+ args.num_clusters,
380
+ dist_type,
381
+ args.a,
382
+ args.b,
383
+ args.c,
384
+ use_clustering,
385
+ )
386
+ )
387
+ else:
388
+ loss, reconstruction_loss, contrastive_loss, clustering_loss, _, _, _ = (
389
+ model(
390
+ im_q,
391
+ im_k_list,
392
+ edge_index_q,
393
+ edge_index_k_list,
394
+ batch,
395
+ args.num_clusters,
396
+ dist_type,
397
+ args.a,
398
+ args.b,
399
+ args.c,
400
+ use_clustering,
401
+ )
402
+ )
403
+
404
+ optimizer.zero_grad()
405
+ loss.backward()
406
+
407
+ if hasattr(args, "use_gradient_clipping") and args.use_gradient_clipping:
408
+ torch.nn.utils.clip_grad_norm_(
409
+ model.parameters(), max_norm=args.gradient_clip_norm
410
+ )
411
+
412
+ optimizer.step()
413
+ total_loss += loss.item()
414
+ total_reconstruction_loss += reconstruction_loss.item()
415
+ total_contrastive_loss += contrastive_loss.item()
416
+ total_clustering_loss += clustering_loss.item()
417
+
418
+ if batch_count > 0:
419
+ total_loss /= batch_count
420
+ total_reconstruction_loss /= batch_count
421
+ total_contrastive_loss /= batch_count
422
+ total_clustering_loss /= batch_count
423
+
424
+ return {
425
+ "total_loss": total_loss,
426
+ "reconstruction_loss": total_reconstruction_loss,
427
+ "contrastive_loss": total_contrastive_loss,
428
+ "clustering_loss": total_clustering_loss,
429
+ }
430
+
431
+
432
+ def train_model(args: argparse.Namespace) -> None:
433
+ save_path, device = None, None
434
+ try:
435
+ (
436
+ original_graphs,
437
+ augmented_graphs,
438
+ gene_labels,
439
+ cell_labels,
440
+ gw_distances_df,
441
+ js_distances_df,
442
+ ) = load_data(args)
443
+ save_path, device = setup_training(args)
444
+ positive_sample_method = getattr(args, "positive_sample_method", "gw")
445
+
446
+ if positive_sample_method == "gw":
447
+ positive_samples = gcl.MoCoMultiPositive.generate_samples_gw(
448
+ original_graphs,
449
+ augmented_graphs,
450
+ gene_labels,
451
+ cell_labels,
452
+ args.num_positive,
453
+ gw_distances_df,
454
+ )
455
+ elif positive_sample_method == "js":
456
+ positive_samples = gcl.MoCoMultiPositive.generate_samples_js(
457
+ original_graphs,
458
+ augmented_graphs,
459
+ gene_labels,
460
+ cell_labels,
461
+ args.num_positive,
462
+ js_distances_df,
463
+ )
464
+ else:
465
+ positive_samples = gcl.MoCoMultiPositive.generate_samples_random_window(
466
+ original_graphs,
467
+ augmented_graphs,
468
+ gene_labels,
469
+ cell_labels,
470
+ args.num_positive,
471
+ args.window_size,
472
+ )
473
+
474
+ config_path = f"{save_path}/1_training_config.json"
475
+ config_dict = {
476
+ k: str(v) if isinstance(v, (np.ndarray, torch.Tensor)) else v
477
+ for k, v in vars(args).items()
478
+ }
479
+ with open(config_path, "w") as f:
480
+ json.dump(config_dict, f, indent=4)
481
+
482
+ visualize, clustering = (
483
+ getattr(args, "visualize", True),
484
+ getattr(args, "clustering", True),
485
+ )
486
+ print_freq, checkpoint_freq = (
487
+ getattr(args, "print_freq", 10),
488
+ getattr(args, "checkpoint_freq", 20),
489
+ )
490
+ save_best_only, early_stopping = (
491
+ getattr(args, "save_best_only", False),
492
+ getattr(args, "early_stopping", 0),
493
+ )
494
+
495
+ for lr in args.lrs:
496
+ print(f"Starting training with lr: {lr}")
497
+ experiment_base_name = os.path.basename(save_path)
498
+
499
+ feature_dim = 16
500
+ try:
501
+ feature_dim = original_graphs[0].x.shape[1]
502
+ print(f"Detected feature_dim: {feature_dim}")
503
+ except (IndexError, AttributeError):
504
+ pass
505
+
506
+ base_encoder = gcl.GATEncoder(
507
+ in_channels=feature_dim, hidden_channels=64, out_channels=128
508
+ ).to(device)
509
+ model = gcl.MoCoMultiPositive(
510
+ base_encoder, dim=128, K=args.k, m=0.999, T=args.temperature
511
+ ).to(device)
512
+
513
+ if getattr(args, "spectral_loss", False):
514
+ model.k_neighbors, model.sigma = args.k_neighbors, args.sigma
515
+ torch.autograd.set_detect_anomaly(True)
516
+
517
+ model.weighted_recon_loss = getattr(args, "weighted", False)
518
+ optimizer = torch.optim.Adam(
519
+ list(model.encoder_q.parameters()),
520
+ lr=lr,
521
+ weight_decay=getattr(args, "weight_decay", 1e-5),
522
+ )
523
+
524
+ lr_scheduler_type = getattr(args, "lr_scheduler", "plateau").lower()
525
+ if lr_scheduler_type == "plateau":
526
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
527
+ optimizer,
528
+ mode="min",
529
+ factor=0.5,
530
+ patience=args.lr_patience,
531
+ min_lr=1e-6,
532
+ )
533
+ else:
534
+ scheduler = None
535
+
536
+ best_metrics = {}
537
+ if clustering:
538
+ methods = [
539
+ "KMeans",
540
+ "Agglomerative",
541
+ "SpectralClustering",
542
+ "GaussianMixture",
543
+ ]
544
+ best_metrics = {
545
+ k: {m: {"best_epoch": 0, "metrics": None} for m in methods}
546
+ for k in ["basic", "scaler", "pca", "select"]
547
+ }
548
+
549
+ early_stop_counter, best_loss = 0, float("inf")
550
+
551
+ # Epoch 0 evaluation
552
+ model.eval()
553
+ with torch.no_grad():
554
+ evaluate_and_visualize(
555
+ model,
556
+ original_graphs,
557
+ device,
558
+ save_path,
559
+ 0,
560
+ lr,
561
+ args,
562
+ visualize=visualize,
563
+ clustering=clustering,
564
+ )
565
+
566
+ for epoch in range(1, args.num_epoch + 1):
567
+ losses = train_epoch(
568
+ model,
569
+ original_graphs,
570
+ augmented_graphs,
571
+ positive_samples,
572
+ optimizer,
573
+ device,
574
+ args,
575
+ epoch,
576
+ )
577
+ if scheduler:
578
+ if isinstance(scheduler, ReduceLROnPlateau):
579
+ scheduler.step(losses["total_loss"])
580
+ else:
581
+ scheduler.step()
582
+
583
+ if epoch % print_freq == 0:
584
+ print(
585
+ f"Epoch [{epoch}/{args.num_epoch}], Loss: {losses['total_loss']:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}"
586
+ )
587
+
588
+ should_save_checkpoint = (epoch % checkpoint_freq == 0) or (
589
+ epoch == args.num_epoch
590
+ )
591
+ if early_stopping > 0:
592
+ if losses["total_loss"] < best_loss:
593
+ best_loss, early_stop_counter = losses["total_loss"], 0
594
+ else:
595
+ early_stop_counter += 1
596
+ if early_stop_counter >= early_stopping:
597
+ break
598
+
599
+ if should_save_checkpoint:
600
+ if not save_best_only:
601
+ torch.save(
602
+ {"epoch": epoch, "model_state_dict": model.state_dict()},
603
+ os.path.join(
604
+ save_path, f"epoch_{epoch}_lr_{lr}_checkpoint.pth"
605
+ ),
606
+ )
607
+
608
+ model.eval()
609
+ with torch.no_grad():
610
+ current_metrics, current_figs = evaluate_and_visualize(
611
+ model,
612
+ original_graphs,
613
+ device,
614
+ save_path,
615
+ epoch,
616
+ lr,
617
+ args,
618
+ visualize=visualize,
619
+ clustering=clustering,
620
+ )
621
+
622
+ best_model_found = False
623
+ if clustering:
624
+ for vis_method, vis_results in current_metrics.items():
625
+ for cluster_method, cluster_results in vis_results.items():
626
+ if (
627
+ best_metrics[vis_method][cluster_method]["metrics"]
628
+ is None
629
+ or cluster_results["F1-Score"]
630
+ > best_metrics[vis_method][cluster_method][
631
+ "metrics"
632
+ ]["F1-Score"]
633
+ ):
634
+ best_metrics[vis_method][cluster_method].update(
635
+ {
636
+ "best_epoch": epoch,
637
+ "metrics": cluster_results,
638
+ }
639
+ )
640
+ best_model_found = True
641
+ if (
642
+ visualize
643
+ and vis_method in current_figs
644
+ and current_figs[vis_method]
645
+ ):
646
+ current_figs[vis_method].savefig(
647
+ f"{save_path}/best_{vis_method}_{cluster_method}_lr{lr}.png",
648
+ bbox_inches="tight",
649
+ )
650
+ elif losses["total_loss"] < best_loss:
651
+ best_loss, best_model_found = losses["total_loss"], True
652
+
653
+ if save_best_only and best_model_found:
654
+ torch.save(
655
+ {"model_state_dict": model.state_dict()},
656
+ os.path.join(
657
+ save_path, f"best_model_epoch_{epoch}_lr_{lr}.pth"
658
+ ),
659
+ )
660
+
661
+ if clustering:
662
+ with open(f"{save_path}/best_metrics_lr{lr}.json", "w") as f:
663
+ json.dump(convert_to_serializable(best_metrics), f, indent=4)
664
+ print(f"Completed training for lr: {lr}")
665
+
666
+ with open(f"{save_path}/ALL_COMPLETED.txt", "w") as f:
667
+ f.write(f"Completed at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
668
+
669
+ except Exception as e:
670
+ print(f"Error: {e}")
671
+ traceback.print_exc()
672
+ if save_path and os.path.isdir(save_path):
673
+ shutil.rmtree(save_path)
674
+
675
+
676
+ def evaluate_and_visualize(
677
+ model,
678
+ original_graphs,
679
+ device,
680
+ save_path,
681
+ epoch,
682
+ lr,
683
+ args,
684
+ visualize=True,
685
+ clustering=True,
686
+ specific_label_file=None,
687
+ tsne_perplexity=None,
688
+ umap_n_neighbors=None,
689
+ umap_min_dist=None,
690
+ ):
691
+ if tsne_perplexity is None:
692
+ tsne_perplexity = getattr(args, "tsne_perplexity", 30.0)
693
+ if umap_n_neighbors is None:
694
+ umap_n_neighbors = getattr(args, "umap_n_neighbors", 15)
695
+ if umap_min_dist is None:
696
+ umap_min_dist = getattr(args, "umap_min_dist", 0.1)
697
+
698
+ model.eval()
699
+ graph_representations = []
700
+ with torch.no_grad():
701
+ for graph in original_graphs:
702
+ try:
703
+ graph = graph.to(device)
704
+ _, rep = model.encoder_q(graph.x, graph.edge_index, batch=None)
705
+ graph_representations.append(
706
+ rep.cpu().numpy().tolist() + [graph.cell, graph.gene]
707
+ )
708
+ except:
709
+ continue
710
+
711
+ if not graph_representations:
712
+ return {}, {}
713
+
714
+ df = pd.DataFrame(
715
+ graph_representations,
716
+ columns=pd.Index(
717
+ [f"feature_{i + 1}" for i in range(len(graph_representations[0]) - 2)]
718
+ + ["cell", "gene"]
719
+ ),
720
+ )
721
+ df.to_csv(f"{save_path}/epoch{epoch}_lr{lr}_embedding.csv", index=False)
722
+
723
+ if not clustering:
724
+ if visualize:
725
+ vis.plot_embeddings_only(df, save_path, epoch, lr, visualize=True)
726
+ return {}, {}
727
+
728
+ all_metrics, figures_dict = vis.evaluate_and_visualize(
729
+ dataset=args.dataset,
730
+ df=df,
731
+ save_path=save_path,
732
+ num_epochs=epoch,
733
+ lr=lr,
734
+ n_clusters=args.num_clusters,
735
+ visualize=visualize,
736
+ clustering_methods=args.clustering_methods,
737
+ specific_label_file=specific_label_file,
738
+ tsne_perplexity=tsne_perplexity,
739
+ umap_n_neighbors=umap_n_neighbors,
740
+ umap_min_dist=umap_min_dist,
741
+ )
742
+ if torch.cuda.is_available():
743
+ torch.cuda.empty_cache()
744
+ return all_metrics, figures_dict
745
+
746
+
747
+ def convert_to_serializable(obj):
748
+ if isinstance(obj, dict):
749
+ return {str(k): convert_to_serializable(v) for k, v in obj.items()}
750
+ if isinstance(obj, (list, tuple)):
751
+ return [convert_to_serializable(x) for x in obj]
752
+ if isinstance(obj, (np.integer, np.floating)):
753
+ return obj.item()
754
+ if isinstance(obj, np.ndarray):
755
+ return obj.tolist()
756
+ return obj
757
+
758
+
759
+ def main():
760
+ try:
761
+ args = parse_args()
762
+ except SystemExit as e:
763
+ # argparse uses SystemExit for --help/-h and parse errors.
764
+ code = getattr(e, "code", 0)
765
+ return int(code) if isinstance(code, int) else 0
766
+
767
+ try:
768
+ _lazy_import_training_deps()
769
+ except ModuleNotFoundError as e:
770
+ print(str(e))
771
+ return 1
772
+ set_seed(args.seed)
773
+ train_model(args)
774
+ return 0
775
+
776
+
777
+ if __name__ == "__main__":
778
+ raise SystemExit(main())